class-groups 0.0.2-alpha

A cryptographic library for working with binary quadratic forms (elements of a class group)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
//! A constant-time reduction algorithm.
//!
//! This is derived from Algorithm 1 of <https://eprint.iacr.org/2022-466>. Some typos have been
//! accounted for. Algorithm 2 is a restatement of Algorithm 1 in an iterative fashion and
//! accordingly may of more approximate structure to the following, yet this work was independently
//! derived from Algorithm 1.
//!
//! In order to be efficient, this does not always operate over the full amount of limbs of each
//! number. Instead, as the reduction occurs (and as the numbers shrink), the amount of limbs
//! operated over also reduce. This requires extensively arguing the bounds for inputs and the
//! outputs of each function. This is done via brief proofs written in comments within each
//! function.

#![expect(clippy::needless_pass_by_value)] // Triggered by `(&mut Choice, &mut L)`
#![expect(clippy::inline_always)]

use crypto_bigint::{Choice, CtEq, CtSelect, CtLt as _, CtGt as _, Zero, BitOps, Limb, UintRef};

/// A collection of limbs and associated helper methods.
///
/// The provided reduction algorithm dances along the `Limb` boundaries, performance requiring
/// correct decision of when to terminate execution of a given function. This API unifies `Uint`
/// and `BoxedUint` (in a way `Integer` appeared ineligible for) while providing the niche methods
/// required for performance.
///
/// Implementations MAY iterate up to the `limbs` argument (for performance) or MAY ignore it.
/// Callers MUST NOT expect that if they specify a `limbs` argument, operations will only occur to
/// that subset of limbs, and any results are undefined when any non-included limbs are non-zero.
/// Callers MUST NOT specify more `limbs` than the value has.
///
/// Implementations MUST implement all functions in time constant to the value of the inputs,
/// except for the amount of limbs, unless otherwise stated. Implementations MUST NOT panic for any
/// input which the caller MAY pass.
///
/// A long-term goal is to replace this entirely for just `UintRef`. Currently, the reduction
/// algorithm still requires allocating one scratch variable however, making this non-immediate.
pub(crate) trait Limbs: AsRef<[Limb]> + AsMut<[Limb]> + CtEq + Zero + BitOps {
  /// The number but with precision equal to `self`.
  ///
  /// This is equivalent to [`crypto_bigint::Zero::zero_like`] but avoids requiring `Self: Clone`.
  /// We do not want to bound `Self: Clone` for performance reasons. Specifically, it's a goal of
  /// the reduction algorithm to not allocate at all (for performance reasons), this one function
  /// being necessary in one spot and the current sole exception.
  fn like_zero(&self) -> Self;

  /// Swap the values of `self` and `b` if `choice` is `true`.
  ///
  /// This is a basic helper as there is no `UintRef::ct_swap`.
  #[inline(always)]
  fn swap(&mut self, b: &mut Self, choice: Choice) {
    let a = &mut <_ as AsMut<[Limb]>>::as_mut(self);
    let b = &mut <_ as AsMut<[Limb]>>::as_mut(b);
    for (a, b) in a.iter_mut().zip(b.iter_mut()) {
      <_>::ct_swap(a, b, choice);
    }
  }
}

/// Obtain the equivalent form `(a', b', c')` for which `floor(log_2(a')) <= floor(log_2(c'))`.
///
/// This assumes the values for `a` and `c` would each fit in each others' variables.
///
/// This corresponds to step 2 of Algorithm 1, albeit solely confirming the `floor(log_2(_))` is
/// smaller, not the value itself. This is much cheaper to evaluate and `b'` is still able to
/// reduce by at least one bit for as long it has a greater logarithm. If `b'` has an equal
/// logarithm, then we are at the final iteration, for which `a_lte_c` MUST be used.
#[inline(always)]
fn approximate_a_lte_c<L: Limbs>(
  a: (&mut u32, &mut L),
  b_sign: &mut Choice,
  c: (&mut u32, &mut L),
) {
  let c_lt_a = c.0.ct_lt(a.0);
  c.0.ct_swap(a.0, c_lt_a);
  L::swap(a.1, c.1, c_lt_a);

  /*
    This line differs from the paper, whose described algorithm has a some typos (as further
    evidenced by the correctness proof transcribing line 6 as
    "[C - epsilon m B + m^2 A]" as "[C, - epsilon m B + A^2]").

    As this swaps `a, c`, and as it's known `(a, b, c) == (c, -b, a)`, we MUST negate `b` here
    if we performed a swap.
  */
  *b_sign ^= c_lt_a;
}

/// Obtain the equivalent form `(a', b', c')` for which `a' <= c'`.
///
/// This assumes `c` has at least as many limbs as `a`.
///
/// This corresponds to step 2 of Algorithm 1.
#[inline(always)]
fn a_lte_c<L: Limbs>(a: &mut L, b_sign: &mut Choice, c: &mut L) {
  let limbs = <_ as AsRef<[Limb]>>::as_ref(c).len();
  let c_lt_a = UintRef::new(&c.as_ref()[.. limbs]).ct_lt(UintRef::new(&a.as_ref()[.. limbs]));
  L::swap(a, c, c_lt_a);
  *b_sign ^= c_lt_a;
}

