1use crate::{Limb, NonZero, SubMod, Uint};
4
5impl<const LIMBS: usize> Uint<LIMBS> {
6 #[must_use]
10 pub const fn sub_mod(&self, rhs: &Self, p: &NonZero<Self>) -> Self {
11 let (out, mask) = self.borrowing_sub(rhs, Limb::ZERO);
12
13 out.wrapping_add(&p.as_ref().bitand_limb(mask))
16 }
17
18 #[inline(always)]
21 pub(crate) const fn try_sub_with_carry(&self, mut carry: Limb, rhs: &Self) -> (Self, Limb) {
22 let (out, borrow) = self.borrowing_sub(rhs, Limb::ZERO);
23 let revert = borrow.lsb_to_choice().and(carry.is_zero());
24 (_, carry) = carry.borrowing_sub(Limb::ZERO, Limb::select(borrow, Limb::ZERO, revert));
25 (Uint::select(&out, self, revert), carry)
26 }
27
28 #[must_use]
33 pub const fn sub_mod_special(&self, rhs: &Self, c: Limb) -> Self {
34 let (out, borrow) = self.borrowing_sub(rhs, Limb::ZERO);
35
36 let l = borrow.0 & c.0;
40 out.wrapping_sub(&Self::from_word(l))
41 }
42}
43
44impl<const LIMBS: usize> SubMod for Uint<LIMBS> {
45 type Output = Self;
46
47 fn sub_mod(&self, rhs: &Self, p: &NonZero<Self>) -> Self {
48 debug_assert!(self < p.as_ref());
49 debug_assert!(rhs < p.as_ref());
50 self.sub_mod(rhs, p)
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use crate::U256;
57
58 #[cfg(feature = "rand_core")]
59 use crate::{Limb, NonZero, Random, RandomMod, Uint};
60 #[cfg(feature = "rand_core")]
61 use rand_core::SeedableRng;
62
63 #[test]
64 fn sub_mod_nist_p256() {
65 let a =
66 U256::from_be_hex("1a2472fde50286541d97ca6a3592dd75beb9c9646e40c511b82496cfc3926956");
67 let b =
68 U256::from_be_hex("d5777c45019673125ad240f83094d4252d829516fac8601ed01979ec1ec1a251");
69 let n =
70 U256::from_be_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551")
71 .to_nz()
72 .unwrap();
73
74 let actual = a.sub_mod(&b, &n);
75 let expected =
76 U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
77
78 assert_eq!(expected, actual);
79 }
80
81 #[test]
82 #[cfg(feature = "rand_core")]
83 fn sub_mod() {
84 fn test_size<const LIMBS: usize>() {
85 let mut rng = chacha20::ChaCha8Rng::seed_from_u64(1);
86 let moduli = [
87 NonZero::<Uint<LIMBS>>::random_from_rng(&mut rng),
88 NonZero::<Uint<LIMBS>>::random_from_rng(&mut rng),
89 ];
90
91 for p in &moduli {
92 let base_cases = [
93 (1u64, 0u64, 1u64.into()),
94 (0, 1, p.wrapping_sub(&1u64.into())),
95 (0, 0, 0u64.into()),
96 ];
97 for (a, b, c) in &base_cases {
98 let a: Uint<LIMBS> = (*a).into();
99 let b: Uint<LIMBS> = (*b).into();
100
101 let x = a.sub_mod(&b, p);
102 assert_eq!(*c, x, "{} - {} mod {} = {} != {}", a, b, p, x, c);
103 }
104
105 if LIMBS > 1 {
106 for _i in 0..100 {
107 let a: Uint<LIMBS> = Limb::random_from_rng(&mut rng).into();
108 let b: Uint<LIMBS> = Limb::random_from_rng(&mut rng).into();
109 let (a, b) = if a < b { (b, a) } else { (a, b) };
110
111 let c = a.sub_mod(&b, p);
112 assert!(c < **p, "not reduced");
113 assert_eq!(c, a.wrapping_sub(&b), "result incorrect");
114 }
115 }
116
117 for _i in 0..100 {
118 let a = Uint::<LIMBS>::random_mod_vartime(&mut rng, p);
119 let b = Uint::<LIMBS>::random_mod_vartime(&mut rng, p);
120
121 let c = a.sub_mod(&b, p);
122 assert!(c < **p, "not reduced: {} >= {} ", c, p);
123
124 let x = a.wrapping_sub(&b);
125 if a >= b && x < **p {
126 assert_eq!(c, x, "incorrect result");
127 }
128 }
129 }
130 }
131
132 cpubits::cpubits! {
134 64 => { test_size::<1>(); }
135 }
136
137 test_size::<2>();
138 test_size::<3>();
139 if cfg!(not(miri)) {
140 test_size::<4>();
141 test_size::<8>();
142 test_size::<16>();
143 }
144 }
145
146 #[cfg(feature = "rand_core")]
147 #[test]
148 fn sub_mod_special() {
149 fn test_size<const LIMBS: usize>() {
150 let mut rng = chacha20::ChaCha8Rng::seed_from_u64(1);
151 let moduli = [
152 NonZero::<Limb>::random_from_rng(&mut rng),
153 NonZero::<Limb>::random_from_rng(&mut rng),
154 ];
155
156 for special in &moduli {
157 let p = &NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from(special.get()))).unwrap();
158
159 let minus_one = p.wrapping_sub(&Uint::ONE);
160
161 let base_cases = [
162 (Uint::ZERO, Uint::ZERO, Uint::ZERO),
163 (Uint::ONE, Uint::ZERO, Uint::ONE),
164 (Uint::ZERO, Uint::ONE, minus_one),
165 (minus_one, minus_one, Uint::ZERO),
166 (Uint::ZERO, minus_one, Uint::ONE),
167 ];
168 for (a, b, c) in &base_cases {
169 let x = a.sub_mod_special(b, *special.as_ref());
170 assert_eq!(*c, x, "{} - {} mod {} = {} != {}", a, b, p, x, c);
171 }
172
173 for _i in 0..100 {
174 let a = Uint::<LIMBS>::random_mod_vartime(&mut rng, p);
175 let b = Uint::<LIMBS>::random_mod_vartime(&mut rng, p);
176
177 let c = a.sub_mod_special(&b, *special.as_ref());
178 assert!(c < **p, "not reduced: {} >= {} ", c, p);
179
180 let expected = a.sub_mod(&b, p);
181 assert_eq!(c, expected, "incorrect result");
182 }
183 }
184 }
185
186 test_size::<1>();
187 test_size::<2>();
188 test_size::<3>();
189 if cfg!(not(miri)) {
190 test_size::<4>();
191 test_size::<8>();
192 test_size::<16>();
193 }
194 }
195}