modmath/
sub.rs

1/// # Modular Subtraction (Basic)
2/// Simple version that operates on values and copies them. Requires
3/// `WrappingAdd` and `WrappingSub` traits to be implemented.
4pub fn basic_mod_sub<T>(a: T, b: T, m: T) -> T
5where
6    T: core::cmp::PartialOrd
7        + Copy
8        + num_traits::ops::wrapping::WrappingAdd
9        + num_traits::ops::wrapping::WrappingSub
10        + core::ops::Rem<Output = T>,
11{
12    let a = a % m;
13    let diff = a.wrapping_sub(&(b % m));
14    if diff > a {
15        // If we wrapped around (underflow)
16        diff.wrapping_add(&m)
17    } else {
18        diff
19    }
20}
21
22/// # Modular Subtraction (Constrained)
23/// Version that works with references, requires `WrappingAdd` and
24/// `WrappingSub` traits to be implemented.
25pub fn constrained_mod_sub<T>(a: T, b: &T, m: &T) -> T
26where
27    T: core::cmp::PartialOrd
28        + num_traits::ops::wrapping::WrappingAdd
29        + num_traits::ops::wrapping::WrappingSub,
30    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
31{
32    let a_mod = &a % m;
33    let diff = a_mod.wrapping_sub(&(b % m));
34    if diff > a_mod {
35        // If we wrapped around (underflow)
36        diff.wrapping_add(m)
37    } else {
38        diff
39    }
40}
41
42/// # Modular Subtraction (Strict)
43/// Most constrained version that works with references. Requires
44/// `OverflowingAdd` and `OverflowingSub` traits to be implemented.
45pub fn strict_mod_sub<T>(mut a: T, b: &T, m: &T) -> T
46where
47    T: core::cmp::PartialOrd
48        + num_traits::ops::overflowing::OverflowingAdd
49        + num_traits::ops::overflowing::OverflowingSub,
50    for<'b> T: core::ops::RemAssign<&'b T>,
51    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
52{
53    a.rem_assign(m);
54    let (diff, overflow) = a.overflowing_sub(&(b % m));
55
56    if overflow {
57        m.overflowing_add(&diff).0
58    } else {
59        diff
60    }
61}
62
63#[cfg(test)]
64macro_rules! select_mod_sub {
65    ($mod_sub:path, $t:ty, by_ref) => {
66        fn mod_sub(a: $t, b: &$t, m: &$t) -> $t {
67            $mod_sub(a, b, m)
68        }
69    };
70    ($mod_sub:path, $t:ty, by_val) => {
71        fn mod_sub(a: $t, b: &$t, m: &$t) -> $t {
72            $mod_sub(a, *b, *m)
73        }
74    };
75}
76
77#[cfg(test)]
78macro_rules! generate_mod_sub_tests {
79    ($mod_sub:path, $t:ty, $by_ref:tt) => {
80        select_mod_sub!($mod_sub, $t, $by_ref);
81
82        #[test]
83        fn test_mod_sub_basic() {
84            assert_eq!(mod_sub(5u8, &10u8, &20u8), 15u8); // 5 - 10 = -5 + 20 = 15
85            assert_eq!(mod_sub(7u8, &3u8, &11u8), 4u8); // 7 - 3 = 4
86            assert_eq!(mod_sub(0u8, &0u8, &10u8), 0u8); // 0 - 0 = 0
87        }
88
89        #[test]
90        fn test_mod_sub_res_equals_modulus() {
91            assert_eq!(mod_sub(10u8, &10u8, &20u8), 0u8); // 10 - 10 = 0
92            assert_eq!(mod_sub(5u8, &5u8, &7u8), 0u8); // 5 - 5 = 0
93        }
94
95        #[test]
96        fn test_mod_sub_res_less_than_modulus() {
97            assert_eq!(mod_sub(15u8, &5u8, &20u8), 10u8); // 15 - 5 = 10
98            assert_eq!(mod_sub(8u8, &3u8, &10u8), 5u8); // 8 - 3 = 5
99        }
100
101        #[test]
102        fn test_mod_sub_with_large_numbers() {
103            assert_eq!(mod_sub(255u8, &254u8, &100u8), 1u8); // (255 % 100) - (254 % 100) = 55 - 54 = 1
104            assert_eq!(mod_sub(200u8, &100u8, &50u8), 0u8); // (200 % 50) - (100 % 50) = 0 - 0 = 0
105        }
106
107        #[test]
108        fn test_mod_sub_with_zero() {
109            assert_eq!(mod_sub(0u8, &5u8, &10u8), 5u8); // 0 - 5 = -5 + 10 = 5
110            assert_eq!(mod_sub(5u8, &0u8, &10u8), 5u8); // 5 - 0 = 5
111        }
112
113        #[test]
114        fn test_mod_sub_with_max_values() {
115            assert_eq!(mod_sub(255u8, &255u8, &100u8), 0u8); // (255 % 100) - (255 % 100) = 55 - 55 = 0
116            assert_eq!(mod_sub(255u8, &1u8, &100u8), 54u8); // (255 % 100) - (1 % 100) = 55 - 1 = 54
117        }
118
119        #[test]
120        fn test_mod_sub_modulus_is_one() {
121            assert_eq!(mod_sub(10u8, &20u8, &1u8), 0u8); // Everything mod 1 is 0
122            assert_eq!(mod_sub(255u8, &255u8, &1u8), 0u8);
123        }
124
125        #[test]
126        #[should_panic]
127        fn test_mod_sub_modulus_is_zero() {
128            mod_sub(10u8, &20u8, &0u8); // Should panic - division by zero
129        }
130
131        #[test]
132        fn test_mod_sub_operands_exceed_modulus() {
133            assert_eq!(mod_sub(100u8, &50u8, &30u8), 20u8); // (100 % 30) - (50 % 30) = 10 - 20 + 30 = 20
134            assert_eq!(mod_sub(75u8, &80u8, &20u8), 15u8); // (75 % 20) - (80 % 20) = 15 - 0 = 15
135        }
136
137        #[test]
138        fn test_mod_sub_edge_cases() {
139            assert_eq!(mod_sub(u8::MAX, &u8::MAX, &100u8), 0u8);
140            assert_eq!(mod_sub(u8::MAX, &0u8, &100u8), 55u8); // 255 % 100 = 55
141            assert_eq!(mod_sub(0u8, &u8::MAX, &100u8), 45u8); // -(255 % 100) + 100 = -55 + 100 = 45
142        }
143
144        #[test]
145        fn test_mod_sub_result_equals_modulus() {
146            assert_eq!(mod_sub(25u8, &5u8, &20u8), 0u8); // (25 % 20) - (5 % 20) = 5 - 5 = 0
147            assert_eq!(mod_sub(45u8, &25u8, &20u8), 0u8); // (45 % 20) - (25 % 20) = 5 - 5 = 0
148        }
149
150        #[test]
151        fn test_mod_sub_overflow() {
152            assert_eq!(mod_sub(200u8, &100u8, &50u8), 0u8); // (200 % 50) - (100 % 50) = 0
153            assert_eq!(mod_sub(255u8, &254u8, &100u8), 1u8); // Overflow in intermediate calculation
154        }
155    };
156}
157
158#[cfg(test)]
159mod strict_mod_sub_tests {
160    generate_mod_sub_tests!(super::strict_mod_sub, u8, by_ref);
161}
162
163#[cfg(test)]
164mod constrained_mod_sub_tests {
165    generate_mod_sub_tests!(super::constrained_mod_sub, u8, by_ref);
166}
167
168#[cfg(test)]
169mod basic_mod_sub_tests {
170    generate_mod_sub_tests!(super::basic_mod_sub, u8, by_val);
171}
172
173#[cfg(test)]
174macro_rules! sub_test_module {
175    (
176        $stem:ident,           // Base name (e.g., "bnum")
177        $type_path:path,       // Full path to the type
178        $(type $type_def:ty = $type_expr:ty;)? // Optional type definition
179        strict: $strict:expr,
180        constrained: $constrained:expr,
181        basic: $basic:expr,
182    ) => {
183        paste::paste! {
184            mod [<$stem _tests>] {     // This creates e.g., mod bnum_tests
185                #[allow(unused_imports)]
186                use $type_path;
187                $( type $type_def = $type_expr; )?
188
189                #[test]
190                #[allow(unused_variables)]
191                fn test_mod_sub_basic() {
192                    let a = U256::from(10u8);
193                    let b = U256::from(5u8);
194                    let m = U256::from(20u8);
195                    let result = U256::from(5u8);
196
197                    crate::maybe_test!($strict, assert_eq!(super::strict_mod_sub(a, &b, &m), result));
198                    let a = U256::from(10u8);
199                    crate::maybe_test!($constrained, assert_eq!(super::constrained_mod_sub(a, &b, &m), result));
200                    let a = U256::from(10u8);
201                    crate::maybe_test!($basic, assert_eq!(super::basic_mod_sub(a, b, m), result));
202                }
203            }
204        }
205    };
206}
207
208#[cfg(test)]
209mod bnum_sub_tests {
210    use super::basic_mod_sub;
211    use super::constrained_mod_sub;
212    use super::strict_mod_sub;
213
214    sub_test_module!(
215        bnum,
216        bnum::types::U256,
217        strict: off, // OverflowingAdd + OverflowingSub is not implemented
218        constrained: on,
219        basic: on,
220    );
221
222    sub_test_module!(
223        bnum_patched,
224        bnum_patched::types::U256,
225        strict: on,
226        constrained: on,
227        basic: on,
228    );
229
230    sub_test_module!(
231        crypto_bigint,
232        crypto_bigint::U256,
233        strict: off,  // "Missing OverflowingAdd + OverflowingSub" },
234        constrained: off, // "Rem<'a> is not implemented for U256" },
235        basic: on,
236    );
237
238    sub_test_module!(
239        crypto_bigint_patched,
240        crypto_bigint_patched::U256,
241        strict: on,
242        constrained: on,
243        basic: on,
244    );
245
246    sub_test_module!(
247        num_bigint,
248        num_bigint::BigUint,
249        type U256 = num_bigint::BigUint;
250        strict: off, // OverflowingAdd + OverflowingSub is not implemented
251        constrained: off, // WrappingAdd + WrappingSub is not implemented
252        basic: off, // Copy cannot be implemented, heap allocation
253    );
254
255    sub_test_module!(
256        num_bigint_patched,
257        num_bigint_patched::BigUint,
258        type U256 = num_bigint_patched::BigUint;
259        strict: on,
260        constrained: on,
261        basic: off, // Copy cannot be implemented, heap allocation
262    );
263
264    sub_test_module!(
265        ibig,
266        ibig::UBig,
267        type U256 = ibig::UBig;
268        strict: off, // OverflowingAdd + OverflowingSub is not implemented
269        constrained: off, // WrappingAdd + WrappingSub is not implemented
270        basic: off, // Copy cannot be implemented, heap allocation
271    );
272
273    sub_test_module!(
274        ibig_patched,
275        ibig_patched::UBig,
276        type U256 = ibig_patched::UBig;
277        strict: on,
278        constrained: on,
279        basic: off, // Copy cannot be implemented, heap allocation
280    );
281
282    sub_test_module!(
283        fixed_bigint,
284        fixed_bigint::FixedUInt,
285        type U256 = FixedUInt<u32, 4>;
286        strict: on,
287        constrained: on,
288        basic: on,
289    );
290}