poulpy_cpu_ref/reference/ntt120/
mat_vec.rs1use crate::reference::ntt120::primes::PrimeSet;
33
34pub struct BaaMeta<P: PrimeSet> {
42 pub h: u64,
43 pub h_pow_red: [u64; 4], _phantom: std::marker::PhantomData<P>,
45}
46
47impl<P: PrimeSet> BaaMeta<P> {
48 pub fn new() -> Self {
51 const MAX_ELL: f64 = 10_000.0;
52 let ell_bs = MAX_ELL.log2();
53
54 let mut min_res_bs = f64::MAX;
55 let mut min_h = 0u64;
56
57 for h in 1u64..64 {
58 let h_pow2_bs = compute_bit_size_red(h, P::Q);
59 let res_bs = log2_sum_two(h as f64 + ell_bs, (64.0 - h as f64) + ell_bs + h_pow2_bs);
61 if res_bs < min_res_bs {
62 min_res_bs = res_bs;
63 min_h = h;
64 }
65 }
66
67 let h_pow_red: [u64; 4] = std::array::from_fn(|k| {
68 let q = P::Q[k] as u64;
69 pow2_mod(min_h, q)
70 });
71
72 Self {
73 h: min_h,
74 h_pow_red,
75 _phantom: std::marker::PhantomData,
76 }
77 }
78}
79
80impl<P: PrimeSet> Default for BaaMeta<P> {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86pub struct BbbMeta<P: PrimeSet> {
88 pub h: u64,
89 pub s1h_pow_red: u64, pub s2l_pow_red: [u64; 4], pub s2h_pow_red: [u64; 4], pub s3l_pow_red: [u64; 4], pub s3h_pow_red: [u64; 4], pub s4l_pow_red: [u64; 4], pub s4h_pow_red: [u64; 4], _phantom: std::marker::PhantomData<P>,
97}
98
99impl<P: PrimeSet> BbbMeta<P> {
100 pub fn new() -> Self {
102 const MAX_ELL: f64 = 10_000.0;
103 let ell_bs = MAX_ELL.log2();
104 let pow2_32_bs = compute_bit_size_red(32, P::Q);
105
106 let s1_bs = 32.0 + ell_bs;
107 let s2_bs = 32.0 + ell_bs + 3.0f64.log2(); let s3_bs = 32.0 + ell_bs + 3.0f64.log2();
109 let s4_bs = 32.0 + ell_bs;
110
111 let mut min_res_bs = f64::MAX;
112 let mut min_h = 16u64;
113
114 for h in 16u64..32 {
115 let s1l_bs = h as f64;
116 let s1h_bs = (s1_bs - h as f64) + compute_bit_size_red(h, P::Q);
117 let s2l_bs = h as f64 + pow2_32_bs;
118 let s2h_bs = (s2_bs - h as f64) + compute_bit_size_red(32 + h, P::Q);
119 let s3l_bs = h as f64 + compute_bit_size_red(64, P::Q);
120 let s3h_bs = (s3_bs - h as f64) + compute_bit_size_red(64 + h, P::Q);
121 let s4l_bs = h as f64 + compute_bit_size_red(96, P::Q);
122 let s4h_bs = (s4_bs - h as f64) + compute_bit_size_red(96 + h, P::Q);
123
124 let res_bs = log2_sum_n(&[s1l_bs, s1h_bs, s2l_bs, s2h_bs, s3l_bs, s3h_bs, s4l_bs, s4h_bs]);
125 if res_bs < min_res_bs {
126 min_res_bs = res_bs;
127 min_h = h;
128 }
129 }
130
131 let s1h_pow_red: u64 = 1u64 << min_h; let s2l_pow_red: [u64; 4] = std::array::from_fn(|k| pow2_mod(32, P::Q[k] as u64));
133 let s2h_pow_red: [u64; 4] = std::array::from_fn(|k| {
134 let q = P::Q[k] as u64;
135 (s2l_pow_red[k] * s1h_pow_red) % q
136 });
137 let s3l_pow_red: [u64; 4] = std::array::from_fn(|k| {
138 let q = P::Q[k] as u64;
139 (s2l_pow_red[k] * s2l_pow_red[k]) % q
140 });
141 let s3h_pow_red: [u64; 4] = std::array::from_fn(|k| {
142 let q = P::Q[k] as u64;
143 (s3l_pow_red[k] * s1h_pow_red) % q
144 });
145 let s4l_pow_red: [u64; 4] = std::array::from_fn(|k| {
146 let q = P::Q[k] as u64;
147 (s3l_pow_red[k] * s2l_pow_red[k]) % q
148 });
149 let s4h_pow_red: [u64; 4] = std::array::from_fn(|k| {
150 let q = P::Q[k] as u64;
151 (s4l_pow_red[k] * s1h_pow_red) % q
152 });
153
154 Self {
155 h: min_h,
156 s1h_pow_red,
157 s2l_pow_red,
158 s2h_pow_red,
159 s3l_pow_red,
160 s3h_pow_red,
161 s4l_pow_red,
162 s4h_pow_red,
163 _phantom: std::marker::PhantomData,
164 }
165 }
166}
167
168impl<P: PrimeSet> Default for BbbMeta<P> {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174pub struct BbcMeta<P: PrimeSet> {
176 pub h: u64,
177 pub s2l_pow_red: [u64; 4], pub s2h_pow_red: [u64; 4], _phantom: std::marker::PhantomData<P>,
180}
181
182impl<P: PrimeSet> BbcMeta<P> {
183 pub fn new() -> Self {
185 const MAX_ELL: f64 = 10_000.0;
186 let ell_bs = MAX_ELL.log2();
187 let pow2_32_bs = compute_bit_size_red(32, P::Q);
188 let s1_bs = 32.0 + ell_bs;
189
190 let mut min_res_bs = f64::MAX;
191 let mut min_h = 16u64;
192
193 for h in 16u64..32 {
194 let s2l_bs = pow2_32_bs + h as f64;
195 let s2h_bs = (s1_bs - h as f64) + compute_bit_size_red(32 + h, P::Q);
196 let res_bs = log2_sum_n(&[s1_bs, s2l_bs, s2h_bs]);
197 if res_bs < min_res_bs {
198 min_res_bs = res_bs;
199 min_h = h;
200 }
201 }
202
203 let s2l_pow_red: [u64; 4] = std::array::from_fn(|k| pow2_mod(32, P::Q[k] as u64));
204 let s2h_pow_red: [u64; 4] = std::array::from_fn(|k| pow2_mod(32 + min_h, P::Q[k] as u64));
205
206 Self {
207 h: min_h,
208 s2l_pow_red,
209 s2h_pow_red,
210 _phantom: std::marker::PhantomData,
211 }
212 }
213}
214
215impl<P: PrimeSet> Default for BbcMeta<P> {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221pub fn vec_mat1col_product_baa_ref<P: PrimeSet>(meta: &BaaMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
239 assert!(res.len() >= 4);
240 assert!(x.len() >= 4 * ell);
241 assert!(y.len() >= 4 * ell);
242
243 let h = meta.h;
244 let mask = (1u64 << h) - 1;
245
246 let mut acc1 = [0u64; 4];
247 let mut acc2 = [0u64; 4];
248
249 for i in 0..ell {
250 for k in 0..4 {
251 let t = x[4 * i + k] as u64 * y[4 * i + k] as u64;
252 acc1[k] += t & mask;
253 acc2[k] += t >> h;
254 }
255 }
256
257 for k in 0..4 {
258 res[k] = acc1[k] + acc2[k] * meta.h_pow_red[k];
259 }
260}
261
262pub fn vec_mat1col_product_bbb_ref<P: PrimeSet>(meta: &BbbMeta<P>, ell: usize, res: &mut [u64], x: &[u64], y: &[u64]) {
269 assert!(res.len() >= 4);
270 assert!(x.len() >= 4 * ell);
271 assert!(y.len() >= 4 * ell);
272
273 const MASK1: u64 = u32::MAX as u64; let mut s1 = [0u64; 4];
276 let mut s2 = [0u64; 4];
277 let mut s3 = [0u64; 4];
278 let mut s4 = [0u64; 4];
279
280 for i in 0..ell {
281 for k in 0..4 {
282 let xv = x[4 * i + k];
283 let yv = y[4 * i + k];
284 let xl = xv & MASK1;
285 let xh = xv >> 32;
286 let yl = yv & MASK1;
287 let yh = yv >> 32;
288
289 let a = xl * yl;
290 let al = a & MASK1;
291 let ah = a >> 32;
292
293 let b = xl * yh;
294 let bl = b & MASK1;
295 let bh = b >> 32;
296
297 let c = xh * yl;
298 let cl = c & MASK1;
299 let ch = c >> 32;
300
301 let d = xh * yh;
302 let dl = d & MASK1;
303 let dh = d >> 32;
304
305 s1[k] += al;
306 s2[k] += ah + bl + cl;
307 s3[k] += bh + ch + dl;
308 s4[k] += dh;
309 }
310 }
311
312 let h2 = meta.h;
313 let mask2 = (1u64 << h2) - 1;
314
315 for k in 0..4 {
316 let s1l = s1[k] & mask2;
317 let s1h = s1[k] >> h2;
318 let s2l = s2[k] & mask2;
319 let s2h = s2[k] >> h2;
320 let s3l = s3[k] & mask2;
321 let s3h = s3[k] >> h2;
322 let s4l = s4[k] & mask2;
323 let s4h = s4[k] >> h2;
324
325 let mut t = s1l;
326 t += s1h * meta.s1h_pow_red;
327 t += s2l * meta.s2l_pow_red[k];
328 t += s2h * meta.s2h_pow_red[k];
329 t += s3l * meta.s3l_pow_red[k];
330 t += s3h * meta.s3h_pow_red[k];
331 t += s4l * meta.s4l_pow_red[k];
332 t += s4h * meta.s4h_pow_red[k];
333
334 res[k] = t;
335 }
336}
337
338#[inline(always)]
343pub(crate) fn accum_mul_q120_bc(s: &mut [u64; 8], x: &[u32; 8], y: &[u32; 8]) {
344 const MASK32: u64 = u32::MAX as u64;
345 for i in 0..4 {
346 let x_lo = x[2 * i] as u64;
347 let x_hi = x[2 * i + 1] as u64;
348 let y_lo = y[2 * i] as u64;
349 let y_hi = y[2 * i + 1] as u64;
350 let xy_lo = x_lo * y_lo;
351 let xy_hi = x_hi * y_hi;
352 s[2 * i] += (xy_lo & MASK32) + (xy_hi & MASK32);
353 s[2 * i + 1] += (xy_lo >> 32) + (xy_hi >> 32);
354 }
355}
356
357#[inline(always)]
359pub(crate) fn accum_to_q120b<P: PrimeSet>(res: &mut [u64; 4], s: &[u64; 8], meta: &BbcMeta<P>) {
360 let h2 = meta.h;
361 let mask2 = (1u64 << h2) - 1;
362 for k in 0..4 {
363 let s2l = s[2 * k + 1] & mask2;
364 let s2h = s[2 * k + 1] >> h2;
365 let t = s[2 * k] + s2l * meta.s2l_pow_red[k] + s2h * meta.s2h_pow_red[k];
366 res[k] = t;
367 }
368}
369
370pub fn vec_mat1col_product_bbc_ref<P: PrimeSet>(meta: &BbcMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
375 assert!(res.len() >= 4);
376 assert!(x.len() >= 8 * ell);
377 assert!(y.len() >= 8 * ell);
378
379 let mut s = [0u64; 8];
380 for i in 0..ell {
381 let xi: &[u32; 8] = x[8 * i..8 * i + 8].try_into().unwrap();
382 let yi: &[u32; 8] = y[8 * i..8 * i + 8].try_into().unwrap();
383 accum_mul_q120_bc(&mut s, xi, yi);
384 }
385 let res4: &mut [u64; 4] = (&mut res[..4]).try_into().unwrap();
386 accum_to_q120b::<P>(res4, &s, meta);
387}
388
389pub fn vec_mat1col_product_x2_bbc_ref<P: PrimeSet>(meta: &BbcMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
395 assert!(res.len() >= 8);
396 assert!(x.len() >= 16 * ell);
397 assert!(y.len() >= 16 * ell);
398
399 let mut s = [[0u64; 8]; 2];
400
401 for i in 0..ell {
402 let x0: &[u32; 8] = x[16 * i..16 * i + 8].try_into().unwrap();
404 let x1: &[u32; 8] = x[16 * i + 8..16 * i + 16].try_into().unwrap();
405 let y0: &[u32; 8] = y[16 * i..16 * i + 8].try_into().unwrap();
406 let y1: &[u32; 8] = y[16 * i + 8..16 * i + 16].try_into().unwrap();
407 accum_mul_q120_bc(&mut s[0], x0, y0);
408 accum_mul_q120_bc(&mut s[1], x1, y1);
409 }
410
411 let (res0, res1) = res.split_at_mut(4);
412 let r0: &mut [u64; 4] = res0.try_into().unwrap();
413 accum_to_q120b::<P>(r0, &s[0], meta);
414 let r1: &mut [u64; 4] = (&mut res1[..4]).try_into().unwrap();
415 accum_to_q120b::<P>(r1, &s[1], meta);
416}
417
418pub fn vec_mat2cols_product_x2_bbc_ref<P: PrimeSet>(meta: &BbcMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
424 assert!(res.len() >= 16);
425 assert!(x.len() >= 16 * ell);
426 assert!(y.len() >= 32 * ell);
427
428 let mut s = [[0u64; 8]; 4];
429
430 for i in 0..ell {
431 let x0: &[u32; 8] = x[16 * i..16 * i + 8].try_into().unwrap();
432 let x1: &[u32; 8] = x[16 * i + 8..16 * i + 16].try_into().unwrap();
433 let y0: &[u32; 8] = y[32 * i..32 * i + 8].try_into().unwrap();
434 let y1: &[u32; 8] = y[32 * i + 8..32 * i + 16].try_into().unwrap();
435 let y2: &[u32; 8] = y[32 * i + 16..32 * i + 24].try_into().unwrap();
436 let y3: &[u32; 8] = y[32 * i + 24..32 * i + 32].try_into().unwrap();
437 accum_mul_q120_bc(&mut s[0], x0, y0);
438 accum_mul_q120_bc(&mut s[1], x1, y1);
439 accum_mul_q120_bc(&mut s[2], x0, y2);
440 accum_mul_q120_bc(&mut s[3], x1, y3);
441 }
442
443 for (out_idx, si) in s.iter().enumerate() {
444 let r: &mut [u64; 4] = (&mut res[4 * out_idx..4 * out_idx + 4]).try_into().unwrap();
445 accum_to_q120b::<P>(r, si, meta);
446 }
447}
448
449pub fn extract_1blk_from_q120b_ref(nn: usize, blk: usize, dst: &mut [u64], src: &[u64]) {
461 debug_assert!(blk < nn / 2);
462 debug_assert!(dst.len() >= 8);
463 debug_assert!(src.len() >= 4 * nn);
464
465 dst[..8].copy_from_slice(&src[8 * blk..8 * blk + 8]);
466}
467
468pub fn extract_1blk_from_contiguous_q120b_ref(nn: usize, nrows: usize, blk: usize, dst: &mut [u64], src: &[u64]) {
477 debug_assert!(blk < nn / 2);
478 debug_assert!(dst.len() >= 8 * nrows);
479 debug_assert!(src.len() >= 4 * nn * nrows);
480
481 for row in 0..nrows {
482 let src_base = 4 * nn * row;
483 let dst_base = 8 * row;
484 dst[dst_base..dst_base + 8].copy_from_slice(&src[src_base + 8 * blk..src_base + 8 * blk + 8]);
485 }
486}
487
488pub fn save_1blk_to_q120b_ref(nn: usize, blk: usize, dst: &mut [u64], src: &[u64]) {
493 debug_assert!(blk < nn / 2);
494 debug_assert!(src.len() >= 8);
495 debug_assert!(dst.len() >= 4 * nn);
496
497 dst[8 * blk..8 * blk + 8].copy_from_slice(&src[..8]);
498}
499
500use super::pow2_mod;
505
506fn compute_bit_size_red(exp: u64, q: [u32; 4]) -> f64 {
510 let mut max_bs = 0.0f64;
511 for &qi in &q {
512 let val = pow2_mod(exp, qi as u64);
513 if val > 1 {
514 let bs = (val as f64).log2();
515 if bs > max_bs {
516 max_bs = bs;
517 }
518 }
519 }
520 max_bs
521}
522
523fn log2_sum_two(a: f64, b: f64) -> f64 {
525 let sum = 2.0f64.powf(a) + 2.0f64.powf(b);
526 sum.log2()
527}
528
529fn log2_sum_n(bs: &[f64]) -> f64 {
531 let sum: f64 = bs.iter().map(|&b| 2.0f64.powf(b)).sum();
532 sum.log2()
533}