1pub fn basic_mod_add<T>(a: T, b: T, m: T) -> T
5where
6 T: core::cmp::PartialOrd
7 + Copy
8 + num_traits::ops::wrapping::WrappingAdd
9 + num_traits::ops::wrapping::WrappingSub
10 + core::ops::Rem<Output = T>,
11{
12 let a = a % m;
13 let sum = a.wrapping_add(&(b % m));
14 if sum >= m || sum < a {
15 sum.wrapping_sub(&m)
16 } else {
17 sum
18 }
19}
20
21pub fn constrained_mod_add<T>(a: T, b: &T, m: &T) -> T
25where
26 T: core::cmp::PartialOrd
27 + num_traits::ops::wrapping::WrappingAdd
28 + num_traits::ops::wrapping::WrappingSub,
29 for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
30{
31 let a_mod = &a % m;
33 let sum = a_mod.wrapping_add(&(b % m));
34 if &sum >= m || sum < a_mod {
35 sum.wrapping_sub(m)
36 } else {
37 sum
38 }
39}
40
41pub fn strict_mod_add<T>(mut a: T, b: &T, m: &T) -> T
45where
46 T: core::cmp::PartialOrd
47 + num_traits::ops::overflowing::OverflowingAdd
48 + num_traits::ops::overflowing::OverflowingSub,
49 for<'b> T: core::ops::RemAssign<&'b T>,
50 for<'a> &'a T: core::ops::Rem<&'a T, Output = T>,
51{
52 a.rem_assign(m);
53 let (sum, overflow) = a.overflowing_add(&(b % m));
54
55 if &sum >= m || overflow {
56 sum.overflowing_sub(m).0
57 } else {
58 sum
59 }
60}
61
62#[cfg(test)]
63macro_rules! select_mod_add {
64 ($mod_add:path, $t:ty, by_ref) => {
65 fn mod_add(a: $t, b: &$t, m: &$t) -> $t {
66 $mod_add(a, b, m)
67 }
68 };
69 ($mod_add:path, $t:ty, by_val) => {
70 fn mod_add(a: $t, b: &$t, m: &$t) -> $t {
71 $mod_add(a, *b, *m)
72 }
73 };
74}
75
76#[cfg(test)]
77macro_rules! generate_mod_add_tests {
78 ($mod_add:path, $t:ty, $by_ref:tt) => {
79 select_mod_add!($mod_add, $t, $by_ref);
80
81 #[test]
82 fn test_mod_add_basic() {
83 assert_eq!(mod_add(5u8, &10u8, &20u8), 15u8);
84 assert_eq!(mod_add(7u8, &6u8, &14u8), 13u8);
85 assert_eq!(mod_add(0u8, &0u8, &10u8), 0u8);
86 }
87
88 #[test]
89 fn test_mod_add_sum_equals_modulus() {
90 assert_eq!(mod_add(10u8, &10u8, &20u8), 0u8);
91 assert_eq!(mod_add(15u8, &5u8, &20u8), 0u8);
92 }
93
94 #[test]
95 fn test_mod_add_sum_exceeds_modulus() {
96 assert_eq!(mod_add(15u8, &10u8, &20u8), 5u8);
97 assert_eq!(mod_add(25u8, &10u8, &30u8), 5u8);
98 }
99
100 #[test]
101 fn test_mod_add_overflow() {
102 assert_eq!(mod_add(200u8, &100u8, &50u8), 0u8);
103 assert_eq!(mod_add(255u8, &255u8, &100u8), 10u8);
104 }
105
106 #[test]
107 fn test_mod_add_with_zero() {
108 assert_eq!(mod_add(0u8, &25u8, &30u8), 25u8);
109 assert_eq!(mod_add(25u8, &0u8, &30u8), 25u8);
110 assert_eq!(mod_add(0u8, &0u8, &30u8), 0u8);
111 }
112
113 #[test]
114 fn test_mod_add_with_max_values() {
115 assert_eq!(mod_add(255u8, &1u8, &100u8), 56u8);
116 assert_eq!(mod_add(254u8, &1u8, &255u8), 0u8);
117 assert_eq!(mod_add(255u8, &255u8, &255u8), 0u8);
118 }
119
120 #[test]
121 fn test_mod_add_modulus_is_one() {
122 assert_eq!(mod_add(10u8, &20u8, &1u8), 0u8);
123 assert_eq!(mod_add(255u8, &255u8, &1u8), 0u8);
124 }
125
126 #[test]
127 #[should_panic]
128 fn test_mod_add_modulus_is_zero() {
129 mod_add(10u8, &20u8, &0u8);
130 }
131
132 #[test]
133 fn test_mod_add_operands_equal_modulus_minus_one() {
134 assert_eq!(mod_add(19u8, &19u8, &20u8), 18u8);
135 assert_eq!(mod_add(254u8, &254u8, &255u8), 253u8);
136 }
137
138 #[test]
139 fn test_mod_add_large_modulus() {
140 let large_modulus = 300u16;
141 let result = mod_add(200u8, &100u8, &(large_modulus as u8));
142 assert_eq!(result, 36u8);
143 }
144
145 #[test]
146 fn test_mod_add_modulus_equals_u8_max() {
147 assert_eq!(mod_add(100u8, &155u8, &255u8), 0u8);
148 assert_eq!(mod_add(200u8, &100u8, &255u8), 45u8);
149 }
150
151 #[test]
152 fn test_mod_add_overflow_edge_case() {
153 assert_eq!(mod_add(255u8, &1u8, &255u8), 1u8);
154 }
155
156 #[test]
157 fn test_mod_add_with_operands_exceeding_modulus() {
158 assert_eq!(mod_add(200u8, &100u8, &50u8), 0u8);
159 assert_eq!(mod_add(75u8, &80u8, &60u8), 35u8);
160 }
161
162 #[test]
163 fn test_mod_add_with_modulus_exceeding_u8_max() {
164 let modulus = 300u16;
165 let result = mod_add(250u8, &10u8, &(modulus as u8));
166 assert_eq!(result, 40u8);
167 }
168 };
169}
170
171#[cfg(test)]
172mod strict_mod_add_tests {
173 use super::strict_mod_add;
174 generate_mod_add_tests!(strict_mod_add, u8, by_ref);
175}
176
177#[cfg(test)]
178mod constrained_mod_add_tests {
179 use super::constrained_mod_add;
180 generate_mod_add_tests!(constrained_mod_add, u8, by_ref);
181}
182
183#[cfg(test)]
184mod basic_mod_add_tests {
185 use super::basic_mod_add;
186 generate_mod_add_tests!(basic_mod_add, u8, by_val);
187}
188
189#[cfg(test)]
190macro_rules! add_test_module {
191 (
192 $stem:ident, $type_path:path, $(type $type_def:ty = $type_expr:ty;)? strict: $strict:expr,
196 constrained: $constrained:expr,
197 basic: $basic:expr,
198 ) => {
199 paste::paste! {
200 mod [<$stem _tests>] { #[allow(unused_imports)]
202 use $type_path;
203 $( type $type_def = $type_expr; )?
204
205 #[test]
206 #[allow(unused_variables)]
207 fn test_mod_add_basic() {
208 let a = U256::from(5u8);
209 let b = U256::from(10u8);
210 let m = U256::from(20u8);
211 let result = U256::from(15u8);
212
213 crate::maybe_test!($strict, assert_eq!(super::strict_mod_add(a, &b, &m), result));
214 let a = U256::from(5u8);
215 crate::maybe_test!($constrained, assert_eq!(super::constrained_mod_add(a, &b, &m), result));
216 let a = U256::from(5u8);
217 crate::maybe_test!($basic, assert_eq!(super::basic_mod_add(a, b, m), result));
218 }
219 }
220 }
221 };
222}
223
224#[cfg(test)]
225mod bnum_add_tests {
226 use super::basic_mod_add;
227 use super::constrained_mod_add;
228 use super::strict_mod_add;
229
230 add_test_module!(
231 bnum,
232 bnum::types::U256,
233 strict: off, constrained: on,
235 basic: on,
236 );
237
238 add_test_module!(
239 bnum_patched,
240 bnum_patched::types::U256,
241 strict: on,
242 constrained: on,
243 basic: on,
244 );
245
246 add_test_module!(
247 crypto_bigint,
248 crypto_bigint::U256,
249 strict: off, constrained: off, basic: on,
252 );
253
254 add_test_module!(
255 crypto_bigint_patched,
256 crypto_bigint_patched::U256,
257 strict: on,
258 constrained: on,
259 basic: on,
260 );
261
262 add_test_module!(
263 num_bigint,
264 num_bigint::BigUint,
265 type U256 = num_bigint::BigUint;
266 strict: off, constrained: off, basic: off, );
270
271 add_test_module!(
272 num_bigint_patched,
273 num_bigint_patched::BigUint,
274 type U256 = num_bigint_patched::BigUint;
275 strict: on,
276 constrained: on,
277 basic: off, );
279
280 add_test_module!(
281 ibig,
282 ibig::UBig,
283 type U256 = ibig::UBig;
284 strict: off, constrained: off, basic: off, );
288
289 add_test_module!(
290 ibig_patched,
291 ibig_patched::UBig,
292 type U256 = ibig_patched::UBig;
293 strict: on,
294 constrained: on,
295 basic: off, );
297
298 add_test_module!(
299 fixed_bigint,
300 fixed_bigint::FixedUInt,
301 type U256 = FixedUInt<u32, 4>;
302 strict: on,
303 constrained: on,
304 basic: on,
305 );
306}