phantom_zone/backend/
modulus_u64.rs

1use itertools::izip;
2use num_traits::WrappingMul;
3
4use super::{
5    ArithmeticLazyOps, ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps,
6};
7use crate::RowMut;
8
9pub struct ModularOpsU64<T> {
10    q: u64,
11    q_twice: u64,
12    logq: usize,
13    barrett_mu: u128,
14    barrett_alpha: usize,
15    modulus: T,
16}
17
18impl<T> ModInit for ModularOpsU64<T>
19where
20    T: Modulus<Element = u64>,
21{
22    type M = T;
23    fn new(modulus: Self::M) -> ModularOpsU64<T> {
24        assert!(!modulus.is_native());
25
26        // largest unsigned value modulus fits is modulus-1
27        let q = modulus.largest_unsigned_value() + 1;
28        let logq = 64 - (q + 1u64).leading_zeros();
29
30        // barrett calculation
31        let mu = (1u128 << (logq * 2 + 3)) / (q as u128);
32        let alpha = logq + 3;
33
34        ModularOpsU64 {
35            q,
36            q_twice: q << 1,
37            logq: logq as usize,
38            barrett_alpha: alpha as usize,
39            barrett_mu: mu,
40            modulus,
41        }
42    }
43}
44
45impl<T> ModularOpsU64<T> {
46    fn add_mod_fast(&self, a: u64, b: u64) -> u64 {
47        debug_assert!(a < self.q);
48        debug_assert!(b < self.q);
49
50        let mut o = a + b;
51        if o >= self.q {
52            o -= self.q;
53        }
54        o
55    }
56
57    fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
58        debug_assert!(a < self.q_twice);
59        debug_assert!(b < self.q_twice);
60
61        let mut o = a + b;
62        if o >= self.q_twice {
63            o -= self.q_twice;
64        }
65        o
66    }
67
68    fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
69        debug_assert!(a < self.q);
70        debug_assert!(b < self.q);
71
72        if a >= b {
73            a - b
74        } else {
75            (self.q + a) - b
76        }
77    }
78
79    // returns (a * b)  % q
80    ///
81    /// - both a and b must be in range [0, 2q)
82    /// - output is in range [0 , 2q)
83    fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
84        debug_assert!(a < 2 * self.q);
85        debug_assert!(b < 2 * self.q);
86
87        let ab = a as u128 * b as u128;
88
89        // ab / (2^{n + \beta})
90        // note: \beta is assumed to -2
91        let tmp = ab >> (self.logq - 2);
92
93        // k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
94        let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
95
96        // ab - k*p
97        let tmp = k * (self.q as u128);
98
99        (ab - tmp) as u64
100    }
101
102    /// returns (a * b)  % q
103    ///
104    /// - both a and b must be in range [0, 2q)
105    /// - output is in range [0 , q)
106    fn mul_mod_fast(&self, a: u64, b: u64) -> u64 {
107        debug_assert!(a < 2 * self.q);
108        debug_assert!(b < 2 * self.q);
109
110        let ab = a as u128 * b as u128;
111
112        // ab / (2^{n + \beta})
113        // note: \beta is assumed to -2
114        let tmp = ab >> (self.logq - 2);
115
116        // k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)}
117        let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
118
119        // ab - k*p
120        let tmp = k * (self.q as u128);
121
122        let mut out = (ab - tmp) as u64;
123
124        if out >= self.q {
125            out -= self.q;
126        }
127
128        return out;
129    }
130}
131
132impl<T> ArithmeticOps for ModularOpsU64<T> {
133    type Element = u64;
134
135    fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
136        self.add_mod_fast(*a, *b)
137    }
138
139    fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
140        self.mul_mod_fast(*a, *b)
141    }
142
143    fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
144        self.sub_mod_fast(*a, *b)
145    }
146
147    fn neg(&self, a: &Self::Element) -> Self::Element {
148        self.q - *a
149    }
150
151    // fn modulus(&self) -> Self::Element {
152    //     self.q
153    // }
154}
155
156impl<T> ArithmeticLazyOps for ModularOpsU64<T> {
157    type Element = u64;
158    fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
159        self.add_mod_fast_lazy(*a, *b)
160    }
161    fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
162        self.mul_mod_fast_lazy(*a, *b)
163    }
164}
165
166impl<T> VectorOps for ModularOpsU64<T> {
167    type Element = u64;
168
169    fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
170        izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
171            *ai = self.add_mod_fast(*ai, *bi);
172        });
173    }
174
175    fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
176        izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
177            *ai = self.sub_mod_fast(*ai, *bi);
178        });
179    }
180
181    fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
182        izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
183            *ai = self.mul_mod_fast(*ai, *bi);
184        });
185    }
186
187    fn elwise_neg_mut(&self, a: &mut [Self::Element]) {
188        a.iter_mut().for_each(|ai| *ai = self.q - *ai);
189    }
190
191    fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) {
192        izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| {
193            *oi = self.mul_mod_fast(*ai, *b);
194        });
195    }
196
197    fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) {
198        izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| {
199            *oi = self.mul_mod_fast(*ai, *bi);
200        });
201    }
202
203    fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) {
204        a.iter_mut().for_each(|ai| {
205            *ai = self.mul_mod_fast(*ai, *b);
206        });
207    }
208
209    fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) {
210        izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| {
211            *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *ci));
212        });
213    }
214
215    fn elwise_fma_scalar_mut(
216        &self,
217        a: &mut [Self::Element],
218        b: &[Self::Element],
219        c: &Self::Element,
220    ) {
221        izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
222            *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *c));
223        });
224    }
225
226    // fn modulus(&self) -> Self::Element {
227    //     self.q
228    // }
229}
230
231impl<R: RowMut<Element = u64>, T> ShoupMatrixFMA<R> for ModularOpsU64<T> {
232    fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]) {
233        assert!(a.len() == a_shoup.len());
234        assert!(
235            a.len() == b.len(),
236            "Unequal length {}!={}",
237            a.len(),
238            b.len()
239        );
240
241        let q = self.q;
242        let q_twice = self.q << 1;
243
244        izip!(a.iter(), a_shoup.iter(), b.iter()).for_each(|(a_row, a_shoup_row, b_row)| {
245            izip!(
246                out.as_mut().iter_mut(),
247                a_row.as_ref().iter(),
248                a_shoup_row.as_ref().iter(),
249                b_row.as_ref().iter()
250            )
251            .for_each(|(o, a0, a0_shoup, b0)| {
252                let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64;
253                let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o);
254                v = v.wrapping_sub(q.wrapping_mul(quotient));
255
256                if v >= q_twice {
257                    v -= q_twice;
258                }
259
260                *o = v;
261            });
262        });
263    }
264}
265
266impl<T> GetModulus for ModularOpsU64<T>
267where
268    T: Modulus,
269{
270    type Element = T::Element;
271    type M = T;
272    fn modulus(&self) -> &Self::M {
273        &self.modulus
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use itertools::Itertools;
281    use rand::{thread_rng, Rng};
282    use rand_distr::Uniform;
283
284    #[test]
285    fn fma() {
286        let mut rng = thread_rng();
287        let prime = 36028797017456641;
288        let ring_size = 1 << 3;
289
290        let dist = Uniform::new(0, prime);
291        let d = 2;
292        let a0_matrix = (0..d)
293            .into_iter()
294            .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
295            .collect_vec();
296        // a0 in shoup representation
297        let a0_shoup_matrix = a0_matrix
298            .iter()
299            .map(|r| {
300                r.iter()
301                    .map(|v| {
302                        // $(v * 2^{\beta}) / p$
303                        ((*v as u128 * (1u128 << 64)) / prime as u128) as u64
304                    })
305                    .collect_vec()
306            })
307            .collect_vec();
308        let a1_matrix = (0..d)
309            .into_iter()
310            .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
311            .collect_vec();
312
313        let modop = ModularOpsU64::new(prime);
314
315        let mut out_shoup_fma_lazy = vec![0u64; ring_size];
316        modop.shoup_matrix_fma(
317            &mut out_shoup_fma_lazy,
318            &a0_matrix,
319            &a0_shoup_matrix,
320            &a1_matrix,
321        );
322        let out_shoup_fma = out_shoup_fma_lazy
323            .iter()
324            .map(|v| if *v >= prime { v - prime } else { *v })
325            .collect_vec();
326
327        // expected
328        let mut out_expected = vec![0u64; ring_size];
329        izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| {
330            izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| {
331                *o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime;
332            });
333        });
334
335        assert_eq!(out_expected, out_shoup_fma);
336    }
337}