Skip to main content

modmath/
sub.rs

1#[cfg(feature = "nightly")]
2use fixed_bigint::const_numtraits::{ConstOverflowingAdd, ConstOverflowingSub};
3
4#[cfg(feature = "nightly")]
5c0nst::c0nst! {
6    /// # Const Modular Subtraction
7    /// Const-evaluable version of modular subtraction. Uses const traits from
8    /// `fixed_bigint::const_numtraits` instead of `num_traits`.
9    pub c0nst fn const_mod_sub<T>(a: T, b: T, m: T) -> T
10    where
11        T: [c0nst] core::cmp::PartialOrd
12            + Copy
13            + [c0nst] ConstOverflowingAdd
14            + [c0nst] ConstOverflowingSub
15            + [c0nst] core::ops::Rem<Output = T>,
16    {
17        let a = a % m;
18        let (diff, overflow) = a.overflowing_sub(&(b % m));
19        if overflow {
20            m.overflowing_add(&diff).0
21        } else {
22            diff
23        }
24    }
25}
26
27/// # Modular Subtraction (Basic)
28/// Simple version that operates on values and copies them. Requires
29/// `WrappingAdd` and `WrappingSub` traits to be implemented.
30pub fn basic_mod_sub<T>(a: T, b: T, m: T) -> T
31where
32    T: core::cmp::PartialOrd
33        + Copy
34        + num_traits::ops::wrapping::WrappingAdd
35        + num_traits::ops::wrapping::WrappingSub
36        + core::ops::Rem<Output = T>,
37{
38    let a = a % m;
39    let diff = a.wrapping_sub(&(b % m));
40    if diff > a {
41        // If we wrapped around (underflow)
42        diff.wrapping_add(&m)
43    } else {
44        diff
45    }
46}
47
48/// # Modular Subtraction (Constrained)
49/// Version that works with references, requires `WrappingAdd` and
50/// `WrappingSub` traits to be implemented.
51pub fn constrained_mod_sub<T>(a: T, b: &T, m: &T) -> T
52where
53    T: core::cmp::PartialOrd
54        + num_traits::ops::wrapping::WrappingAdd
55        + num_traits::ops::wrapping::WrappingSub,
56    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
57{
58    let a_mod = &a % m;
59    let diff = a_mod.wrapping_sub(&(b % m));
60    if diff > a_mod {
61        // If we wrapped around (underflow)
62        diff.wrapping_add(m)
63    } else {
64        diff
65    }
66}
67
68/// # Modular Subtraction (Strict)
69/// Most constrained version that works with references. Requires
70/// `OverflowingAdd` and `OverflowingSub` traits to be implemented.
71pub fn strict_mod_sub<T>(mut a: T, b: &T, m: &T) -> T
72where
73    T: core::cmp::PartialOrd
74        + num_traits::ops::overflowing::OverflowingAdd
75        + num_traits::ops::overflowing::OverflowingSub,
76    for<'b> T: core::ops::RemAssign<&'b T>,
77    for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
78{
79    a.rem_assign(m);
80    let (diff, overflow) = a.overflowing_sub(&(b % m));
81
82    if overflow {
83        m.overflowing_add(&diff).0
84    } else {
85        diff
86    }
87}
88
89#[cfg(test)]
90macro_rules! select_mod_sub {
91    ($mod_sub:path, $t:ty, by_ref) => {
92        fn mod_sub(a: $t, b: &$t, m: &$t) -> $t {
93            $mod_sub(a, b, m)
94        }
95    };
96    ($mod_sub:path, $t:ty, by_val) => {
97        fn mod_sub(a: $t, b: &$t, m: &$t) -> $t {
98            $mod_sub(a, *b, *m)
99        }
100    };
101}
102
103#[cfg(test)]
104macro_rules! generate_mod_sub_tests {
105    ($mod_sub:path, $t:ty, $by_ref:tt) => {
106        select_mod_sub!($mod_sub, $t, $by_ref);
107
108        #[test]
109        fn test_mod_sub_basic() {
110            assert_eq!(mod_sub(5u8, &10u8, &20u8), 15u8); // 5 - 10 = -5 + 20 = 15
111            assert_eq!(mod_sub(7u8, &3u8, &11u8), 4u8); // 7 - 3 = 4
112            assert_eq!(mod_sub(0u8, &0u8, &10u8), 0u8); // 0 - 0 = 0
113        }
114
115        #[test]
116        fn test_mod_sub_res_equals_modulus() {
117            assert_eq!(mod_sub(10u8, &10u8, &20u8), 0u8); // 10 - 10 = 0
118            assert_eq!(mod_sub(5u8, &5u8, &7u8), 0u8); // 5 - 5 = 0
119        }
120
121        #[test]
122        fn test_mod_sub_res_less_than_modulus() {
123            assert_eq!(mod_sub(15u8, &5u8, &20u8), 10u8); // 15 - 5 = 10
124            assert_eq!(mod_sub(8u8, &3u8, &10u8), 5u8); // 8 - 3 = 5
125        }
126
127        #[test]
128        fn test_mod_sub_with_large_numbers() {
129            assert_eq!(mod_sub(255u8, &254u8, &100u8), 1u8); // (255 % 100) - (254 % 100) = 55 - 54 = 1
130            assert_eq!(mod_sub(200u8, &100u8, &50u8), 0u8); // (200 % 50) - (100 % 50) = 0 - 0 = 0
131        }
132
133        #[test]
134        fn test_mod_sub_with_zero() {
135            assert_eq!(mod_sub(0u8, &5u8, &10u8), 5u8); // 0 - 5 = -5 + 10 = 5
136            assert_eq!(mod_sub(5u8, &0u8, &10u8), 5u8); // 5 - 0 = 5
137        }
138
139        #[test]
140        fn test_mod_sub_with_max_values() {
141            assert_eq!(mod_sub(255u8, &255u8, &100u8), 0u8); // (255 % 100) - (255 % 100) = 55 - 55 = 0
142            assert_eq!(mod_sub(255u8, &1u8, &100u8), 54u8); // (255 % 100) - (1 % 100) = 55 - 1 = 54
143        }
144
145        #[test]
146        fn test_mod_sub_modulus_is_one() {
147            assert_eq!(mod_sub(10u8, &20u8, &1u8), 0u8); // Everything mod 1 is 0
148            assert_eq!(mod_sub(255u8, &255u8, &1u8), 0u8);
149        }
150
151        #[test]
152        #[should_panic]
153        fn test_mod_sub_modulus_is_zero() {
154            mod_sub(10u8, &20u8, &0u8); // Should panic - division by zero
155        }
156
157        #[test]
158        fn test_mod_sub_operands_exceed_modulus() {
159            assert_eq!(mod_sub(100u8, &50u8, &30u8), 20u8); // (100 % 30) - (50 % 30) = 10 - 20 + 30 = 20
160            assert_eq!(mod_sub(75u8, &80u8, &20u8), 15u8); // (75 % 20) - (80 % 20) = 15 - 0 = 15
161        }
162
163        #[test]
164        fn test_mod_sub_edge_cases() {
165            assert_eq!(mod_sub(u8::MAX, &u8::MAX, &100u8), 0u8);
166            assert_eq!(mod_sub(u8::MAX, &0u8, &100u8), 55u8); // 255 % 100 = 55
167            assert_eq!(mod_sub(0u8, &u8::MAX, &100u8), 45u8); // -(255 % 100) + 100 = -55 + 100 = 45
168        }
169
170        #[test]
171        fn test_mod_sub_result_equals_modulus() {
172            assert_eq!(mod_sub(25u8, &5u8, &20u8), 0u8); // (25 % 20) - (5 % 20) = 5 - 5 = 0
173            assert_eq!(mod_sub(45u8, &25u8, &20u8), 0u8); // (45 % 20) - (25 % 20) = 5 - 5 = 0
174        }
175
176        #[test]
177        fn test_mod_sub_overflow() {
178            assert_eq!(mod_sub(200u8, &100u8, &50u8), 0u8); // (200 % 50) - (100 % 50) = 0
179            assert_eq!(mod_sub(255u8, &254u8, &100u8), 1u8); // Overflow in intermediate calculation
180        }
181    };
182}
183
184#[cfg(test)]
185mod strict_mod_sub_tests {
186    generate_mod_sub_tests!(super::strict_mod_sub, u8, by_ref);
187}
188
189#[cfg(test)]
190mod constrained_mod_sub_tests {
191    generate_mod_sub_tests!(super::constrained_mod_sub, u8, by_ref);
192}
193
194#[cfg(test)]
195mod basic_mod_sub_tests {
196    generate_mod_sub_tests!(super::basic_mod_sub, u8, by_val);
197}
198
199#[cfg(test)]
200#[cfg(feature = "nightly")]
201const _: () = {
202    let result = const_mod_sub(5u8, 10u8, 20u8);
203    assert!(result == 15u8);
204};
205
206#[cfg(test)]
207#[cfg(feature = "nightly")]
208mod const_mod_sub_tests {
209    generate_mod_sub_tests!(super::const_mod_sub, u8, by_val);
210}
211
212#[cfg(test)]
213macro_rules! sub_test_module {
214    (
215        $stem:ident,           // Base name (e.g., "bnum")
216        $type_path:path,       // Full path to the type
217        $(type $type_def:ty = $type_expr:ty;)? // Optional type definition
218        strict: $strict:expr,
219        constrained: $constrained:expr,
220        basic: $basic:expr,
221    ) => {
222        paste::paste! {
223            mod [<$stem _tests>] {     // This creates e.g., mod bnum_tests
224                #[allow(unused_imports)]
225                use $type_path;
226                $( type $type_def = $type_expr; )?
227
228                #[test]
229                #[allow(unused_variables)]
230                fn test_mod_sub_basic() {
231                    let a = U256::from(10u8);
232                    let b = U256::from(5u8);
233                    let m = U256::from(20u8);
234                    let result = U256::from(5u8);
235
236                    crate::maybe_test!($strict, assert_eq!(super::strict_mod_sub(a, &b, &m), result));
237                    let a = U256::from(10u8);
238                    crate::maybe_test!($constrained, assert_eq!(super::constrained_mod_sub(a, &b, &m), result));
239                    let a = U256::from(10u8);
240                    crate::maybe_test!($basic, assert_eq!(super::basic_mod_sub(a, b, m), result));
241                }
242            }
243        }
244    };
245}
246
247#[cfg(test)]
248mod bnum_sub_tests {
249    use super::basic_mod_sub;
250    use super::constrained_mod_sub;
251    use super::strict_mod_sub;
252
253    sub_test_module!(
254        bnum,
255        bnum::types::U256,
256        strict: off, // OverflowingAdd + OverflowingSub is not implemented
257        constrained: on,
258        basic: on,
259    );
260
261    sub_test_module!(
262        bnum_patched,
263        bnum_patched::types::U256,
264        strict: on,
265        constrained: on,
266        basic: on,
267    );
268
269    sub_test_module!(
270        crypto_bigint,
271        crypto_bigint::U256,
272        strict: off,  // "Missing OverflowingAdd + OverflowingSub" },
273        constrained: off, // "Rem<'a> is not implemented for U256" },
274        basic: on,
275    );
276
277    sub_test_module!(
278        crypto_bigint_patched,
279        crypto_bigint_patched::U256,
280        strict: on,
281        constrained: on,
282        basic: on,
283    );
284
285    sub_test_module!(
286        num_bigint,
287        num_bigint::BigUint,
288        type U256 = num_bigint::BigUint;
289        strict: off, // OverflowingAdd + OverflowingSub is not implemented
290        constrained: off, // WrappingAdd + WrappingSub is not implemented
291        basic: off, // Copy cannot be implemented, heap allocation
292    );
293
294    sub_test_module!(
295        num_bigint_patched,
296        num_bigint_patched::BigUint,
297        type U256 = num_bigint_patched::BigUint;
298        strict: on,
299        constrained: on,
300        basic: off, // Copy cannot be implemented, heap allocation
301    );
302
303    sub_test_module!(
304        ibig,
305        ibig::UBig,
306        type U256 = ibig::UBig;
307        strict: off, // OverflowingAdd + OverflowingSub is not implemented
308        constrained: off, // WrappingAdd + WrappingSub is not implemented
309        basic: off, // Copy cannot be implemented, heap allocation
310    );
311
312    sub_test_module!(
313        ibig_patched,
314        ibig_patched::UBig,
315        type U256 = ibig_patched::UBig;
316        strict: on,
317        constrained: on,
318        basic: off, // Copy cannot be implemented, heap allocation
319    );
320
321    sub_test_module!(
322        fixed_bigint,
323        fixed_bigint::FixedUInt,
324        type U256 = FixedUInt<u32, 4>;
325        strict: on,
326        constrained: on,
327        basic: on,
328    );
329}