Skip to main content

bsv/primitives/
reduction_context.rs

1//! Modular reduction context for field arithmetic.
2//!
3//! ReductionContext provides modular arithmetic operations (add, sub, mul, sqr,
4//! invm, pow, sqrt) over BigNumber values under a given modulus. This is used
5//! for secp256k1 field and scalar arithmetic.
6
7use crate::primitives::big_number::BigNumber;
8use crate::primitives::k256::k256_reduce_limbs;
9use crate::primitives::montgomery::Montgomery;
10use std::sync::Arc;
11
12/// Context for performing modular reduction operations.
13///
14/// Mirrors the TS SDK's ReductionContext class. Can be constructed with an
15/// arbitrary modulus or with the string "k256" to use the secp256k1 prime.
16#[derive(Debug)]
17pub struct ReductionContext {
18    /// The modulus used for reduction.
19    pub m: BigNumber,
20    /// Optional Mersenne prime for fast reduction.
21    prime: Option<Box<dyn MersennePrime>>,
22    /// Optional Montgomery context for the modulus (available for K256).
23    pub mont: Option<Montgomery>,
24}
25
26/// Trait for Mersenne-like prime reduction (used by K256).
27pub trait MersennePrime: std::fmt::Debug + Send + Sync {
28    /// Reduce a BigNumber in-place using the Mersenne prime structure.
29    fn ireduce(&self, num: &mut BigNumber);
30    /// The prime value.
31    fn p(&self) -> &BigNumber;
32}
33
34impl ReductionContext {
35    /// Create a new ReductionContext with the given modulus.
36    pub fn new(m: BigNumber) -> Arc<Self> {
37        Arc::new(ReductionContext {
38            m,
39            prime: None,
40            mont: None,
41        })
42    }
43
44    /// Create a ReductionContext for the secp256k1 field prime (k256).
45    /// Includes a Montgomery context for use by callers needing Montgomery form.
46    pub fn k256() -> Arc<Self> {
47        let k = crate::primitives::k256::K256::new();
48        let m = k.p().clone();
49        let mont = Montgomery::new(&m);
50        Arc::new(ReductionContext {
51            m,
52            prime: Some(Box::new(k)),
53            mont: Some(mont),
54        })
55    }
56
57    /// Create a new ReductionContext with a Mersenne prime.
58    pub fn with_prime(prime: Box<dyn MersennePrime>) -> Arc<Self> {
59        let m = prime.p().clone();
60        Arc::new(ReductionContext {
61            m,
62            prime: Some(prime),
63            mont: None,
64        })
65    }
66
67    /// Reduce a BigNumber modulo m.
68    pub fn imod(&self, a: &BigNumber) -> BigNumber {
69        if let Some(ref prime) = self.prime {
70            let mut r = a.clone();
71            prime.ireduce(&mut r);
72            r
73        } else {
74            a.umod(&self.m).unwrap_or_else(|_| BigNumber::zero())
75        }
76    }
77
78    /// Convert a BigNumber into this reduction context (reduce mod m).
79    pub fn convert_to(&self, num: &BigNumber) -> BigNumber {
80        num.umod(&self.m).unwrap_or_else(|_| BigNumber::zero())
81    }
82
83    /// Convert a BigNumber from this reduction context (just clone).
84    pub fn convert_from(&self, num: &BigNumber) -> BigNumber {
85        let mut r = num.clone();
86        r.red = None;
87        r
88    }
89
90    /// Negate a in the context of modulus m.
91    pub fn neg(&self, a: &BigNumber) -> BigNumber {
92        if a.is_zero() {
93            return a.clone();
94        }
95        self.m.sub(a)
96    }
97
98    /// Add two BigNumbers mod m.
99    pub fn add(&self, a: &BigNumber, b: &BigNumber) -> BigNumber {
100        let mut res = a.add(b);
101        res = res.sub(&self.m);
102        if res.is_neg() {
103            res = res.add(&self.m);
104        }
105        // Preserve red context
106        res.red = a.red.clone();
107        res
108    }
109
110    /// Subtract b from a mod m.
111    pub fn sub(&self, a: &BigNumber, b: &BigNumber) -> BigNumber {
112        let mut res = a.sub(b);
113        if res.is_neg() {
114            res = res.add(&self.m);
115        }
116        res.red = a.red.clone();
117        res
118    }
119
120    /// Multiply two BigNumbers mod m.
121    /// For K256 with 4-limb operands, uses Karatsuba mul_4x4 followed by
122    /// limb-level K256 reduction, avoiding all BigNumber temporary allocations.
123    pub fn mul(&self, a: &BigNumber, b: &BigNumber) -> BigNumber {
124        // Fast path: 4-limb K256 -- bypass all BigNumber intermediates
125        let a_limbs = a.get_limbs();
126        let b_limbs = b.get_limbs();
127        if self.prime.is_some() && a_limbs.len() == 4 && b_limbs.len() == 4 {
128            let a4: [u64; 4] = [a_limbs[0], a_limbs[1], a_limbs[2], a_limbs[3]];
129            let b4: [u64; 4] = [b_limbs[0], b_limbs[1], b_limbs[2], b_limbs[3]];
130            let prod8 = crate::primitives::big_number::mul_4x4(&a4, &b4);
131            let reduced = k256_reduce_limbs(&prod8);
132            let mut result = BigNumber::from_raw_limbs(&reduced);
133            result.red = a.red.clone();
134            return result;
135        }
136
137        let prod = a.mul(b);
138        let mut result = self.imod(&prod);
139        result.red = a.red.clone();
140        result
141    }
142
143    /// Square a BigNumber mod m.
144    /// For K256 with 4-limb operands, uses sqr_4x4 followed by
145    /// limb-level K256 reduction.
146    pub fn sqr(&self, a: &BigNumber) -> BigNumber {
147        // Fast path: 4-limb K256 -- bypass all BigNumber intermediates
148        let a_limbs = a.get_limbs();
149        if self.prime.is_some() && a_limbs.len() == 4 {
150            let a4: [u64; 4] = [a_limbs[0], a_limbs[1], a_limbs[2], a_limbs[3]];
151            let prod8 = crate::primitives::big_number::mul_4x4(&a4, &a4);
152            let reduced = k256_reduce_limbs(&prod8);
153            let mut result = BigNumber::from_raw_limbs(&reduced);
154            result.red = a.red.clone();
155            return result;
156        }
157
158        let sq = a.sqr();
159        let mut result = self.imod(&sq);
160        result.red = a.red.clone();
161        result
162    }
163
164    /// Modular inverse in context.
165    pub fn invm(&self, a: &BigNumber) -> BigNumber {
166        let inv = a.invm(&self.m).unwrap_or_else(|_| BigNumber::zero());
167        let mut result = self.imod(&inv);
168        result.red = a.red.clone();
169        result
170    }
171
172    /// Modular exponentiation: a^exp mod m.
173    pub fn pow(&self, a: &BigNumber, exp: &BigNumber) -> BigNumber {
174        if exp.is_zero() {
175            let mut one = BigNumber::one();
176            one.red = a.red.clone();
177            return one;
178        }
179
180        let mut result = BigNumber::one();
181        result.red = a.red.clone();
182        let base = a.clone();
183        let bits = exp.bit_length();
184
185        for i in (0..bits).rev() {
186            result = self.sqr(&result);
187            if exp.testn(i) {
188                result = self.mul(&result, &base);
189            }
190        }
191
192        result
193    }
194
195    /// Modular square root (Tonelli-Shanks for p % 4 == 3).
196    pub fn sqrt(&self, a: &BigNumber) -> BigNumber {
197        if a.is_zero() {
198            return a.clone();
199        }
200
201        let mod4 = self.m.andln(2);
202        // p % 4 == 3 fast path: sqrt(a) = a^((p+1)/4) mod p
203        if mod4 != 0 {
204            let exp = self.m.addn(1);
205            let exp = exp.ushrn(2);
206            return self.pow(a, &exp);
207        }
208
209        // Tonelli-Shanks for general case
210        let mut q = self.m.subn(1);
211        let mut s = 0usize;
212        while !q.is_zero() && q.andln(1) == 0 {
213            s += 1;
214            q.iushrn(1);
215        }
216
217        let one = BigNumber::one();
218        let one_red = {
219            let mut o = one.clone();
220            o.red = a.red.clone();
221            o
222        };
223        let neg_one = self.neg(&one_red);
224
225        // Find quadratic non-residue z
226        let lpow = self.m.subn(1).ushrn(1);
227        let zl = self.m.bit_length();
228        let mut z = BigNumber::from_number(2 * (zl * zl) as i64);
229        z.red = a.red.clone();
230
231        while self.pow(&z, &lpow).cmp(&neg_one) != 0 {
232            let neg_one_clone = neg_one.clone();
233            z = self.add(&z, &neg_one_clone);
234        }
235
236        let mut c = self.pow(&z, &q);
237        let mut r = self.pow(a, &q.addn(1).ushrn(1));
238        let mut t = self.pow(a, &q);
239        let mut m = s;
240
241        while t.cmp(&one_red) != 0 {
242            let mut tmp = t.clone();
243            let mut i = 0;
244            while tmp.cmp(&one_red) != 0 {
245                tmp = self.sqr(&tmp);
246                i += 1;
247            }
248
249            let mut shift = BigNumber::one();
250            shift.iushln(m - i - 1);
251            let b = self.pow(&c, &shift);
252
253            r = self.mul(&r, &b);
254            c = self.sqr(&b);
255            t = self.mul(&t, &c);
256            m = i;
257        }
258
259        r
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_reduction_context_basic() {
269        let ctx = ReductionContext::new(BigNumber::from_number(7));
270        let a = BigNumber::from_number(10);
271        let result = ctx.imod(&a);
272        assert_eq!(result.to_number(), Some(3)); // 10 mod 7 = 3
273    }
274
275    #[test]
276    fn test_reduction_context_add() {
277        let ctx = ReductionContext::new(BigNumber::from_number(7));
278        let a = BigNumber::from_number(5);
279        let b = BigNumber::from_number(4);
280        let result = ctx.add(&a, &b);
281        assert_eq!(result.to_number(), Some(2)); // (5 + 4) mod 7 = 2
282    }
283
284    #[test]
285    fn test_reduction_context_sub() {
286        let ctx = ReductionContext::new(BigNumber::from_number(7));
287        let a = BigNumber::from_number(3);
288        let b = BigNumber::from_number(5);
289        let result = ctx.sub(&a, &b);
290        assert_eq!(result.to_number(), Some(5)); // (3 - 5) mod 7 = 5
291    }
292
293    #[test]
294    fn test_reduction_context_mul() {
295        let ctx = ReductionContext::new(BigNumber::from_number(7));
296        let a = BigNumber::from_number(3);
297        let b = BigNumber::from_number(4);
298        let result = ctx.mul(&a, &b);
299        assert_eq!(result.to_number(), Some(5)); // (3 * 4) mod 7 = 5
300    }
301
302    #[test]
303    fn test_reduction_context_invm() {
304        let ctx = ReductionContext::new(BigNumber::from_number(11));
305        let a = BigNumber::from_number(3);
306        let inv = ctx.invm(&a);
307        // 3 * inv mod 11 should be 1
308        let check = ctx.mul(&a, &inv);
309        assert_eq!(check.to_number(), Some(1));
310    }
311
312    #[test]
313    fn test_reduction_context_pow() {
314        let ctx = ReductionContext::new(BigNumber::from_number(7));
315        let a = BigNumber::from_number(3);
316        let exp = BigNumber::from_number(2);
317        let result = ctx.pow(&a, &exp);
318        assert_eq!(result.to_number(), Some(2)); // 3^2 mod 7 = 2
319    }
320}