use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::io;
use crypto_bigint::{
Choice, CtEq as _, CtAssign as _, NonZero, ConcatenatingMul as _, ConcatenatingSquare as _,
Gcd as _, Resize as _, BoxedUint,
};
use super::Error;
mod varint;
use varint::{encode_varint, decode_varint};
mod bigint;
use bigint::{encode_bigint, decode_bigint};
mod partial_xgcd;
use partial_xgcd::t;
pub(crate) fn encode_compressed_binary_quadratic_form(
a: NonZero<BoxedUint>,
b_positive: Choice,
b_abs: BoxedUint,
discriminant_abs: &BoxedUint,
) -> Vec<u8> {
let (t_positive, t_abs) = t(a.clone(), b_abs.clone());
let g = a.gcd_vartime(&t_abs);
let a = a.get();
let a_apo = &a / &g;
let a_apo = NonZero::new(a_apo.clone()).expect("`a != 0` so `(a / gcd(a, t)) != 0`");
let t_apo_abs = t_abs.get() / &g;
let b_0 = b_abs / &a_apo;
let g_bits = usize::try_from(g.bits()).unwrap();
let g_bytes = g_bits.div_ceil(8);
let mut result = encode_varint(g_bytes);
result.extend(&encode_bigint(g.as_ref(), g_bits));
result.extend(&encode_bigint(
a_apo.as_ref(),
((usize::try_from(discriminant_abs.bits()).unwrap() - 1) / 2) + 1 - (g_bits - 1),
));
result.push((u8::from(t_positive) << 1) | u8::from(b_positive));
result.extend(&encode_bigint(
&t_apo_abs,
((usize::try_from(discriminant_abs.bits()).unwrap() - 1) / 4) + 1 - (g_bits - 1),
));
result.extend(&encode_bigint(&b_0, g_bits));
result
}
#[cfg(feature = "std")]
#[expect(clippy::type_complexity)]
pub(crate) fn decode_compressed_binary_quadratic_form(
mut reader: impl io::Read,
discriminant_abs: &BoxedUint,
) -> Result<(NonZero<BoxedUint>, (Choice, BoxedUint), BoxedUint), Error> {
debug_assert!(
discriminant_abs.floor_sqrt_vartime().bits() <= ((discriminant_abs.bits() - 1) / 2) + 1
);
debug_assert!(
discriminant_abs.floor_sqrt_vartime().floor_sqrt_vartime().bits() <=
((discriminant_abs.bits() - 1) / 4) + 1
);
let g_bytes = u32::try_from(decode_varint(&mut reader)?).map_err(|_| Error::Overflow)?;
if g_bytes > (((discriminant_abs.bits() - 1) / 2) + 1).div_ceil(8) {
Err(Error::Incorrect)?;
}
let g = decode_bigint(&mut reader, g_bytes * 8)?;
let g_bits = g.bits();
if g_bytes != g_bits.div_ceil(8) {
Err(Error::NonCanonical)?;
}
let g = Option::<NonZero<_>>::from(NonZero::new(g)).ok_or(Error::Incorrect)?;
let a_apo = decode_bigint(&mut reader, ((discriminant_abs.bits() - 1) / 2) + 1 - (g_bits - 1))?;
let a_apo = NonZero::new(a_apo).ok_or(Error::Incorrect)?;
let a = a_apo.concatenating_mul(g.as_ref());
let a = NonZero::new(a).expect("the product of two non-zero values is itself non-zero");
let (b_positive, b_abs) = {
let mut sign_bits = [0xff];
reader.read_exact(&mut sign_bits).map_err(|_| Error::UnexpectedEof)?;
let sign_bits = sign_bits[0];
if (sign_bits >> 2) != 0 {
Err(Error::NonCanonical)?;
}
let b_positive = (sign_bits & 1).ct_eq(&1);
let t_positive = (sign_bits >> 1).ct_eq(&1);
let t_apo_abs =
decode_bigint(&mut reader, ((discriminant_abs.bits() - 1) / 4) + 1 - (g_bits - 1))?;
if bool::from(t_apo_abs.is_zero()) {
Err(Error::Incorrect)?;
}
let t_abs = t_apo_abs.concatenating_mul(g.as_ref());
let b_abs = {
let s_apo = {
let s = {
let x = t_abs.square_mod(&a).mul_mod(discriminant_abs, &a).neg_mod(&a);
let s = x.floor_sqrt_vartime();
if s.concatenating_square() != x {
Err(Error::Incorrect)?;
}
s
};
let (s_apo, zero) = s.div_rem(&g);
if bool::from(!zero.is_zero()) {
Err(Error::Incorrect)?;
}
s_apo
};
if bool::from(!t_apo_abs.gcd_vartime(&a_apo).is_one()) {
Err(Error::Incorrect)?;
}
let u = t_apo_abs
.resize(a_apo.bits_precision())
.invert_mod(&a_apo)
.expect("non-zero and coprime but no modular inverse?");
let mut b_apo = s_apo.mul_mod(&u, &a_apo);
b_apo.ct_assign(&b_apo.neg_mod(&a_apo), !t_positive);
let b_0 = decode_bigint(&mut reader, g_bits)?;
if b_0 > *g {
Err(Error::Incorrect)?;
}
b_0.concatenating_mul(a_apo.as_ref()).concatenating_add(&b_apo)
};
{
if b_abs > (*a.as_ref()) {
Err(Error::Incorrect)?;
}
let (t_positive_recalculated, t_abs_recalculated) = t(a.clone(), b_abs.clone());
if (bool::from(t_positive), t_abs) !=
(bool::from(t_positive_recalculated), t_abs_recalculated.get())
{
Err(Error::NonCanonical)?;
}
}
(b_positive, b_abs)
};
Option::from(super::validate_binary_quadratic_form(a, (b_positive, b_abs), discriminant_abs))
.ok_or(Error::Incorrect)
}