/// If `reduce_to_next_bit` should actually apply, except possibly incorrect.
///
/// This MAY yield `false` even when one final iteration should be ran, hence `except_final`. It is
/// correct for all but the final iteration.
///
/// This simplifies `|b| > a` to `floor(log_2(|b|)) > floor(log_2(a))`, which is true whenever
/// `|b| >= 2a`. See `should_reduce_to_next_bit_final` for why this is sufficient for all but the
/// last iteration.
///
/// This is optimized by only indicating the reduction algorithm should be run when
/// `(floor(log_2(|b|)) + 1) == b_bits_bound` _not_ always if `|b| >= 2a`. This is done as:
///
/// 1) It is still correct. `reduce_to_next_bit`, if called correctly, must be called from the
///    current bound to the minimal bound, the current bound decrementing by one bit with each call,
///    as `reduce_to_next_bit` is only guaranteed to reduce `|b|` by a single bit (until `b` is
///    reduced). Assuming the bound is properly decremented with each call, then we know `|b|` is
///    within the bound for each call, as the iteration is either unnecessary (the current bound
///    exceeding the actual `floor(log_2(|b|)) + 1`) or will be reduced by at least one bit (and
///    therefore within the next iteration's bound) if not already reduced.
///
/// 2) It is faster to check if `(floor(log_2(|b|)) + 1) == b_bits_bound` than to calculate
///    `floor(log_2(|b|))`, even with how cheap that operation is. Finding the leading bit is an
///    operation of linear complexity, while checking if the highest bit within the bound is set is
///    of constant complexity (assuming the bound is public, allowing us to perform the retrieval
///    with a variable memory-access pattern). This optimization decreased the time of point
///    doubling by ~3.5%, implying this to optimize ~5% of reduction.
///
/// 3) If we always did the iterations which need to happen at the start, than being off-by-one
///    would be quite hard to detect as the last iterations are unlikely to actually be necessary.
///    This means `|b|` is likely below the bound for the entire duration of the algorithm, and the
///    bound being inaccurate may not be noticed. As this methodology intersperses the necessary
///    iterations with the unnecessary, `|b|` is worked with at the bound itself, which more
///    aggressively requires the accuracy of the bounds. This itself helps to ensure the bounds are
///    accurate.
///
/// This function assumes $|delta| \cong 1 \mod 2$, `a_bits = floor(log_2(a)) + 1`, and
/// `b_bits_bound >= floor(log_2(|b|)) + 1` when `b_lte_a == false`. This function will update
/// `b_lte_a`, but may inaccurately set it to `true` upon reaching the final necessary iteration.
/// This function also assumes `b_bits_bound > 0` and may be incorrect or panic in that case.
#[inline(always)]
fn should_reduce_to_next_bit_except_final(
  b: (&mut Choice, &mut UintRef),
  b_needs_negation: Choice,
  b_lte_a: &mut Choice,
  a_bits: u32,
  b_bits_bound: u32,
) -> Choice {
  // Because this check is only valid `b_lte_a == false`, short-circuit if `b_lte_a == false`
  let b_gt_a = (!*b_lte_a) & b_bits_bound.ct_gt(&a_bits);

  /*
    Update `b_lte_a`.

    If `b_lte_a` is set, this will never unset it, as `b_gt_a` won't be set if `b_lte_a` was.
  */
  *b_lte_a = !b_gt_a;

  /*
    Only run this iteration if this specific bit of `b` is in fact set.

    If `b` needs to be negated, we check if the bit is _not_ set, which is immediately an
    _incorrect optimization_. The negation process is defined as the logical NOT and a carrying
    addition of `1`.

    As an immediate counterexample, `1` has negative representation `0xFF`, which _should_ be
    considered as having its trailing bit set (as when negated, it's `1`) but would not be
    considered so here.

    We bound `|delta|` to be odd, so as `b^2 - 4ac = delta`, taking the expression modulo `4`, it's
    obvious `b` must be odd for any valid form (even if unreduced). This means we know the addition
    of `1` will always set the trailing bit, making this optimization correct _for all but the
    trailing bit_.

    That makes this correct _except_ when `b_bits_bound = 1`, but any form with a 1-bit `b`
    coefficient is already reduced (potentially after normalizing its sign). Accordingly, the fact
    we incorrectly consider `-1` as not having its trailing bit set, still leads to a correct
    result that no further reduction should occur.
  */
  b_gt_a &
    Choice::from(u8::from(b.1.bit_vartime(b_bits_bound.wrapping_sub(1))))
      .ct_eq(&!b_needs_negation)
}

