Skip to main content

fr_batch_inv_ref/
lib.rs

1//! BN254 scalar-field batch-inverse reference implementation
2//! (`SIMD-XXXX-alt-bn128-fr-batch-inverse`).
3//!
4//! Computes `(s₁⁻¹, …, sₙ⁻¹)` in the BN254 scalar field `Fr` using
5//! Montgomery's batch-inverse trick. Cost: **1 modular inverse + 3·(n−1)
6//! multiplications**, regardless of n. Same surface as halo2curves's
7//! `Field::batch_invert` and arkworks's `batch_inversion` — re-implemented
8//! here as a no_std reference for the agave native syscall to mirror.
9//!
10//! ## Algorithm
11//!
12//! Given `s = (s₀, s₁, …, s_{n-1})`:
13//!
14//! 1. **Forward pass**: compute prefix products `p[i] = s₀ · s₁ · … · sᵢ`.
15//! 2. **One inverse**: `inv_total = (s₀ · … · s_{n-1})⁻¹`.
16//! 3. **Backward pass**: at step `i` (from n-1 down to 1):
17//!    * `sᵢ⁻¹ = inv_total · p[i-1]`
18//!    * `inv_total *= sᵢ`   (so `inv_total` becomes `(s₀ · … · sᵢ₋₁)⁻¹`)
19//! 4. At step 0: `s₀⁻¹ = inv_total`.
20//!
21//! Output is the inverses `(s₀⁻¹, …, s_{n-1}⁻¹)` in input order.
22
23#![cfg_attr(not(feature = "std"), no_std)]
24
25extern crate alloc;
26
27use alloc::vec::Vec;
28use ark_bn254::Fr;
29use ark_ff::{Field, Zero};
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum Error {
33    /// One of the input scalars is zero — no inverse exists.
34    ZeroInverse,
35}
36
37/// Compute the inverses of `scalars` in batch via Montgomery's trick.
38/// Returns inverses in the same order as the input.
39///
40/// Edge cases:
41/// - `scalars.is_empty()` returns an empty Vec.
42/// - `scalars.len() == 1` falls back to a direct `Fr::inverse()` call.
43/// - Any scalar = 0 ⇒ `Err(Error::ZeroInverse)` (no inverse defined).
44pub fn batch_inverse(scalars: &[Fr]) -> Result<Vec<Fr>, Error> {
45    if scalars.is_empty() {
46        return Ok(Vec::new());
47    }
48    if scalars.iter().any(|s| s.is_zero()) {
49        return Err(Error::ZeroInverse);
50    }
51
52    let n = scalars.len();
53    if n == 1 {
54        return Ok(alloc::vec![scalars[0].inverse().expect("non-zero")]);
55    }
56
57    // Forward pass: prefix products. products[i] = s_0 · s_1 · ... · s_i.
58    let mut products: Vec<Fr> = Vec::with_capacity(n);
59    products.push(scalars[0]);
60    for i in 1..n {
61        let prev = products[i - 1];
62        products.push(prev * scalars[i]);
63    }
64
65    // One inverse: (s_0 · s_1 · ... · s_{n-1})⁻¹.
66    let mut inv_total = products[n - 1].inverse().expect("non-zero product");
67
68    // Backward pass.
69    let mut inverses: Vec<Fr> = alloc::vec![Fr::zero(); n];
70    for i in (1..n).rev() {
71        // s_i⁻¹ = inv_total · (s_0 · ... · s_{i-1}) = inv_total · products[i-1]
72        inverses[i] = inv_total * products[i - 1];
73        // Update inv_total to be (s_0 · ... · s_{i-1})⁻¹
74        inv_total *= scalars[i];
75    }
76    inverses[0] = inv_total;
77
78    Ok(inverses)
79}
80
81#[cfg(all(test, feature = "std"))]
82mod tests {
83    use super::*;
84    use ark_ff::UniformRand;
85
86    fn assert_inv_pair(s: Fr, s_inv: Fr) {
87        assert_eq!(s * s_inv, Fr::from(1u64), "s * s⁻¹ != 1");
88    }
89
90    #[test]
91    fn empty_input_returns_empty() {
92        assert!(batch_inverse(&[]).unwrap().is_empty());
93    }
94
95    #[test]
96    fn single_input_uses_direct_inverse() {
97        let s = Fr::from(7u64);
98        let r = batch_inverse(&[s]).unwrap();
99        assert_eq!(r.len(), 1);
100        assert_inv_pair(s, r[0]);
101    }
102
103    #[test]
104    fn small_batch_matches_naive_inverses() {
105        let scalars: Vec<Fr> = (1u64..=8).map(Fr::from).collect();
106        let batch = batch_inverse(&scalars).unwrap();
107        for (s, s_inv) in scalars.iter().zip(batch.iter()) {
108            assert_inv_pair(*s, *s_inv);
109        }
110    }
111
112    #[test]
113    fn large_random_batch_correct() {
114        let mut rng = ark_std::test_rng();
115        let n = 64;
116        let scalars: Vec<Fr> = (0..n).map(|_| Fr::rand(&mut rng)).collect();
117        let batch = batch_inverse(&scalars).unwrap();
118        for (s, s_inv) in scalars.iter().zip(batch.iter()) {
119            assert_inv_pair(*s, *s_inv);
120        }
121    }
122
123    #[test]
124    fn zero_in_batch_rejected() {
125        let scalars = alloc::vec![Fr::from(3u64), Fr::from(0u64), Fr::from(5u64)];
126        assert_eq!(batch_inverse(&scalars), Err(Error::ZeroInverse));
127    }
128
129    /// Batch inverse output order matches input order (positional, not
130    /// sorted).
131    #[test]
132    fn output_preserves_input_order() {
133        let scalars: Vec<Fr> = alloc::vec![
134            Fr::from(13u64), Fr::from(2u64), Fr::from(99u64), Fr::from(7u64),
135        ];
136        let batch = batch_inverse(&scalars).unwrap();
137        let naive: Vec<Fr> = scalars.iter().map(|s| s.inverse().unwrap()).collect();
138        for (b, n) in batch.iter().zip(naive.iter()) {
139            assert_eq!(b, n, "batch differs from naive at same index");
140        }
141    }
142
143    /// Two inverses: simplest non-trivial case for the prefix-products path.
144    #[test]
145    fn two_inputs_correct() {
146        let a = Fr::from(11u64);
147        let b = Fr::from(13u64);
148        let r = batch_inverse(&[a, b]).unwrap();
149        assert_inv_pair(a, r[0]);
150        assert_inv_pair(b, r[1]);
151    }
152}