Skip to main content

class_groups/crypto_bigint/
reduction.rs

1//! A constant-time reduction algorithm.
2//!
3//! This is derived from Algorithm 1 of <https://eprint.iacr.org/2022-466>. Some typos have been
4//! accounted for. Algorithm 2 is a restatement of Algorithm 1 in an iterative fashion and
5//! accordingly may of more approximate structure to the following, yet this work was independently
6//! derived from Algorithm 1.
7//!
8//! In order to be efficient, this does not always operate over the full amount of limbs of each
9//! number. Instead, as the reduction occurs (and as the numbers shrink), the amount of limbs
10//! operated over also reduce. This requires extensively arguing the bounds for inputs and the
11//! outputs of each function. This is done via brief proofs written in comments within each
12//! function.
13
14#![expect(clippy::needless_pass_by_value)] // Triggered by `(&mut Choice, &mut L)`
15#![expect(clippy::inline_always)]
16
17use crypto_bigint::{Choice, CtEq, CtSelect, CtLt as _, CtGt as _, Zero, BitOps, Limb, UintRef};
18
19/// A collection of limbs and associated helper methods.
20///
21/// The provided reduction algorithm dances along the `Limb` boundaries, performance requiring
22/// correct decision of when to terminate execution of a given function. This API unifies `Uint`
23/// and `BoxedUint` (in a way `Integer` appeared ineligible for) while providing the niche methods
24/// required for performance.
25///
26/// Implementations MAY iterate up to the `limbs` argument (for performance) or MAY ignore it.
27/// Callers MUST NOT expect that if they specify a `limbs` argument, operations will only occur to
28/// that subset of limbs, and any results are undefined when any non-included limbs are non-zero.
29/// Callers MUST NOT specify more `limbs` than the value has.
30///
31/// Implementations MUST implement all functions in time constant to the value of the inputs,
32/// except for the amount of limbs, unless otherwise stated. Implementations MUST NOT panic for any
33/// input which the caller MAY pass.
34///
35/// A long-term goal is to replace this entirely for just `UintRef`. Currently, the reduction
36/// algorithm still requires allocating one scratch variable however, making this non-immediate.
37pub(crate) trait Limbs: AsRef<[Limb]> + AsMut<[Limb]> + CtEq + Zero + BitOps {
38  /// The number but with precision equal to `self`.
39  ///
40  /// This is equivalent to [`crypto_bigint::Zero::zero_like`] but avoids requiring `Self: Clone`.
41  /// We do not want to bound `Self: Clone` for performance reasons. Specifically, it's a goal of
42  /// the reduction algorithm to not allocate at all (for performance reasons), this one function
43  /// being necessary in one spot and the current sole exception.
44  fn like_zero(&self) -> Self;
45
46  /// Swap the values of `self` and `b` if `choice` is `true`.
47  ///
48  /// This is a basic helper as there is no `UintRef::ct_swap`.
49  #[inline(always)]
50  fn swap(&mut self, b: &mut Self, choice: Choice) {
51    let a = &mut <_ as AsMut<[Limb]>>::as_mut(self);
52    let b = &mut <_ as AsMut<[Limb]>>::as_mut(b);
53    for (a, b) in a.iter_mut().zip(b.iter_mut()) {
54      <_>::ct_swap(a, b, choice);
55    }
56  }
57}
58
59/// Obtain the equivalent form `(a', b', c')` for which `floor(log_2(a')) <= floor(log_2(c'))`.
60///
61/// This assumes the values for `a` and `c` would each fit in each others' variables.
62///
63/// This corresponds to step 2 of Algorithm 1, albeit solely confirming the `floor(log_2(_))` is
64/// smaller, not the value itself. This is much cheaper to evaluate and `b'` is still able to
65/// reduce by at least one bit for as long it has a greater logarithm. If `b'` has an equal
66/// logarithm, then we are at the final iteration, for which `a_lte_c` MUST be used.
67#[inline(always)]
68fn approximate_a_lte_c<L: Limbs>(
69  a: (&mut u32, &mut L),
70  b_sign: &mut Choice,
71  c: (&mut u32, &mut L),
72) {
73  let c_lt_a = c.0.ct_lt(a.0);
74  c.0.ct_swap(a.0, c_lt_a);
75  L::swap(a.1, c.1, c_lt_a);
76
77  /*
78    This line differs from the paper, whose described algorithm has a some typos (as further
79    evidenced by the correctness proof transcribing line 6 as
80    "[C - epsilon m B + m^2 A]" as "[C, - epsilon m B + A^2]").
81
82    As this swaps `a, c`, and as it's known `(a, b, c) == (c, -b, a)`, we MUST negate `b` here
83    if we performed a swap.
84  */
85  *b_sign ^= c_lt_a;
86}
87
88/// Obtain the equivalent form `(a', b', c')` for which `a' <= c'`.
89///
90/// This assumes `c` has at least as many limbs as `a`.
91///
92/// This corresponds to step 2 of Algorithm 1.
93#[inline(always)]
94fn a_lte_c<L: Limbs>(a: &mut L, b_sign: &mut Choice, c: &mut L) {
95  let limbs = <_ as AsRef<[Limb]>>::as_ref(c).len();
96  let c_lt_a = UintRef::new(&c.as_ref()[.. limbs]).ct_lt(UintRef::new(&a.as_ref()[.. limbs]));
97  L::swap(a, c, c_lt_a);
98  *b_sign ^= c_lt_a;
99}
100
101/// If `reduce_to_next_bit` should actually apply, except possibly incorrect.
102///
103/// This MAY yield `false` even when one final iteration should be ran, hence `except_final`. It is
104/// correct for all but the final iteration.
105///
106/// This simplifies `|b| > a` to `floor(log_2(|b|)) > floor(log_2(a))`, which is true whenever
107/// `|b| >= 2a`. See `should_reduce_to_next_bit_final` for why this is sufficient for all but the
108/// last iteration.
109///
110/// This is optimized by only indicating the reduction algorithm should be run when
111/// `(floor(log_2(|b|)) + 1) == b_bits_bound` _not_ always if `|b| >= 2a`. This is done as:
112///
113/// 1) It is still correct. `reduce_to_next_bit`, if called correctly, must be called from the
114///    current bound to the minimal bound, the current bound decrementing by one bit with each call,
115///    as `reduce_to_next_bit` is only guaranteed to reduce `|b|` by a single bit (until `b` is
116///    reduced). Assuming the bound is properly decremented with each call, then we know `|b|` is
117///    within the bound for each call, as the iteration is either unnecessary (the current bound
118///    exceeding the actual `floor(log_2(|b|)) + 1`) or will be reduced by at least one bit (and
119///    therefore within the next iteration's bound) if not already reduced.
120///
121/// 2) It is faster to check if `(floor(log_2(|b|)) + 1) == b_bits_bound` than to calculate
122///    `floor(log_2(|b|))`, even with how cheap that operation is. Finding the leading bit is an
123///    operation of linear complexity, while checking if the highest bit within the bound is set is
124///    of constant complexity (assuming the bound is public, allowing us to perform the retrieval
125///    with a variable memory-access pattern). This optimization decreased the time of point
126///    doubling by ~3.5%, implying this to optimize ~5% of reduction.
127///
128/// 3) If we always did the iterations which need to happen at the start, than being off-by-one
129///    would be quite hard to detect as the last iterations are unlikely to actually be necessary.
130///    This means `|b|` is likely below the bound for the entire duration of the algorithm, and the
131///    bound being inaccurate may not be noticed. As this methodology intersperses the necessary
132///    iterations with the unnecessary, `|b|` is worked with at the bound itself, which more
133///    aggressively requires the accuracy of the bounds. This itself helps to ensure the bounds are
134///    accurate.
135///
136/// This function assumes $|delta| \cong 1 \mod 2$, `a_bits = floor(log_2(a)) + 1`, and
137/// `b_bits_bound >= floor(log_2(|b|)) + 1` when `b_lte_a == false`. This function will update
138/// `b_lte_a`, but may inaccurately set it to `true` upon reaching the final necessary iteration.
139/// This function also assumes `b_bits_bound > 0` and may be incorrect or panic in that case.
140#[inline(always)]
141fn should_reduce_to_next_bit_except_final(
142  b: (&mut Choice, &mut UintRef),
143  b_needs_negation: Choice,
144  b_lte_a: &mut Choice,
145  a_bits: u32,
146  b_bits_bound: u32,
147) -> Choice {
148  // Because this check is only valid `b_lte_a == false`, short-circuit if `b_lte_a == false`
149  let b_gt_a = (!*b_lte_a) & b_bits_bound.ct_gt(&a_bits);
150
151  /*
152    Update `b_lte_a`.
153
154    If `b_lte_a` is set, this will never unset it, as `b_gt_a` won't be set if `b_lte_a` was.
155  */
156  *b_lte_a = !b_gt_a;
157
158  /*
159    Only run this iteration if this specific bit of `b` is in fact set.
160
161    If `b` needs to be negated, we check if the bit is _not_ set, which is immediately an
162    _incorrect optimization_. The negation process is defined as the logical NOT and a carrying
163    addition of `1`.
164
165    As an immediate counterexample, `1` has negative representation `0xFF`, which _should_ be
166    considered as having its trailing bit set (as when negated, it's `1`) but would not be
167    considered so here.
168
169    We bound `|delta|` to be odd, so as `b^2 - 4ac = delta`, taking the expression modulo `4`, it's
170    obvious `b` must be odd for any valid form (even if unreduced). This means we know the addition
171    of `1` will always set the trailing bit, making this optimization correct _for all but the
172    trailing bit_.
173
174    That makes this correct _except_ when `b_bits_bound = 1`, but any form with a 1-bit `b`
175    coefficient is already reduced (potentially after normalizing its sign). Accordingly, the fact
176    we incorrectly consider `-1` as not having its trailing bit set, still leads to a correct
177    result that no further reduction should occur.
178  */
179  b_gt_a &
180    Choice::from(u8::from(b.1.bit_vartime(b_bits_bound.wrapping_sub(1))))
181      .ct_eq(&!b_needs_negation)
182}
183
184/// If `reduce_to_next_bit` should actually apply.
185///
186/// This uses an exact determination for if the function should apply, which is only necessary when
187/// `floor(log_2(|b|)) == floor(log_2(a))`. However, in this case, `b <= 2 a`. The paper itself
188/// establishes the bound the algorithm will run for at most two more iterations in this case, yet
189/// a tighter bound of just one more iteration is possible. This is as the `b == 2a` case will
190/// resolved with `b' = 0`, while the case `a < b < 2a` will result in `b' < a`.
191///
192/// That's why this is 'final', as while it works all of the time, it should only be used for one
193/// final iteration.
194///
195/// This function assumes `<_ as AsRef<[Limb]>>(a).len() == <_ as AsRef<[Limb]>>(b.1).len()`.
196///
197/// Comments within this function which would be duplicated with
198/// `should_reduce_to_next_bit_except_final` are omitted.
199#[inline(always)]
200fn should_reduce_to_next_bit_final<L: Limbs>(a: &L, b: (&Choice, &UintRef)) -> Choice {
201  // If `a - b.1` has a borrow afterwards, then `b.1 > a`
202  let b_abs = b.1.as_limbs();
203  let mut borrow = Limb::ZERO;
204  for (a_limb, b_limb) in <_ as AsRef<[Limb]>>::as_ref(a)[.. b_abs.len()].iter().zip(b_abs) {
205    let _a_diff_b_abs_limb;
206    (_a_diff_b_abs_limb, borrow) = a_limb.borrowing_sub(*b_limb, borrow);
207  }
208  !borrow.is_zero()
209}
210
211/// Reduce the `b` coefficient by at least one bit or until reduced.
212///
213/// For a positive definite binary quadratic form `(a, b, c)` such that:
214/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
215/// - $delta \cong 1 \mod 2$
216/// - `0 <= a, c` (`a` and `c` aren't negative, as enforced by the type system)
217/// - `floor(log_2(a)) <= floor(log_2(c))` (such as forms output of `(approximate_)a_lte_c`)
218/// - `limbs <= <_ as AsRef<[Limb]>>(a).len()`
219/// - `limbs <= <_ as AsRef<[Limb]>>(b.1).len()`
220/// - The variable `c` does in fact contain the full representation of the `c` coefficient.
221///
222/// We require the following bounds when `should_reduce == true`:
223/// - `a < b`
224/// - `a_bits = floor(log_2(a)) + 1`
225/// - `b_bits = floor(log_2(|b|)) + 1`
226/// - `1 + b_bits <= (limbs * Limb::BITS)`
227///
228/// Yield an equivalent form `(a', b', c')` such that:
229/// - `a' = a`
230/// - `floor(log_2(|b'|)) <= floor(log_2(|b|)) - 1` if `|b| > 2 a` and `should_reduce == true`,
231///   else `(a', b', c') = (a, b, c)`
232/// - `|b'| <= a'` if `a < |b| <= 2 a` and `should_reduce == true`.
233///
234/// This corresponds to steps 3, 4, and 6 of Algorithm 1, except as a NOP if
235/// `should_reduce == false` (`b <= a`) (in which case the form is reduced, or reduced after
236/// normalizing the sign of `b`).
237#[expect(clippy::too_many_arguments)]
238#[inline(always)]
239fn reduce_to_next_bit<L: Limbs>(
240  a: &L,
241  b: (&mut Choice, &mut UintRef),
242  b_needs_negation: &mut Choice,
243  c: &mut L,
244  should_reduce: Choice,
245  limbs: usize,
246  a_bits: u32,
247  b_bits: u32,
248) {
249  #[cfg(debug_assertions)]
250  {
251    debug_assert!(bool::from(a.bits().ct_lt(&c.bits()) | a.bits().ct_eq(&c.bits())));
252    debug_assert!(limbs <= <_ as AsRef::<[Limb]>>::as_ref(a).len());
253    debug_assert!(limbs <= <_ as AsRef::<[Limb]>>::as_ref(&b.1).len());
254    debug_assert!(bool::from((!should_reduce) | UintRef::new(a.as_ref()).ct_lt(b.1)));
255    debug_assert!(bool::from((!should_reduce) | a_bits.ct_eq(&a.bits())));
256    debug_assert!(bool::from((*b_needs_negation) | (!should_reduce) | b_bits.ct_eq(&b.1.bits())));
257    debug_assert!(bool::from((!should_reduce) | b_bits.ct_lt(&b.1.bits_precision())));
258  }
259
260  // Calculate `m` (the body of step 3's branch, step 4)
261  let log_2_m = {
262    // This is only well-defined if `a_bits < b_bits`
263    let log_2_m = b_bits.wrapping_sub(a_bits).wrapping_sub(1);
264    // Set `m = 0` if `m` they have equal bit lengths or if `m` wouldn't be well-defined otherwise
265    <_ as CtSelect>::ct_select(&0, &log_2_m, (!a_bits.ct_eq(&b_bits)) & should_reduce)
266  };
267
268  // Step 6
269
270  /*
271    This is a container of size `c` as we later operate on it with the bound the derivative is
272    `<= c`, which means this has to be large enough to contain a number `<= c`.
273
274    TODO: Can we remove the requirement for this scratch variable? Presumably not, as we have a
275    constant-time shift of unbounded bit-length (where the shift may exceed a limb), so this can't
276    trivially be done as we iterate over limbs. We then need this to calculate `b` before we again
277    use it to calculate `c`, so we can't write it directly into one of those. We could directly
278    modify `a`, or at least, require we be passed in a copy of `a` we then use as scratch, but that
279    really just defers allocating this scratch variable to the caller.
280
281    This is currently the only explicit `clone` (or equivalent) in this entire file.
282  */
283  let mut m_a = c.like_zero();
284  // When `should_reduce = true`, `((1 << log_2_m) * a) < b`, so this will fit in `limbs` limbs
285  {
286    let m_a = UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut m_a)[.. limbs]);
287    m_a.copy_from_slice(&a.as_ref()[.. limbs]);
288    m_a.shl_assign(log_2_m);
289  }
290  let m_a = UintRef::new_mut(<_ as AsMut<[Limb]>>::as_mut(&mut m_a));
291
292  /*
293    The following does _not_ swap `c, a`, as we always perform any necessary swap during the next
294    iteration's step 2 regardless. This means our `c` is updated to the paper's output `a`, and our
295    `a` is left as-is.
296
297    Note that in terms of this original paper which outputs `(a, b, c)`, this outputs `(c, b, a)`,
298    which is not an equivalent form. The equivalent form would be `(c, -b, a)`. The paper is
299    missing a negation on the output of its form, and once that's considered, this is equivalent.
300  */
301
302  // $\epsilon b == |b|$ since $\epsilon = \mathsf{sgn}(b)$
303  /*
304    Instead of calculating `- epsilon m b + m m a`, we calculate `m (-|b| + m a)` to reduce the
305    bit-length of the addition within the parentheses. As `m a < b` when `a < b`, the evaluation
306    of the parentheses is negative and has an absolute value `< |b|` (which fits in `limbs` limbs)
307    whenever `should_reduce == true`.
308
309    When `should_reduce == false`, `b_diff_m_a` is set to `0` so we may unconditionally calculate
310    the new `c` coefficient as `c - m b_diff_m_a`.
311
312    We simultaneously calculate $|b| - m a$ and $b - \epsilon 2 m a$ as we can merge their loops.
313    While the resulting code is of non-trivially greater complexity, it saves ~8% of the time to
314    execute.
315
316    Note $b - \epsilon 2 m a$ may underflow, causing `b'` to have a distinct sign from `b`. In
317    order to ensure the variable `b.1` remains the absolute value, this would require the logical
318    NOT operator combined with a carrying addition of `1` _after_ calculating $b - \epsilon 2 m a$.
319    To avoid another loop, we instead defer performing the negation to the next iteration's
320    instance of _this_ loop, reducing the amount of times we introduce flow control/branching, via
321    the `b_needs_negation` variable. Note the caller must handle `b_needs_negation` when shifting
322    `b`'s limb boundaries.
323  */
324  {
325    *b.0 ^= *b_needs_negation;
326    let b_negation_carry = Limb::from(u8::from(*b_needs_negation));
327    let b_negation_mask = Limb::ZERO.wrapping_sub(b_negation_carry);
328
329    let mut b_diff_m_a_carry = Limb::ZERO;
330
331    let mut two_m_a_carry = Limb::ZERO;
332    /*
333      We express `b - 2 m a` as `b + -(2 m a)`, where the negation requires a carry of `1`
334      (hence why this is initialized to `1` when `should_reduce == true`). We simultaneously
335      apply the deferred negation to `b`, hence why we also sum `b_negation_carry`.
336    */
337    let mut b_diff_two_m_a_carry =
338      Limb::from(u8::from(should_reduce)).wrapping_add(b_negation_carry);
339
340    for (b_limb, m_a_limb) in b.1.iter_mut().zip(m_a.iter_mut()) {
341      /*
342        Set `b' = b - 2 m a`, while simultaneously negating `b_limb` (if necessary).
343
344        Negating `b'` is expressed as the logical NOT combined with a carrying addition of `1`.
345        This is incompatible with needing to perform a borrowing subtraction of `2 m a`. Instead,
346        we rewrite it as `b + -(2 m a)`, where `2 m a`'s negation can be expressed with a
347        carrying addition. This means all three aspects (negating `b` if necessary, negating
348        `2 m a`, and summing `b, -2 m a`) can be so expressed and done simultaneously.
349      */
350      {
351        let two_m_a_limb: Limb = ((*m_a_limb) << 1) | two_m_a_carry;
352        two_m_a_carry = (*m_a_limb) >> const { Limb::BITS - 1 };
353
354        /*
355          `floor(log_2(|b|)) = floor(log_2(2 m a))`, so their difference `|b'|` has the property
356          `floor(log_2(|b'|)) < floor(log_2(|b|))`. Therefore, `|b'|` will fit in any container
357          which fits `|b|`.
358        */
359        let new_b_limb;
360        (new_b_limb, b_diff_two_m_a_carry) = ((*b_limb) ^ b_negation_mask).carrying_add(
361          Limb::ct_select(&Limb::ZERO, &!two_m_a_limb, should_reduce),
362          b_diff_two_m_a_carry,
363        );
364        *b_limb = new_b_limb;
365      }
366
367      /*
368        Calculate `b_diff_m_a` as `b' + m a` (where `b' = b - 2 m a`).
369
370        This writes `b_diff_m_a` directly into `m_a`, as we have no further use for `m_a`.
371      */
372      {
373        let new_b_diff_m_a_limb;
374        (new_b_diff_m_a_limb, b_diff_m_a_carry) =
375          (*b_limb).carrying_add(*m_a_limb, b_diff_m_a_carry);
376        *m_a_limb = Limb::ct_select(&Limb::ZERO, &new_b_diff_m_a_limb, should_reduce);
377      }
378    }
379
380    /*
381      Finish calculating `b'`, handling if `b < 2 m a`.
382
383      Because we expressed `b'` as `b + -(2 m a)`, there is _no_ carry if `2 m a > b`. This means
384      `b'` needs negation if there was _no_ carry.
385    */
386    *b_needs_negation = should_reduce & b_diff_two_m_a_carry.ct_eq(&Limb::ZERO);
387  }
388  let b_diff_m_a = m_a;
389
390  // Calculate `c'`
391  {
392    /*
393      We need to prove that `c >= (m b - m^2 a)`. We do so with the claim the output `c'` will be a
394      positive integer, and therefore `c` MUST be greater than or equal to `m b - m^2 a` (when
395      `should_reduce == true`), as else `c'` would be negative.
396
397      We know each intermediate form is equivalent to the input form, and therefore as for input
398      `(a, b, c)` satisfying `b^2 - (4 a c) = delta`, we have `b'^2 - (4 a' c') = delta`. As
399      `delta < 0`, and `b^2 >= 0`, `4 a' c'` MUST be a positive number. As our algorithm sets
400      `a' = a` where `a` is positive, `c'` must be positive as well.
401
402      Because `c` is greater than or equal to `m b - m^2 a`, it will fit within a container which
403      fits `c`, where `b_diff_m_a` is a container of size equal to `c`'s container (making this
404      `shl` call well-defined).
405    */
406    let m_b_diff_m_square_a = b_diff_m_a;
407    m_b_diff_m_square_a.shl_assign(log_2_m);
408
409    // This subtraction is well-defined as `c >= m_b_diff_m_square_a` when `should_reduce == true`
410    let mut borrow = Limb::ZERO;
411    for (c_limb, m_b_diff_m_square_a_limb) in
412      <_ as AsMut<[Limb]>>::as_mut(c).iter_mut().zip(m_b_diff_m_square_a.as_limbs())
413    {
414      // When `should_reduce == false`, `m_b_diff_m_square_a_limb = 0`, effecting a NOP
415      let new_limb;
416      (new_limb, borrow) = c_limb.borrowing_sub(*m_b_diff_m_square_a_limb, borrow);
417      *c_limb = new_limb;
418    }
419  }
420}
421
422/// Conditionally negate the `b` coefficient for a binary quadratic form of odd discriminant.
423///
424/// Negation is defined as flipping the sign bit, before taking the negative of the `UintRef`
425/// (considered a ring of 2^`k`, for some `k`). The latter process is via taking the logical NOT
426/// before applying a carrying addition of `1`.
427fn negate_b(b: (&mut Choice, &mut UintRef), b_needs_negation: Choice) {
428  *b.0 ^= b_needs_negation;
429  // If this needs negation, apply the logical NOT
430  let mask = Limb::ZERO.wrapping_sub(Limb::from(u8::from(b_needs_negation)));
431  for b_limb in b.1.iter_mut() {
432    *b_limb ^= mask;
433  }
434  /*
435    If this needs negation, complete the process by adding 1.
436
437    As the discriminant is odd, we know `b` is odd. This means, when negated, its trailing bit will
438    will be set, and after the above logical NOT, its trailing bit _will not_ be set. This means
439    the carrying addition is actually solely a regular addition, due to observing there won't be a
440    carry.
441
442    In the case this should not be negated, its trailing bit should be set regardless, so we can
443    simplify this to unilaterally ensuring the trailing bit is set.
444  */
445  b.1[0] |= Limb::ONE;
446}
447
448/// Normalize an almost-reduced element.
449///
450/// For a positive definite binary quadratic form `(a, b, c)` such that:
451/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
452/// - $delta \cong 1 \mod 2$
453/// - `0 <= a, c` (`a` and `c` aren't negative, as enforced by the type system)
454/// - `|b| <= a <= c`
455///
456/// Yield the reduced equivalent form `(a', b', c')` such that:
457/// - `|b'| <= a' <= c'`
458/// - `b' >= 0` if `(|b'| == a') || (a' == c')`
459///
460/// This is intended to correspond to steps 2 and 5 of Algorithm 1.
461#[inline(always)]
462fn normalize<L: Limbs>(a: L, mut b: (Choice, L), c: L) -> (L, (Choice, L), L) {
463  /*
464    Set `b` to be positive if `|b| == a` or `a == c`.
465
466    We do not consider normalizing if `b == 0` to positive as we bound an odd discriminant, meaning
467    `b` will be odd and therefore non-zero.
468  */
469  b.0 |= b.1.ct_eq(&a) | a.ct_eq(&c);
470  (a, b, c)
471}
472
473/// Reduce an element until either its reduced or $|b| < 2^{upper_bound}$.
474///
475/// For a positive definite binary quadratic form `(a, b, c)` such that:
476/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
477/// - $delta \cong 1 \mod 2$
478/// - `0 <= a, c` (`a, c` aren't negative, as enforced by the type system)
479/// - `floor(log_2(a)) + 1 <= log_2_bound`
480/// - `floor(log_2(|b|)) + 1 <= log_2_bound`
481/// - `ceil(log_2_bound / Limb::BITS) <= <L as AsRef::<[Limb]>>::as_ref(&a).len()`
482/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() <= <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
483/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() == <L as AsRef::<[Limb]>>::as_ref(&c).len()`
484///
485/// Yield an equivalent form `(a', b', c')` such that:
486/// - `(a', b', c')` is reduced or $|b| < 2^{upper_bound}$.
487/// - `(a', b', c')` is reduced or `b' > a'`
488///
489/// `b.0, b'.0` are `true` if the value is _positive_.
490#[inline(always)]
491pub(crate) fn reduce_to_upper_bound<L: Limbs>(
492  log_2_bound: u32,
493  mut a: L,
494  mut b: (Choice, L),
495  mut c: L,
496  upper_bound: u32,
497) -> (L, (Choice, L), L) {
498  #[cfg(debug_assertions)]
499  {
500    debug_assert!(bool::from(a.bits().ct_lt(&log_2_bound) | a.bits().ct_eq(&log_2_bound)));
501    debug_assert!(bool::from(b.1.bits().ct_lt(&log_2_bound) | b.1.bits().ct_eq(&log_2_bound)));
502    debug_assert!(
503      usize::try_from(log_2_bound.div_ceil(Limb::BITS)).unwrap() <=
504        <_ as AsRef::<[Limb]>>::as_ref(&a).len()
505    );
506    debug_assert!(
507      <_ as AsRef::<[Limb]>>::as_ref(&a).len() <= <_ as AsRef::<[Limb]>>::as_ref(&b.1).len()
508    );
509  }
510
511  let original_limbs = usize::try_from(log_2_bound.div_ceil(Limb::BITS)).unwrap();
512
513  /*
514    Iterate from our bound on `b` to a `b'` which by bit-length, would satisfy `upper_bound`.
515    Each iteration will reduce the bit length of `b` by at least `1`, until `b <= a` and it is
516    reduced (if given sufficient iterations to reach that point).
517  */
518  {
519    let (b_sign, mut b_value) =
520      (&mut b.0, UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut b.1)[.. original_limbs]));
521    let mut b_lte_a = Choice::FALSE;
522    let mut b_needs_negation = Choice::FALSE;
523
524    let mut limbs = original_limbs;
525
526    // `RangeInclusive` doesn't implement `FixedSizeIterator`, so we use a `Range` instead
527    #[expect(clippy::range_plus_one)]
528    let mut bits = ((upper_bound + 1) .. (log_2_bound + 1)).rev();
529
530    let mut a_bits = a.bits();
531    let mut c_bits = c.bits();
532
533    /*
534      Handle the partial limb we may inherently have by the bound not necessarily perfectly
535      aligning to limbs, and two more bits.
536
537      `reduce_to_next_bit` is documented to need limbs corresponding to one extra bit, which is
538      as `floor(log_2(|b|)) + 1 == floor(log_2(a)) + 1` is a possible input and the function must
539      then calculate `2 m a`.
540
541      We provide one additional bit here as for a value `|b| <= a`, this will only be noticed on
542      the iteration _after_ the condition becomes true, so we need to defer when we move to the
543      smaller amount of limbs until after this later iteration.
544    */
545    {
546      let progress_in_partial_limb = usize::try_from(log_2_bound % Limb::BITS).unwrap();
547      for bits in (&mut bits).take(2 + progress_in_partial_limb) {
548        approximate_a_lte_c((&mut a_bits, &mut a), b_sign, (&mut c_bits, &mut c));
549        let should_reduce = should_reduce_to_next_bit_except_final(
550          (b_sign, b_value),
551          b_needs_negation,
552          &mut b_lte_a,
553          a_bits,
554          bits,
555        );
556        reduce_to_next_bit(
557          &a,
558          (b_sign, b_value),
559          &mut b_needs_negation,
560          &mut c,
561          should_reduce,
562          limbs,
563          a_bits,
564          bits,
565        );
566        debug_assert!(bool::from(
567          b_needs_negation | (!should_reduce) | b_value.bits().ct_lt(&bits)
568        ));
569        c_bits = c.bits();
570      }
571
572      // Negate `b` if necessary, before crossing the limb boundary
573      negate_b((b_sign, b_value), b_needs_negation);
574      b_needs_negation = Choice::FALSE;
575
576      // Only decrement the amount of `limbs` if we did actually have a partial limb
577      if progress_in_partial_limb != 0 {
578        limbs -= 1;
579        b_value = b_value.leading_mut(limbs);
580      }
581    }
582
583    /*
584      Handle each remaining limb.
585
586      While we could use a single loop for both the partial limb and the full limbs, that would
587      have structure approximate to:
588
589      ```
590      for bit {
591        reduce_to_next_bit();
592        if limb {
593          limbs -= 1;
594        }
595      }
596      ```
597
598      and place a branch within every single loop body. This achieves a straight-line, other than
599      the loops' conditionals themselves (which the compiler appears to handle better, possibly as
600      we may use the constant `Limb::BITS` for how many steps this inner loop takes).
601
602      `bits.len() != 0` is used as `bits.is_empty()` (`FixedSizeIterator::is_empty`) is
603      experimental.
604    */
605    while bits.len() != 0 {
606      debug_assert_ne!(limbs, 0);
607
608      #[expect(clippy::as_conversions)]
609      for bits in (&mut bits).take(const { Limb::BITS as usize }) {
610        approximate_a_lte_c((&mut a_bits, &mut a), b_sign, (&mut c_bits, &mut c));
611        let should_reduce = should_reduce_to_next_bit_except_final(
612          (b_sign, b_value),
613          b_needs_negation,
614          &mut b_lte_a,
615          a_bits,
616          bits,
617        );
618        reduce_to_next_bit(
619          &a,
620          (b_sign, b_value),
621          &mut b_needs_negation,
622          &mut c,
623          should_reduce,
624          limbs,
625          a_bits,
626          bits,
627        );
628        debug_assert!(bool::from(
629          b_needs_negation | (!should_reduce) | b_value.bits().ct_lt(&bits)
630        ));
631        c_bits = c.bits();
632      }
633
634      negate_b((b_sign, b_value), b_needs_negation);
635      b_needs_negation = Choice::FALSE;
636
637      limbs -= 1;
638      b_value = b_value.leading_mut(limbs);
639    }
640
641    /*
642      We apply the final reduction with a full width as we don't know when the above iterations
643      stopped, nor how far the number has been truncated since.
644    */
645    {
646      let (b_sign, b_value) = (
647        &mut b.0,
648        UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut b.1)[.. original_limbs]),
649      );
650
651      a_lte_c(&mut a, b_sign, &mut c);
652      let a_bits = a.bits();
653      let b_bits = b_value.bits();
654      let should_reduce = should_reduce_to_next_bit_final(&a, (b_sign, b_value));
655      reduce_to_next_bit(
656        &a,
657        (b_sign, b_value),
658        &mut b_needs_negation,
659        &mut c,
660        should_reduce,
661        original_limbs,
662        a_bits,
663        b_bits,
664      );
665
666      negate_b((b_sign, b_value), b_needs_negation);
667    }
668  }
669
670  (a, b, c)
671}
672
673/*
674  We wish to prove that for `(a, b, c)` input to the reduction algorithm, the output `(a', b', c')`
675  satisfies `gcd(a, b, c) = gcd(a', b', c')`.
676
677  The reduction algorithm solely repeatedly performs one of the following two actions:
678  `(a, b, c)` -> `(c, -b, a)`
679  `(a, b, c)` -> `(a, b - 2ma, c - m|b| + m^2 a)`
680
681  It is immediate that `gcd(a, b, c) = gcd(c, -b, a)` as `0 <= a, c`.
682
683  For the second action, we require an identity (which we state and assume but do not prove here):
684  - `gcd(x + z * y, y) = gcd(x, y)` for any integer `z` (positive or negative), which we refer to
685    as the modular identity due to its corollary `gcd(x % y, y) = gcd(x, y)`
686  and the following definition of a three-argument GCD call:
687  - `gcd(x, y, z) = gcd(gcd(x, y), z)`
688
689  Via the modular identity, `gcd(a, b) = gcd(a, b - 2ma)` is immediate. In order to now prove
690  `gcd(a, b, c) = gcd(a, b - 2ma, c - m|b| + m^2 a)`, we rewrite the right-hand side using our
691  definition of a three-argument GCD call as:
692
693    `gcd(gcd(a, b), c - m|b| + m^2 a)`
694
695  (simplifying `gcd(a, b - 2 ma)` to ust `gcd(a, b)`, as we've proven them equivalent)
696
697  The second argument to the outer-GCD call expands as
698  `c - m(|b| / gcd(a, b)) gcd(a, b) + m^2 (a / gcd(a, b)) gcd(a, b)`, and is able to be rewritten
699  as `c - (m(|b| / gcd(a, b)) + m^2 (a / gcd(a, b))) gcd(a, b)`, from which it's clear the modular
700  identity proves our desired result as when `z = m(|b| / gcd(a, b)) + m^2 (a / gcd(a, b))`, we
701  have:
702
703    `gcd(gcd(a, b), c - z gcd(a, b))`
704
705  Accordingly, for an element `(a, b, c)` input to the reduction algorithm, the output
706  `(a', b', c')` satisfies `gcd(a, b, c) = gcd(a', b', c')`.
707*/
708
709/// Partially reduce a positive definite binary quadratic form.
710///
711/// For a positive definite binary quadratic form `(a, b, c)` such that:
712/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
713/// - $delta \cong 1 \mod 2$
714/// - `0 <= a` (`a` isn't negative, as enforced by the type system)
715/// - `floor(log_2(a)) + 1 <= log_2_bound`
716/// - `floor(log_2(|b|)) + 1 <= log_2_bound`
717/// - There is an integer solution for `c` in `b^2 - 4 a c = delta`.
718/// - `ceil(log_2_bound / Limb::BITS) <= <L as AsRef::<[Limb]>>::as_ref(&a).len()`
719/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() == <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
720/// - `<L as AsRef::<[Limb]>>::as_ref(&negative_discriminant_abs).len() <=
721///      2 * <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
722/// - $floor(log_2(|delta|)) + 1 < <_ as AsRef<[Limb]>>::as_ref(a).len() * Limb::BITS$
723/// - $floor(log_2(a)) + 1 < <_ as AsRef<[Limb]>>::as_ref(a).len() * Limb::BITS$
724///
725/// Yield an equivalent form `(a', b', c')` such that:
726/// - `b'^2 <= |delta|`
727/// - `(a', b', c')` is reduced or `b' > a'`
728/// - `gcd(a, b, c) = gcd(a', b', c')`
729///
730/// As composition is presumably programmed to compose `b`-bit-length numbers, where composition
731/// outputs `2 * b`-bit-length numbers, this function intends to solely perform the necessary
732/// reduction such that the numbers are once again of `b`-bit-length (and able to be composed
733/// again). While these forms are not reduced, they may still usable for composition _without_
734/// performing a full reduction (which would take roughly twice as long). This allows deferring a
735/// full reduction until one _needs_ a reduced form.
736///
737/// This second bound on the output, `(a', b', c')` is reduced or `b' > a'`, is critical as it
738/// enables the following corollary: `a'^2 < |delta|`.
739///
740/// `b.0, b'.0` are `true` if the value is _positive_.
741///
742/// `delta` is bound to be negative and specified via its absolute value in
743/// `negative_discriminant_abs`.
744#[expect(private_bounds)]
745#[inline(always)]
746pub(crate) fn partial_reduce<L: super::c::Limbs + Limbs>(
747  log_2_bound: u32,
748  a: L,
749  mut b: (Choice, L),
750  negative_discriminant_abs: &L,
751) -> (L, (Choice, L), L) {
752  let discriminant_bits = negative_discriminant_abs.bits_vartime();
753  let sqrt_discriminant_bits = discriminant_bits.div_ceil(2);
754
755  #[cfg(debug_assertions)]
756  {
757    debug_assert!(
758      negative_discriminant_abs.bits_vartime() <
759        (u32::try_from(a.as_ref().len()).unwrap() * Limb::BITS)
760    );
761    debug_assert!(bool::from(
762      a.bits().ct_lt(&(u32::try_from(a.as_ref().len()).unwrap() * Limb::BITS))
763    ));
764    debug_assert_eq!(
765      <L as AsRef::<[Limb]>>::as_ref(&a).len(),
766      <L as AsRef::<[Limb]>>::as_ref(&b.1).len()
767    );
768  }
769
770  b.1 = {
771    /*
772      This is safe as `a` is the same bit-length as `|delta|`, at most, and `|delta|` has a spare
773      bit of capacity. `a` is bounded to be in a container of size equal to `b.1`, and `|delta|`
774      is bound to be in a container of size less than or equal in size.
775
776      TODO: `clone` :/
777    */
778    let mut two_a = a.clone();
779    UintRef::new_mut(two_a.as_mut()).shl1_assign();
780
781    // Ensure $|b| < 2a$, as required to calculate `c`
782    L::rem(b.1, &two_a)
783  };
784
785  let c = super::c(&a, &b, negative_discriminant_abs);
786  let (mut a, mut b, mut c) =
787    reduce_to_upper_bound(log_2_bound, a, b, c, sqrt_discriminant_bits - 1);
788
789  // This is needed to ensure our second bound, "`(a', b', c')` is reduced or `b' > a'`"
790  a_lte_c(&mut a, &mut b.0, &mut c);
791
792  #[cfg(debug_assertions)]
793  {
794    debug_assert!(bool::from(
795      a.bits().ct_lt(&discriminant_bits.div_ceil(2)) |
796        a.bits().ct_eq(&discriminant_bits.div_ceil(2))
797    ));
798    debug_assert!(bool::from(
799      b.1.bits().ct_lt(&discriminant_bits.div_ceil(2)) |
800        b.1.bits().ct_eq(&discriminant_bits.div_ceil(2))
801    ));
802  }
803
804  (a, b, c)
805}
806
807/// Reduce an element.
808///
809/// For a positive definite binary quadratic form `(a, b, c)` such that:
810/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
811/// - $delta \cong 1 \mod 2$
812/// - `0 <= a, c` (`a, c` aren't negative, as enforced by the type system)
813/// - `floor(log_2(a)) + 1 <= log_2_bound`
814/// - `floor(log_2(|b|)) + 1 <= log_2_bound`
815/// - `ceil(log_2_bound / Limb::BITS) <= <L as AsRef::<[Limb]>>::as_ref(&a).len()`
816/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() <= <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
817/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() == <L as AsRef::<[Limb]>>::as_ref(&c).len()`
818///
819/// Yield the reduced equivalent form `(a', b', c')` such that:
820/// - `|b'| <= a' <= c'`
821/// - `b' >= 0` if `(|b'| == a') || (a' == c')`
822/// - `gcd(a, b, c) = gcd(a', b', c')`
823///
824/// `b.0, b'.0` are `true` if the value is _positive_.
825///
826/// `delta` is bound to be negative and specified via its absolute value in
827/// `negative_discriminant_abs`.
828#[inline(always)]
829pub(crate) fn reduce<L: Limbs>(
830  log_2_bound: u32,
831  a: L,
832  b: (Choice, L),
833  c: L,
834) -> (L, (Choice, L), L) {
835  debug_assert_eq!(
836    <L as AsRef::<[Limb]>>::as_ref(&a).len(),
837    <L as AsRef::<[Limb]>>::as_ref(&c).len()
838  );
839
840  let (mut a, mut b, mut c) = reduce_to_upper_bound(log_2_bound, a, b, c, 0);
841
842  a_lte_c(&mut a, &mut b.0, &mut c);
843  let (a, b, c) = normalize(a, b, c);
844
845  #[cfg(debug_assertions)]
846  {
847    let a = UintRef::new(AsRef::<[Limb]>::as_ref(&a));
848    let b_abs = UintRef::new(AsRef::<[Limb]>::as_ref(&b.1));
849    let c = UintRef::new(AsRef::<[Limb]>::as_ref(&c));
850    debug_assert!(bool::from(b_abs.ct_lt(a) | b_abs.ct_eq(&a)));
851    debug_assert!(bool::from(a.ct_lt(c) | a.ct_eq(&c)));
852    let b_eq_a_or_a_eq_c = a.ct_eq(&b_abs) | a.ct_eq(&c);
853    debug_assert!(bool::from((!b_eq_a_or_a_eq_c) | b.0));
854  }
855
856  (a, b, c)
857}