/// If `reduce_to_next_bit` should actually apply.
///
/// This uses an exact determination for if the function should apply, which is only necessary when
/// `floor(log_2(|b|)) == floor(log_2(a))`. However, in this case, `b <= 2 a`. The paper itself
/// establishes the bound the algorithm will run for at most two more iterations in this case, yet
/// a tighter bound of just one more iteration is possible. This is as the `b == 2a` case will
/// resolved with `b' = 0`, while the case `a < b < 2a` will result in `b' < a`.
///
/// That's why this is 'final', as while it works all of the time, it should only be used for one
/// final iteration.
///
/// This function assumes `<_ as AsRef<[Limb]>>(a).len() == <_ as AsRef<[Limb]>>(b.1).len()`.
///
/// Comments within this function which would be duplicated with
/// `should_reduce_to_next_bit_except_final` are omitted.
#[inline(always)]
fn should_reduce_to_next_bit_final<L: Limbs>(a: &L, b: (&Choice, &UintRef)) -> Choice {
  // If `a - b.1` has a borrow afterwards, then `b.1 > a`
  let b_abs = b.1.as_limbs();
  let mut borrow = Limb::ZERO;
  for (a_limb, b_limb) in <_ as AsRef<[Limb]>>::as_ref(a)[.. b_abs.len()].iter().zip(b_abs) {
    let _a_diff_b_abs_limb;
    (_a_diff_b_abs_limb, borrow) = a_limb.borrowing_sub(*b_limb, borrow);
  }
  !borrow.is_zero()
}

