1use crate::{
2 layouts::{
3 Backend, CnvPVecLBackendMut, CnvPVecLBackendRef, CnvPVecRBackendMut, CnvPVecRBackendRef, HostDataRef, VecZnxBackendRef,
4 VecZnxBigBackendMut, VecZnxDftBackendMut, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
5 },
6 reference::fft64::{
7 reim::{ReimArith, ReimFFTExecute, ReimFFTTable},
8 reim4::{Reim4BlkMatVec, Reim4Convolution},
9 vec_znx_dft::vec_znx_dft_apply,
10 },
11};
12
13pub fn convolution_prepare_left<BE>(
14 table: &ReimFFTTable<f64>,
15 res: &mut CnvPVecLBackendMut<'_, BE>,
16 a: &VecZnxBackendRef<'_, BE>,
17 mask: i64,
18 tmp: &mut VecZnxDftBackendMut<'_, BE>,
19) where
20 BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
21 for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
22{
23 convolution_prepare(table, res, a, mask, tmp)
24}
25
26pub fn convolution_prepare_right<BE>(
27 table: &ReimFFTTable<f64>,
28 res: &mut CnvPVecRBackendMut<'_, BE>,
29 a: &VecZnxBackendRef<'_, BE>,
30 mask: i64,
31 tmp: &mut VecZnxDftBackendMut<'_, BE>,
32) where
33 BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
34 for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
35{
36 convolution_prepare(table, res, a, mask, tmp)
37}
38
39fn convolution_prepare<R, BE>(
40 table: &ReimFFTTable<f64>,
41 res: &mut R,
42 a: &VecZnxBackendRef<'_, BE>,
43 mask: i64,
44 tmp: &mut VecZnxDftBackendMut<'_, BE>,
45) where
46 BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
47 for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
48 R: ZnxInfos + ZnxViewMut<Scalar = BE::ScalarPrep>,
49{
50 let cols: usize = res.cols();
51 assert_eq!(a.cols(), cols, "a.cols():{} != res.cols():{cols}", a.cols());
52
53 let res_size: usize = res.size();
54 let min_size: usize = res_size.min(a.size());
55
56 let m: usize = a.n() >> 1;
57
58 let n: usize = table.m() << 1;
59
60 let res_raw: &mut [f64] = res.raw_mut();
61
62 for i in 0..cols {
63 vec_znx_dft_apply(table, 1, 0, tmp, 0, a, i);
65
66 if min_size > 0 {
68 let last = min_size - 1;
69 BE::reim_from_znx_masked(tmp.at_mut(0, last), a.at(i, last), mask);
70 BE::reim_dft_execute(table, tmp.at_mut(0, last));
71 }
72
73 let tmp_raw: &[f64] = tmp.raw();
74 let res_col: &mut [f64] = &mut res_raw[i * n * res_size..];
75
76 for blk_i in 0..m / 4 {
77 BE::reim4_extract_1blk_contiguous(m, min_size, blk_i, &mut res_col[blk_i * res_size * 8..], tmp_raw);
78 BE::reim_zero(&mut res_col[blk_i * res_size * 8 + min_size * 8..(blk_i + 1) * res_size * 8]);
79 }
80 }
81}
82
83pub fn convolution_prepare_self<BE>(
84 table: &ReimFFTTable<f64>,
85 left: &mut CnvPVecLBackendMut<'_, BE>,
86 right: &mut CnvPVecRBackendMut<'_, BE>,
87 a: &VecZnxBackendRef<'_, BE>,
88 mask: i64,
89 tmp: &mut VecZnxDftBackendMut<'_, BE>,
90) where
91 BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + ReimFFTExecute<ReimFFTTable<f64>, f64> + 'static,
92 for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
93{
94 let cols: usize = left.cols();
95 assert_eq!(a.cols(), cols, "a.cols():{} != left.cols():{cols}", a.cols());
96 assert_eq!(right.cols(), cols, "right.cols():{} != left.cols():{cols}", right.cols());
97
98 let left_size: usize = left.size();
99 let right_size: usize = right.size();
100 assert_eq!(
101 left_size, right_size,
102 "left.size():{} != right.size():{right_size}",
103 left_size
104 );
105 let res_size: usize = left_size;
106 let min_size: usize = res_size.min(a.size());
107
108 let m: usize = a.n() >> 1;
109 let n: usize = table.m() << 1;
110
111 let left_raw: &mut [f64] = left.raw_mut();
112 let right_raw: &mut [f64] = right.raw_mut();
113
114 for i in 0..cols {
115 vec_znx_dft_apply(table, 1, 0, tmp, 0, a, i);
117
118 if min_size > 0 {
120 let last = min_size - 1;
121 BE::reim_from_znx_masked(tmp.at_mut(0, last), a.at(i, last), mask);
122 BE::reim_dft_execute(table, tmp.at_mut(0, last));
123 }
124
125 let tmp_raw: &[f64] = tmp.raw();
126 let left_col: &mut [f64] = &mut left_raw[i * n * res_size..];
127
128 for blk_i in 0..m / 4 {
129 BE::reim4_extract_1blk_contiguous(m, min_size, blk_i, &mut left_col[blk_i * res_size * 8..], tmp_raw);
130 BE::reim_zero(&mut left_col[blk_i * res_size * 8 + min_size * 8..(blk_i + 1) * res_size * 8]);
131 }
132
133 let col_bytes: usize = n * res_size;
135 let right_col: &mut [f64] = &mut right_raw[i * n * res_size..];
136 right_col[..col_bytes].copy_from_slice(&left_col[..col_bytes]);
137 }
138}
139
140pub fn convolution_by_const_apply_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
141 let min_size: usize = res_size.min(a_size + b_size - 1);
142 size_of::<i64>() * (min_size + a_size) * 8
143}
144
145pub fn convolution_by_const_apply<BE>(
146 cnv_offset: usize,
147 res: &mut VecZnxBigBackendMut<'_, BE>,
148 res_col: usize,
149 a: &VecZnxBackendRef<'_, BE>,
150 a_col: usize,
151 b: &VecZnxBackendRef<'_, BE>,
152 b_col: usize,
153 b_coeff: usize,
154 tmp: &mut [i64],
155) where
156 BE: Backend<ScalarBig = i64> + I64Ops + 'static,
157 for<'x> BE: Backend<BufRef<'x> = &'x [u8]>,
158 for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
159{
160 let n: usize = res.n();
161 assert_eq!(a.n(), n);
162
163 let res_size: usize = res.size();
164 let a_size: usize = a.size();
165 let b_size: usize = b.size();
166
167 let bound: usize = a_size + b_size - 1;
168 let min_size: usize = res_size.min(bound);
169 let offset: usize = cnv_offset.min(bound);
170
171 let a_sl: usize = n * a.cols();
172 let res_sl: usize = n * res.cols();
173
174 let res_raw: &mut [i64] = res.raw_mut();
175 let a_raw: &[i64] = a.raw();
176
177 let a_idx: usize = n * a_col;
178 let res_idx: usize = n * res_col;
179
180 let (res_blk, a_blk) = tmp[..(min_size + a_size) * 8].split_at_mut(min_size * 8);
181 let mut b_digits = vec![0i64; b_size];
182 for (j, digit) in b_digits.iter_mut().enumerate() {
183 *digit = b.at(b_col, j)[b_coeff];
184 }
185
186 for blk_i in 0..n / 8 {
187 BE::i64_extract_1blk_contiguous(a_sl, a_idx, a_size, blk_i, a_blk, a_raw);
188 BE::i64_convolution_by_const(res_blk, min_size, offset, a_blk, a_size, &b_digits);
189 BE::i64_save_1blk_contiguous(res_sl, res_idx, min_size, blk_i, res_raw, res_blk);
190 }
191
192 for j in min_size..res_size {
193 res.zero_at(res_col, j);
194 }
195}
196
197pub fn convolution_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
198 let min_size: usize = res_size.min(a_size + b_size - 1);
199 size_of::<f64>() * 8 * min_size
200}
201
202#[allow(clippy::too_many_arguments)]
203pub fn convolution_apply_dft<BE>(
204 cnv_offset: usize,
205 res: &mut VecZnxDftBackendMut<'_, BE>,
206 res_col: usize,
207 a: &CnvPVecLBackendRef<'_, BE>,
208 a_col: usize,
209 b: &CnvPVecRBackendRef<'_, BE>,
210 b_col: usize,
211 tmp: &mut [f64],
212) where
213 BE: Backend<ScalarPrep = f64> + Reim4BlkMatVec + Reim4Convolution,
214 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
215 for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
216{
217 let n: usize = res.n();
218 assert_eq!(a.n(), n);
219 assert_eq!(b.n(), n);
220 let m: usize = n >> 1;
221
222 let res_size: usize = res.size();
223 let a_size: usize = a.size();
224 let b_size: usize = b.size();
225
226 let bound: usize = a_size + b_size - 1;
227 let min_size: usize = res_size.min(bound);
228 let offset: usize = cnv_offset.min(bound);
229
230 let dst: &mut [f64] = res.raw_mut();
231 let a_raw: &[f64] = a.raw();
232 let b_raw: &[f64] = b.raw();
233
234 let mut a_idx: usize = a_col * n * a_size;
235 let mut b_idx: usize = b_col * n * b_size;
236 let a_offset: usize = a_size * 8;
237 let b_offset: usize = b_size * 8;
238 for blk_i in 0..m / 4 {
239 BE::reim4_convolution(tmp, min_size, offset, &a_raw[a_idx..], a_size, &b_raw[b_idx..], b_size);
240 BE::reim4_save_1blk_contiguous(m, min_size, blk_i, dst, tmp);
241 a_idx += a_offset;
242 b_idx += b_offset;
243 }
244
245 for j in min_size..res_size {
246 res.zero_at(res_col, j);
247 }
248}
249
250pub fn convolution_pairwise_apply_dft_tmp_bytes(res_size: usize, a_size: usize, b_size: usize) -> usize {
251 convolution_apply_dft_tmp_bytes(res_size, a_size, b_size) + (a_size + b_size) * size_of::<f64>() * 8
252}
253
254#[allow(clippy::too_many_arguments)]
255pub fn convolution_pairwise_apply_dft<BE>(
256 cnv_offset: usize,
257 res: &mut VecZnxDftBackendMut<'_, BE>,
258 res_col: usize,
259 a: &CnvPVecLBackendRef<'_, BE>,
260 b: &CnvPVecRBackendRef<'_, BE>,
261 col_i: usize,
262 col_j: usize,
263 tmp: &mut [f64],
264) where
265 BE: Backend<ScalarPrep = f64> + ReimArith + Reim4BlkMatVec + Reim4Convolution,
266 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
267 for<'x> <BE as Backend>::BufMut<'x>: crate::layouts::HostDataMut,
268{
269 if col_i == col_j {
270 convolution_apply_dft(cnv_offset, res, res_col, a, col_i, b, col_j, tmp);
271 return;
272 }
273
274 let n: usize = res.n();
275 let m: usize = n >> 1;
276
277 assert_eq!(a.n(), n);
278 assert_eq!(b.n(), n);
279
280 let res_size: usize = res.size();
281 let a_size: usize = a.size();
282 let b_size: usize = b.size();
283
284 assert_eq!(
285 tmp.len(),
286 convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size) / size_of::<f64>()
287 );
288
289 let bound: usize = a_size + b_size - 1;
290 let min_size: usize = res_size.min(bound);
291 let offset: usize = cnv_offset.min(bound);
292
293 let res_raw: &mut [f64] = res.raw_mut();
294 let a_raw: &[f64] = a.raw();
295 let b_raw: &[f64] = b.raw();
296
297 let a_row_size: usize = a_size * 8;
298 let b_row_size: usize = b_size * 8;
299
300 let mut a0_idx: usize = col_i * n * a_size;
301 let mut a1_idx: usize = col_j * n * a_size;
302 let mut b0_idx: usize = col_i * n * b_size;
303 let mut b1_idx: usize = col_j * n * b_size;
304
305 let (tmp_a, tmp) = tmp.split_at_mut(a_row_size);
306 let (tmp_b, tmp_res) = tmp.split_at_mut(b_row_size);
307
308 for blk_i in 0..m / 4 {
309 let a0: &[f64] = &a_raw[a0_idx..];
310 let a1: &[f64] = &a_raw[a1_idx..];
311 let b0: &[f64] = &b_raw[b0_idx..];
312 let b1: &[f64] = &b_raw[b1_idx..];
313
314 BE::reim_add(tmp_a, &a0[..a_row_size], &a1[..a_row_size]);
315 BE::reim_add(tmp_b, &b0[..b_row_size], &b1[..b_row_size]);
316
317 BE::reim4_convolution(tmp_res, min_size, offset, tmp_a, a_size, tmp_b, b_size);
318 BE::reim4_save_1blk_contiguous(m, min_size, blk_i, res_raw, tmp_res);
319
320 a0_idx += a_row_size;
321 a1_idx += a_row_size;
322 b0_idx += b_row_size;
323 b1_idx += b_row_size;
324 }
325
326 for j in min_size..res_size {
327 res.zero_at(res_col, j);
328 }
329}
330
331pub trait I64Ops {
332 fn i64_hadamard_product(res: &mut [i64], a: &[i64], b: &[i64]) {
333 debug_assert_eq!(res.len(), a.len());
334 debug_assert_eq!(res.len(), b.len());
335
336 res.iter_mut()
337 .zip(a.iter())
338 .zip(b.iter())
339 .for_each(|((r, &ai), &bi)| *r = ai.wrapping_mul(bi));
340 }
341
342 fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
343 i64_extract_1blk_contiguous_ref(n, offset, rows, blk, dst, src)
344 }
345
346 fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
347 i64_save_1blk_contiguous_ref(n, offset, rows, blk, dst, src)
348 }
349
350 fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
351 i64_convolution_by_const_1coeff_ref(k, dst, a, a_size, b)
352 }
353
354 fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
355 i64_convolution_by_const_2coeffs_ref(k, dst, a, a_size, b)
356 }
357
358 fn i64_convolution_by_const(dst: &mut [i64], dst_size: usize, offset: usize, a: &[i64], a_size: usize, b: &[i64]) {
359 assert!(a_size > 0);
360
361 for k in (0..dst_size - 1).step_by(2) {
362 Self::i64_convolution_by_const_2coeffs(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
363 }
364
365 if !dst_size.is_multiple_of(2) {
366 let k: usize = dst_size - 1;
367 Self::i64_convolution_by_const_1coeff(k + offset, as_arr_i64_mut(&mut dst[8 * k..]), a, a_size, b);
368 }
369 }
370}
371
372#[inline(always)]
373pub fn i64_extract_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
374 debug_assert!(blk < (n >> 3));
375 debug_assert!(dst.len() >= rows * 8, "dst.len(): {} < rows*8: {}", dst.len(), 8 * rows);
376
377 let offset: usize = offset + (blk << 3);
378
379 let src_rows = src.chunks_exact(n).take(rows);
381 let dst_chunks = dst.chunks_exact_mut(8).take(rows);
382
383 for (dst_chunk, src_row) in dst_chunks.zip(src_rows) {
384 dst_chunk.copy_from_slice(&src_row[offset..offset + 8]);
385 }
386}
387
388#[inline(always)]
389pub fn i64_save_1blk_contiguous_ref(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
390 debug_assert!(blk < (n >> 3));
391 debug_assert!(src.len() >= rows * 8);
392
393 let offset: usize = offset + (blk << 3);
394
395 let dst_rows = dst.chunks_exact_mut(n).take(rows);
397 let src_chunks = src.chunks_exact(8).take(rows);
398
399 for (dst_row, src_chunk) in dst_rows.zip(src_chunks) {
400 dst_row[offset..offset + 8].copy_from_slice(src_chunk);
401 }
402}
403
404#[inline(always)]
405pub fn i64_convolution_by_const_1coeff_ref(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
406 dst.fill(0);
407
408 let b_size: usize = b.len();
409
410 if k >= a_size + b_size {
411 return;
412 }
413 let j_min: usize = k.saturating_sub(a_size - 1);
414 let j_max: usize = (k + 1).min(b_size);
415
416 for j in j_min..j_max {
417 let ai: &[i64] = &a[8 * (k - j)..];
418 let bi: i64 = b[j];
419
420 dst[0] = dst[0].wrapping_add(ai[0].wrapping_mul(bi));
421 dst[1] = dst[1].wrapping_add(ai[1].wrapping_mul(bi));
422 dst[2] = dst[2].wrapping_add(ai[2].wrapping_mul(bi));
423 dst[3] = dst[3].wrapping_add(ai[3].wrapping_mul(bi));
424 dst[4] = dst[4].wrapping_add(ai[4].wrapping_mul(bi));
425 dst[5] = dst[5].wrapping_add(ai[5].wrapping_mul(bi));
426 dst[6] = dst[6].wrapping_add(ai[6].wrapping_mul(bi));
427 dst[7] = dst[7].wrapping_add(ai[7].wrapping_mul(bi));
428 }
429}
430
431#[allow(dead_code)]
432#[inline(always)]
433pub(crate) fn as_arr_i64<const SIZE: usize>(x: &[i64]) -> &[i64; SIZE] {
434 debug_assert!(x.len() >= SIZE, "x.len():{} < size:{}", x.len(), SIZE);
435 unsafe { &*(x.as_ptr() as *const [i64; SIZE]) }
436}
437
438#[allow(dead_code)]
439#[inline(always)]
440pub(crate) fn as_arr_i64_mut<const SIZE: usize>(x: &mut [i64]) -> &mut [i64; SIZE] {
441 debug_assert!(x.len() >= SIZE, "x.len():{} < size:{}", x.len(), SIZE);
442 unsafe { &mut *(x.as_mut_ptr() as *mut [i64; SIZE]) }
443}
444
445#[inline(always)]
446pub fn i64_convolution_by_const_2coeffs_ref(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
447 i64_convolution_by_const_1coeff_ref(k, as_arr_i64_mut(&mut dst[..8]), a, a_size, b);
448 i64_convolution_by_const_1coeff_ref(k + 1, as_arr_i64_mut(&mut dst[8..]), a, a_size, b);
449}