class_groups/crypto_bigint/encoding/compressed/mod.rs
1//! Compression of binary quadratic forms
2//!
3//! This implements compression of primitive reduced positive definite binary quadratic forms of
4//! negative discriminants (even or odd) such that they can represented in approximately
5//! `2 + 1.5 (floor(log_2(sqrt(|discriminant|))) + 1)` bits instead of the naïvely required
6//! `1 + 2 (floor(log_2(sqrt(|discriminant|))) + 1)` (the absolute value of the `a`, `b`
7//! coefficients and the sign of the `b` coefficient).
8//!
9//! The methodology is as posited in
10//! [Trustless unknown-order groups](https://eprint.iacr.org/2020/196) by Samuel Dobson,
11//! Steven Galbraith, and Benjamin Smith. They do describe a rather complete description of
12//! compression, and the result, which we appreciate but continue the specification of (and
13//! somewhat differ from).
14//!
15//! Notably, instead of encoding the `b` coefficient as a pair of congruences, we encode the `b`
16//! coefficient as the result of a Euclidean division. This avoids having to find a solution for
17//! `f`, which lacked bounds (and presumably only had a statistical bound regarding the
18//! distribution of prime factors).
19//!
20//! We bound to primitive forms as we do not practically want to work with imprimitive forms and
21//! MUST eventially validate all forms to be primitive (upon decode being the logical place to do
22//! so, ensuring imprimitive forms never even enter our context). We bound to reduced forms to
23//! ensure a canonical representation. We bound to positive definite forms (of negative
24//! discriminant) to ensure the `a` coefficient is positive and non-zero, as required for the
25//! compression algorithm (to avoid it as an exceptional case).
26//!
27//! The pseudocode denotes what we would discuss as `z'` as `z_apo`, instead of the more
28//! traditional `z_prime` (for arbitrary "z"). This is to avoid confusion on if this variable is
29//! notably considered (co)prime.
30//!
31//! We assume the existence of:
32//! - A `gcd` function, which for `gcd(x, y)` returns the greatest common divisor of `x, y`
33//! - An `xgcd` function, which for `xgcd(x, y)`, returns `(u, v, d)` where `u * x + v * y = d` and
34//! `d = gcd(x, y)`.
35//! - A `floor_sqrt` function, which for `floor_sqrt(x)`, returns `y` where $y^2 \le x < (y + 1)^2$
36//! - A `floor_log_2` function, which for `floor_log_2(x)`, returns `k` such that
37//! $2^k \le x < 2^{k + 1}$
38//!
39//! `//` is used to represent floor division.
40//!
41//! ```py
42//! # Note `discriminant` is a _signed_ big integer, bound to be negative
43//! fn encode_compressed_binary_quadratic_form(a, b_positive, b_abs, discriminant) {
44//! (t_positive, t_abs) = t(a, b_abs)
45//! g = gcd(a, t_abs)
46//! a_apo = a / g
47//! t_apo_abs = t_abs / g
48//! b_0 = b_abs // a_apo
49//!
50//! g_bits = floor_log_2(g) + 1
51//! g_bytes = (g_bits + 7) // 8
52//! result = encode_varint(g_bytes)
53//!
54//! result.extend(encode_bigint(g, g_bits))
55//! result.extend(encode_bigint(a_apo, (floor_log_2(-discriminant) // 2) + 1 - (g_bits - 1)))
56//!
57//! result.push((t_positive << 1) | b_positive)
58//!
59//! result.extend(encode_bigint(t_apo_abs, (floor_log_2(-discriminant) // 4) + 1 - (g_bits - 1)))
60//! result.extend(encode_bigint(b_0, g_bits))
61//!
62//! return result
63//! }
64//!
65//! # Note `discriminant` is a _signed_ big integer, bound to be negative
66//! fn decode_compressed_binary_quadratic_form(bytestream, discriminant) {
67//! g_bytes = decode_varint(bytestream)
68//! assert g_bytes <= ((((floor_log_2(-discriminant) // 2) + 1) + 7) // 8)
69//! g = decode_bigint(bytestream, g_bytes * 8)
70//! g_bits = floor_log_2(g) + 1
71//! assert g_bytes == ((g_bits + 7) // 8)
72//!
73//! a_apo = decode_bigint(bytestream, (floor_log_2(-discriminant) // 2) + 1 - (g_bits - 1))
74//!
75//! a = a_apo * g
76//! # For a negative discriminant, `a != 0`
77//! assert a != 0
78//!
79//! sign_bits = bytestream.next_byte()
80//! # Ensure `sign_bits` was canonically encoded
81//! assert (sign_bits >> 2) == 0
82//! b_positive = sign_bits & 1
83//! t_positive = sign_bits >> 1
84//!
85//! t_apo_abs = decode_bigint(bytestream, (floor_log_2(-discriminant) // 4) + 1 - (g_bits - 1))
86//!
87//! t_abs = t_apo_abs * g
88//! # We ignore the sign of `t` here as `-1 * -1 = 1`
89//! x = (t_abs * t_abs * discriminant) % a
90//!
91//! s = floor_sqrt(x)
92//! assert (s * s) == x
93//!
94//! s_apo = s // g
95//! assert (s_apo * g) == s
96//! # `u t_apo_abs + v a_apo = d` where `d = gcd(a, b)`
97//! (u, _v, one) = xgcd(t_apo_abs, a_apo)
98//! # This asserts the modular inverse exists and that `g = gcd(t, a)`
99//! assert one == 1
100//! b_apo = (s_apo * u) % a_apo
101//! # If `t` was negative, negate `b_apo % a_apo`
102//! if (b_apo != 0) && (!t_positive) {
103//! b_apo = a_apo - b_apo
104//! }
105//!
106//! b_0 = decode_bigint(bytestream, g_bits)
107//! assert b_0 <= g
108//! b_abs = (b_0 * a_apo) + b_apo
109//!
110//! # Assert `b_abs <= a`
111//! # This is a prerequisite for calling `t`, which so bounds its inputs
112//! assert b_abs <= a
113//! # Assert `t` was canonically chosen
114//! assert (t_positive, t_abs) == t(a, b_abs)
115//!
116//! return validate_binary_quadratic_form(a, b_positive, b_abs, discriminant)
117//! }
118//! ```
119//!
120//! When decoding `a_apo, g`, their bit bounds are such that their product is at most the square
121//! root of the discriminant. Note these bit bounds aren't strictly enforced, solely used to
122//! determine the lengths of the big integers' encodings, and we ensure they're canonical via
123//! validating `g_bytes` upon deserialization and ensuring the resulting form is in fact reduced.
124//! Similarly, when decoding `t_apo_abs, b_0`, their bit bounds are such that their product is at
125//! most the fourth root of the discriminant (if the bit bounds were enforced, where we do validate
126//! `b_0 <= g` and then validate `t` was canonically chosen and therefore `t_apo_abs` was correctly
127//! encoded). This causes the encoding, ignoring the alignment of the big integers to byte
128//! boundaries, to be of length $v + \lfloor log_2(-discriminant)^{3 / 4} \rfloor + 11$ where $v$
129//! is the length of the VarInt encoding of $g_bytes$ (experimentally, $1$ in the average case).
130//! Note most of the $11$ is from using an entire byte to represent the two sign bits, $b_positive$
131//! and $t_positive$.
132
133use alloc::vec::Vec;
134
135#[cfg(feature = "std")]
136use std::io;
137
138use crypto_bigint::{
139 Choice, CtEq as _, CtAssign as _, NonZero, ConcatenatingMul as _, ConcatenatingSquare as _,
140 Gcd as _, Resize as _, BoxedUint,
141};
142
143use super::Error;
144
145mod varint;
146use varint::{encode_varint, decode_varint};
147
148mod bigint;
149use bigint::{encode_bigint, decode_bigint};
150
151mod partial_xgcd;
152use partial_xgcd::t;
153
154/// This function runs in time variable to the input.
155pub(crate) fn encode_compressed_binary_quadratic_form(
156 a: NonZero<BoxedUint>,
157 b_positive: Choice,
158 b_abs: BoxedUint,
159 discriminant_abs: &BoxedUint,
160) -> Vec<u8> {
161 let (t_positive, t_abs) = t(a.clone(), b_abs.clone());
162 let g = a.gcd_vartime(&t_abs);
163 let a = a.get();
164 let a_apo = &a / &g;
165 let a_apo = NonZero::new(a_apo.clone()).expect("`a != 0` so `(a / gcd(a, t)) != 0`");
166 let t_apo_abs = t_abs.get() / &g;
167 let b_0 = b_abs / &a_apo;
168
169 let g_bits = usize::try_from(g.bits()).unwrap();
170 let g_bytes = g_bits.div_ceil(8);
171 let mut result = encode_varint(g_bytes);
172 result.extend(&encode_bigint(g.as_ref(), g_bits));
173 result.extend(&encode_bigint(
174 a_apo.as_ref(),
175 ((usize::try_from(discriminant_abs.bits()).unwrap() - 1) / 2) + 1 - (g_bits - 1),
176 ));
177 result.push((u8::from(t_positive) << 1) | u8::from(b_positive));
178 result.extend(&encode_bigint(
179 &t_apo_abs,
180 ((usize::try_from(discriminant_abs.bits()).unwrap() - 1) / 4) + 1 - (g_bits - 1),
181 ));
182 result.extend(&encode_bigint(&b_0, g_bits));
183 result
184}
185
186/// This function runs in time variable to the input.
187#[cfg(feature = "std")]
188#[expect(clippy::type_complexity)]
189pub(crate) fn decode_compressed_binary_quadratic_form(
190 mut reader: impl io::Read,
191 discriminant_abs: &BoxedUint,
192) -> Result<(NonZero<BoxedUint>, (Choice, BoxedUint), BoxedUint), Error> {
193 debug_assert!(
194 discriminant_abs.floor_sqrt_vartime().bits() <= ((discriminant_abs.bits() - 1) / 2) + 1
195 );
196 debug_assert!(
197 discriminant_abs.floor_sqrt_vartime().floor_sqrt_vartime().bits() <=
198 ((discriminant_abs.bits() - 1) / 4) + 1
199 );
200
201 let g_bytes = u32::try_from(decode_varint(&mut reader)?).map_err(|_| Error::Overflow)?;
202 if g_bytes > (((discriminant_abs.bits() - 1) / 2) + 1).div_ceil(8) {
203 Err(Error::Incorrect)?;
204 }
205 let g = decode_bigint(&mut reader, g_bytes * 8)?;
206 let g_bits = g.bits();
207 if g_bytes != g_bits.div_ceil(8) {
208 Err(Error::NonCanonical)?;
209 }
210 let g = Option::<NonZero<_>>::from(NonZero::new(g)).ok_or(Error::Incorrect)?;
211
212 let a_apo = decode_bigint(&mut reader, ((discriminant_abs.bits() - 1) / 2) + 1 - (g_bits - 1))?;
213 let a_apo = NonZero::new(a_apo).ok_or(Error::Incorrect)?;
214 let a = a_apo.concatenating_mul(g.as_ref());
215 let a = NonZero::new(a).expect("the product of two non-zero values is itself non-zero");
216
217 let (b_positive, b_abs) = {
218 let mut sign_bits = [0xff];
219 reader.read_exact(&mut sign_bits).map_err(|_| Error::UnexpectedEof)?;
220 let sign_bits = sign_bits[0];
221 if (sign_bits >> 2) != 0 {
222 Err(Error::NonCanonical)?;
223 }
224 let b_positive = (sign_bits & 1).ct_eq(&1);
225 let t_positive = (sign_bits >> 1).ct_eq(&1);
226
227 let t_apo_abs =
228 decode_bigint(&mut reader, ((discriminant_abs.bits() - 1) / 4) + 1 - (g_bits - 1))?;
229 if bool::from(t_apo_abs.is_zero()) {
230 Err(Error::Incorrect)?;
231 }
232 let t_abs = t_apo_abs.concatenating_mul(g.as_ref());
233
234 let b_abs = {
235 let s_apo = {
236 let s = {
237 let x = t_abs.square_mod(&a).mul_mod(discriminant_abs, &a).neg_mod(&a);
238
239 let s = x.floor_sqrt_vartime();
240 if s.concatenating_square() != x {
241 Err(Error::Incorrect)?;
242 }
243 s
244 };
245
246 let (s_apo, zero) = s.div_rem(&g);
247 if bool::from(!zero.is_zero()) {
248 Err(Error::Incorrect)?;
249 }
250 s_apo
251 };
252
253 if bool::from(!t_apo_abs.gcd_vartime(&a_apo).is_one()) {
254 Err(Error::Incorrect)?;
255 }
256 let u = t_apo_abs
257 .resize(a_apo.bits_precision())
258 .invert_mod(&a_apo)
259 .expect("non-zero and coprime but no modular inverse?");
260 let mut b_apo = s_apo.mul_mod(&u, &a_apo);
261 b_apo.ct_assign(&b_apo.neg_mod(&a_apo), !t_positive);
262
263 let b_0 = decode_bigint(&mut reader, g_bits)?;
264 if b_0 > *g {
265 Err(Error::Incorrect)?;
266 }
267
268 b_0.concatenating_mul(a_apo.as_ref()).concatenating_add(&b_apo)
269 };
270
271 {
272 if b_abs > (*a.as_ref()) {
273 Err(Error::Incorrect)?;
274 }
275 let (t_positive_recalculated, t_abs_recalculated) = t(a.clone(), b_abs.clone());
276 if (bool::from(t_positive), t_abs) !=
277 (bool::from(t_positive_recalculated), t_abs_recalculated.get())
278 {
279 Err(Error::NonCanonical)?;
280 }
281 }
282
283 (b_positive, b_abs)
284 };
285
286 Option::from(super::validate_binary_quadratic_form(a, (b_positive, b_abs), discriminant_abs))
287 .ok_or(Error::Incorrect)
288}