use rand::CryptoRng;
use ::crypto_bigint::{
Choice, CtOption, CtEq, CtGt, CtSelect, CtAssign, Zero, One, NonZero, Odd, Limb, CheckedAdd,
CheckedSub, Mul, ConcatenatingMul, ConcatenatingSquare, Div, Rem, Gcd, NegMod, MulMod, SquareMod,
InvertMod, BitOps, Encoding, RandomBits, RandomMod, UnsignedWithMontyForm,
};
use crate::Element;
#[must_use]
fn coprime_form<P, U: AsRef<[Limb]> + AsMut<[Limb]> + CtSelect + Gcd<P, Output: One>>(
mut a: U,
(mut b_positive, mut b_abs): (Choice, U),
mut c: U,
p: &P,
) -> CtOption<(U, (Choice, U))> {
let a_is_coprime = a.gcd(p).is_one();
let c_is_coprime = c.gcd(p).is_one();
let neither_a_c_coprime = !(a_is_coprime | c_is_coprime);
let mut correct = Choice::TRUE;
{
{
let mut carry = Limb::from(u8::from(!b_positive));
let mask = Limb::ZERO.wrapping_sub(carry);
for (b_limb, a_limb) in b_abs.as_mut().iter_mut().zip(a.as_ref()) {
let new_limb;
(new_limb, carry) = ((*b_limb) ^ mask).carrying_add(*a_limb, carry);
*b_limb = Limb::ct_select(b_limb, &new_limb, neither_a_c_coprime);
}
correct &= (!neither_a_c_coprime) | (!b_positive) | carry.is_zero();
b_positive |= neither_a_c_coprime;
}
{
let mut c_carry = Limb::ZERO;
let mut b_carry = Limb::ZERO;
for ((a_limb, b_limb), c_limb) in a.as_ref().iter().zip(b_abs.as_mut()).zip(c.as_mut()) {
let new_c_limb;
(new_c_limb, c_carry) =
c_limb.carrying_add(Limb::ct_select(&Limb::ZERO, b_limb, neither_a_c_coprime), c_carry);
*c_limb = new_c_limb;
let new_b_limb;
(new_b_limb, b_carry) =
b_limb.carrying_add(Limb::ct_select(&Limb::ZERO, a_limb, neither_a_c_coprime), b_carry);
*b_limb = new_b_limb;
}
correct &= c_carry.is_zero() & b_carry.is_zero();
}
}
{
let swap = !a_is_coprime;
U::ct_swap(&mut a, &mut c, swap);
b_positive ^= swap;
}
correct &= a.gcd(p).is_one();
CtOption::new((a, (b_positive, b_abs)), correct)
}
#[must_use]
fn le_malleable_eq(a: &[u8], b: &[u8]) -> Choice {
let mut eq = Choice::TRUE;
let mutual_len = a.len().min(b.len());
for (a, b) in a[.. mutual_len].iter().zip(&b[.. mutual_len]) {
eq &= a.ct_eq(b);
}
for a in &a[mutual_len ..] {
eq &= a.ct_eq(&0);
}
for b in &b[mutual_len ..] {
eq &= b.ct_eq(&0);
}
eq
}
pub trait Discriminant {}
pub trait NegativeDiscriminant: Discriminant {
#[must_use]
fn upper_bound_on_order(&self) -> u32 {
let absolute_value = self.absolute_value();
let mut absolute_value = absolute_value.as_ref();
while absolute_value.last() == Some(&0) {
absolute_value = &absolute_value[.. (absolute_value.len() - 1)];
}
let discriminant_bits = u32::try_from(8 * absolute_value.len()).expect("4 GB discriminant?") -
absolute_value
.last()
.expect("negative discriminant's absolute value was zero")
.leading_zeros();
let sqrt_bits = discriminant_bits.div_ceil(2);
let logarithm_bits = discriminant_bits.ilog2() + 1;
logarithm_bits + sqrt_bits - 1
}
#[must_use]
fn absolute_value(&self) -> impl AsRef<[u8]>;
}
pub trait OddDiscriminant: Discriminant {}
pub trait FundamentalDiscriminant: Discriminant {
#[cfg(feature = "alloc")] #[must_use]
fn inject<E: Element>(&self, element: impl Element, p: &impl Encoding) -> E
where
Self: NegativeDiscriminant,
{
use crypto_bigint::{Resize as _, BoxedUint};
let (a, (b_positive, b_abs), c, discriminant_abs) = element.a_b_c_discriminant();
assert!(bool::from(le_malleable_eq(self.absolute_value().as_ref(), discriminant_abs.as_ref())));
let a = BoxedUint::from_le_slice_vartime(a.as_ref());
let b_abs = BoxedUint::from_le_slice_vartime(b_abs.as_ref());
let c = BoxedUint::from_le_slice_vartime(c.as_ref());
let discriminant_abs = BoxedUint::from_le_slice_vartime(discriminant_abs.as_ref());
let p = {
let p = p.to_le_bytes();
BoxedUint::from_le_slice_vartime(p.as_ref())
};
let discriminant_abs = discriminant_abs.concatenating_mul(p.concatenating_square());
let bits_precision = 2 + a.bits_precision().max(b_abs.bits_precision()).max(c.bits_precision());
let p = p.resize(bits_precision);
let (a, (b_positive, b_abs)) = coprime_form(
a.resize(bits_precision),
(b_positive, b_abs.resize(bits_precision)),
c.resize(bits_precision),
&p,
)
.expect("could not find a coprime form (non-primitive or unreduced?)");
let b_abs = b_abs.concatenating_mul(&p);
let log_2_bound = 8 + bits_precision.max(discriminant_abs.bits_precision());
let discriminant_abs = discriminant_abs.resize(log_2_bound);
let (a, (b_positive, b_abs), c) = crate::crypto_bigint::partial_reduce(
log_2_bound,
a.resize(log_2_bound),
(b_positive, b_abs.resize(log_2_bound)),
&discriminant_abs,
);
let discriminant_bits = discriminant_abs.bits_vartime();
let sqrt_discriminant_bits = discriminant_bits.div_ceil(2);
let (a, (b_positive, b_abs), c) =
crate::crypto_bigint::reduce(sqrt_discriminant_bits, a, (b_positive, b_abs), c);
let discriminant_bits = usize::try_from(discriminant_bits).unwrap();
let sqrt_discriminant_bits = usize::try_from(sqrt_discriminant_bits).unwrap();
unsafe {
E::from_coefficients(
&a.to_le_bytes().as_ref()[.. sqrt_discriminant_bits.div_ceil(8)],
(b_positive, &b_abs.to_le_bytes().as_ref()[.. sqrt_discriminant_bits.div_ceil(8)]),
&c.to_le_bytes()[.. discriminant_bits.div_ceil(8)],
&discriminant_abs.to_le_bytes()[.. discriminant_bits.div_ceil(8)],
)
}
}
}
struct WithoutTrailingZeroBytes<B: AsRef<[u8]>>(B);
impl<B: AsRef<[u8]>> AsRef<[u8]> for WithoutTrailingZeroBytes<B> {
fn as_ref(&self) -> &[u8] {
let mut bytes = self.0.as_ref();
while bytes.last() == Some(&0) {
bytes = &bytes[.. (bytes.len() - 1)];
}
bytes
}
}
#[derive(Clone)]
pub struct Cl15k<Up, Udk> {
p: Odd<Up>,
absolute_value: Udk,
}
impl<Up, Udk> Discriminant for Cl15k<Up, Udk> {}
impl<Up, Udk: Encoding> NegativeDiscriminant for Cl15k<Up, Udk> {
fn absolute_value(&self) -> impl AsRef<[u8]> {
WithoutTrailingZeroBytes(self.absolute_value.to_le_bytes())
}
}
impl<Up, Udk> OddDiscriminant for Cl15k<Up, Udk> {}
impl<Up, Udk> FundamentalDiscriminant for Cl15k<Up, Udk> {}
impl<Up, Udk> Cl15k<Up, Udk> {
#[must_use]
pub fn p(&self) -> &Up {
&self.p
}
}
#[derive(Clone)]
pub struct Cl15p<Up, Up2, Udk, Udp> {
fundamental: Cl15k<Up, Udk>,
p_square: Up2,
absolute_value: Udp,
}
impl<Up, Up2, Udk, Udp> Discriminant for Cl15p<Up, Up2, Udk, Udp> {}
impl<Up: BitOps, Up2, Udk: Encoding, Udp: Encoding> NegativeDiscriminant
for Cl15p<Up, Up2, Udk, Udp>
{
fn upper_bound_on_order(&self) -> u32 {
self.fundamental.upper_bound_on_order() + self.fundamental.p.as_ref().bits_vartime()
}
fn absolute_value(&self) -> impl AsRef<[u8]> {
WithoutTrailingZeroBytes(self.absolute_value.to_le_bytes())
}
}
impl<Up, Up2, Udk, Udp> OddDiscriminant for Cl15p<Up, Up2, Udk, Udp> {}
#[derive(Debug)]
pub enum Cl15Error {
SmallP,
NoQ,
}
impl<
Up: Clone
+ AsRef<[Limb]>
+ AsMut<[Limb]>
+ CtAssign
+ Zero
+ One
+ NegMod<Output = Up>
+ MulMod<Output = Up>
+ SquareMod<Output = Up>
+ ConcatenatingSquare
+ BitOps,
Udk: Clone
+ AsRef<[Limb]>
+ AsMut<[Limb]>
+ CtGt
+ One
+ CheckedAdd
+ CheckedSub<Udk>
+ for<'a> Mul<&'a Up, Output = Udk>
+ ConcatenatingMul<<Up as ConcatenatingSquare>::Output>
+ for<'a> Div<&'a NonZero<Up>, Output = Udk>
+ for<'a> Rem<&'a NonZero<Up>, Output = Up>
+ BitOps
+ Encoding
+ RandomBits
+ RandomMod
+ UnsignedWithMontyForm,
>
Cl15p<
Up,
<Up as ConcatenatingSquare>::Output,
Udk,
<Udk as ConcatenatingMul<<Up as ConcatenatingSquare>::Output>>::Output,
>
{
pub fn sample(
mut rng: impl CryptoRng,
bits_of_security: u32,
fundamental_discriminant_bit_length: u32,
p: Odd<Up>,
) -> Result<Self, Cl15Error> {
#[allow(non_snake_case)]
let Udk_zero_with_precision = |bits_precision| -> Udk {
struct Zero;
impl crypto_bigint::rand_core::TryRng for Zero {
type Error = crypto_bigint::rand_core::Infallible;
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
Ok(0)
}
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
Ok(0)
}
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
for b in dst {
*b = 0;
}
Ok(())
}
}
let result = Udk::random_bits_with_precision(&mut Zero, 0, bits_precision);
debug_assert!(bool::from(result.is_zero()));
result
};
let mu = p.as_ref().bits_vartime();
if (mu - 1) < bits_of_security {
Err(Cl15Error::SmallP)?;
}
let q = {
let mut lower_bound_inclusive = Udk_zero_with_precision(fundamental_discriminant_bit_length);
lower_bound_inclusive.set_bit(fundamental_discriminant_bit_length - 1, Choice::TRUE);
debug_assert_eq!(lower_bound_inclusive.bits_vartime(), fundamental_discriminant_bit_length);
lower_bound_inclusive = lower_bound_inclusive / p.as_nz_ref();
let mut upper_bound_inclusive = Udk_zero_with_precision(fundamental_discriminant_bit_length);
for bit in 0 .. fundamental_discriminant_bit_length {
upper_bound_inclusive.set_bit(bit, Choice::TRUE);
}
debug_assert_eq!(upper_bound_inclusive.bits_vartime(), fundamental_discriminant_bit_length);
upper_bound_inclusive = upper_bound_inclusive / p.as_nz_ref();
{
let lower_bound_inclusive_lt_4p = {
let lower_bound_inclusive = <_ as AsRef<[Limb]>>::as_ref(&lower_bound_inclusive);
let p = <_ as AsRef<[Limb]>>::as_ref(&p);
let mut borrow = Limb::ZERO;
let mut carry = Limb::ZERO;
let virtual_len = lower_bound_inclusive.len().max(1 + p.len());
for (lower_bound_inclusive, p) in lower_bound_inclusive
.iter()
.chain(core::iter::repeat(&Limb::ZERO))
.take(virtual_len)
.zip(p.iter().chain(core::iter::repeat(&Limb::ZERO)).take(virtual_len))
{
let four_p = ((*p) << 2) | carry;
carry = (*p) >> (Limb::BITS - 2);
let _diff_limb;
(_diff_limb, borrow) = lower_bound_inclusive.borrowing_sub(four_p, borrow);
}
debug_assert!(bool::from(carry.is_zero()));
!borrow.is_zero()
};
if bool::from(lower_bound_inclusive_lt_4p) {
if (2 + p.bits_vartime()) > fundamental_discriminant_bit_length {
Err(Cl15Error::NoQ)?;
}
let mut lower_bound_inclusive =
<_ as AsMut<[Limb]>>::as_mut(&mut lower_bound_inclusive).iter_mut();
let p = <_ as AsRef<[Limb]>>::as_ref(&p);
let mut carry = Limb::ZERO;
for (lower_bound_inclusive, p) in (&mut lower_bound_inclusive).zip(p) {
let four_p = ((*p) << 2) | carry;
carry = (*p) >> (Limb::BITS - 2);
*lower_bound_inclusive = four_p;
}
if bool::from(!carry.is_zero()) {
*lower_bound_inclusive.next().unwrap() = carry;
}
}
}
let mut seed = {
let sample_range =
Option::<Udk>::from(upper_bound_inclusive.checked_sub(&lower_bound_inclusive))
.ok_or(Cl15Error::NoQ)?;
let mut starting_point_in_range = sample_range.to_le_bytes();
for b in starting_point_in_range.as_mut() {
*b = 0;
}
while {
rng.fill_bytes(
&mut starting_point_in_range.as_mut()
[.. usize::try_from(sample_range.bits_vartime().div_ceil(8)).unwrap()],
);
bool::from(Udk::from_le_bytes(starting_point_in_range.clone()).ct_gt(&sample_range))
} {}
let starting_point_in_range = Udk::from_le_bytes(starting_point_in_range);
lower_bound_inclusive.checked_add(&starting_point_in_range).expect(
"result is less than or equal to `upper_bound_inclusive`, which has the same capacity",
)
};
let mut lower_bound_inclusive = Some(lower_bound_inclusive);
loop {
let q = match super::primes::next_prime(&mut rng, seed.clone(), bits_of_security) {
Ok(q) => q,
Err(super::primes::Error::Capacity) => {
seed = lower_bound_inclusive.take().ok_or(Cl15Error::NoQ)?;
continue;
}
Err(super::primes::Error::NoMillerRabin) => Err(Cl15Error::NoQ)?,
};
if bool::from(q.ct_gt(&upper_bound_inclusive)) {
seed = lower_bound_inclusive.take().ok_or(Cl15Error::NoQ)?;
continue;
}
seed = match Option::<Udk>::from(q.checked_add(&Udk::one())) {
Some(q_plus_one) => q_plus_one,
None => lower_bound_inclusive.take().ok_or(Cl15Error::NoQ)?,
};
const {
assert!(Limb::BITS >= 2);
}
if {
let product =
<_ as AsRef<[Limb]>>::as_ref(&p)[0].wrapping_mul(<_ as AsRef<[Limb]>>::as_ref(&q)[0]);
(product.0 & 0b11) != 0b11
} {
continue;
}
if {
let q_mod_p = q.clone().rem(p.as_nz_ref());
crate::crypto_bigint::legendre_symbol(q_mod_p, &p) !=
::crypto_bigint::JacobiSymbol::MinusOne
} {
continue;
}
break q;
}
};
let fundamental_discriminant_absolute_value = q.mul(p.as_ref());
debug_assert_eq!(
fundamental_discriminant_absolute_value.bits_vartime(),
fundamental_discriminant_bit_length
);
let p_square = p.as_ref().concatenating_square();
let non_fundamental_discriminant_absolute_value =
fundamental_discriminant_absolute_value.concatenating_mul(p_square.clone());
Ok(Cl15p {
fundamental: Cl15k { p, absolute_value: fundamental_discriminant_absolute_value },
p_square,
absolute_value: non_fundamental_discriminant_absolute_value,
})
}
}
impl<Up, Up2, Udk, Udp> Cl15p<Up, Up2, Udk, Udp> {
#[must_use]
pub fn fundamental_discriminant(&self) -> &Cl15k<Up, Udk> {
&self.fundamental
}
}
impl<Up: Encoding, Up2: Encoding, Udk: Clone + AsMut<[Limb]> + Encoding, Udp: Encoding>
Cl15p<Up, Up2, Udk, Udp>
{
#[must_use]
pub fn f<E: Element>(&self) -> E {
let c = {
let mut c = self.fundamental.absolute_value.clone();
{
let c = <_ as AsMut<[Limb]>>::as_mut(&mut c);
let mut carry = Limb::ONE;
for c_limb in c.iter_mut() {
let new_limb;
(new_limb, carry) = c_limb.carrying_add(Limb::ZERO, carry);
*c_limb = new_limb;
}
carry <<= Limb::BITS - 2;
for c_limb in c.iter_mut().rev() {
let new_limb = carry | ((*c_limb) >> 2);
carry = (*c_limb) << (Limb::BITS - 2);
*c_limb = new_limb;
}
debug_assert!(bool::from(carry.is_zero()));
}
c
};
let discriminant_abs = WithoutTrailingZeroBytes(self.absolute_value.to_le_bytes());
let discriminant_abs = discriminant_abs.as_ref();
let discriminant_bytes = discriminant_abs.len();
let sqrt_discriminant_bytes = discriminant_bytes.div_ceil(2);
let a = self.p_square.to_le_bytes();
let a = a.as_ref();
let a = &a[.. sqrt_discriminant_bytes.min(a.len())];
let b_positive = Choice::TRUE;
let b_abs = self.fundamental.p.to_le_bytes();
let b_abs = b_abs.as_ref();
let b_abs = &b_abs[.. sqrt_discriminant_bytes.min(b_abs.len())];
let c = c.to_le_bytes();
let c = c.as_ref();
let c = &c[.. discriminant_bytes.min(c.len())];
unsafe { E::from_coefficients(a, (b_positive, b_abs), c, discriminant_abs) }
}
}
impl<Up: BitOps + Encoding, Up2, Udk: Clone + AsMut<[Limb]> + Encoding, Udp: Encoding>
Cl15p<Up, Up2, Udk, Udp>
{
#[cfg(feature = "alloc")] #[must_use]
pub fn surject<E: Element>(&self, element: impl Element) -> E {
use crypto_bigint::{Resize as _, BoxedUint};
let (a, (b_positive, b_abs), c, discriminant_abs) = element.a_b_c_discriminant();
assert!(bool::from(le_malleable_eq(self.absolute_value().as_ref(), discriminant_abs.as_ref())));
let a = BoxedUint::from_le_slice_vartime(a.as_ref());
let b_abs = BoxedUint::from_le_slice_vartime(b_abs.as_ref());
let c = BoxedUint::from_le_slice_vartime(c.as_ref());
let p = self.fundamental.p.to_le_bytes();
let p = BoxedUint::from_le_slice_vartime(p.as_ref());
let bits_precision = 2 + a.bits_precision().max(b_abs.bits_precision()).max(c.bits_precision());
let p = p.resize(bits_precision);
let (a, (mut b_positive, b_abs)) = coprime_form(
a.resize(bits_precision),
(b_positive, b_abs.resize(bits_precision)),
c.resize(bits_precision),
&p,
)
.expect("could not find a coprime form (non-primitive or unreduced?)");
let b_abs = {
let a = NonZero::new(a.clone())
.expect("`a` is non-zero for a positive definite form of negative discriminant");
let mu = p.invert_mod(&a).expect("`a` is coprime to `p`");
let lambda_is_even =
(mu.concatenating_mul(&p) - BoxedUint::one()).trailing_zeros().ct_gt(&a.trailing_zeros());
let a = a.get();
let two_a = NonZero::new(a.clone().concatenating_add(&a))
.expect("`2 a` is non-zero as `a` is non-zero");
let b_mu = b_abs.mul_mod(&mu, &two_a);
let b_mu = <_>::ct_select(&b_mu.clone().neg_mod(&two_a), &b_mu, b_positive);
b_positive = Choice::TRUE;
b_mu.sub_mod(
&<_>::ct_select(&BoxedUint::zero_like(&a), &a, !lambda_is_even)
.resize(b_mu.bits_precision()),
&two_a,
)
};
let discriminant_abs =
BoxedUint::from_le_slice_vartime(self.fundamental_discriminant().absolute_value().as_ref());
let log_2_bound =
8 + bits_precision.max(b_abs.bits_precision()).max(discriminant_abs.bits_precision());
let discriminant_abs = discriminant_abs.resize(log_2_bound);
let (a, (b_positive, b_abs), c) = crate::crypto_bigint::partial_reduce(
log_2_bound,
a.resize(log_2_bound),
(b_positive, b_abs.resize(log_2_bound)),
&discriminant_abs,
);
let discriminant_bits = discriminant_abs.bits_vartime();
let sqrt_discriminant_bits = discriminant_bits.div_ceil(2);
let (a, (b_positive, b_abs), c) =
crate::crypto_bigint::reduce(sqrt_discriminant_bits, a, (b_positive, b_abs), c);
let discriminant_bits = usize::try_from(discriminant_bits).unwrap();
let sqrt_discriminant_bits = usize::try_from(sqrt_discriminant_bits).unwrap();
unsafe {
E::from_coefficients(
&a.to_le_bytes().as_ref()[.. sqrt_discriminant_bits.div_ceil(8)],
(b_positive, &b_abs.to_le_bytes().as_ref()[.. sqrt_discriminant_bits.div_ceil(8)]),
&c.to_le_bytes()[.. discriminant_bits.div_ceil(8)],
&discriminant_abs.to_le_bytes()[.. discriminant_bits.div_ceil(8)],
)
}
}
#[cfg(feature = "alloc")] #[must_use]
pub fn coset_labeling_function<E: Element>(&self, element: impl Element) -> E {
self.fundamental_discriminant().inject(self.surject::<E>(element), self.fundamental.p.as_ref())
}
}
impl<
Up: Clone + CtSelect + Zero + NegMod<Output = Up> + InvertMod<Output = Up> + BitOps + Encoding,
Up2: Encoding,
Udk,
Udp: Clone
+ CtEq
+ for<'a> Mul<&'a Up, Output = Udp>
+ for<'a> Div<&'a NonZero<Up>, Output = Udp>
+ Encoding,
> Cl15p<Up, Up2, Udk, Udp>
{
#[must_use]
pub fn discrete_logarithm(&self, element: impl Element) -> CtOption<Up> {
let identity = element.is_identity();
let (a, (b_positive, b_abs), _c, discriminant_abs) = element.a_b_c_discriminant();
let correct_discriminant =
le_malleable_eq(self.absolute_value.to_le_bytes().as_ref(), discriminant_abs.as_ref());
let correct_a_coefficient = le_malleable_eq(self.p_square.to_le_bytes().as_ref(), a.as_ref());
let b_abs = b_abs.as_ref();
let b_abs = {
let mut repr = self.absolute_value.to_le_bytes();
{
let repr = repr.as_mut();
for b in repr.iter_mut() {
*b = 0;
}
let mutual_len = repr.len().min(b_abs.len());
repr[.. mutual_len].copy_from_slice(&b_abs[.. mutual_len]);
}
Udp::from_le_bytes(repr)
};
let x_tilde = b_abs.clone().div(self.fundamental.p.as_nz_ref());
let correct_b_coefficient = x_tilde.clone().mul(self.fundamental.p.as_ref()).ct_eq(&b_abs);
let x_tilde = x_tilde.to_le_bytes();
let x_tilde = {
let x_tilde = x_tilde.as_ref();
let mut p_repr = self.fundamental.p.as_ref().to_le_bytes();
{
let p_repr = p_repr.as_mut();
for b in p_repr.iter_mut() {
*b = 0;
}
let mutual_len = p_repr.len().min(x_tilde.len());
p_repr[.. mutual_len].copy_from_slice(&x_tilde[.. mutual_len]);
}
Up::from_le_bytes(p_repr)
};
let inverse =
Up::ct_select(&x_tilde.neg_mod(self.fundamental.p.as_nz_ref()), &x_tilde, b_positive)
.invert_mod(self.fundamental.p.as_nz_ref())
.filter_by(correct_discriminant & correct_a_coefficient & correct_b_coefficient);
inverse.or(CtOption::new(
Up::zero_like(self.fundamental.p.as_ref()),
correct_discriminant & identity,
))
}
}