/// Reduce the `b` coefficient by at least one bit or until reduced.
///
/// For a positive definite binary quadratic form `(a, b, c)` such that:
/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
/// - $delta \cong 1 \mod 2$
/// - `0 <= a, c` (`a` and `c` aren't negative, as enforced by the type system)
/// - `floor(log_2(a)) <= floor(log_2(c))` (such as forms output of `(approximate_)a_lte_c`)
/// - `limbs <= <_ as AsRef<[Limb]>>(a).len()`
/// - `limbs <= <_ as AsRef<[Limb]>>(b.1).len()`
/// - The variable `c` does in fact contain the full representation of the `c` coefficient.
///
/// We require the following bounds when `should_reduce == true`:
/// - `a < b`
/// - `a_bits = floor(log_2(a)) + 1`
/// - `b_bits = floor(log_2(|b|)) + 1`
/// - `1 + b_bits <= (limbs * Limb::BITS)`
///
/// Yield an equivalent form `(a', b', c')` such that:
/// - `a' = a`
/// - `floor(log_2(|b'|)) <= floor(log_2(|b|)) - 1` if `|b| > 2 a` and `should_reduce == true`,
///   else `(a', b', c') = (a, b, c)`
/// - `|b'| <= a'` if `a < |b| <= 2 a` and `should_reduce == true`.
///
/// This corresponds to steps 3, 4, and 6 of Algorithm 1, except as a NOP if
/// `should_reduce == false` (`b <= a`) (in which case the form is reduced, or reduced after
/// normalizing the sign of `b`).
#[expect(clippy::too_many_arguments)]
#[inline(always)]
fn reduce_to_next_bit<L: Limbs>(
  a: &L,
  b: (&mut Choice, &mut UintRef),
  b_needs_negation: &mut Choice,
  c: &mut L,
  should_reduce: Choice,
  limbs: usize,
  a_bits: u32,
  b_bits: u32,
) {
  #[cfg(debug_assertions)]
  {
    debug_assert!(bool::from(a.bits().ct_lt(&c.bits()) | a.bits().ct_eq(&c.bits())));
    debug_assert!(limbs <= <_ as AsRef::<[Limb]>>::as_ref(a).len());
    debug_assert!(limbs <= <_ as AsRef::<[Limb]>>::as_ref(&b.1).len());
    debug_assert!(bool::from((!should_reduce) | UintRef::new(a.as_ref()).ct_lt(b.1)));
    debug_assert!(bool::from((!should_reduce) | a_bits.ct_eq(&a.bits())));
    debug_assert!(bool::from((*b_needs_negation) | (!should_reduce) | b_bits.ct_eq(&b.1.bits())));
    debug_assert!(bool::from((!should_reduce) | b_bits.ct_lt(&b.1.bits_precision())));
  }

  // Calculate `m` (the body of step 3's branch, step 4)
  let log_2_m = {
    // This is only well-defined if `a_bits < b_bits`
    let log_2_m = b_bits.wrapping_sub(a_bits).wrapping_sub(1);
    // Set `m = 0` if `m` they have equal bit lengths or if `m` wouldn't be well-defined otherwise
    <_ as CtSelect>::ct_select(&0, &log_2_m, (!a_bits.ct_eq(&b_bits)) & should_reduce)
  };

  // Step 6

  /*
    This is a container of size `c` as we later operate on it with the bound the derivative is
    `<= c`, which means this has to be large enough to contain a number `<= c`.

    TODO: Can we remove the requirement for this scratch variable? Presumably not, as we have a
    constant-time shift of unbounded bit-length (where the shift may exceed a limb), so this can't
    trivially be done as we iterate over limbs. We then need this to calculate `b` before we again
    use it to calculate `c`, so we can't write it directly into one of those. We could directly
    modify `a`, or at least, require we be passed in a copy of `a` we then use as scratch, but that
    really just defers allocating this scratch variable to the caller.

    This is currently the only explicit `clone` (or equivalent) in this entire file.
  */
  let mut m_a = c.like_zero();
  // When `should_reduce = true`, `((1 << log_2_m) * a) < b`, so this will fit in `limbs` limbs
  {
    let m_a = UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut m_a)[.. limbs]);
    m_a.copy_from_slice(&a.as_ref()[.. limbs]);
    m_a.shl_assign(log_2_m);
  }
  let m_a = UintRef::new_mut(<_ as AsMut<[Limb]>>::as_mut(&mut m_a));

  /*
    The following does _not_ swap `c, a`, as we always perform any necessary swap during the next
    iteration's step 2 regardless. This means our `c` is updated to the paper's output `a`, and our
    `a` is left as-is.

    Note that in terms of this original paper which outputs `(a, b, c)`, this outputs `(c, b, a)`,
    which is not an equivalent form. The equivalent form would be `(c, -b, a)`. The paper is
    missing a negation on the output of its form, and once that's considered, this is equivalent.
  */

  // $\epsilon b == |b|$ since $\epsilon = \mathsf{sgn}(b)$
  /*
    Instead of calculating `- epsilon m b + m m a`, we calculate `m (-|b| + m a)` to reduce the
    bit-length of the addition within the parentheses. As `m a < b` when `a < b`, the evaluation
    of the parentheses is negative and has an absolute value `< |b|` (which fits in `limbs` limbs)
    whenever `should_reduce == true`.

    When `should_reduce == false`, `b_diff_m_a` is set to `0` so we may unconditionally calculate
    the new `c` coefficient as `c - m b_diff_m_a`.

    We simultaneously calculate $|b| - m a$ and $b - \epsilon 2 m a$ as we can merge their loops.
    While the resulting code is of non-trivially greater complexity, it saves ~8% of the time to
    execute.

    Note $b - \epsilon 2 m a$ may underflow, causing `b'` to have a distinct sign from `b`. In
    order to ensure the variable `b.1` remains the absolute value, this would require the logical
    NOT operator combined with a carrying addition of `1` _after_ calculating $b - \epsilon 2 m a$.
    To avoid another loop, we instead defer performing the negation to the next iteration's
    instance of _this_ loop, reducing the amount of times we introduce flow control/branching, via
    the `b_needs_negation` variable. Note the caller must handle `b_needs_negation` when shifting
    `b`'s limb boundaries.
  */
  {
    *b.0 ^= *b_needs_negation;
    let b_negation_carry = Limb::from(u8::from(*b_needs_negation));
    let b_negation_mask = Limb::ZERO.wrapping_sub(b_negation_carry);

    let mut b_diff_m_a_carry = Limb::ZERO;

    let mut two_m_a_carry = Limb::ZERO;
    /*
      We express `b - 2 m a` as `b + -(2 m a)`, where the negation requires a carry of `1`
      (hence why this is initialized to `1` when `should_reduce == true`). We simultaneously
      apply the deferred negation to `b`, hence why we also sum `b_negation_carry`.
    */
    let mut b_diff_two_m_a_carry =
      Limb::from(u8::from(should_reduce)).wrapping_add(b_negation_carry);

    for (b_limb, m_a_limb) in b.1.iter_mut().zip(m_a.iter_mut()) {
      /*
        Set `b' = b - 2 m a`, while simultaneously negating `b_limb` (if necessary).

        Negating `b'` is expressed as the logical NOT combined with a carrying addition of `1`.
        This is incompatible with needing to perform a borrowing subtraction of `2 m a`. Instead,
        we rewrite it as `b + -(2 m a)`, where `2 m a`'s negation can be expressed with a
        carrying addition. This means all three aspects (negating `b` if necessary, negating
        `2 m a`, and summing `b, -2 m a`) can be so expressed and done simultaneously.
      */
      {
        let two_m_a_limb: Limb = ((*m_a_limb) << 1) | two_m_a_carry;
        two_m_a_carry = (*m_a_limb) >> const { Limb::BITS - 1 };

        /*
          `floor(log_2(|b|)) = floor(log_2(2 m a))`, so their difference `|b'|` has the property
          `floor(log_2(|b'|)) < floor(log_2(|b|))`. Therefore, `|b'|` will fit in any container
          which fits `|b|`.
        */
        let new_b_limb;
        (new_b_limb, b_diff_two_m_a_carry) = ((*b_limb) ^ b_negation_mask).carrying_add(
          Limb::ct_select(&Limb::ZERO, &!two_m_a_limb, should_reduce),
          b_diff_two_m_a_carry,
        );
        *b_limb = new_b_limb;
      }

      /*
        Calculate `b_diff_m_a` as `b' + m a` (where `b' = b - 2 m a`).

        This writes `b_diff_m_a` directly into `m_a`, as we have no further use for `m_a`.
      */
      {
        let new_b_diff_m_a_limb;
        (new_b_diff_m_a_limb, b_diff_m_a_carry) =
          (*b_limb).carrying_add(*m_a_limb, b_diff_m_a_carry);
        *m_a_limb = Limb::ct_select(&Limb::ZERO, &new_b_diff_m_a_limb, should_reduce);
      }
    }

    /*
      Finish calculating `b'`, handling if `b < 2 m a`.

      Because we expressed `b'` as `b + -(2 m a)`, there is _no_ carry if `2 m a > b`. This means
      `b'` needs negation if there was _no_ carry.
    */
    *b_needs_negation = should_reduce & b_diff_two_m_a_carry.ct_eq(&Limb::ZERO);
  }
  let b_diff_m_a = m_a;

  // Calculate `c'`
  {
    /*
      We need to prove that `c >= (m b - m^2 a)`. We do so with the claim the output `c'` will be a
      positive integer, and therefore `c` MUST be greater than or equal to `m b - m^2 a` (when
      `should_reduce == true`), as else `c'` would be negative.

      We know each intermediate form is equivalent to the input form, and therefore as for input
      `(a, b, c)` satisfying `b^2 - (4 a c) = delta`, we have `b'^2 - (4 a' c') = delta`. As
      `delta < 0`, and `b^2 >= 0`, `4 a' c'` MUST be a positive number. As our algorithm sets
      `a' = a` where `a` is positive, `c'` must be positive as well.

      Because `c` is greater than or equal to `m b - m^2 a`, it will fit within a container which
      fits `c`, where `b_diff_m_a` is a container of size equal to `c`'s container (making this
      `shl` call well-defined).
    */
    let m_b_diff_m_square_a = b_diff_m_a;
    m_b_diff_m_square_a.shl_assign(log_2_m);

    // This subtraction is well-defined as `c >= m_b_diff_m_square_a` when `should_reduce == true`
    let mut borrow = Limb::ZERO;
    for (c_limb, m_b_diff_m_square_a_limb) in
      <_ as AsMut<[Limb]>>::as_mut(c).iter_mut().zip(m_b_diff_m_square_a.as_limbs())
    {
      // When `should_reduce == false`, `m_b_diff_m_square_a_limb = 0`, effecting a NOP
      let new_limb;
      (new_limb, borrow) = c_limb.borrowing_sub(*m_b_diff_m_square_a_limb, borrow);
      *c_limb = new_limb;
    }
  }
}

