1use itertools::izip;
2use num_traits::WrappingMul;
3
4use super::{
5 ArithmeticLazyOps, ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps,
6};
7use crate::RowMut;
8
9pub struct ModularOpsU64<T> {
10 q: u64,
11 q_twice: u64,
12 logq: usize,
13 barrett_mu: u128,
14 barrett_alpha: usize,
15 modulus: T,
16}
17
18impl<T> ModInit for ModularOpsU64<T>
19where
20 T: Modulus<Element = u64>,
21{
22 type M = T;
23 fn new(modulus: Self::M) -> ModularOpsU64<T> {
24 assert!(!modulus.is_native());
25
26 let q = modulus.largest_unsigned_value() + 1;
28 let logq = 64 - (q + 1u64).leading_zeros();
29
30 let mu = (1u128 << (logq * 2 + 3)) / (q as u128);
32 let alpha = logq + 3;
33
34 ModularOpsU64 {
35 q,
36 q_twice: q << 1,
37 logq: logq as usize,
38 barrett_alpha: alpha as usize,
39 barrett_mu: mu,
40 modulus,
41 }
42 }
43}
44
45impl<T> ModularOpsU64<T> {
46 fn add_mod_fast(&self, a: u64, b: u64) -> u64 {
47 debug_assert!(a < self.q);
48 debug_assert!(b < self.q);
49
50 let mut o = a + b;
51 if o >= self.q {
52 o -= self.q;
53 }
54 o
55 }
56
57 fn add_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
58 debug_assert!(a < self.q_twice);
59 debug_assert!(b < self.q_twice);
60
61 let mut o = a + b;
62 if o >= self.q_twice {
63 o -= self.q_twice;
64 }
65 o
66 }
67
68 fn sub_mod_fast(&self, a: u64, b: u64) -> u64 {
69 debug_assert!(a < self.q);
70 debug_assert!(b < self.q);
71
72 if a >= b {
73 a - b
74 } else {
75 (self.q + a) - b
76 }
77 }
78
79 fn mul_mod_fast_lazy(&self, a: u64, b: u64) -> u64 {
84 debug_assert!(a < 2 * self.q);
85 debug_assert!(b < 2 * self.q);
86
87 let ab = a as u128 * b as u128;
88
89 let tmp = ab >> (self.logq - 2);
92
93 let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
95
96 let tmp = k * (self.q as u128);
98
99 (ab - tmp) as u64
100 }
101
102 fn mul_mod_fast(&self, a: u64, b: u64) -> u64 {
107 debug_assert!(a < 2 * self.q);
108 debug_assert!(b < 2 * self.q);
109
110 let ab = a as u128 * b as u128;
111
112 let tmp = ab >> (self.logq - 2);
115
116 let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2);
118
119 let tmp = k * (self.q as u128);
121
122 let mut out = (ab - tmp) as u64;
123
124 if out >= self.q {
125 out -= self.q;
126 }
127
128 return out;
129 }
130}
131
132impl<T> ArithmeticOps for ModularOpsU64<T> {
133 type Element = u64;
134
135 fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
136 self.add_mod_fast(*a, *b)
137 }
138
139 fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
140 self.mul_mod_fast(*a, *b)
141 }
142
143 fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
144 self.sub_mod_fast(*a, *b)
145 }
146
147 fn neg(&self, a: &Self::Element) -> Self::Element {
148 self.q - *a
149 }
150
151 }
155
156impl<T> ArithmeticLazyOps for ModularOpsU64<T> {
157 type Element = u64;
158 fn add_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
159 self.add_mod_fast_lazy(*a, *b)
160 }
161 fn mul_lazy(&self, a: &Self::Element, b: &Self::Element) -> Self::Element {
162 self.mul_mod_fast_lazy(*a, *b)
163 }
164}
165
166impl<T> VectorOps for ModularOpsU64<T> {
167 type Element = u64;
168
169 fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
170 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
171 *ai = self.add_mod_fast(*ai, *bi);
172 });
173 }
174
175 fn elwise_sub_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
176 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
177 *ai = self.sub_mod_fast(*ai, *bi);
178 });
179 }
180
181 fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) {
182 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
183 *ai = self.mul_mod_fast(*ai, *bi);
184 });
185 }
186
187 fn elwise_neg_mut(&self, a: &mut [Self::Element]) {
188 a.iter_mut().for_each(|ai| *ai = self.q - *ai);
189 }
190
191 fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) {
192 izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| {
193 *oi = self.mul_mod_fast(*ai, *b);
194 });
195 }
196
197 fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) {
198 izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| {
199 *oi = self.mul_mod_fast(*ai, *bi);
200 });
201 }
202
203 fn elwise_scalar_mul_mut(&self, a: &mut [Self::Element], b: &Self::Element) {
204 a.iter_mut().for_each(|ai| {
205 *ai = self.mul_mod_fast(*ai, *b);
206 });
207 }
208
209 fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) {
210 izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| {
211 *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *ci));
212 });
213 }
214
215 fn elwise_fma_scalar_mut(
216 &self,
217 a: &mut [Self::Element],
218 b: &[Self::Element],
219 c: &Self::Element,
220 ) {
221 izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| {
222 *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *c));
223 });
224 }
225
226 }
230
231impl<R: RowMut<Element = u64>, T> ShoupMatrixFMA<R> for ModularOpsU64<T> {
232 fn shoup_matrix_fma(&self, out: &mut [R::Element], a: &[R], a_shoup: &[R], b: &[R]) {
233 assert!(a.len() == a_shoup.len());
234 assert!(
235 a.len() == b.len(),
236 "Unequal length {}!={}",
237 a.len(),
238 b.len()
239 );
240
241 let q = self.q;
242 let q_twice = self.q << 1;
243
244 izip!(a.iter(), a_shoup.iter(), b.iter()).for_each(|(a_row, a_shoup_row, b_row)| {
245 izip!(
246 out.as_mut().iter_mut(),
247 a_row.as_ref().iter(),
248 a_shoup_row.as_ref().iter(),
249 b_row.as_ref().iter()
250 )
251 .for_each(|(o, a0, a0_shoup, b0)| {
252 let quotient = ((*a0_shoup as u128 * *b0 as u128) >> 64) as u64;
253 let mut v = (a0.wrapping_mul(b0)).wrapping_add(*o);
254 v = v.wrapping_sub(q.wrapping_mul(quotient));
255
256 if v >= q_twice {
257 v -= q_twice;
258 }
259
260 *o = v;
261 });
262 });
263 }
264}
265
266impl<T> GetModulus for ModularOpsU64<T>
267where
268 T: Modulus,
269{
270 type Element = T::Element;
271 type M = T;
272 fn modulus(&self) -> &Self::M {
273 &self.modulus
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use itertools::Itertools;
281 use rand::{thread_rng, Rng};
282 use rand_distr::Uniform;
283
284 #[test]
285 fn fma() {
286 let mut rng = thread_rng();
287 let prime = 36028797017456641;
288 let ring_size = 1 << 3;
289
290 let dist = Uniform::new(0, prime);
291 let d = 2;
292 let a0_matrix = (0..d)
293 .into_iter()
294 .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
295 .collect_vec();
296 let a0_shoup_matrix = a0_matrix
298 .iter()
299 .map(|r| {
300 r.iter()
301 .map(|v| {
302 ((*v as u128 * (1u128 << 64)) / prime as u128) as u64
304 })
305 .collect_vec()
306 })
307 .collect_vec();
308 let a1_matrix = (0..d)
309 .into_iter()
310 .map(|_| (&mut rng).sample_iter(dist).take(ring_size).collect_vec())
311 .collect_vec();
312
313 let modop = ModularOpsU64::new(prime);
314
315 let mut out_shoup_fma_lazy = vec![0u64; ring_size];
316 modop.shoup_matrix_fma(
317 &mut out_shoup_fma_lazy,
318 &a0_matrix,
319 &a0_shoup_matrix,
320 &a1_matrix,
321 );
322 let out_shoup_fma = out_shoup_fma_lazy
323 .iter()
324 .map(|v| if *v >= prime { v - prime } else { *v })
325 .collect_vec();
326
327 let mut out_expected = vec![0u64; ring_size];
329 izip!(a0_matrix.iter(), a1_matrix.iter()).for_each(|(a_r, b_r)| {
330 izip!(out_expected.iter_mut(), a_r.iter(), b_r.iter()).for_each(|(o, a0, a1)| {
331 *o = (*o + ((*a0 as u128 * *a1 as u128) % prime as u128) as u64) % prime;
332 });
333 });
334
335 assert_eq!(out_expected, out_shoup_fma);
336 }
337}