modmath/
add.rs

1/// # Modular Addition (Basic)
2/// Simple version that operates on values and copies them. Requires
3/// `WrappingAdd` and `WrappingSub` traits to be implemented.
4pub fn basic_mod_add<T>(a: T, b: T, m: T) -> T
5where
6    T: core::cmp::PartialOrd
7        + Copy
8        + num_traits::ops::wrapping::WrappingAdd
9        + num_traits::ops::wrapping::WrappingSub
10        + core::ops::Rem<Output = T>,
11{
12    let a = a % m;
13    let sum = a.wrapping_add(&(b % m));
14    if sum >= m || sum < a {
15        sum.wrapping_sub(&m)
16    } else {
17        sum
18    }
19}
20
21/// # Modular Addition (Constrained)
22/// Version that works with references, requires `WrappingAdd` and
23/// `WrappingSub` traits to be implemented.
24pub fn constrained_mod_add<T>(a: T, b: &T, m: &T) -> T
25where
26    T: core::cmp::PartialOrd
27        + num_traits::ops::wrapping::WrappingAdd
28        + num_traits::ops::wrapping::WrappingSub,
29    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
30{
31    // maybe a should be `mut a`
32    let a_mod = &a % m;
33    let sum = a_mod.wrapping_add(&(b % m));
34    if &sum >= m || sum < a_mod {
35        sum.wrapping_sub(m)
36    } else {
37        sum
38    }
39}
40
41/// # Modular Addition (Strict)
42/// Most constrained version that works with references. Requires
43/// `OverflowingAdd` and `OverflowingSub` traits to be implemented.
44pub fn strict_mod_add<T>(mut a: T, b: &T, m: &T) -> T
45where
46    T: core::cmp::PartialOrd
47        + num_traits::ops::overflowing::OverflowingAdd
48        + num_traits::ops::overflowing::OverflowingSub,
49    for<'b> T: core::ops::RemAssign<&'b T>,
50    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
51{
52    a.rem_assign(m);
53    let (sum, overflow) = a.overflowing_add(&(b % m));
54
55    if &sum >= m || overflow {
56        sum.overflowing_sub(m).0
57    } else {
58        sum
59    }
60}
61
62#[cfg(test)]
63macro_rules! select_mod_add {
64    ($mod_add:path, $t:ty, by_ref) => {
65        fn mod_add(a: $t, b: &$t, m: &$t) -> $t {
66            $mod_add(a, b, m)
67        }
68    };
69    ($mod_add:path, $t:ty, by_val) => {
70        fn mod_add(a: $t, b: &$t, m: &$t) -> $t {
71            $mod_add(a, *b, *m)
72        }
73    };
74}
75
76#[cfg(test)]
77macro_rules! generate_mod_add_tests {
78    ($mod_add:path, $t:ty, $by_ref:tt) => {
79        select_mod_add!($mod_add, $t, $by_ref);
80
81        #[test]
82        fn test_mod_add_basic() {
83            assert_eq!(mod_add(5u8, &10u8, &20u8), 15u8);
84            assert_eq!(mod_add(7u8, &6u8, &14u8), 13u8);
85            assert_eq!(mod_add(0u8, &0u8, &10u8), 0u8);
86        }
87
88        #[test]
89        fn test_mod_add_sum_equals_modulus() {
90            assert_eq!(mod_add(10u8, &10u8, &20u8), 0u8);
91            assert_eq!(mod_add(15u8, &5u8, &20u8), 0u8);
92        }
93
94        #[test]
95        fn test_mod_add_sum_exceeds_modulus() {
96            assert_eq!(mod_add(15u8, &10u8, &20u8), 5u8);
97            assert_eq!(mod_add(25u8, &10u8, &30u8), 5u8);
98        }
99
100        #[test]
101        fn test_mod_add_overflow() {
102            assert_eq!(mod_add(200u8, &100u8, &50u8), 0u8);
103            assert_eq!(mod_add(255u8, &255u8, &100u8), 10u8);
104        }
105
106        #[test]
107        fn test_mod_add_with_zero() {
108            assert_eq!(mod_add(0u8, &25u8, &30u8), 25u8);
109            assert_eq!(mod_add(25u8, &0u8, &30u8), 25u8);
110            assert_eq!(mod_add(0u8, &0u8, &30u8), 0u8);
111        }
112
113        #[test]
114        fn test_mod_add_with_max_values() {
115            assert_eq!(mod_add(255u8, &1u8, &100u8), 56u8);
116            assert_eq!(mod_add(254u8, &1u8, &255u8), 0u8);
117            assert_eq!(mod_add(255u8, &255u8, &255u8), 0u8);
118        }
119
120        #[test]
121        fn test_mod_add_modulus_is_one() {
122            assert_eq!(mod_add(10u8, &20u8, &1u8), 0u8);
123            assert_eq!(mod_add(255u8, &255u8, &1u8), 0u8);
124        }
125
126        #[test]
127        #[should_panic]
128        fn test_mod_add_modulus_is_zero() {
129            mod_add(10u8, &20u8, &0u8);
130        }
131
132        #[test]
133        fn test_mod_add_operands_equal_modulus_minus_one() {
134            assert_eq!(mod_add(19u8, &19u8, &20u8), 18u8);
135            assert_eq!(mod_add(254u8, &254u8, &255u8), 253u8);
136        }
137
138        #[test]
139        fn test_mod_add_large_modulus() {
140            let large_modulus = 300u16;
141            let result = mod_add(200u8, &100u8, &(large_modulus as u8));
142            assert_eq!(result, 36u8);
143        }
144
145        #[test]
146        fn test_mod_add_modulus_equals_u8_max() {
147            assert_eq!(mod_add(100u8, &155u8, &255u8), 0u8);
148            assert_eq!(mod_add(200u8, &100u8, &255u8), 45u8);
149        }
150
151        #[test]
152        fn test_mod_add_overflow_edge_case() {
153            assert_eq!(mod_add(255u8, &1u8, &255u8), 1u8);
154        }
155
156        #[test]
157        fn test_mod_add_with_operands_exceeding_modulus() {
158            assert_eq!(mod_add(200u8, &100u8, &50u8), 0u8);
159            assert_eq!(mod_add(75u8, &80u8, &60u8), 35u8);
160        }
161
162        #[test]
163        fn test_mod_add_with_modulus_exceeding_u8_max() {
164            let modulus = 300u16;
165            let result = mod_add(250u8, &10u8, &(modulus as u8));
166            assert_eq!(result, 40u8);
167        }
168    };
169}
170
171#[cfg(test)]
172mod strict_mod_add_tests {
173    use super::strict_mod_add;
174    generate_mod_add_tests!(strict_mod_add, u8, by_ref);
175}
176
177#[cfg(test)]
178mod constrained_mod_add_tests {
179    use super::constrained_mod_add;
180    generate_mod_add_tests!(constrained_mod_add, u8, by_ref);
181}
182
183#[cfg(test)]
184mod basic_mod_add_tests {
185    use super::basic_mod_add;
186    generate_mod_add_tests!(basic_mod_add, u8, by_val);
187}
188
189#[cfg(test)]
190macro_rules! add_test_module {
191    (
192        $stem:ident,           // Base name (e.g., "bnum")
193        $type_path:path,       // Full path to the type
194        $(type $type_def:ty = $type_expr:ty;)? // Optional type definition
195        strict: $strict:expr,
196        constrained: $constrained:expr,
197        basic: $basic:expr,
198    ) => {
199        paste::paste! {
200            mod [<$stem _tests>] {     // This creates e.g., mod bnum_tests
201                #[allow(unused_imports)]
202                use $type_path;
203                $( type $type_def = $type_expr; )?
204
205                #[test]
206                #[allow(unused_variables)]
207                fn test_mod_add_basic() {
208                    let a = U256::from(5u8);
209                    let b = U256::from(10u8);
210                    let m = U256::from(20u8);
211                    let result = U256::from(15u8);
212
213                    crate::maybe_test!($strict, assert_eq!(super::strict_mod_add(a, &b, &m), result));
214                    let a = U256::from(5u8);
215                    crate::maybe_test!($constrained, assert_eq!(super::constrained_mod_add(a, &b, &m), result));
216                    let a = U256::from(5u8);
217                    crate::maybe_test!($basic, assert_eq!(super::basic_mod_add(a, b, m), result));
218                }
219            }
220        }
221    };
222}
223
224#[cfg(test)]
225mod bnum_add_tests {
226    use super::basic_mod_add;
227    use super::constrained_mod_add;
228    use super::strict_mod_add;
229
230    add_test_module!(
231        bnum,
232        bnum::types::U256,
233        strict: off, // OverflowingAdd + OverflowingSub is not implemented
234        constrained: on,
235        basic: on,
236    );
237
238    add_test_module!(
239        bnum_patched,
240        bnum_patched::types::U256,
241        strict: on,
242        constrained: on,
243        basic: on,
244    );
245
246    add_test_module!(
247        crypto_bigint,
248        crypto_bigint::U256,
249        strict: off,  // "Missing OverflowingAdd + OverflowingSub" },
250        constrained: off, // "Rem<'a> is not implemented for U256" },
251        basic: on,
252    );
253
254    add_test_module!(
255        crypto_bigint_patched,
256        crypto_bigint_patched::U256,
257        strict: on,
258        constrained: on,
259        basic: on,
260    );
261
262    add_test_module!(
263        num_bigint,
264        num_bigint::BigUint,
265        type U256 = num_bigint::BigUint;
266        strict: off, // OverflowingAdd + OverflowingSub is not implemented
267        constrained: off, // WrappingAdd + WrappingSub is not implemented
268        basic: off, // Copy cannot be implemented, heap allocation
269    );
270
271    add_test_module!(
272        num_bigint_patched,
273        num_bigint_patched::BigUint,
274        type U256 = num_bigint_patched::BigUint;
275        strict: on,
276        constrained: on,
277        basic: off, // Copy cannot be implemented, heap allocation
278    );
279
280    add_test_module!(
281        ibig,
282        ibig::UBig,
283        type U256 = ibig::UBig;
284        strict: off, // OverflowingAdd + OverflowingSub is not implemented
285        constrained: off, // WrappingAdd + WrappingSub is not implemented
286        basic: off, // Copy cannot be implemented, heap allocation
287    );
288
289    add_test_module!(
290        ibig_patched,
291        ibig_patched::UBig,
292        type U256 = ibig_patched::UBig;
293        strict: on,
294        constrained: on,
295        basic: off, // Copy cannot be implemented, heap allocation
296    );
297
298    add_test_module!(
299        fixed_bigint,
300        fixed_bigint::FixedUInt,
301        type U256 = FixedUInt<u32, 4>;
302        strict: on,
303        constrained: on,
304        basic: on,
305    );
306}