/// Conditionally negate the `b` coefficient for a binary quadratic form of odd discriminant.
///
/// Negation is defined as flipping the sign bit, before taking the negative of the `UintRef`
/// (considered a ring of 2^`k`, for some `k`). The latter process is via taking the logical NOT
/// before applying a carrying addition of `1`.
fn negate_b(b: (&mut Choice, &mut UintRef), b_needs_negation: Choice) {
  *b.0 ^= b_needs_negation;
  // If this needs negation, apply the logical NOT
  let mask = Limb::ZERO.wrapping_sub(Limb::from(u8::from(b_needs_negation)));
  for b_limb in b.1.iter_mut() {
    *b_limb ^= mask;
  }
  /*
    If this needs negation, complete the process by adding 1.

    As the discriminant is odd, we know `b` is odd. This means, when negated, its trailing bit will
    will be set, and after the above logical NOT, its trailing bit _will not_ be set. This means
    the carrying addition is actually solely a regular addition, due to observing there won't be a
    carry.

    In the case this should not be negated, its trailing bit should be set regardless, so we can
    simplify this to unilaterally ensuring the trailing bit is set.
  */
  b.1[0] |= Limb::ONE;
}

/// Normalize an almost-reduced element.
///
/// For a positive definite binary quadratic form `(a, b, c)` such that:
/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
/// - $delta \cong 1 \mod 2$
/// - `0 <= a, c` (`a` and `c` aren't negative, as enforced by the type system)
/// - `|b| <= a <= c`
///
/// Yield the reduced equivalent form `(a', b', c')` such that:
/// - `|b'| <= a' <= c'`
/// - `b' >= 0` if `(|b'| == a') || (a' == c')`
///
/// This is intended to correspond to steps 2 and 5 of Algorithm 1.
#[inline(always)]
fn normalize<L: Limbs>(a: L, mut b: (Choice, L), c: L) -> (L, (Choice, L), L) {
  /*
    Set `b` to be positive if `|b| == a` or `a == c`.

    We do not consider normalizing if `b == 0` to positive as we bound an odd discriminant, meaning
    `b` will be odd and therefore non-zero.
  */
  b.0 |= b.1.ct_eq(&a) | a.ct_eq(&c);
  (a, b, c)
}

