1#![cfg_attr(not(feature = "std"), no_std)]
24
25extern crate alloc;
26
27use alloc::vec::Vec;
28use ark_bn254::Fr;
29use ark_ff::{Field, Zero};
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum Error {
33 ZeroInverse,
35}
36
37pub fn batch_inverse(scalars: &[Fr]) -> Result<Vec<Fr>, Error> {
45 if scalars.is_empty() {
46 return Ok(Vec::new());
47 }
48 if scalars.iter().any(|s| s.is_zero()) {
49 return Err(Error::ZeroInverse);
50 }
51
52 let n = scalars.len();
53 if n == 1 {
54 return Ok(alloc::vec![scalars[0].inverse().expect("non-zero")]);
55 }
56
57 let mut products: Vec<Fr> = Vec::with_capacity(n);
59 products.push(scalars[0]);
60 for i in 1..n {
61 let prev = products[i - 1];
62 products.push(prev * scalars[i]);
63 }
64
65 let mut inv_total = products[n - 1].inverse().expect("non-zero product");
67
68 let mut inverses: Vec<Fr> = alloc::vec![Fr::zero(); n];
70 for i in (1..n).rev() {
71 inverses[i] = inv_total * products[i - 1];
73 inv_total *= scalars[i];
75 }
76 inverses[0] = inv_total;
77
78 Ok(inverses)
79}
80
81#[cfg(all(test, feature = "std"))]
82mod tests {
83 use super::*;
84 use ark_ff::UniformRand;
85
86 fn assert_inv_pair(s: Fr, s_inv: Fr) {
87 assert_eq!(s * s_inv, Fr::from(1u64), "s * s⁻¹ != 1");
88 }
89
90 #[test]
91 fn empty_input_returns_empty() {
92 assert!(batch_inverse(&[]).unwrap().is_empty());
93 }
94
95 #[test]
96 fn single_input_uses_direct_inverse() {
97 let s = Fr::from(7u64);
98 let r = batch_inverse(&[s]).unwrap();
99 assert_eq!(r.len(), 1);
100 assert_inv_pair(s, r[0]);
101 }
102
103 #[test]
104 fn small_batch_matches_naive_inverses() {
105 let scalars: Vec<Fr> = (1u64..=8).map(Fr::from).collect();
106 let batch = batch_inverse(&scalars).unwrap();
107 for (s, s_inv) in scalars.iter().zip(batch.iter()) {
108 assert_inv_pair(*s, *s_inv);
109 }
110 }
111
112 #[test]
113 fn large_random_batch_correct() {
114 let mut rng = ark_std::test_rng();
115 let n = 64;
116 let scalars: Vec<Fr> = (0..n).map(|_| Fr::rand(&mut rng)).collect();
117 let batch = batch_inverse(&scalars).unwrap();
118 for (s, s_inv) in scalars.iter().zip(batch.iter()) {
119 assert_inv_pair(*s, *s_inv);
120 }
121 }
122
123 #[test]
124 fn zero_in_batch_rejected() {
125 let scalars = alloc::vec![Fr::from(3u64), Fr::from(0u64), Fr::from(5u64)];
126 assert_eq!(batch_inverse(&scalars), Err(Error::ZeroInverse));
127 }
128
129 #[test]
132 fn output_preserves_input_order() {
133 let scalars: Vec<Fr> = alloc::vec![
134 Fr::from(13u64), Fr::from(2u64), Fr::from(99u64), Fr::from(7u64),
135 ];
136 let batch = batch_inverse(&scalars).unwrap();
137 let naive: Vec<Fr> = scalars.iter().map(|s| s.inverse().unwrap()).collect();
138 for (b, n) in batch.iter().zip(naive.iter()) {
139 assert_eq!(b, n, "batch differs from naive at same index");
140 }
141 }
142
143 #[test]
145 fn two_inputs_correct() {
146 let a = Fr::from(11u64);
147 let b = Fr::from(13u64);
148 let r = batch_inverse(&[a, b]).unwrap();
149 assert_inv_pair(a, r[0]);
150 assert_inv_pair(b, r[1]);
151 }
152}