Skip to main content

poulpy_cpu_ref/reference/fft64/
convolution.rs

1use crate::{
2    layouts::{
3        Backend, CnvPVecLBackendMut, CnvPVecLBackendRef, CnvPVecRBackendMut, CnvPVecRBackendRef, HostDataRef, VecZnxBackendRef,
4        VecZnxBigBackendMut, VecZnxDftBackendMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
5    },
6    reference::fft64::{
7        reim::{ReimArith, ReimFFTExecute, ReimFFTTable},
8        reim4::{Reim4BlkMatVec, Reim4Convolution},
9        vec_znx_dft::vec_znx_dft_apply,
10    },
11};
12
13pub fn convolution_prepare_left<BE>(
14    table: &ReimFFTTable<f64>,
15    res: &mut CnvPVecLBackendMut<'_, BE>,
16    a: &VecZnxBackendRef<'_, BE>,
17    mask: i64,
18    tmp: &mut VecZnxDftBackendMut<'_, BE>,
19) where
20    BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
21    for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
22{
23    convolution_prepare(table, res, a, mask, tmp)
24}
25
26pub fn convolution_prepare_right<BE>(
27    table: &ReimFFTTable<f64>,
28    res: &mut CnvPVecRBackendMut<'_, BE>,
29    a: &VecZnxBackendRef<'_, BE>,
30    mask: i64,
31    tmp: &mut VecZnxDftBackendMut<'_, BE>,
32) where
33    BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
34    for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
35{
36    convolution_prepare(table, res, a, mask, tmp)
37}
38
39fn convolution_prepare<R, BE>(
40    table: &ReimFFTTable<f64>,
41    res: &mut R,
42    a: &VecZnxBackendRef<'_, BE>,
43    mask: i64,
44    tmp: &mut VecZnxDftBackendMut<'_, BE>,
45) where
46    BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
47    for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
48    R: ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
49{
50    let cols: usize = res.cols();
51    assert_eq!(a.cols(), cols, "a.cols():{} != res.cols():{cols}", a.cols());
52
53    let res_size: usize = res.size();
54    let min_size: usize = res_size.min(a.size());
55
56    let m: usize = a.n() >> 1;
57
58    let n: usize = table.m() << 1;
59
60    let res_raw: &mut [f64] = res.raw_mut();
61
62    for i in 0..cols {
63        // FFT all limbs (unmasked); the last active limb will be overwritten below.
64        vec_znx_dft_apply(table, 1, 0, tmp, 0, a, i);
65
66        // Re-compute only the last active limb with the mask applied.
67        if min_size > 0 {
68            let last = min_size - 1;
69            BE::reim_from_znx_masked(tmp.at_mut(0, last), a.at(i, last), mask);
70            BE::reim_dft_execute(table, tmp.at_mut(0, last));
71        }
72
73        let tmp_raw: &[f64] = tmp.raw();
74        let res_col: &mut [f64] = &mut res_raw[i * n * res_size..];
75
76        for blk_i in 0..m / 4 {
77            BE::reim4_extract_1blk_contiguous(m, min_size, blk_i, &mut res_col[blk_i * res_size * 8..], tmp_raw);
78            BE::reim_zero(&mut res_col[blk_i * res_size * 8 + min_size * 8..(blk_i + 1) * res_size * 8]);
79        }
80    }
81}
82
83pub fn convolution_prepare_self<BE>(
84    table: &ReimFFTTable<f64>,
85    left: &mut CnvPVecLBackendMut<'_, BE>,
86    right: &mut CnvPVecRBackendMut<'_, BE>,
87    a: &VecZnxBackendRef<'_, BE>,
88    mask: i64,
89    tmp: &mut VecZnxDftBackendMut<'_, BE>,
90) where
91    BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
92    for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
93{
94    let cols: usize = left.cols();
95    assert_eq!(a.cols(), cols, "a.cols():{} != left.cols():{cols}", a.cols());
96    assert_eq!(right.cols(), cols, "right.cols():{} != left.cols():{cols}", right.cols());
97
98    let left_size: usize = left.size();
99    let right_size: usize = right.size();
100    assert_eq!(
101        left_size, right_size,
102        "left.size():{} != right.size():{right_size}",
103        left_size
104    );
105    let res_size: usize = left_size;
106    let min_size: usize = res_size.min(a.size());
107
108    let m: usize = a.n() >> 1;
109    let n: usize = table.m() << 1;
110
111    let left_raw: &mut [f64] = left.raw_mut();
112    let right_raw: &mut [f64] = right.raw_mut();
113
114    for i in 0..cols {
115        // FFT all limbs (unmasked); the last active limb will be overwritten below.
116        vec_znx_dft_apply(table, 1, 0, tmp, 0, a, i);
117
118        // Re-compute only the last active limb with the mask applied.
119        if min_size > 0 {
120            let last = min_size - 1;
121            BE::reim_from_znx_masked(tmp.at_mut(0, last), a.at(i, last), mask);
122            BE::reim_dft_execute(table, tmp.at_mut(0, last));
123        }
124
125        let tmp_raw: &[f64] = tmp.raw();
126        let left_col: &mut [f64] = &mut left_raw[i * n * res_size..];
127
128        for blk_i in 0..m / 4 {
129            BE::reim4_extract_1blk_contiguous(m, min_size, blk_i, &mut left_col[blk_i * res_size * 8..], tmp_raw);
130            BE::reim_zero(&mut left_col[blk_i * res_size * 8 + min_size * 8..(blk_i + 1) * res_size * 8]);
131        }
132
133        // Copy from left to right (identical data for FFT64)
134        let col_bytes: usize = n * res_size;
135        let right_col: &mut [f64] = &mut right_raw[i * n * res_size..];
136        right_col[..col_bytes].copy_from_slice(&left_col[..col_bytes]);
137    }
138}
139
140pub fn convolution_by_const_apply_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
141    let min_size: usize = res_size.min(a_size + b_size - 1);
142    size_of::<i64>() * (min_size + a_size) * 8
143}
144
145pub fn convolution_by_const_apply<BE>(
146    cnv_offset: usize,
147    res: &mut VecZnxBigBackendMut<'_, BE>,
148    res_col: usize,
149    a: &VecZnxBackendRef<'_, BE>,
150    a_col: usize,
151    b: &VecZnxBackendRef<'_, BE>,
152    b_col: usize,
153    b_coeff: usize,
154    tmp: &mut [i64],
155) where
156    BE: Backend<ScalarBig = i64> + I64Ops + 'static,
157    for<'x> BE: Backend<BufRef<'x> = &'x [u8]>,
158    for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
159{
160    let n: usize = res.n();
161    assert_eq!(a.n(), n);
162
163    let res_size: usize = res.size();
164    let a_size: usize = a.size();
165    let b_size: usize = b.size();
166
167    let bound: usize = a_size + b_size - 1;
168    let min_size: usize = res_size.min(bound);
169    let offset: usize = cnv_offset.min(bound);
170
171    let a_sl: usize = n * a.cols();
172    let res_sl: usize = n * res.cols();
173
174    let res_raw: &mut [i64] = res.raw_mut();
175    let a_raw: &[i64] = a.raw();
176
177    let a_idx: usize = n * a_col;
178    let res_idx: usize = n * res_col;
179
180    let (res_blk, a_blk) = tmp[..(min_size + a_size) * 8].split_at_mut(min_size * 8);
181    let mut b_digits = vec![0i64; b_size];
182    for (j, digit) in b_digits.iter_mut().enumerate() {
183        *digit = b.at(b_col, j)[b_coeff];
184    }
185
186    for blk_i in 0..n / 8 {
187        BE::i64_extract_1blk_contiguous(a_sl, a_idx, a_size, blk_i, a_blk, a_raw);
188        BE::i64_convolution_by_const(res_blk, min_size, offset, a_blk, a_size, &b_digits);
189        BE::i64_save_1blk_contiguous(res_sl, res_idx, min_size, blk_i, res_raw, res_blk);
190    }
191
192    for j in min_size..res_size {
193        res.zero_at(res_col, j);
194    }
195}
196
197pub fn convolution_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
198    let min_size: usize = res_size.min(a_size + b_size - 1);
199    size_of::<f64>() * 8 * min_size
200}
201
202#[allow(clippy::too_many_arguments)]
203pub fn convolution_apply_dft<BE>(
204    cnv_offset: usize,
205    res: &mut VecZnxDftBackendMut<'_, BE>,
206    res_col: usize,
207    a: &CnvPVecLBackendRef<'_, BE>,
208    a_col: usize,
209    b: &CnvPVecRBackendRef<'_, BE>,
210    b_col: usize,
211    tmp: &mut [f64],
212) where
213    BE: Backend<ScalarPrep = f64> + Reim4BlkMatVec + Reim4Convolution,
214    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
215    for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
216{
217    let n: usize = res.n();
218    assert_eq!(a.n(), n);
219    assert_eq!(b.n(), n);
220    let m: usize = n >> 1;
221
222    let res_size: usize = res.size();
223    let a_size: usize = a.size();
224    let b_size: usize = b.size();
225
226    let bound: usize = a_size + b_size - 1;
227    let min_size: usize = res_size.min(bound);
228    let offset: usize = cnv_offset.min(bound);
229
230    let dst: &mut [f64] = res.raw_mut();
231    let a_raw: &[f64] = a.raw();
232    let b_raw: &[f64] = b.raw();
233
234    let mut a_idx: usize = a_col * n * a_size;
235    let mut b_idx: usize = b_col * n * b_size;
236    let a_offset: usize = a_size * 8;
237    let b_offset: usize = b_size * 8;
238    for blk_i in 0..m / 4 {
239        BE::reim4_convolution(tmp, min_size, offset, &a_raw[a_idx..], a_size, &b_raw[b_idx..], b_size);
240        BE::reim4_save_1blk_contiguous(m, min_size, blk_i, dst, tmp);
241        a_idx += a_offset;
242        b_idx += b_offset;
243    }
244
245    for j in min_size..res_size {
246        res.zero_at(res_col, j);
247    }
248}
249
250pub fn convolution_pairwise_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
251    convolution_apply_dft_tmp_bytes(res_size, a_size, b_size) + (a_size + b_size) * size_of::<f64>() * 8
252}
253
254#[allow(clippy::too_many_arguments)]
255pub fn convolution_pairwise_apply_dft<BE>(
256    cnv_offset: usize,
257    res: &mut VecZnxDftBackendMut<'_, BE>,
258    res_col: usize,
259    a: &CnvPVecLBackendRef<'_, BE>,
260    b: &CnvPVecRBackendRef<'_, BE>,
261    col_i: usize,
262    col_j: usize,
263    tmp: &mut [f64],
264) where
265    BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + Reim4Convolution,
266    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
267    for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
268{
269    if col_i == col_j {
270        convolution_apply_dft(cnv_offset, res, res_col, a, col_i, b, col_j, tmp);
271        return;
272    }
273
274    let n: usize = res.n();
275    let m: usize = n >> 1;
276
277    assert_eq!(a.n(), n);
278    assert_eq!(b.n(), n);
279
280    let res_size: usize = res.size();
281    let a_size: usize = a.size();
282    let b_size: usize = b.size();
283
284    assert_eq!(
285        tmp.len(),
286        convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size) / size_of::<f64>()
287    );
288
289    let bound: usize = a_size + b_size - 1;
290    let min_size: usize = res_size.min(bound);
291    let offset: usize = cnv_offset.min(bound);
292
293    let res_raw: &mut [f64] = res.raw_mut();
294    let a_raw: &[f64] = a.raw();
295    let b_raw: &[f64] = b.raw();
296
297    let a_row_size: usize = a_size * 8;
298    let b_row_size: usize = b_size * 8;
299
300    let mut a0_idx: usize = col_i * n * a_size;
301    let mut a1_idx: usize = col_j * n * a_size;
302    let mut b0_idx: usize = col_i * n * b_size;
303    let mut b1_idx: usize = col_j * n * b_size;
304
305    let (tmp_a, tmp) = tmp.split_at_mut(a_row_size);
306    let (tmp_b, tmp_res) = tmp.split_at_mut(b_row_size);
307
308    for blk_i in 0..m / 4 {
309        let a0: &[f64] = &a_raw[a0_idx..];
310        let a1: &[f64] = &a_raw[a1_idx..];
311        let b0: &[f64] = &b_raw[b0_idx..];
312        let b1: &[f64] = &b_raw[b1_idx..];
313
314        BE::reim_add(tmp_a, &a0[..a_row_size], &a1[..a_row_size]);
315        BE::reim_add(tmp_b, &b0[..b_row_size], &b1[..b_row_size]);
316
317        BE::reim4_convolution(tmp_res, min_size, offset, tmp_a, a_size, tmp_b, b_size);
318        BE::reim4_save_1blk_contiguous(m, min_size, blk_i, res_raw, tmp_res);
319
320        a0_idx += a_row_size;
321        a1_idx += a_row_size;
322        b0_idx += b_row_size;
323        b1_idx += b_row_size;
324    }
325
326    for j in min_size..res_size {
327        res.zero_at(res_col, j);
328    }
329}
330
331pub trait I64Ops {
332    fn i64_hadamard_product(res: &mut [i64], a: &[i64], b: &[i64]) {
333        debug_assert_eq!(res.len(), a.len());
334        debug_assert_eq!(res.len(), b.len());
335
336        res.iter_mut()
337            .zip(a.iter())
338            .zip(b.iter())
339            .for_each(|((r, &ai), &bi)| *r = ai.wrapping_mul(bi));
340    }
341
342    fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
343        i64_extract_1blk_contiguous_ref(n, offset, rows, blk, dst, src)
344    }
345
346    fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
347        i64_save_1blk_contiguous_ref(n, offset, rows, blk, dst, src)
348    }
349
350    fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
351        i64_convolution_by_const_1coeff_ref(k, dst, a, a_size, b)
352    }
353
354    fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
355        i64_convolution_by_const_2coeffs_ref(k, dst, a, a_size, b)
356    }
357
358    fn i64_convolution_by_const(dst: &mut [i64], dst_size: usize, offset: usize, a: &[i64], a_size: usize, b: &[i64]) {
359        assert!(a_size > 0);
360
361        for k in (0..dst_size - 1).step_by(2) {
362            Self::i64_convolution_by_const_2coeffs(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
363        }
364
365        if !dst_size.is_multiple_of(2) {
366            let k: usize = dst_size - 1;
367            Self::i64_convolution_by_const_1coeff(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
368        }
369    }
370}
371
372#[inline(always)]
373pub fn i64_extract_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
374    debug_assert!(blk < (n >> 3));
375    debug_assert!(dst.len() >= rows * 8, "dst.len(): {} < rows*8: {}", dst.len(), 8 * rows);
376
377    let offset: usize = offset + (blk << 3);
378
379    // src = 8-values chunks spaced by n, dst = sequential 8-values chunks
380    let src_rows = src.chunks_exact(n).take(rows);
381    let dst_chunks = dst.chunks_exact_mut(8).take(rows);
382
383    for (dst_chunk, src_row) in dst_chunks.zip(src_rows) {
384        dst_chunk.copy_from_slice(&src_row[offset..offset + 8]);
385    }
386}
387
388#[inline(always)]
389pub fn i64_save_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
390    debug_assert!(blk < (n >> 3));
391    debug_assert!(src.len() >= rows * 8);
392
393    let offset: usize = offset + (blk << 3);
394
395    // dst = 4-values chunks spaced by m, src = sequential 4-values chunks
396    let dst_rows = dst.chunks_exact_mut(n).take(rows);
397    let src_chunks = src.chunks_exact(8).take(rows);
398
399    for (dst_row, src_chunk) in dst_rows.zip(src_chunks) {
400        dst_row[offset..offset + 8].copy_from_slice(src_chunk);
401    }
402}
403
404#[inline(always)]
405pub fn i64_convolution_by_const_1coeff_ref(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
406    dst.fill(0);
407
408    let b_size: usize = b.len();
409
410    if k >= a_size + b_size {
411        return;
412    }
413    let j_min: usize = k.saturating_sub(a_size - 1);
414    let j_max: usize = (k + 1).min(b_size);
415
416    for j in j_min..j_max {
417        let ai: &[i64] = &a[8 * (k - j)..];
418        let bi: i64 = b[j];
419
420        dst[0] = dst[0].wrapping_add(ai[0].wrapping_mul(bi));
421        dst[1] = dst[1].wrapping_add(ai[1].wrapping_mul(bi));
422        dst[2] = dst[2].wrapping_add(ai[2].wrapping_mul(bi));
423        dst[3] = dst[3].wrapping_add(ai[3].wrapping_mul(bi));
424        dst[4] = dst[4].wrapping_add(ai[4].wrapping_mul(bi));
425        dst[5] = dst[5].wrapping_add(ai[5].wrapping_mul(bi));
426        dst[6] = dst[6].wrapping_add(ai[6].wrapping_mul(bi));
427        dst[7] = dst[7].wrapping_add(ai[7].wrapping_mul(bi));
428    }
429}
430
431#[allow(dead_code)]
432#[inline(always)]
433pub(crate) fn as_arr_i64<const SIZE: usize>(x: &[i64]) -> &[i64; SIZE] {
434    debug_assert!(x.len() >= SIZE, "x.len():{} < size:{}", x.len(), SIZE);
435    unsafe { &*(x.as_ptr() as *const [i64; SIZE]) }
436}
437
438#[allow(dead_code)]
439#[inline(always)]
440pub(crate) fn as_arr_i64_mut<const SIZE: usize>(x: &mut [i64]) -> &mut [i64; SIZE] {
441    debug_assert!(x.len() >= SIZE, "x.len():{} < size:{}", x.len(), SIZE);
442    unsafe { &mut *(x.as_mut_ptr() as *mut [i64; SIZE]) }
443}
444
445#[inline(always)]
446pub fn i64_convolution_by_const_2coeffs_ref(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
447    i64_convolution_by_const_1coeff_ref(k, as_arr_i64_mut(&mut dst[..8]), a, a_size, b);
448    i64_convolution_by_const_1coeff_ref(k + 1, as_arr_i64_mut(&mut dst[8..]), a, a_size, b);
449}