/// Reduce an element until either its reduced or $|b| < 2^{upper_bound}$.
///
/// For a positive definite binary quadratic form `(a, b, c)` such that:
/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
/// - $delta \cong 1 \mod 2$
/// - `0 <= a, c` (`a, c` aren't negative, as enforced by the type system)
/// - `floor(log_2(a)) + 1 <= log_2_bound`
/// - `floor(log_2(|b|)) + 1 <= log_2_bound`
/// - `ceil(log_2_bound / Limb::BITS) <= <L as AsRef::<[Limb]>>::as_ref(&a).len()`
/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() <= <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() == <L as AsRef::<[Limb]>>::as_ref(&c).len()`
///
/// Yield an equivalent form `(a', b', c')` such that:
/// - `(a', b', c')` is reduced or $|b| < 2^{upper_bound}$.
/// - `(a', b', c')` is reduced or `b' > a'`
///
/// `b.0, b'.0` are `true` if the value is _positive_.
#[inline(always)]
pub(crate) fn reduce_to_upper_bound<L: Limbs>(
  log_2_bound: u32,
  mut a: L,
  mut b: (Choice, L),
  mut c: L,
  upper_bound: u32,
) -> (L, (Choice, L), L) {
  #[cfg(debug_assertions)]
  {
    debug_assert!(bool::from(a.bits().ct_lt(&log_2_bound) | a.bits().ct_eq(&log_2_bound)));
    debug_assert!(bool::from(b.1.bits().ct_lt(&log_2_bound) | b.1.bits().ct_eq(&log_2_bound)));
    debug_assert!(
      usize::try_from(log_2_bound.div_ceil(Limb::BITS)).unwrap() <=
        <_ as AsRef::<[Limb]>>::as_ref(&a).len()
    );
    debug_assert!(
      <_ as AsRef::<[Limb]>>::as_ref(&a).len() <= <_ as AsRef::<[Limb]>>::as_ref(&b.1).len()
    );
  }

  let original_limbs = usize::try_from(log_2_bound.div_ceil(Limb::BITS)).unwrap();

  /*
    Iterate from our bound on `b` to a `b'` which by bit-length, would satisfy `upper_bound`.
    Each iteration will reduce the bit length of `b` by at least `1`, until `b <= a` and it is
    reduced (if given sufficient iterations to reach that point).
  */
  {
    let (b_sign, mut b_value) =
      (&mut b.0, UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut b.1)[.. original_limbs]));
    let mut b_lte_a = Choice::FALSE;
    let mut b_needs_negation = Choice::FALSE;

    let mut limbs = original_limbs;

    // `RangeInclusive` doesn't implement `FixedSizeIterator`, so we use a `Range` instead
    #[expect(clippy::range_plus_one)]
    let mut bits = ((upper_bound + 1) .. (log_2_bound + 1)).rev();

    let mut a_bits = a.bits();
    let mut c_bits = c.bits();

    /*
      Handle the partial limb we may inherently have by the bound not necessarily perfectly
      aligning to limbs, and two more bits.

      `reduce_to_next_bit` is documented to need limbs corresponding to one extra bit, which is
      as `floor(log_2(|b|)) + 1 == floor(log_2(a)) + 1` is a possible input and the function must
      then calculate `2 m a`.

      We provide one additional bit here as for a value `|b| <= a`, this will only be noticed on
      the iteration _after_ the condition becomes true, so we need to defer when we move to the
      smaller amount of limbs until after this later iteration.
    */
    {
      let progress_in_partial_limb = usize::try_from(log_2_bound % Limb::BITS).unwrap();
      for bits in (&mut bits).take(2 + progress_in_partial_limb) {
        approximate_a_lte_c((&mut a_bits, &mut a), b_sign, (&mut c_bits, &mut c));
        let should_reduce = should_reduce_to_next_bit_except_final(
          (b_sign, b_value),
          b_needs_negation,
          &mut b_lte_a,
          a_bits,
          bits,
        );
        reduce_to_next_bit(
          &a,
          (b_sign, b_value),
          &mut b_needs_negation,
          &mut c,
          should_reduce,
          limbs,
          a_bits,
          bits,
        );
        debug_assert!(bool::from(
          b_needs_negation | (!should_reduce) | b_value.bits().ct_lt(&bits)
        ));
        c_bits = c.bits();
      }

      // Negate `b` if necessary, before crossing the limb boundary
      negate_b((b_sign, b_value), b_needs_negation);
      b_needs_negation = Choice::FALSE;

      // Only decrement the amount of `limbs` if we did actually have a partial limb
      if progress_in_partial_limb != 0 {
        limbs -= 1;
        b_value = b_value.leading_mut(limbs);
      }
    }

    /*
      Handle each remaining limb.

      While we could use a single loop for both the partial limb and the full limbs, that would
      have structure approximate to:

      ```
      for bit {
        reduce_to_next_bit();
        if limb {
          limbs -= 1;
        }
      }
      ```

      and place a branch within every single loop body. This achieves a straight-line, other than
      the loops' conditionals themselves (which the compiler appears to handle better, possibly as
      we may use the constant `Limb::BITS` for how many steps this inner loop takes).

      `bits.len() != 0` is used as `bits.is_empty()` (`FixedSizeIterator::is_empty`) is
      experimental.
    */
    while bits.len() != 0 {
      debug_assert_ne!(limbs, 0);

      #[expect(clippy::as_conversions)]
      for bits in (&mut bits).take(const { Limb::BITS as usize }) {
        approximate_a_lte_c((&mut a_bits, &mut a), b_sign, (&mut c_bits, &mut c));
        let should_reduce = should_reduce_to_next_bit_except_final(
          (b_sign, b_value),
          b_needs_negation,
          &mut b_lte_a,
          a_bits,
          bits,
        );
        reduce_to_next_bit(
          &a,
          (b_sign, b_value),
          &mut b_needs_negation,
          &mut c,
          should_reduce,
          limbs,
          a_bits,
          bits,
        );
        debug_assert!(bool::from(
          b_needs_negation | (!should_reduce) | b_value.bits().ct_lt(&bits)
        ));
        c_bits = c.bits();
      }

      negate_b((b_sign, b_value), b_needs_negation);
      b_needs_negation = Choice::FALSE;

      limbs -= 1;
      b_value = b_value.leading_mut(limbs);
    }

    /*
      We apply the final reduction with a full width as we don't know when the above iterations
      stopped, nor how far the number has been truncated since.
    */
    {
      let (b_sign, b_value) = (
        &mut b.0,
        UintRef::new_mut(&mut <_ as AsMut<[Limb]>>::as_mut(&mut b.1)[.. original_limbs]),
      );

      a_lte_c(&mut a, b_sign, &mut c);
      let a_bits = a.bits();
      let b_bits = b_value.bits();
      let should_reduce = should_reduce_to_next_bit_final(&a, (b_sign, b_value));
      reduce_to_next_bit(
        &a,
        (b_sign, b_value),
        &mut b_needs_negation,
        &mut c,
        should_reduce,
        original_limbs,
        a_bits,
        b_bits,
      );

      negate_b((b_sign, b_value), b_needs_negation);
    }
  }

  (a, b, c)
}

