Skip to main content

poulpy_cpu_ref/reference/ntt120/
mat_vec.rs

1// ----------------------------------------------------------------------
2// DISCLAIMER
3//
4// This module contains code that has been directly ported from the
5// spqlios-arithmetic library
6// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
7// under the Apache License, Version 2.0.
8//
9// The porting process from C to Rust was done with minimal changes
10// in order to preserve the semantics and performance characteristics
11// of the original implementation.
12//
13// Both Poulpy and spqlios-arithmetic are distributed under the terms
14// of the Apache License, Version 2.0. See the LICENSE file for details.
15//
16// ----------------------------------------------------------------------
17
18//! Lazy-accumulation matrix–vector dot products in Q120 arithmetic.
19//!
20//! These functions are direct Rust ports of `q120_arithmetic_ref.c`
21//! in spqlios-arithmetic.  They implement:
22//!
23//! - **baa**: inner product of q120a × q120a → q120b (inputs in `[0, 2^32)`)
24//! - **bbb**: inner product of q120b × q120b → q120b (inputs in `[0, 2^64)`)
25//! - **bbc**: inner product of q120b × q120c → q120b (NTT × prepared-const)
26//! - **x2** / **2cols** variants that process two output elements at once.
27//! - Block extract/save helpers.
28//!
29//! The accumulation is designed so that `ell < 10 000` inner products
30//! fit without overflow (all intermediate sums stay below 64 bits).
31
32use crate::reference::ntt120::primes::PrimeSet;
33
34// ──────────────────────────────────────────────────────────────────────────────
35// Precomputed metadata
36// ──────────────────────────────────────────────────────────────────────────────
37
38/// Precomputed metadata for the q120a × q120a → q120b dot product.
39///
40/// Constructed once (per prime set) and reused for any `ell < 10 000`.
41pub struct BaaMeta<P: PrimeSet> {
42    pub h: u64,
43    pub h_pow_red: [u64; 4], // (2^h) % Q[k]
44    _phantom: std::marker::PhantomData<P>,
45}
46
47impl<P: PrimeSet> BaaMeta<P> {
48    /// Computes the optimal split point `h` that minimises the output
49    /// bit-width for an accumulation of up to `MAX_ELL = 10 000` terms.
50    pub fn new() -> Self {
51        const MAX_ELL: f64 = 10_000.0;
52        let ell_bs = MAX_ELL.log2();
53
54        let mut min_res_bs = f64::MAX;
55        let mut min_h = 0u64;
56
57        for h in 1u64..64 {
58            let h_pow2_bs = compute_bit_size_red(h, P::Q);
59            // S1 has (h + ell_bs) bits, S2 has (64-h + ell_bs + h_pow2_bs) bits
60            let res_bs = log2_sum_two(h as f64 + ell_bs, (64.0 - h as f64) + ell_bs + h_pow2_bs);
61            if res_bs < min_res_bs {
62                min_res_bs = res_bs;
63                min_h = h;
64            }
65        }
66
67        let h_pow_red: [u64; 4] = std::array::from_fn(|k| {
68            let q = P::Q[k] as u64;
69            pow2_mod(min_h, q)
70        });
71
72        Self {
73            h: min_h,
74            h_pow_red,
75            _phantom: std::marker::PhantomData,
76        }
77    }
78}
79
80impl<P: PrimeSet> Default for BaaMeta<P> {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86/// Precomputed metadata for the q120b × q120b → q120b dot product.
87pub struct BbbMeta<P: PrimeSet> {
88    pub h: u64,
89    pub s1h_pow_red: u64,      // 2^h (prime-independent)
90    pub s2l_pow_red: [u64; 4], // 2^32 mod Q[k]
91    pub s2h_pow_red: [u64; 4], // 2^(32+h) mod Q[k]
92    pub s3l_pow_red: [u64; 4], // 2^64 mod Q[k]
93    pub s3h_pow_red: [u64; 4], // 2^(64+h) mod Q[k]
94    pub s4l_pow_red: [u64; 4], // 2^96 mod Q[k]
95    pub s4h_pow_red: [u64; 4], // 2^(96+h) mod Q[k]
96    _phantom: std::marker::PhantomData<P>,
97}
98
99impl<P: PrimeSet> BbbMeta<P> {
100    /// Computes the optimal `h` for the four-term accumulation scheme.
101    pub fn new() -> Self {
102        const MAX_ELL: f64 = 10_000.0;
103        let ell_bs = MAX_ELL.log2();
104        let pow2_32_bs = compute_bit_size_red(32, P::Q);
105
106        let s1_bs = 32.0 + ell_bs;
107        let s2_bs = 32.0 + ell_bs + 3.0f64.log2(); // +log2(3) from Ah+Bl+Cl
108        let s3_bs = 32.0 + ell_bs + 3.0f64.log2();
109        let s4_bs = 32.0 + ell_bs;
110
111        let mut min_res_bs = f64::MAX;
112        let mut min_h = 16u64;
113
114        for h in 16u64..32 {
115            let s1l_bs = h as f64;
116            let s1h_bs = (s1_bs - h as f64) + compute_bit_size_red(h, P::Q);
117            let s2l_bs = h as f64 + pow2_32_bs;
118            let s2h_bs = (s2_bs - h as f64) + compute_bit_size_red(32 + h, P::Q);
119            let s3l_bs = h as f64 + compute_bit_size_red(64, P::Q);
120            let s3h_bs = (s3_bs - h as f64) + compute_bit_size_red(64 + h, P::Q);
121            let s4l_bs = h as f64 + compute_bit_size_red(96, P::Q);
122            let s4h_bs = (s4_bs - h as f64) + compute_bit_size_red(96 + h, P::Q);
123
124            let res_bs = log2_sum_n(&[s1l_bs, s1h_bs, s2l_bs, s2h_bs, s3l_bs, s3h_bs, s4l_bs, s4h_bs]);
125            if res_bs < min_res_bs {
126                min_res_bs = res_bs;
127                min_h = h;
128            }
129        }
130
131        let s1h_pow_red: u64 = 1u64 << min_h; // prime-independent: 2^h is the same for all primes
132        let s2l_pow_red: [u64; 4] = std::array::from_fn(|k| pow2_mod(32, P::Q[k] as u64));
133        let s2h_pow_red: [u64; 4] = std::array::from_fn(|k| {
134            let q = P::Q[k] as u64;
135            (s2l_pow_red[k] * s1h_pow_red) % q
136        });
137        let s3l_pow_red: [u64; 4] = std::array::from_fn(|k| {
138            let q = P::Q[k] as u64;
139            (s2l_pow_red[k] * s2l_pow_red[k]) % q
140        });
141        let s3h_pow_red: [u64; 4] = std::array::from_fn(|k| {
142            let q = P::Q[k] as u64;
143            (s3l_pow_red[k] * s1h_pow_red) % q
144        });
145        let s4l_pow_red: [u64; 4] = std::array::from_fn(|k| {
146            let q = P::Q[k] as u64;
147            (s3l_pow_red[k] * s2l_pow_red[k]) % q
148        });
149        let s4h_pow_red: [u64; 4] = std::array::from_fn(|k| {
150            let q = P::Q[k] as u64;
151            (s4l_pow_red[k] * s1h_pow_red) % q
152        });
153
154        Self {
155            h: min_h,
156            s1h_pow_red,
157            s2l_pow_red,
158            s2h_pow_red,
159            s3l_pow_red,
160            s3h_pow_red,
161            s4l_pow_red,
162            s4h_pow_red,
163            _phantom: std::marker::PhantomData,
164        }
165    }
166}
167
168impl<P: PrimeSet> Default for BbbMeta<P> {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174/// Precomputed metadata for the q120b × q120c → q120b dot product.
175pub struct BbcMeta<P: PrimeSet> {
176    pub h: u64,
177    pub s2l_pow_red: [u64; 4], // 2^32 mod Q[k]
178    pub s2h_pow_red: [u64; 4], // 2^(32+h) mod Q[k]
179    _phantom: std::marker::PhantomData<P>,
180}
181
182impl<P: PrimeSet> BbcMeta<P> {
183    /// Computes the optimal `h` for the two-term accumulation scheme.
184    pub fn new() -> Self {
185        const MAX_ELL: f64 = 10_000.0;
186        let ell_bs = MAX_ELL.log2();
187        let pow2_32_bs = compute_bit_size_red(32, P::Q);
188        let s1_bs = 32.0 + ell_bs;
189
190        let mut min_res_bs = f64::MAX;
191        let mut min_h = 16u64;
192
193        for h in 16u64..32 {
194            let s2l_bs = pow2_32_bs + h as f64;
195            let s2h_bs = (s1_bs - h as f64) + compute_bit_size_red(32 + h, P::Q);
196            let res_bs = log2_sum_n(&[s1_bs, s2l_bs, s2h_bs]);
197            if res_bs < min_res_bs {
198                min_res_bs = res_bs;
199                min_h = h;
200            }
201        }
202
203        let s2l_pow_red: [u64; 4] = std::array::from_fn(|k| pow2_mod(32, P::Q[k] as u64));
204        let s2h_pow_red: [u64; 4] = std::array::from_fn(|k| pow2_mod(32 + min_h, P::Q[k] as u64));
205
206        Self {
207            h: min_h,
208            s2l_pow_red,
209            s2h_pow_red,
210            _phantom: std::marker::PhantomData,
211        }
212    }
213}
214
215impl<P: PrimeSet> Default for BbcMeta<P> {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221// ──────────────────────────────────────────────────────────────────────────────
222// Dot-product kernels
223//
224// PERF POLICY: `assert!` guards appear at public function entry where an
225// incorrect slice length would silently access out-of-bounds memory in a
226// release build.  Pure arithmetic inner loops (no slice indexing beyond the
227// entry check) use `debug_assert!` to avoid redundant bounds work per
228// iteration.
229// ──────────────────────────────────────────────────────────────────────────────
230
231/// Computes `res = sum_{i=0}^{ell-1} x[i] * y[i]` in Q120b format,
232/// where `x` and `y` are in q120a (values in `[0, 2^32)`).
233///
234/// `ell` must be < 10 000.
235///
236/// Inputs: `x` and `y` as flat `u32` slices with stride 4 (one group
237/// of 4 per ring element), `res` as a `u64` slice of length 4.
238pub fn vec_mat1col_product_baa_ref<P: PrimeSet>(meta: &BaaMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
239    assert!(res.len() >= 4);
240    assert!(x.len() >= 4 * ell);
241    assert!(y.len() >= 4 * ell);
242
243    let h = meta.h;
244    let mask = (1u64 << h) - 1;
245
246    let mut acc1 = [0u64; 4];
247    let mut acc2 = [0u64; 4];
248
249    for i in 0..ell {
250        for k in 0..4 {
251            let t = x[4 * i + k] as u64 * y[4 * i + k] as u64;
252            acc1[k] += t & mask;
253            acc2[k] += t >> h;
254        }
255    }
256
257    for k in 0..4 {
258        res[k] = acc1[k] + acc2[k] * meta.h_pow_red[k];
259    }
260}
261
262/// Computes `res = sum_{i=0}^{ell-1} x[i] * y[i]` in Q120b format,
263/// where `x` and `y` are in q120b (values in `[0, 2^64)`).
264///
265/// `ell` must be < 10 000.
266///
267/// Both inputs and output are flat `u64` slices with stride 4.
268pub fn vec_mat1col_product_bbb_ref<P: PrimeSet>(meta: &BbbMeta<P>, ell: usize, res: &mut [u64], x: &[u64], y: &[u64]) {
269    assert!(res.len() >= 4);
270    assert!(x.len() >= 4 * ell);
271    assert!(y.len() >= 4 * ell);
272
273    const MASK1: u64 = u32::MAX as u64; // lower 32 bits
274
275    let mut s1 = [0u64; 4];
276    let mut s2 = [0u64; 4];
277    let mut s3 = [0u64; 4];
278    let mut s4 = [0u64; 4];
279
280    for i in 0..ell {
281        for k in 0..4 {
282            let xv = x[4 * i + k];
283            let yv = y[4 * i + k];
284            let xl = xv & MASK1;
285            let xh = xv >> 32;
286            let yl = yv & MASK1;
287            let yh = yv >> 32;
288
289            let a = xl * yl;
290            let al = a & MASK1;
291            let ah = a >> 32;
292
293            let b = xl * yh;
294            let bl = b & MASK1;
295            let bh = b >> 32;
296
297            let c = xh * yl;
298            let cl = c & MASK1;
299            let ch = c >> 32;
300
301            let d = xh * yh;
302            let dl = d & MASK1;
303            let dh = d >> 32;
304
305            s1[k] += al;
306            s2[k] += ah + bl + cl;
307            s3[k] += bh + ch + dl;
308            s4[k] += dh;
309        }
310    }
311
312    let h2 = meta.h;
313    let mask2 = (1u64 << h2) - 1;
314
315    for k in 0..4 {
316        let s1l = s1[k] & mask2;
317        let s1h = s1[k] >> h2;
318        let s2l = s2[k] & mask2;
319        let s2h = s2[k] >> h2;
320        let s3l = s3[k] & mask2;
321        let s3h = s3[k] >> h2;
322        let s4l = s4[k] & mask2;
323        let s4h = s4[k] >> h2;
324
325        let mut t = s1l;
326        t += s1h * meta.s1h_pow_red;
327        t += s2l * meta.s2l_pow_red[k];
328        t += s2h * meta.s2h_pow_red[k];
329        t += s3l * meta.s3l_pow_red[k];
330        t += s3h * meta.s3h_pow_red[k];
331        t += s4l * meta.s4l_pow_red[k];
332        t += s4h * meta.s4h_pow_red[k];
333
334        res[k] = t;
335    }
336}
337
338/// Inner helper: accumulate one q120b × q120c pair into an 8-wide `u64` sum.
339///
340/// `s[2*k]` collects the low-32-bit part and `s[2*k+1]` the high-32-bit part
341/// of the per-prime product for prime index `k`.
342#[inline(always)]
343pub(crate) fn accum_mul_q120_bc(s: &mut [u64; 8], x: &[u32; 8], y: &[u32; 8]) {
344    const MASK32: u64 = u32::MAX as u64;
345    for i in 0..4 {
346        let x_lo = x[2 * i] as u64;
347        let x_hi = x[2 * i + 1] as u64;
348        let y_lo = y[2 * i] as u64;
349        let y_hi = y[2 * i + 1] as u64;
350        let xy_lo = x_lo * y_lo;
351        let xy_hi = x_hi * y_hi;
352        s[2 * i] += (xy_lo & MASK32) + (xy_hi & MASK32);
353        s[2 * i + 1] += (xy_lo >> 32) + (xy_hi >> 32);
354    }
355}
356
357/// Collapses the 8-wide accumulator `s` into a 4-wide q120b result.
358#[inline(always)]
359pub(crate) fn accum_to_q120b<P: PrimeSet>(res: &mut [u64; 4], s: &[u64; 8], meta: &BbcMeta<P>) {
360    let h2 = meta.h;
361    let mask2 = (1u64 << h2) - 1;
362    for k in 0..4 {
363        let s2l = s[2 * k + 1] & mask2;
364        let s2h = s[2 * k + 1] >> h2;
365        let t = s[2 * k] + s2l * meta.s2l_pow_red[k] + s2h * meta.s2h_pow_red[k];
366        res[k] = t;
367    }
368}
369
370/// Computes `res = sum_{i=0}^{ell-1} x[i] * y[i]` in Q120b format,
371/// where `x` is in q120b and `y` is in q120c.
372///
373/// `ell` must be < 10 000.
374pub fn vec_mat1col_product_bbc_ref<P: PrimeSet>(meta: &BbcMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
375    assert!(res.len() >= 4);
376    assert!(x.len() >= 8 * ell);
377    assert!(y.len() >= 8 * ell);
378
379    let mut s = [0u64; 8];
380    for i in 0..ell {
381        let xi: &[u32; 8] = x[8 * i..8 * i + 8].try_into().unwrap();
382        let yi: &[u32; 8] = y[8 * i..8 * i + 8].try_into().unwrap();
383        accum_mul_q120_bc(&mut s, xi, yi);
384    }
385    let res4: &mut [u64; 4] = (&mut res[..4]).try_into().unwrap();
386    accum_to_q120b::<P>(res4, &s, meta);
387}
388
389/// Computes two q120b dot products simultaneously (x2 variant).
390///
391/// `x` contains two interleaved q120b vectors (each of length `ell`),
392/// and `y` contains two interleaved q120c vectors.
393/// Both output q120b values are written into `res` (8 contiguous u64s).
394pub fn vec_mat1col_product_x2_bbc_ref<P: PrimeSet>(meta: &BbcMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
395    assert!(res.len() >= 8);
396    assert!(x.len() >= 16 * ell);
397    assert!(y.len() >= 16 * ell);
398
399    let mut s = [[0u64; 8]; 2];
400
401    for i in 0..ell {
402        // Each element: 2 × q120b (16 u32) in x, 2 × q120c (16 u32) in y
403        let x0: &[u32; 8] = x[16 * i..16 * i + 8].try_into().unwrap();
404        let x1: &[u32; 8] = x[16 * i + 8..16 * i + 16].try_into().unwrap();
405        let y0: &[u32; 8] = y[16 * i..16 * i + 8].try_into().unwrap();
406        let y1: &[u32; 8] = y[16 * i + 8..16 * i + 16].try_into().unwrap();
407        accum_mul_q120_bc(&mut s[0], x0, y0);
408        accum_mul_q120_bc(&mut s[1], x1, y1);
409    }
410
411    let (res0, res1) = res.split_at_mut(4);
412    let r0: &mut [u64; 4] = res0.try_into().unwrap();
413    accum_to_q120b::<P>(r0, &s[0], meta);
414    let r1: &mut [u64; 4] = (&mut res1[..4]).try_into().unwrap();
415    accum_to_q120b::<P>(r1, &s[1], meta);
416}
417
418/// Computes four q120b dot products (two output, two columns).
419///
420/// Equivalent to calling `vec_mat1col_product_x2_bbc_ref` twice with
421/// two different column slices of `y`, accumulating into `res[0..8]`
422/// and `res[8..16]` respectively.
423pub fn vec_mat2cols_product_x2_bbc_ref<P: PrimeSet>(meta: &BbcMeta<P>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
424    assert!(res.len() >= 16);
425    assert!(x.len() >= 16 * ell);
426    assert!(y.len() >= 32 * ell);
427
428    let mut s = [[0u64; 8]; 4];
429
430    for i in 0..ell {
431        let x0: &[u32; 8] = x[16 * i..16 * i + 8].try_into().unwrap();
432        let x1: &[u32; 8] = x[16 * i + 8..16 * i + 16].try_into().unwrap();
433        let y0: &[u32; 8] = y[32 * i..32 * i + 8].try_into().unwrap();
434        let y1: &[u32; 8] = y[32 * i + 8..32 * i + 16].try_into().unwrap();
435        let y2: &[u32; 8] = y[32 * i + 16..32 * i + 24].try_into().unwrap();
436        let y3: &[u32; 8] = y[32 * i + 24..32 * i + 32].try_into().unwrap();
437        accum_mul_q120_bc(&mut s[0], x0, y0);
438        accum_mul_q120_bc(&mut s[1], x1, y1);
439        accum_mul_q120_bc(&mut s[2], x0, y2);
440        accum_mul_q120_bc(&mut s[3], x1, y3);
441    }
442
443    for (out_idx, si) in s.iter().enumerate() {
444        let r: &mut [u64; 4] = (&mut res[4 * out_idx..4 * out_idx + 4]).try_into().unwrap();
445        accum_to_q120b::<P>(r, si, meta);
446    }
447}
448
449// ──────────────────────────────────────────────────────────────────────────────
450// Block extract / save helpers
451// ──────────────────────────────────────────────────────────────────────────────
452
453/// Extracts one block of 4 q120b coefficients (= 8 u64 values) from
454/// a q120b NTT vector of length `nn`, copying into `dst`.
455///
456/// A "block" here groups 2 consecutive NTT coefficients (indices
457/// `2*blk` and `2*blk+1`), so `blk < nn/2`.
458///
459/// This is the Rust port of `q120x2_extract_1blk_from_q120b_ref`.
460pub fn extract_1blk_from_q120b_ref(nn: usize, blk: usize, dst: &mut [u64], src: &[u64]) {
461    debug_assert!(blk < nn / 2);
462    debug_assert!(dst.len() >= 8);
463    debug_assert!(src.len() >= 4 * nn);
464
465    dst[..8].copy_from_slice(&src[8 * blk..8 * blk + 8]);
466}
467
468/// Extracts one block from a contiguous array of `nrows` q120b NTT
469/// vectors, each of length `nn`.
470///
471/// `dst` receives `nrows` consecutive blocks of 8 u64 each.
472/// `src` is laid out as `[row_0 || row_1 || ... || row_{nrows-1}]`
473/// where each row has `4*nn` u64 values.
474///
475/// Port of `q120x2_extract_1blk_from_contiguous_q120b_ref`.
476pub fn extract_1blk_from_contiguous_q120b_ref(nn: usize, nrows: usize, blk: usize, dst: &mut [u64], src: &[u64]) {
477    debug_assert!(blk < nn / 2);
478    debug_assert!(dst.len() >= 8 * nrows);
479    debug_assert!(src.len() >= 4 * nn * nrows);
480
481    for row in 0..nrows {
482        let src_base = 4 * nn * row;
483        let dst_base = 8 * row;
484        dst[dst_base..dst_base + 8].copy_from_slice(&src[src_base + 8 * blk..src_base + 8 * blk + 8]);
485    }
486}
487
488/// Saves one q120x2b block (8 u64 values) into the corresponding
489/// position of a q120b NTT vector of length `nn`.
490///
491/// Port of `q120x2b_save_1blk_to_q120b_ref`.
492pub fn save_1blk_to_q120b_ref(nn: usize, blk: usize, dst: &mut [u64], src: &[u64]) {
493    debug_assert!(blk < nn / 2);
494    debug_assert!(src.len() >= 8);
495    debug_assert!(dst.len() >= 4 * nn);
496
497    dst[8 * blk..8 * blk + 8].copy_from_slice(&src[..8]);
498}
499
500// ──────────────────────────────────────────────────────────────────────────────
501// Internal helpers
502// ──────────────────────────────────────────────────────────────────────────────
503
504use super::pow2_mod;
505
506/// `ceil(log2(x))` for x ≥ 1 encoded as bit-size estimate.
507///
508/// Returns the maximum, over all four primes, of `ceil(log2((2^exp) % Q[k]))`.
509fn compute_bit_size_red(exp: u64, q: [u32; 4]) -> f64 {
510    let mut max_bs = 0.0f64;
511    for &qi in &q {
512        let val = pow2_mod(exp, qi as u64);
513        if val > 1 {
514            let bs = (val as f64).log2();
515            if bs > max_bs {
516                max_bs = bs;
517            }
518        }
519    }
520    max_bs
521}
522
523/// `log2(2^a + 2^b)`.
524fn log2_sum_two(a: f64, b: f64) -> f64 {
525    let sum = 2.0f64.powf(a) + 2.0f64.powf(b);
526    sum.log2()
527}
528
529/// `log2(sum_i 2^{bs[i]})`.
530fn log2_sum_n(bs: &[f64]) -> f64 {
531    let sum: f64 = bs.iter().map(|&b| 2.0f64.powf(b)).sum();
532    sum.log2()
533}