1use num_bigint::{BigInt, BigUint, Sign};
8use std::fmt::{Display, Formatter};
9use std::ops::Shl;
10
11#[derive(Debug, Copy, Clone, Eq, PartialEq)]
13pub struct Mod {
14 bits: u32,
15 words: u32,
16 msb_mask: Word,
17}
18
19impl Mod {
20 #[inline]
21 pub const fn from_words(words: usize) -> Self {
22 Self::from_bits(words as u32 * Word::BITS)
23 }
24
25 #[inline]
26 pub fn from_factor(factor: &BigUint) -> Self {
27 let ones = factor.count_ones();
28 let is_power_of_two = ones == 1;
29 assert!(is_power_of_two);
30 let bits = factor.trailing_zeros().unwrap() as u32;
31 Self::from_bits(bits)
32 }
33
34 #[inline]
35 pub const fn from_bits(bits: u32) -> Self {
36 let words = bits.div_ceil(Word::BITS);
37 let msb_mask = if bits.is_multiple_of(Word::BITS) {
38 Word::MAX
39 } else {
40 ((1 as Word) << (bits % Word::BITS)) - 1
41 };
42 Self {
43 bits,
44 words,
45 msb_mask,
46 }
47 }
48
49 #[inline]
50 pub fn bits(&self) -> u32 {
51 self.bits
52 }
53
54 #[inline]
56 pub fn bytes(&self) -> u32 {
57 self.bits().div_ceil(8)
58 }
59
60 pub fn factor(&self) -> BigUint {
61 BigUint::from(1u32).shl(self.bits() as usize)
62 }
63
64 #[inline]
65 fn words(&self) -> usize {
66 self.words as usize
67 }
68
69 #[inline]
70 fn msb_mask(&self) -> Word {
71 self.msb_mask
72 }
73}
74
75pub trait Coef: Clone {
77 fn from_big(value: &BigInt, m: Mod) -> Self;
78 fn from_i64(v: i64, m: Mod) -> Self;
79 fn pow2(e: u32, m: Mod) -> Self;
80 fn zero() -> Self;
81 fn is_zero(&self) -> bool;
82 fn assign_zero(&mut self);
83 fn add_assign(&mut self, other: &Self, m: Mod);
84 fn mul_assign(&mut self, other: &Self, m: Mod);
85 const MAX_MOD: Mod;
86}
87
88type Word = u64;
89type DoubleWord = u128;
90
91#[derive(Debug, Clone, PartialEq)]
93pub struct ArrayCoef<const W: usize> {
94 words: [Word; W],
96}
97
98#[inline]
100fn adc(carry: u8, a: &mut Word, b: Word) -> u8 {
101 let sum = carry as DoubleWord + *a as DoubleWord + b as DoubleWord;
102 *a = sum as Word;
103 (sum >> Word::BITS) as u8
104}
105
106#[inline]
107fn mul<const W: usize>(a: &mut [Word; W], b: &[Word; W]) {
108 debug_assert_eq!(a.len(), b.len());
109 let mut acc = [0 as Word; W];
110 for (i, bi) in b.iter().enumerate() {
111 mac_word(&mut acc[i..], a, *bi);
112 }
113
114 a.copy_from_slice(&acc);
116}
117
118#[inline]
119fn mac_word(acc: &mut [Word], b: &[Word], word: Word) {
120 let mut carry = 0;
121 for (a, b) in acc.iter_mut().zip(b) {
122 *a = mac_with_carry(*a, *b, word, &mut carry);
123 }
124}
125
126#[inline]
127fn mac_with_carry(a: Word, b: Word, c: Word, acc: &mut DoubleWord) -> Word {
128 *acc += a as DoubleWord;
129 *acc += (b as DoubleWord) * (c as DoubleWord);
130 let lo = *acc as Word;
131 *acc >>= Word::BITS;
132 lo
133}
134
135fn words_to_u32(words: &[Word]) -> Vec<u32> {
136 debug_assert_eq!(u32::BITS * 2, Word::BITS);
137 let mut words32 = Vec::with_capacity(words.len() * 2);
138 let mask32 = u32::MAX as Word;
139 for w in words.iter() {
140 let word = *w;
141 let lsb = (word & mask32) as u32;
142 let msb = ((word >> 32) & mask32) as u32;
143 words32.push(lsb);
144 words32.push(msb);
145 }
146 words32
147}
148
149impl<const W: usize> ArrayCoef<W> {
150 const MAX_BYTES: u32 = W as u32 * Word::BITS / 8;
151 fn to_ubig(&self) -> BigUint {
152 BigUint::from_slice(&words_to_u32(&self.words))
153 }
154
155 #[cfg(test)]
156 fn from_words(words_in: &[Word], m: Mod) -> Self {
157 debug_assert_eq!(words_in.len(), W);
158 let mut words = [0 as Word; W];
159 words.as_mut_slice().copy_from_slice(words_in);
160 let mut r = Self { words };
161 r.do_mask(m);
162 r
163 }
164
165 #[inline]
166 fn do_mask(&mut self, m: Mod) {
167 for w in self.words.iter_mut().skip(m.words()) {
169 *w = 0;
170 }
171 self.words[m.words() - 1] &= m.msb_mask();
172 }
173
174 fn negate(&mut self, m: Mod) {
175 for ii in 0..W {
177 self.words[ii] = !self.words[ii];
178 }
179 let mut carry = adc(0, &mut self.words[0], 1);
181 for ii in 1..W {
182 carry = adc(carry, &mut self.words[ii], 0);
183 }
184 self.do_mask(m);
185 }
186
187 fn from_ubig(value: &BigUint, m: Mod) -> Self {
188 let digits: Vec<Word> = value.iter_u64_digits().collect();
190 let mut words = [0; W];
191 words[0..digits.len()].copy_from_slice(&digits);
192 let mut r = Self { words };
193 r.do_mask(m);
194 r
195 }
196
197 #[inline]
198 fn from_u64(v: u64, m: Mod) -> Self {
199 debug_assert!(m.bytes() <= Self::MAX_BYTES);
200 debug_assert!(Self::MAX_BYTES * 8 >= u64::BITS);
201 let mut r = Self::zero();
202 r.words[0] = v as Word;
203 r.do_mask(m);
204 r
205 }
206}
207
208impl<const W: usize> Coef for ArrayCoef<W> {
209 fn from_big(value: &BigInt, m: Mod) -> Self {
210 let is_negative = value.sign() == Sign::Minus;
211 let mut r = Self::from_ubig(value.magnitude(), m);
212 if is_negative {
213 r.negate(m);
214 }
215 r
216 }
217
218 #[inline]
219 fn from_i64(v: i64, m: Mod) -> Self {
220 Self::from_u64(v as u64, m)
221 }
222
223 fn pow2(e: u32, m: Mod) -> Self {
224 let mut r = Self::zero();
225 if m.bits() > e {
226 let word_ii = e / Word::BITS;
227 r.words[word_ii as usize] = (1 as Word) << (e % Word::BITS);
228 }
229 r
230 }
231
232 #[inline]
233 fn zero() -> Self {
234 Self { words: [0; W] }
235 }
236
237 fn is_zero(&self) -> bool {
238 self.words.iter().all(|w| *w == 0)
239 }
240
241 fn assign_zero(&mut self) {
242 for w in self.words.iter_mut() {
243 *w = 0;
244 }
245 }
246
247 fn add_assign(&mut self, other: &Self, m: Mod) {
248 debug_assert!(m.bytes() <= Self::MAX_BYTES);
249 let mut carry = 0;
250 for ii in 0..W {
251 carry = adc(carry, &mut self.words[ii], other.words[ii]);
252 }
253 self.do_mask(m);
254 }
255
256 fn mul_assign(&mut self, other: &Self, m: Mod) {
257 debug_assert!(m.bytes() <= Self::MAX_BYTES);
258 mul(&mut self.words, &other.words);
259 self.do_mask(m);
260 }
261
262 const MAX_MOD: Mod = Mod::from_words(W);
263}
264
265impl<const W: usize> Display for ArrayCoef<W> {
266 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
267 write!(f, "{}", self.to_ubig())
268 }
269}
270
271impl Coef for Word {
272 fn from_big(value: &BigInt, m: Mod) -> Self {
273 Self::from_i64(value.try_into().unwrap(), m)
274 }
275
276 fn from_i64(v: i64, m: Mod) -> Self {
277 v as Word & m.msb_mask
278 }
279
280 fn pow2(e: u32, m: Mod) -> Self {
281 debug_assert!(m.words == 1);
282 if e < m.bits() { (1 as Word) << e } else { 0 }
283 }
284
285 fn zero() -> Self {
286 0
287 }
288
289 fn is_zero(&self) -> bool {
290 *self == 0
291 }
292
293 fn assign_zero(&mut self) {
294 *self = 0;
295 }
296
297 fn add_assign(&mut self, other: &Self, m: Mod) {
298 *self = self.overflowing_add(*other).0 & m.msb_mask;
299 }
300
301 fn mul_assign(&mut self, other: &Self, m: Mod) {
302 *self = self.overflowing_mul(*other).0 & m.msb_mask;
303 }
304
305 const MAX_MOD: Mod = Mod::from_words(1);
306}
307
308#[inline]
309fn mask_double_word(v: DoubleWord, m: Mod) -> DoubleWord {
310 match m.words {
311 0 => 0,
312 1 => v & m.msb_mask as DoubleWord,
313 2 => v & (((m.msb_mask as DoubleWord) << Word::BITS) | Word::MAX as DoubleWord),
314 _ => unreachable!("u128 can only be used to represent up to two words!"),
315 }
316}
317
318type SignedDoubleWord = i128;
319
320impl Coef for DoubleWord {
321 fn from_big(value: &BigInt, m: Mod) -> Self {
322 let r: SignedDoubleWord = value.try_into().unwrap();
323 mask_double_word(r as DoubleWord, m)
324 }
325
326 fn from_i64(v: i64, m: Mod) -> Self {
327 mask_double_word(v as DoubleWord, m)
328 }
329
330 fn pow2(e: u32, m: Mod) -> Self {
331 debug_assert!(m.words <= 2);
332 if e < m.bits() {
333 (1 as DoubleWord) << e
334 } else {
335 0
336 }
337 }
338
339 fn zero() -> Self {
340 0
341 }
342
343 fn is_zero(&self) -> bool {
344 *self == 0
345 }
346
347 fn assign_zero(&mut self) {
348 *self = 0;
349 }
350
351 fn add_assign(&mut self, other: &Self, m: Mod) {
352 *self = mask_double_word(self.overflowing_add(*other).0, m);
353 }
354
355 fn mul_assign(&mut self, other: &Self, m: Mod) {
356 *self = mask_double_word(self.overflowing_mul(*other).0, m);
357 }
358
359 const MAX_MOD: Mod = Mod::from_words(2);
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use num_traits::Num;
366
367 fn do_test_mod(factor: &str, bits: u32, bytes: u32) {
368 let factor = BigUint::from_str_radix(factor, 10).unwrap();
369 let m = Mod::from_factor(&factor);
370 assert_eq!(m.bytes(), bytes);
371 assert_eq!(m.bits(), bits);
372 assert_eq!(m.factor(), factor);
373 }
374
375 #[test]
376 fn test_mod() {
377 do_test_mod("2", 1, 1);
378 do_test_mod("4294967296", 32, 4);
379 do_test_mod("18446744073709551616", 64, 8);
380 do_test_mod("340282366920938463463374607431768211456", 128, 16);
381 do_test_mod(
382 "115792089237316195423570985008687907853269984665640564039457584007913129639936",
383 256,
384 32,
385 );
386 do_test_mod(
387 "13407807929942597099574024998205846127479365820592393377723561443721764030073546976801874298166903427690031858186486050853753882811946569946433649006084096",
388 512,
389 64,
390 );
391 }
392
393 #[test]
394 fn test_sizes() {
395 assert_eq!(
396 std::mem::size_of::<ArrayCoef::<1>>(),
397 std::mem::size_of::<Word>()
398 )
399 }
400
401 #[test]
402 fn test_simple_coef_mod_64_bits_1_word() {
403 let m = Mod::from_bits(64);
404 let mut a = ArrayCoef::<1>::from_u64(2, m);
405 let b = ArrayCoef::<1>::from_u64(1u64 << 63, m);
406 a.mul_assign(&b, m);
407 assert!(a.is_zero(), "{a:?}");
408 }
409
410 #[test]
411 fn test_simple_coef_mod_64_u64() {
412 let m = Mod::from_bits(64);
413 let mut a = u64::from_i64(2, m);
414 let b = u64::from_i64((1u64 << 63) as i64, m);
415 a.mul_assign(&b, m);
416 assert!(a.is_zero(), "{a:?}");
417 }
418
419 #[test]
420 fn test_simple_coef_mod_64_bits_2_word() {
421 let m = Mod::from_bits(64);
422 let mut a = ArrayCoef::<2>::from_u64(2, m);
423 let b = ArrayCoef::<2>::from_u64(1u64 << 63, m);
424 a.mul_assign(&b, m);
425 assert!(a.is_zero());
426 }
427
428 #[test]
429 fn test_simple_coef_mod_128_bits_2_word() {
430 let m = Mod::from_bits(128);
431 let mut a = ArrayCoef::<2>::from_big(&BigInt::from_str_radix("-1", 10).unwrap(), m);
432 let old_a = a.clone();
433 let one = ArrayCoef::<2>::from_u64(1, m);
434 a.add_assign(&one, m);
435 assert!(a.is_zero(), "{old_a} + {one} = {a}");
436 }
437
438 #[test]
439 fn test_simple_coef_mod_128_bits_u128() {
440 let m = Mod::from_bits(128);
441 let mut a = u128::from_big(&BigInt::from_str_radix("-1", 10).unwrap(), m);
442 let old_a = a;
443 let one = u128::from_i64(1, m);
444 a.add_assign(&one, m);
445 assert!(a.is_zero(), "{old_a} + {one} = {a}");
446 }
447
448 #[test]
449 fn test_mul_256() {
450 let m = Mod::from_bits(256);
451 let a = ArrayCoef::<4>::from_words(&[0, 0xff << (Word::BITS - 8), 0, 0], m);
452 let b = ArrayCoef::<4>::from_u64(2, m);
453 let expect = ArrayCoef::<4>::from_words(&[0, 0xfe << (Word::BITS - 8), 1, 0], m);
454 let mut res = a.clone();
455 res.mul_assign(&b, m);
456 assert_eq!(res, expect);
457
458 let mut res2 = b.clone();
459 res2.mul_assign(&a, m);
460 assert_eq!(res2, expect);
461 }
462}