/*
  We wish to prove that for `(a, b, c)` input to the reduction algorithm, the output `(a', b', c')`
  satisfies `gcd(a, b, c) = gcd(a', b', c')`.

  The reduction algorithm solely repeatedly performs one of the following two actions:
  `(a, b, c)` -> `(c, -b, a)`
  `(a, b, c)` -> `(a, b - 2ma, c - m|b| + m^2 a)`

  It is immediate that `gcd(a, b, c) = gcd(c, -b, a)` as `0 <= a, c`.

  For the second action, we require an identity (which we state and assume but do not prove here):
  - `gcd(x + z * y, y) = gcd(x, y)` for any integer `z` (positive or negative), which we refer to
    as the modular identity due to its corollary `gcd(x % y, y) = gcd(x, y)`
  and the following definition of a three-argument GCD call:
  - `gcd(x, y, z) = gcd(gcd(x, y), z)`

  Via the modular identity, `gcd(a, b) = gcd(a, b - 2ma)` is immediate. In order to now prove
  `gcd(a, b, c) = gcd(a, b - 2ma, c - m|b| + m^2 a)`, we rewrite the right-hand side using our
  definition of a three-argument GCD call as:

    `gcd(gcd(a, b), c - m|b| + m^2 a)`

  (simplifying `gcd(a, b - 2 ma)` to ust `gcd(a, b)`, as we've proven them equivalent)

  The second argument to the outer-GCD call expands as
  `c - m(|b| / gcd(a, b)) gcd(a, b) + m^2 (a / gcd(a, b)) gcd(a, b)`, and is able to be rewritten
  as `c - (m(|b| / gcd(a, b)) + m^2 (a / gcd(a, b))) gcd(a, b)`, from which it's clear the modular
  identity proves our desired result as when `z = m(|b| / gcd(a, b)) + m^2 (a / gcd(a, b))`, we
  have:

    `gcd(gcd(a, b), c - z gcd(a, b))`

  Accordingly, for an element `(a, b, c)` input to the reduction algorithm, the output
  `(a', b', c')` satisfies `gcd(a, b, c) = gcd(a', b', c')`.
*/

