Skip to main content

class_groups/crypto_bigint/encoding/compressed/
mod.rs

1//! Compression of binary quadratic forms
2//!
3//! This implements compression of primitive reduced positive definite binary quadratic forms of
4//! negative discriminants (even or odd) such that they can represented in approximately
5//! `2 + 1.5 (floor(log_2(sqrt(|discriminant|))) + 1)` bits instead of the naïvely required
6//! `1 + 2 (floor(log_2(sqrt(|discriminant|))) + 1)` (the absolute value of the `a`, `b`
7//! coefficients and the sign of the `b` coefficient).
8//!
9//! The methodology is as posited in
10//! [Trustless unknown-order groups](https://eprint.iacr.org/2020/196) by Samuel Dobson,
11//! Steven Galbraith, and Benjamin Smith. They do describe a rather complete description of
12//! compression, and the result, which we appreciate but continue the specification of (and
13//! somewhat differ from).
14//!
15//! Notably, instead of encoding the `b` coefficient as a pair of congruences, we encode the `b`
16//! coefficient as the result of a Euclidean division. This avoids having to find a solution for
17//! `f`, which lacked bounds (and presumably only had a statistical bound regarding the
18//! distribution of prime factors).
19//!
20//! We bound to primitive forms as we do not practically want to work with imprimitive forms and
21//! MUST eventially validate all forms to be primitive (upon decode being the logical place to do
22//! so, ensuring imprimitive forms never even enter our context). We bound to reduced forms to
23//! ensure a canonical representation. We bound to positive definite forms (of negative
24//! discriminant) to ensure the `a` coefficient is positive and non-zero, as required for the
25//! compression algorithm (to avoid it as an exceptional case).
26//!
27//! The pseudocode denotes what we would discuss as `z'` as `z_apo`, instead of the more
28//! traditional `z_prime` (for arbitrary "z"). This is to avoid confusion on if this variable is
29//! notably considered (co)prime.
30//!
31//! We assume the existence of:
32//! - A `gcd` function, which for `gcd(x, y)` returns the greatest common divisor of `x, y`
33//! - An `xgcd` function, which for `xgcd(x, y)`, returns `(u, v, d)` where `u * x + v * y = d` and
34//!   `d = gcd(x, y)`.
35//! - A `floor_sqrt` function, which for `floor_sqrt(x)`, returns `y` where $y^2 \le x < (y + 1)^2$
36//! - A `floor_log_2` function, which for `floor_log_2(x)`, returns `k` such that
37//!   $2^k \le x < 2^{k + 1}$
38//!
39//! `//` is used to represent floor division.
40//!
41//! ```py
42//! # Note `discriminant` is a _signed_ big integer, bound to be negative
43//! fn encode_compressed_binary_quadratic_form(a, b_positive, b_abs, discriminant) {
44//!   (t_positive, t_abs) = t(a, b_abs)
45//!   g = gcd(a, t_abs)
46//!   a_apo = a / g
47//!   t_apo_abs = t_abs / g
48//!   b_0 = b_abs // a_apo
49//!
50//!   g_bits = floor_log_2(g) + 1
51//!   g_bytes = (g_bits + 7) // 8
52//!   result = encode_varint(g_bytes)
53//!
54//!   result.extend(encode_bigint(g, g_bits))
55//!   result.extend(encode_bigint(a_apo, (floor_log_2(-discriminant) // 2) + 1 - (g_bits - 1)))
56//!
57//!   result.push((t_positive << 1) | b_positive)
58//!
59//!   result.extend(encode_bigint(t_apo_abs, (floor_log_2(-discriminant) // 4) + 1 - (g_bits - 1)))
60//!   result.extend(encode_bigint(b_0, g_bits))
61//!
62//!   return result
63//! }
64//!
65//! # Note `discriminant` is a _signed_ big integer, bound to be negative
66//! fn decode_compressed_binary_quadratic_form(bytestream, discriminant) {
67//!   g_bytes = decode_varint(bytestream)
68//!   assert g_bytes <= ((((floor_log_2(-discriminant) // 2) + 1) + 7) // 8)
69//!   g = decode_bigint(bytestream, g_bytes * 8)
70//!   g_bits = floor_log_2(g) + 1
71//!   assert g_bytes == ((g_bits + 7) // 8)
72//!
73//!   a_apo = decode_bigint(bytestream, (floor_log_2(-discriminant) // 2) + 1 - (g_bits - 1))
74//!
75//!   a = a_apo * g
76//!   # For a negative discriminant, `a != 0`
77//!   assert a != 0
78//!
79//!   sign_bits = bytestream.next_byte()
80//!   # Ensure `sign_bits` was canonically encoded
81//!   assert (sign_bits >> 2) == 0
82//!   b_positive = sign_bits & 1
83//!   t_positive = sign_bits >> 1
84//!
85//!   t_apo_abs = decode_bigint(bytestream, (floor_log_2(-discriminant) // 4) + 1 - (g_bits - 1))
86//!
87//!   t_abs = t_apo_abs * g
88//!   # We ignore the sign of `t` here as `-1 * -1 = 1`
89//!   x = (t_abs * t_abs * discriminant) % a
90//!
91//!   s = floor_sqrt(x)
92//!   assert (s * s) == x
93//!
94//!   s_apo = s // g
95//!   assert (s_apo * g) == s
96//!   # `u t_apo_abs + v a_apo = d` where `d = gcd(a, b)`
97//!   (u, _v, one) = xgcd(t_apo_abs, a_apo)
98//!   # This asserts the modular inverse exists and that `g = gcd(t, a)`
99//!   assert one == 1
100//!   b_apo = (s_apo * u) % a_apo
101//!   # If `t` was negative, negate `b_apo % a_apo`
102//!   if (b_apo != 0) && (!t_positive) {
103//!     b_apo = a_apo - b_apo
104//!   }
105//!
106//!   b_0 = decode_bigint(bytestream, g_bits)
107//!   assert b_0 <= g
108//!   b_abs = (b_0 * a_apo) + b_apo
109//!
110//!   # Assert `b_abs <= a`
111//!   # This is a prerequisite for calling `t`, which so bounds its inputs
112//!   assert b_abs <= a
113//!   # Assert `t` was canonically chosen
114//!   assert (t_positive, t_abs) == t(a, b_abs)
115//!
116//!   return validate_binary_quadratic_form(a, b_positive, b_abs, discriminant)
117//! }
118//! ```
119//!
120//! When decoding `a_apo, g`, their bit bounds are such that their product is at most the square
121//! root of the discriminant. Note these bit bounds aren't strictly enforced, solely used to
122//! determine the lengths of the big integers' encodings, and we ensure they're canonical via
123//! validating `g_bytes` upon deserialization and ensuring the resulting form is in fact reduced.
124//! Similarly, when decoding `t_apo_abs, b_0`, their bit bounds are such that their product is at
125//! most the fourth root of the discriminant (if the bit bounds were enforced, where we do validate
126//! `b_0 <= g` and then validate `t` was canonically chosen and therefore `t_apo_abs` was correctly
127//! encoded). This causes the encoding, ignoring the alignment of the big integers to byte
128//! boundaries, to be of length $v + \lfloor log_2(-discriminant)^{3 / 4} \rfloor + 11$ where $v$
129//! is the length of the VarInt encoding of $g_bytes$ (experimentally, $1$ in the average case).
130//! Note most of the $11$ is from using an entire byte to represent the two sign bits, $b_positive$
131//! and $t_positive$.
132
133use alloc::vec::Vec;
134
135#[cfg(feature = "std")]
136use std::io;
137
138use crypto_bigint::{
139  Choice, CtEq as _, CtAssign as _, NonZero, ConcatenatingMul as _, ConcatenatingSquare as _,
140  Gcd as _, Resize as _, BoxedUint,
141};
142
143use super::Error;
144
145mod varint;
146use varint::{encode_varint, decode_varint};
147
148mod bigint;
149use bigint::{encode_bigint, decode_bigint};
150
151mod partial_xgcd;
152use partial_xgcd::t;
153
154/// This function runs in time variable to the input.
155pub(crate) fn encode_compressed_binary_quadratic_form(
156  a: NonZero<BoxedUint>,
157  b_positive: Choice,
158  b_abs: BoxedUint,
159  discriminant_abs: &BoxedUint,
160) -> Vec<u8> {
161  let (t_positive, t_abs) = t(a.clone(), b_abs.clone());
162  let g = a.gcd_vartime(&t_abs);
163  let a = a.get();
164  let a_apo = &a / &g;
165  let a_apo = NonZero::new(a_apo.clone()).expect("`a != 0` so `(a / gcd(a, t)) != 0`");
166  let t_apo_abs = t_abs.get() / &g;
167  let b_0 = b_abs / &a_apo;
168
169  let g_bits = usize::try_from(g.bits()).unwrap();
170  let g_bytes = g_bits.div_ceil(8);
171  let mut result = encode_varint(g_bytes);
172  result.extend(&encode_bigint(g.as_ref(), g_bits));
173  result.extend(&encode_bigint(
174    a_apo.as_ref(),
175    ((usize::try_from(discriminant_abs.bits()).unwrap() - 1) / 2) + 1 - (g_bits - 1),
176  ));
177  result.push((u8::from(t_positive) << 1) | u8::from(b_positive));
178  result.extend(&encode_bigint(
179    &t_apo_abs,
180    ((usize::try_from(discriminant_abs.bits()).unwrap() - 1) / 4) + 1 - (g_bits - 1),
181  ));
182  result.extend(&encode_bigint(&b_0, g_bits));
183  result
184}
185
186/// This function runs in time variable to the input.
187#[cfg(feature = "std")]
188#[expect(clippy::type_complexity)]
189pub(crate) fn decode_compressed_binary_quadratic_form(
190  mut reader: impl io::Read,
191  discriminant_abs: &BoxedUint,
192) -> Result<(NonZero<BoxedUint>, (Choice, BoxedUint), BoxedUint), Error> {
193  debug_assert!(
194    discriminant_abs.floor_sqrt_vartime().bits() <= ((discriminant_abs.bits() - 1) / 2) + 1
195  );
196  debug_assert!(
197    discriminant_abs.floor_sqrt_vartime().floor_sqrt_vartime().bits() <=
198      ((discriminant_abs.bits() - 1) / 4) + 1
199  );
200
201  let g_bytes = u32::try_from(decode_varint(&mut reader)?).map_err(|_| Error::Overflow)?;
202  if g_bytes > (((discriminant_abs.bits() - 1) / 2) + 1).div_ceil(8) {
203    Err(Error::Incorrect)?;
204  }
205  let g = decode_bigint(&mut reader, g_bytes * 8)?;
206  let g_bits = g.bits();
207  if g_bytes != g_bits.div_ceil(8) {
208    Err(Error::NonCanonical)?;
209  }
210  let g = Option::<NonZero<_>>::from(NonZero::new(g)).ok_or(Error::Incorrect)?;
211
212  let a_apo = decode_bigint(&mut reader, ((discriminant_abs.bits() - 1) / 2) + 1 - (g_bits - 1))?;
213  let a_apo = NonZero::new(a_apo).ok_or(Error::Incorrect)?;
214  let a = a_apo.concatenating_mul(g.as_ref());
215  let a = NonZero::new(a).expect("the product of two non-zero values is itself non-zero");
216
217  let (b_positive, b_abs) = {
218    let mut sign_bits = [0xff];
219    reader.read_exact(&mut sign_bits).map_err(|_| Error::UnexpectedEof)?;
220    let sign_bits = sign_bits[0];
221    if (sign_bits >> 2) != 0 {
222      Err(Error::NonCanonical)?;
223    }
224    let b_positive = (sign_bits & 1).ct_eq(&1);
225    let t_positive = (sign_bits >> 1).ct_eq(&1);
226
227    let t_apo_abs =
228      decode_bigint(&mut reader, ((discriminant_abs.bits() - 1) / 4) + 1 - (g_bits - 1))?;
229    if bool::from(t_apo_abs.is_zero()) {
230      Err(Error::Incorrect)?;
231    }
232    let t_abs = t_apo_abs.concatenating_mul(g.as_ref());
233
234    let b_abs = {
235      let s_apo = {
236        let s = {
237          let x = t_abs.square_mod(&a).mul_mod(discriminant_abs, &a).neg_mod(&a);
238
239          let s = x.floor_sqrt_vartime();
240          if s.concatenating_square() != x {
241            Err(Error::Incorrect)?;
242          }
243          s
244        };
245
246        let (s_apo, zero) = s.div_rem(&g);
247        if bool::from(!zero.is_zero()) {
248          Err(Error::Incorrect)?;
249        }
250        s_apo
251      };
252
253      if bool::from(!t_apo_abs.gcd_vartime(&a_apo).is_one()) {
254        Err(Error::Incorrect)?;
255      }
256      let u = t_apo_abs
257        .resize(a_apo.bits_precision())
258        .invert_mod(&a_apo)
259        .expect("non-zero and coprime but no modular inverse?");
260      let mut b_apo = s_apo.mul_mod(&u, &a_apo);
261      b_apo.ct_assign(&b_apo.neg_mod(&a_apo), !t_positive);
262
263      let b_0 = decode_bigint(&mut reader, g_bits)?;
264      if b_0 > *g {
265        Err(Error::Incorrect)?;
266      }
267
268      b_0.concatenating_mul(a_apo.as_ref()).concatenating_add(&b_apo)
269    };
270
271    {
272      if b_abs > (*a.as_ref()) {
273        Err(Error::Incorrect)?;
274      }
275      let (t_positive_recalculated, t_abs_recalculated) = t(a.clone(), b_abs.clone());
276      if (bool::from(t_positive), t_abs) !=
277        (bool::from(t_positive_recalculated), t_abs_recalculated.get())
278      {
279        Err(Error::NonCanonical)?;
280      }
281    }
282
283    (b_positive, b_abs)
284  };
285
286  Option::from(super::validate_binary_quadratic_form(a, (b_positive, b_abs), discriminant_abs))
287    .ok_or(Error::Incorrect)
288}