1use bytemuck::{cast_slice, cast_slice_mut};
26
27use crate::{
28 layouts::{
29 Backend, HostDataMut, HostDataRef, Module, VecZnxBackendRef, VecZnxBigBackendMut, VecZnxDft, VecZnxDftBackendMut,
30 VecZnxDftBackendRef, ZnxView, ZnxViewMut,
31 },
32 reference::ntt120::{
33 NttAdd, NttAddAssign, NttCopy, NttDFTExecute, NttFromZnx64, NttNegate, NttNegateAssign, NttSub, NttSubAssign,
34 NttSubNegateAssign, NttToZnx128, NttZero,
35 mat_vec::{BbbMeta, BbcMeta},
36 ntt::{NttTable, NttTableInv, intt_ref},
37 primes::{PrimeSet, Primes30},
38 types::Q120bScalar,
39 },
40};
41
42pub trait NttModuleHandle {
61 fn get_ntt_table(&self) -> &NttTable<Primes30>;
63 fn get_intt_table(&self) -> &NttTableInv<Primes30>;
65 fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
67 fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
69}
70
71pub unsafe trait NttHandleProvider {
88 fn get_ntt_table(&self) -> &NttTable<Primes30>;
90 fn get_intt_table(&self) -> &NttTableInv<Primes30>;
92 fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
94 fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
96}
97
98pub unsafe trait NttHandleFactory: Sized {
106 fn create_ntt_handle(n: usize) -> Self;
108
109 fn assert_ntt_runtime_support() {}
111}
112
113impl<B> NttModuleHandle for Module<B>
116where
117 B: Backend,
118 B::Handle: NttHandleProvider,
119{
120 fn get_ntt_table(&self) -> &NttTable<Primes30> {
121 unsafe { (&*self.ptr()).get_ntt_table() }
125 }
126
127 fn get_intt_table(&self) -> &NttTableInv<Primes30> {
128 unsafe { (&*self.ptr()).get_intt_table() }
129 }
130
131 fn get_bbc_meta(&self) -> &BbcMeta<Primes30> {
132 unsafe { (&*self.ptr()).get_bbc_meta() }
133 }
134
135 fn get_bbb_meta(&self) -> &BbbMeta<Primes30> {
136 unsafe { (&*self.ptr()).get_bbb_meta() }
137 }
138}
139
140#[inline(always)]
149fn limb_u64<D: crate::layouts::HostDataRef, BE: Backend<ScalarPrep = Q120bScalar>>(
150 v: &VecZnxDft<D, BE>,
151 col: usize,
152 limb: usize,
153) -> &[u64] {
154 cast_slice(v.at(col, limb))
155}
156
157#[inline(always)]
158fn limb_u64_mut<D: crate::layouts::HostDataMut, BE: Backend<ScalarPrep = Q120bScalar>>(
159 v: &mut VecZnxDft<D, BE>,
160 col: usize,
161 limb: usize,
162) -> &mut [u64] {
163 cast_slice_mut(v.at_mut(col, limb))
164}
165
166pub fn ntt120_vec_znx_dft_apply<BE>(
178 module: &impl NttModuleHandle,
179 step: usize,
180 offset: usize,
181 res: &mut VecZnxDftBackendMut<'_, BE>,
182 res_col: usize,
183 a: &VecZnxBackendRef<'_, BE>,
184 a_col: usize,
185) where
186 BE: Backend<ScalarPrep = Q120bScalar> + NttDFTExecute<NttTable<Primes30>> + NttFromZnx64 + NttZero + 'static,
187 for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
188{
189 let a_size = a.size();
190 let res_size = res.size();
191
192 let table = module.get_ntt_table();
193
194 let steps = a_size.div_ceil(step);
195 let min_steps = res_size.min(steps);
196
197 for j in 0..min_steps {
198 let limb = offset + j * step;
199 if limb < a_size {
200 let res_slice: &mut [u64] = limb_u64_mut(res, res_col, j);
201 BE::ntt_from_znx64(res_slice, a.at(a_col, limb));
202 BE::ntt_dft_execute(table, res_slice);
203 } else {
204 BE::ntt_zero(limb_u64_mut(res, res_col, j));
205 }
206 }
207
208 for j in min_steps..res_size {
209 BE::ntt_zero(limb_u64_mut(res, res_col, j));
210 }
211}
212
213pub fn ntt120_vec_znx_idft_apply_tmp_bytes(n: usize) -> usize {
221 4 * n * size_of::<u64>()
222}
223
224pub fn ntt120_vec_znx_idft_apply<BE>(
233 module: &impl NttModuleHandle,
234 res: &mut VecZnxBigBackendMut<'_, BE>,
235 res_col: usize,
236 a: &VecZnxDftBackendRef<'_, BE>,
237 a_col: usize,
238 tmp: &mut [u64],
239) where
240 BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128 + NttCopy,
241 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
242 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
243{
244 let n = res.n();
245 let res_size = res.size();
246 let min_size = res_size.min(a.size());
247
248 let table = module.get_intt_table();
249
250 for j in 0..min_size {
251 let a_slice: &[u64] = limb_u64(a, a_col, j);
252 let tmp_n: &mut [u64] = &mut tmp[..4 * n];
253 BE::ntt_copy(tmp_n, a_slice);
254 BE::ntt_dft_execute(table, tmp_n);
255 BE::ntt_to_znx128(res.at_mut(res_col, j), n, tmp_n);
256 }
257
258 for j in min_size..res_size {
259 res.at_mut(res_col, j).fill(0i128);
260 }
261}
262
263pub fn ntt120_vec_znx_idft_apply_tmpa<BE>(
268 module: &impl NttModuleHandle,
269 res: &mut VecZnxBigBackendMut<'_, BE>,
270 res_col: usize,
271 a: &mut VecZnxDftBackendMut<'_, BE>,
272 a_col: usize,
273) where
274 BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128,
275 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
276{
277 let n = res.n();
278 let res_size = res.size();
279 let min_size = res_size.min(a.size());
280
281 let table = module.get_intt_table();
282
283 for j in 0..min_size {
284 BE::ntt_dft_execute(table, limb_u64_mut(a, a_col, j));
285 let a_slice: &[u64] = limb_u64(a, a_col, j);
286 BE::ntt_to_znx128(res.at_mut(res_col, j), n, a_slice);
287 }
288
289 for j in min_size..res_size {
290 res.at_mut(res_col, j).fill(0i128);
291 }
292}
293
294#[allow(dead_code)]
299pub fn ntt120_vec_znx_idft_apply_consume<'a, BE>(
300 module: &impl NttModuleHandle,
301 mut a: VecZnxDftBackendMut<'a, BE>,
302) -> VecZnxBigBackendMut<'a, BE>
303where
304 BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128>,
305 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
306{
307 let table = module.get_intt_table();
308
309 let (n, n_blocks, u64_ptr) = {
310 let n = a.n();
311 let n_blocks = a.cols() * a.size();
312 let ptr: *mut u64 = {
313 let s = a.raw_mut();
314 cast_slice_mut::<_, u64>(s).as_mut_ptr()
315 };
316 (n, n_blocks, ptr)
317 };
318
319 unsafe { compact_all_blocks_scalar(n, n_blocks, u64_ptr, table) };
320
321 a.into_big()
322}
323
324#[allow(dead_code)]
325#[inline(always)]
326fn barrett_u61(x: u64, q: u64, mu: u64) -> u64 {
327 let q_approx = ((x as u128 * mu as u128) >> 61) as u64;
328 let r = x - q_approx * q;
329 let r = if r >= q { r - q } else { r };
330 if r >= q { r - q } else { r }
331}
332
333#[allow(dead_code)]
334#[inline(always)]
335fn reduce_q120b_crt(x: u64, q: u64, mu: u64, pow32_crt: u64, pow16_crt: u64, crt: u64) -> u64 {
336 let x_hi = x >> 32;
337 let x_hi_r = if x_hi >= q { x_hi - q } else { x_hi };
338 let x_lo = x & 0xFFFF_FFFF;
339 let x_lo_hi = x_lo >> 16;
340 let x_lo_lo = x_lo & 0xFFFF;
341 let tmp = x_hi_r
342 .wrapping_mul(pow32_crt)
343 .wrapping_add(x_lo_hi.wrapping_mul(pow16_crt))
344 .wrapping_add(x_lo_lo.wrapping_mul(crt));
345 barrett_u61(tmp, q, mu)
346}
347
348#[allow(dead_code)]
349unsafe fn compact_all_blocks_scalar(n: usize, n_blocks: usize, u64_ptr: *mut u64, table: &NttTableInv<Primes30>) {
350 let q_u64: [u64; 4] = Primes30::Q.map(|qi| qi as u64);
351 let mu: [u64; 4] = q_u64.map(|qi| (1u64 << 61) / qi);
352 let crt: [u64; 4] = Primes30::CRT_CST.map(|c| c as u64);
353
354 let pow32_crt: [u64; 4] = std::array::from_fn(|k| {
355 let pow32 = ((1u128 << 32) % q_u64[k] as u128) as u64;
356 barrett_u61(pow32 * crt[k], q_u64[k], mu[k])
357 });
358 let pow16_crt: [u64; 4] = std::array::from_fn(|k| barrett_u61((1u64 << 16) * crt[k], q_u64[k], mu[k]));
359
360 let q: [u128; 4] = q_u64.map(|qi| qi as u128);
361 let total_q: u128 = q[0] * q[1] * q[2] * q[3];
362 let qm: [u128; 4] = [q[1] * q[2] * q[3], q[0] * q[2] * q[3], q[0] * q[1] * q[3], q[0] * q[1] * q[2]];
363 let half_q: u128 = total_q.div_ceil(2);
364 let total_q_mult: [u128; 4] = [0, total_q, total_q * 2, total_q * 3];
365
366 for k in 0..n_blocks {
367 let src_start = 4 * n * k;
368 let dst_start = 2 * n * k;
369
370 {
371 let blk: &mut [u64] = unsafe { std::slice::from_raw_parts_mut(u64_ptr.add(src_start), 4 * n) };
372 intt_ref::<Primes30>(table, blk);
373 }
374
375 for c in 0..n {
376 let (x0, x1, x2, x3) = unsafe {
377 (
378 *u64_ptr.add(src_start + 4 * c),
379 *u64_ptr.add(src_start + 4 * c + 1),
380 *u64_ptr.add(src_start + 4 * c + 2),
381 *u64_ptr.add(src_start + 4 * c + 3),
382 )
383 };
384
385 let t0 = reduce_q120b_crt(x0, q_u64[0], mu[0], pow32_crt[0], pow16_crt[0], crt[0]);
386 let t1 = reduce_q120b_crt(x1, q_u64[1], mu[1], pow32_crt[1], pow16_crt[1], crt[1]);
387 let t2 = reduce_q120b_crt(x2, q_u64[2], mu[2], pow32_crt[2], pow16_crt[2], crt[2]);
388 let t3 = reduce_q120b_crt(x3, q_u64[3], mu[3], pow32_crt[3], pow16_crt[3], crt[3]);
389
390 let mut v: u128 = t0 as u128 * qm[0] + t1 as u128 * qm[1] + t2 as u128 * qm[2] + t3 as u128 * qm[3];
391
392 let q_approx = (v >> 120) as usize;
393 v -= total_q_mult[q_approx];
394 if v >= total_q {
395 v -= total_q;
396 }
397
398 let val: i128 = if v >= half_q { v as i128 - total_q as i128 } else { v as i128 };
399
400 unsafe { (u64_ptr.add(dst_start + 2 * c) as *mut i128).write_unaligned(val) };
401 }
402 }
403}
404
405pub fn ntt120_vec_znx_dft_add_into<BE>(
413 res: &mut VecZnxDftBackendMut<'_, BE>,
414 res_col: usize,
415 a: &VecZnxDftBackendRef<'_, BE>,
416 a_col: usize,
417 b: &VecZnxDftBackendRef<'_, BE>,
418 b_col: usize,
419) where
420 BE: Backend<ScalarPrep = Q120bScalar> + NttAdd + NttCopy + NttZero,
421 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
422 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
423{
424 let res_size = res.size();
425 let a_size = a.size();
426 let b_size = b.size();
427
428 if a_size <= b_size {
429 let sum_size = a_size.min(res_size);
430 let cpy_size = b_size.min(res_size);
431 for j in 0..sum_size {
432 BE::ntt_add(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
433 }
434 for j in sum_size..cpy_size {
435 BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(b, b_col, j));
436 }
437 for j in cpy_size..res_size {
438 BE::ntt_zero(limb_u64_mut(res, res_col, j));
439 }
440 } else {
441 let sum_size = b_size.min(res_size);
442 let cpy_size = a_size.min(res_size);
443 for j in 0..sum_size {
444 BE::ntt_add(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
445 }
446 for j in sum_size..cpy_size {
447 BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
448 }
449 for j in cpy_size..res_size {
450 BE::ntt_zero(limb_u64_mut(res, res_col, j));
451 }
452 }
453}
454
455pub fn ntt120_vec_znx_dft_add_assign<BE>(
457 res: &mut VecZnxDftBackendMut<'_, BE>,
458 res_col: usize,
459 a: &VecZnxDftBackendRef<'_, BE>,
460 a_col: usize,
461) where
462 BE: Backend<ScalarPrep = Q120bScalar> + NttAddAssign,
463 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
464 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
465{
466 let sum_size = res.size().min(a.size());
467 for j in 0..sum_size {
468 BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
469 }
470}
471
472pub fn ntt120_vec_znx_dft_add_scaled_assign<BE>(
477 res: &mut VecZnxDftBackendMut<'_, BE>,
478 res_col: usize,
479 a: &VecZnxDftBackendRef<'_, BE>,
480 a_col: usize,
481 a_scale: i64,
482) where
483 BE: Backend<ScalarPrep = Q120bScalar> + NttAddAssign,
484 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
485 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
486{
487 let res_size = res.size();
488 let a_size = a.size();
489
490 if a_scale > 0 {
491 let shift = (a_scale as usize).min(a_size);
492 let sum_size = a_size.min(res_size).saturating_sub(shift);
493 for j in 0..sum_size {
494 BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j + shift));
495 }
496 } else if a_scale < 0 {
497 let shift = (a_scale.unsigned_abs() as usize).min(res_size);
498 let sum_size = a_size.min(res_size.saturating_sub(shift));
499 for j in 0..sum_size {
500 BE::ntt_add_assign(limb_u64_mut(res, res_col, j + shift), limb_u64(a, a_col, j));
501 }
502 } else {
503 let sum_size = a_size.min(res_size);
504 for j in 0..sum_size {
505 BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
506 }
507 }
508}
509
510pub fn ntt120_vec_znx_dft_sub<BE>(
512 res: &mut VecZnxDftBackendMut<'_, BE>,
513 res_col: usize,
514 a: &VecZnxDftBackendRef<'_, BE>,
515 a_col: usize,
516 b: &VecZnxDftBackendRef<'_, BE>,
517 b_col: usize,
518) where
519 BE: Backend<ScalarPrep = Q120bScalar> + NttSub + NttNegate + NttCopy + NttZero,
520 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
521 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
522{
523 let res_size = res.size();
524 let a_size = a.size();
525 let b_size = b.size();
526
527 if a_size <= b_size {
528 let sum_size = a_size.min(res_size);
529 let cpy_size = b_size.min(res_size);
530 for j in 0..sum_size {
531 BE::ntt_sub(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
532 }
533 for j in sum_size..cpy_size {
534 BE::ntt_negate(limb_u64_mut(res, res_col, j), limb_u64(b, b_col, j));
535 }
536 for j in cpy_size..res_size {
537 BE::ntt_zero(limb_u64_mut(res, res_col, j));
538 }
539 } else {
540 let sum_size = b_size.min(res_size);
541 let cpy_size = a_size.min(res_size);
542 for j in 0..sum_size {
543 BE::ntt_sub(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
544 }
545 for j in sum_size..cpy_size {
546 BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
547 }
548 for j in cpy_size..res_size {
549 BE::ntt_zero(limb_u64_mut(res, res_col, j));
550 }
551 }
552}
553
554pub fn ntt120_vec_znx_dft_sub_assign<BE>(
556 res: &mut VecZnxDftBackendMut<'_, BE>,
557 res_col: usize,
558 a: &VecZnxDftBackendRef<'_, BE>,
559 a_col: usize,
560) where
561 BE: Backend<ScalarPrep = Q120bScalar> + NttSubAssign,
562 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
563 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
564{
565 let sum_size = res.size().min(a.size());
566 for j in 0..sum_size {
567 BE::ntt_sub_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
568 }
569}
570
571pub fn ntt120_vec_znx_dft_sub_negate_assign<BE>(
575 res: &mut VecZnxDftBackendMut<'_, BE>,
576 res_col: usize,
577 a: &VecZnxDftBackendRef<'_, BE>,
578 a_col: usize,
579) where
580 BE: Backend<ScalarPrep = Q120bScalar> + NttSubNegateAssign + NttNegateAssign,
581 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
582 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
583{
584 let res_size = res.size();
585 let sum_size = res_size.min(a.size());
586 for j in 0..sum_size {
587 BE::ntt_sub_negate_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
588 }
589 for j in sum_size..res_size {
590 BE::ntt_negate_assign(limb_u64_mut(res, res_col, j));
591 }
592}
593
594pub fn ntt120_vec_znx_dft_copy<BE>(
598 step: usize,
599 offset: usize,
600 res: &mut VecZnxDftBackendMut<'_, BE>,
601 res_col: usize,
602 a: &VecZnxDftBackendRef<'_, BE>,
603 a_col: usize,
604) where
605 BE: Backend<ScalarPrep = Q120bScalar> + NttCopy + NttZero,
606 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
607 for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
608{
609 #[cfg(debug_assertions)]
610 {
611 assert_eq!(res.n(), a.n())
612 }
613
614 let steps: usize = a.size().div_ceil(step);
615 let min_steps: usize = res.size().min(steps);
616
617 for j in 0..min_steps {
618 let limb = offset + j * step;
619 if limb < a.size() {
620 BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, limb));
621 } else {
622 BE::ntt_zero(limb_u64_mut(res, res_col, j));
623 }
624 }
625 for j in min_steps..res.size() {
626 BE::ntt_zero(limb_u64_mut(res, res_col, j));
627 }
628}
629
630pub fn ntt120_vec_znx_dft_zero<BE>(res: &mut VecZnxDftBackendMut<'_, BE>, res_col: usize)
632where
633 BE: Backend<ScalarPrep = Q120bScalar> + NttZero,
634 for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
635{
636 for j in 0..res.size() {
637 BE::ntt_zero(limb_u64_mut(res, res_col, j));
638 }
639}