/// Partially reduce a positive definite binary quadratic form.
///
/// For a positive definite binary quadratic form `(a, b, c)` such that:
/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
/// - $delta \cong 1 \mod 2$
/// - `0 <= a` (`a` isn't negative, as enforced by the type system)
/// - `floor(log_2(a)) + 1 <= log_2_bound`
/// - `floor(log_2(|b|)) + 1 <= log_2_bound`
/// - There is an integer solution for `c` in `b^2 - 4 a c = delta`.
/// - `ceil(log_2_bound / Limb::BITS) <= <L as AsRef::<[Limb]>>::as_ref(&a).len()`
/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() == <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
/// - `<L as AsRef::<[Limb]>>::as_ref(&negative_discriminant_abs).len() <=
///      2 * <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
/// - $floor(log_2(|delta|)) + 1 < <_ as AsRef<[Limb]>>::as_ref(a).len() * Limb::BITS$
/// - $floor(log_2(a)) + 1 < <_ as AsRef<[Limb]>>::as_ref(a).len() * Limb::BITS$
///
/// Yield an equivalent form `(a', b', c')` such that:
/// - `b'^2 <= |delta|`
/// - `(a', b', c')` is reduced or `b' > a'`
/// - `gcd(a, b, c) = gcd(a', b', c')`
///
/// As composition is presumably programmed to compose `b`-bit-length numbers, where composition
/// outputs `2 * b`-bit-length numbers, this function intends to solely perform the necessary
/// reduction such that the numbers are once again of `b`-bit-length (and able to be composed
/// again). While these forms are not reduced, they may still usable for composition _without_
/// performing a full reduction (which would take roughly twice as long). This allows deferring a
/// full reduction until one _needs_ a reduced form.
///
/// This second bound on the output, `(a', b', c')` is reduced or `b' > a'`, is critical as it
/// enables the following corollary: `a'^2 < |delta|`.
///
/// `b.0, b'.0` are `true` if the value is _positive_.
///
/// `delta` is bound to be negative and specified via its absolute value in
/// `negative_discriminant_abs`.
#[expect(private_bounds)]
#[inline(always)]
pub(crate) fn partial_reduce<L: super::c::Limbs + Limbs>(
  log_2_bound: u32,
  a: L,
  mut b: (Choice, L),
  negative_discriminant_abs: &L,
) -> (L, (Choice, L), L) {
  let discriminant_bits = negative_discriminant_abs.bits_vartime();
  let sqrt_discriminant_bits = discriminant_bits.div_ceil(2);

  #[cfg(debug_assertions)]
  {
    debug_assert!(
      negative_discriminant_abs.bits_vartime() <
        (u32::try_from(a.as_ref().len()).unwrap() * Limb::BITS)
    );
    debug_assert!(bool::from(
      a.bits().ct_lt(&(u32::try_from(a.as_ref().len()).unwrap() * Limb::BITS))
    ));
    debug_assert_eq!(
      <L as AsRef::<[Limb]>>::as_ref(&a).len(),
      <L as AsRef::<[Limb]>>::as_ref(&b.1).len()
    );
  }

  b.1 = {
    /*
      This is safe as `a` is the same bit-length as `|delta|`, at most, and `|delta|` has a spare
      bit of capacity. `a` is bounded to be in a container of size equal to `b.1`, and `|delta|`
      is bound to be in a container of size less than or equal in size.

      TODO: `clone` :/
    */
    let mut two_a = a.clone();
    UintRef::new_mut(two_a.as_mut()).shl1_assign();

    // Ensure $|b| < 2a$, as required to calculate `c`
    L::rem(b.1, &two_a)
  };

  let c = super::c(&a, &b, negative_discriminant_abs);
  let (mut a, mut b, mut c) =
    reduce_to_upper_bound(log_2_bound, a, b, c, sqrt_discriminant_bits - 1);

  // This is needed to ensure our second bound, "`(a', b', c')` is reduced or `b' > a'`"
  a_lte_c(&mut a, &mut b.0, &mut c);

  #[cfg(debug_assertions)]
  {
    debug_assert!(bool::from(
      a.bits().ct_lt(&discriminant_bits.div_ceil(2)) |
        a.bits().ct_eq(&discriminant_bits.div_ceil(2))
    ));
    debug_assert!(bool::from(
      b.1.bits().ct_lt(&discriminant_bits.div_ceil(2)) |
        b.1.bits().ct_eq(&discriminant_bits.div_ceil(2))
    ));
  }

  (a, b, c)
}

/// Reduce an element.
///
/// For a positive definite binary quadratic form `(a, b, c)` such that:
/// - `b^2 - 4ac = delta` where `delta < 0` (the form is well-defined for a negative discriminant)
/// - $delta \cong 1 \mod 2$
/// - `0 <= a, c` (`a, c` aren't negative, as enforced by the type system)
/// - `floor(log_2(a)) + 1 <= log_2_bound`
/// - `floor(log_2(|b|)) + 1 <= log_2_bound`
/// - `ceil(log_2_bound / Limb::BITS) <= <L as AsRef::<[Limb]>>::as_ref(&a).len()`
/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() <= <L as AsRef::<[Limb]>>::as_ref(&b.1).len()`
/// - `<L as AsRef::<[Limb]>>::as_ref(&a).len() == <L as AsRef::<[Limb]>>::as_ref(&c).len()`
///
/// Yield the reduced equivalent form `(a', b', c')` such that:
/// - `|b'| <= a' <= c'`
/// - `b' >= 0` if `(|b'| == a') || (a' == c')`
/// - `gcd(a, b, c) = gcd(a', b', c')`
///
/// `b.0, b'.0` are `true` if the value is _positive_.
///
/// `delta` is bound to be negative and specified via its absolute value in
/// `negative_discriminant_abs`.
#[inline(always)]
pub(crate) fn reduce<L: Limbs>(
  log_2_bound: u32,
  a: L,
  b: (Choice, L),
  c: L,
) -> (L, (Choice, L), L) {
  debug_assert_eq!(
    <L as AsRef::<[Limb]>>::as_ref(&a).len(),
    <L as AsRef::<[Limb]>>::as_ref(&c).len()
  );

  let (mut a, mut b, mut c) = reduce_to_upper_bound(log_2_bound, a, b, c, 0);

  a_lte_c(&mut a, &mut b.0, &mut c);
  let (a, b, c) = normalize(a, b, c);

  #[cfg(debug_assertions)]
  {
    let a = UintRef::new(AsRef::<[Limb]>::as_ref(&a));
    let b_abs = UintRef::new(AsRef::<[Limb]>::as_ref(&b.1));
    let c = UintRef::new(AsRef::<[Limb]>::as_ref(&c));
    debug_assert!(bool::from(b_abs.ct_lt(a) | b_abs.ct_eq(&a)));
    debug_assert!(bool::from(a.ct_lt(c) | a.ct_eq(&c)));
    let b_eq_a_or_a_eq_c = a.ct_eq(&b_abs) | a.ct_eq(&c);
    debug_assert!(bool::from((!b_eq_a_or_a_eq_c) | b.0));
  }

  (a, b, c)
}