Skip to main content

polysub/
coef.rs

1// Copyright 2025-2026 Cornell University
2// released under MIT license
3// author: Kevin Laeufer <laeufer@cornell.edu>
4
5//! Coefficient Library
6
7use num_bigint::{BigInt, BigUint, Sign};
8use std::fmt::{Display, Formatter};
9use std::ops::Shl;
10
11/// The mod factor that we are using.
12#[derive(Debug, Copy, Clone, Eq, PartialEq)]
13pub struct Mod {
14    bits: u32,
15    words: u32,
16    msb_mask: Word,
17}
18
19impl Mod {
20    #[inline]
21    pub const fn from_words(words: usize) -> Self {
22        Self::from_bits(words as u32 * Word::BITS)
23    }
24
25    #[inline]
26    pub fn from_factor(factor: &BigUint) -> Self {
27        let ones = factor.count_ones();
28        let is_power_of_two = ones == 1;
29        assert!(is_power_of_two);
30        let bits = factor.trailing_zeros().unwrap() as u32;
31        Self::from_bits(bits)
32    }
33
34    #[inline]
35    pub const fn from_bits(bits: u32) -> Self {
36        let words = bits.div_ceil(Word::BITS);
37        let msb_mask = if bits.is_multiple_of(Word::BITS) {
38            Word::MAX
39        } else {
40            ((1 as Word) << (bits % Word::BITS)) - 1
41        };
42        Self {
43            bits,
44            words,
45            msb_mask,
46        }
47    }
48
49    #[inline]
50    pub fn bits(&self) -> u32 {
51        self.bits
52    }
53
54    /// number of bytes required to represent all bits
55    #[inline]
56    pub fn bytes(&self) -> u32 {
57        self.bits().div_ceil(8)
58    }
59
60    pub fn factor(&self) -> BigUint {
61        BigUint::from(1u32).shl(self.bits() as usize)
62    }
63
64    #[inline]
65    fn words(&self) -> usize {
66        self.words as usize
67    }
68
69    #[inline]
70    fn msb_mask(&self) -> Word {
71        self.msb_mask
72    }
73}
74
75/// A coefficient representation that can be used to perform variable substitution.
76pub trait Coef: Clone {
77    fn from_big(value: &BigInt, m: Mod) -> Self;
78    fn from_i64(v: i64, m: Mod) -> Self;
79    fn pow2(e: u32, m: Mod) -> Self;
80    fn zero() -> Self;
81    fn is_zero(&self) -> bool;
82    fn assign_zero(&mut self);
83    fn add_assign(&mut self, other: &Self, m: Mod);
84    fn mul_assign(&mut self, other: &Self, m: Mod);
85    const MAX_MOD: Mod;
86}
87
88type Word = u64;
89type DoubleWord = u128;
90
91/// Custom implementation of a fixed size coefficient using a static sized array.
92#[derive(Debug, Clone, PartialEq)]
93pub struct ArrayCoef<const W: usize> {
94    /// Contain the actual number in big endian
95    words: [Word; W],
96}
97
98// src: https://github.com/cucapra/baa/blob/1840f13d9a6bd9fa60de5b7a6326e96059d56b70/src/bv/arithmetic.rs#L204C1-L210C2
99#[inline]
100fn adc(carry: u8, a: &mut Word, b: Word) -> u8 {
101    let sum = carry as DoubleWord + *a as DoubleWord + b as DoubleWord;
102    *a = sum as Word;
103    (sum >> Word::BITS) as u8
104}
105
106#[inline]
107fn mul<const W: usize>(a: &mut [Word; W], b: &[Word; W]) {
108    debug_assert_eq!(a.len(), b.len());
109    let mut acc = [0 as Word; W];
110    for (i, bi) in b.iter().enumerate() {
111        mac_word(&mut acc[i..], a, *bi);
112    }
113
114    // acc contains the result
115    a.copy_from_slice(&acc);
116}
117
118#[inline]
119fn mac_word(acc: &mut [Word], b: &[Word], word: Word) {
120    let mut carry = 0;
121    for (a, b) in acc.iter_mut().zip(b) {
122        *a = mac_with_carry(*a, *b, word, &mut carry);
123    }
124}
125
126#[inline]
127fn mac_with_carry(a: Word, b: Word, c: Word, acc: &mut DoubleWord) -> Word {
128    *acc += a as DoubleWord;
129    *acc += (b as DoubleWord) * (c as DoubleWord);
130    let lo = *acc as Word;
131    *acc >>= Word::BITS;
132    lo
133}
134
135fn words_to_u32(words: &[Word]) -> Vec<u32> {
136    debug_assert_eq!(u32::BITS * 2, Word::BITS);
137    let mut words32 = Vec::with_capacity(words.len() * 2);
138    let mask32 = u32::MAX as Word;
139    for w in words.iter() {
140        let word = *w;
141        let lsb = (word & mask32) as u32;
142        let msb = ((word >> 32) & mask32) as u32;
143        words32.push(lsb);
144        words32.push(msb);
145    }
146    words32
147}
148
149impl<const W: usize> ArrayCoef<W> {
150    const MAX_BYTES: u32 = W as u32 * Word::BITS / 8;
151    fn to_ubig(&self) -> BigUint {
152        BigUint::from_slice(&words_to_u32(&self.words))
153    }
154
155    #[cfg(test)]
156    fn from_words(words_in: &[Word], m: Mod) -> Self {
157        debug_assert_eq!(words_in.len(), W);
158        let mut words = [0 as Word; W];
159        words.as_mut_slice().copy_from_slice(words_in);
160        let mut r = Self { words };
161        r.do_mask(m);
162        r
163    }
164
165    #[inline]
166    fn do_mask(&mut self, m: Mod) {
167        // mask out upper words
168        for w in self.words.iter_mut().skip(m.words()) {
169            *w = 0;
170        }
171        self.words[m.words() - 1] &= m.msb_mask();
172    }
173
174    fn negate(&mut self, m: Mod) {
175        // invert all
176        for ii in 0..W {
177            self.words[ii] = !self.words[ii];
178        }
179        // add one
180        let mut carry = adc(0, &mut self.words[0], 1);
181        for ii in 1..W {
182            carry = adc(carry, &mut self.words[ii], 0);
183        }
184        self.do_mask(m);
185    }
186
187    fn from_ubig(value: &BigUint, m: Mod) -> Self {
188        // iter_u64 returns lsb first which matches our convention
189        let digits: Vec<Word> = value.iter_u64_digits().collect();
190        let mut words = [0; W];
191        words[0..digits.len()].copy_from_slice(&digits);
192        let mut r = Self { words };
193        r.do_mask(m);
194        r
195    }
196
197    #[inline]
198    fn from_u64(v: u64, m: Mod) -> Self {
199        debug_assert!(m.bytes() <= Self::MAX_BYTES);
200        debug_assert!(Self::MAX_BYTES * 8 >= u64::BITS);
201        let mut r = Self::zero();
202        r.words[0] = v as Word;
203        r.do_mask(m);
204        r
205    }
206}
207
208impl<const W: usize> Coef for ArrayCoef<W> {
209    fn from_big(value: &BigInt, m: Mod) -> Self {
210        let is_negative = value.sign() == Sign::Minus;
211        let mut r = Self::from_ubig(value.magnitude(), m);
212        if is_negative {
213            r.negate(m);
214        }
215        r
216    }
217
218    #[inline]
219    fn from_i64(v: i64, m: Mod) -> Self {
220        Self::from_u64(v as u64, m)
221    }
222
223    fn pow2(e: u32, m: Mod) -> Self {
224        let mut r = Self::zero();
225        if m.bits() > e {
226            let word_ii = e / Word::BITS;
227            r.words[word_ii as usize] = (1 as Word) << (e % Word::BITS);
228        }
229        r
230    }
231
232    #[inline]
233    fn zero() -> Self {
234        Self { words: [0; W] }
235    }
236
237    fn is_zero(&self) -> bool {
238        self.words.iter().all(|w| *w == 0)
239    }
240
241    fn assign_zero(&mut self) {
242        for w in self.words.iter_mut() {
243            *w = 0;
244        }
245    }
246
247    fn add_assign(&mut self, other: &Self, m: Mod) {
248        debug_assert!(m.bytes() <= Self::MAX_BYTES);
249        let mut carry = 0;
250        for ii in 0..W {
251            carry = adc(carry, &mut self.words[ii], other.words[ii]);
252        }
253        self.do_mask(m);
254    }
255
256    fn mul_assign(&mut self, other: &Self, m: Mod) {
257        debug_assert!(m.bytes() <= Self::MAX_BYTES);
258        mul(&mut self.words, &other.words);
259        self.do_mask(m);
260    }
261
262    const MAX_MOD: Mod = Mod::from_words(W);
263}
264
265impl<const W: usize> Display for ArrayCoef<W> {
266    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
267        write!(f, "{}", self.to_ubig())
268    }
269}
270
271impl Coef for Word {
272    fn from_big(value: &BigInt, m: Mod) -> Self {
273        Self::from_i64(value.try_into().unwrap(), m)
274    }
275
276    fn from_i64(v: i64, m: Mod) -> Self {
277        v as Word & m.msb_mask
278    }
279
280    fn pow2(e: u32, m: Mod) -> Self {
281        debug_assert!(m.words == 1);
282        if e < m.bits() { (1 as Word) << e } else { 0 }
283    }
284
285    fn zero() -> Self {
286        0
287    }
288
289    fn is_zero(&self) -> bool {
290        *self == 0
291    }
292
293    fn assign_zero(&mut self) {
294        *self = 0;
295    }
296
297    fn add_assign(&mut self, other: &Self, m: Mod) {
298        *self = self.overflowing_add(*other).0 & m.msb_mask;
299    }
300
301    fn mul_assign(&mut self, other: &Self, m: Mod) {
302        *self = self.overflowing_mul(*other).0 & m.msb_mask;
303    }
304
305    const MAX_MOD: Mod = Mod::from_words(1);
306}
307
308#[inline]
309fn mask_double_word(v: DoubleWord, m: Mod) -> DoubleWord {
310    match m.words {
311        0 => 0,
312        1 => v & m.msb_mask as DoubleWord,
313        2 => v & (((m.msb_mask as DoubleWord) << Word::BITS) | Word::MAX as DoubleWord),
314        _ => unreachable!("u128 can only be used to represent up to two words!"),
315    }
316}
317
318type SignedDoubleWord = i128;
319
320impl Coef for DoubleWord {
321    fn from_big(value: &BigInt, m: Mod) -> Self {
322        let r: SignedDoubleWord = value.try_into().unwrap();
323        mask_double_word(r as DoubleWord, m)
324    }
325
326    fn from_i64(v: i64, m: Mod) -> Self {
327        mask_double_word(v as DoubleWord, m)
328    }
329
330    fn pow2(e: u32, m: Mod) -> Self {
331        debug_assert!(m.words <= 2);
332        if e < m.bits() {
333            (1 as DoubleWord) << e
334        } else {
335            0
336        }
337    }
338
339    fn zero() -> Self {
340        0
341    }
342
343    fn is_zero(&self) -> bool {
344        *self == 0
345    }
346
347    fn assign_zero(&mut self) {
348        *self = 0;
349    }
350
351    fn add_assign(&mut self, other: &Self, m: Mod) {
352        *self = mask_double_word(self.overflowing_add(*other).0, m);
353    }
354
355    fn mul_assign(&mut self, other: &Self, m: Mod) {
356        *self = mask_double_word(self.overflowing_mul(*other).0, m);
357    }
358
359    const MAX_MOD: Mod = Mod::from_words(2);
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use num_traits::Num;
366
367    fn do_test_mod(factor: &str, bits: u32, bytes: u32) {
368        let factor = BigUint::from_str_radix(factor, 10).unwrap();
369        let m = Mod::from_factor(&factor);
370        assert_eq!(m.bytes(), bytes);
371        assert_eq!(m.bits(), bits);
372        assert_eq!(m.factor(), factor);
373    }
374
375    #[test]
376    fn test_mod() {
377        do_test_mod("2", 1, 1);
378        do_test_mod("4294967296", 32, 4);
379        do_test_mod("18446744073709551616", 64, 8);
380        do_test_mod("340282366920938463463374607431768211456", 128, 16);
381        do_test_mod(
382            "115792089237316195423570985008687907853269984665640564039457584007913129639936",
383            256,
384            32,
385        );
386        do_test_mod(
387            "13407807929942597099574024998205846127479365820592393377723561443721764030073546976801874298166903427690031858186486050853753882811946569946433649006084096",
388            512,
389            64,
390        );
391    }
392
393    #[test]
394    fn test_sizes() {
395        assert_eq!(
396            std::mem::size_of::<ArrayCoef::<1>>(),
397            std::mem::size_of::<Word>()
398        )
399    }
400
401    #[test]
402    fn test_simple_coef_mod_64_bits_1_word() {
403        let m = Mod::from_bits(64);
404        let mut a = ArrayCoef::<1>::from_u64(2, m);
405        let b = ArrayCoef::<1>::from_u64(1u64 << 63, m);
406        a.mul_assign(&b, m);
407        assert!(a.is_zero(), "{a:?}");
408    }
409
410    #[test]
411    fn test_simple_coef_mod_64_u64() {
412        let m = Mod::from_bits(64);
413        let mut a = u64::from_i64(2, m);
414        let b = u64::from_i64((1u64 << 63) as i64, m);
415        a.mul_assign(&b, m);
416        assert!(a.is_zero(), "{a:?}");
417    }
418
419    #[test]
420    fn test_simple_coef_mod_64_bits_2_word() {
421        let m = Mod::from_bits(64);
422        let mut a = ArrayCoef::<2>::from_u64(2, m);
423        let b = ArrayCoef::<2>::from_u64(1u64 << 63, m);
424        a.mul_assign(&b, m);
425        assert!(a.is_zero());
426    }
427
428    #[test]
429    fn test_simple_coef_mod_128_bits_2_word() {
430        let m = Mod::from_bits(128);
431        let mut a = ArrayCoef::<2>::from_big(&BigInt::from_str_radix("-1", 10).unwrap(), m);
432        let old_a = a.clone();
433        let one = ArrayCoef::<2>::from_u64(1, m);
434        a.add_assign(&one, m);
435        assert!(a.is_zero(), "{old_a} + {one} = {a}");
436    }
437
438    #[test]
439    fn test_simple_coef_mod_128_bits_u128() {
440        let m = Mod::from_bits(128);
441        let mut a = u128::from_big(&BigInt::from_str_radix("-1", 10).unwrap(), m);
442        let old_a = a;
443        let one = u128::from_i64(1, m);
444        a.add_assign(&one, m);
445        assert!(a.is_zero(), "{old_a} + {one} = {a}");
446    }
447
448    #[test]
449    fn test_mul_256() {
450        let m = Mod::from_bits(256);
451        let a = ArrayCoef::<4>::from_words(&[0, 0xff << (Word::BITS - 8), 0, 0], m);
452        let b = ArrayCoef::<4>::from_u64(2, m);
453        let expect = ArrayCoef::<4>::from_words(&[0, 0xfe << (Word::BITS - 8), 1, 0], m);
454        let mut res = a.clone();
455        res.mul_assign(&b, m);
456        assert_eq!(res, expect);
457
458        let mut res2 = b.clone();
459        res2.mul_assign(&a, m);
460        assert_eq!(res2, expect);
461    }
462}