Skip to main content

poulpy_cpu_ref/reference/ntt120/
vec_znx_dft.rs

1//! NTT-domain vector polynomial operations for the NTT120 backend.
2//!
3//! This module provides:
4//!
5//! - The [`NttModuleHandle`] trait, which exposes precomputed NTT/iNTT
6//!   tables and multiply–accumulate metadata from a module handle.
7//! - Forward (`ntt120_vec_znx_dft_apply`) and inverse
8//!   (`ntt120_vec_znx_idft_apply`, `ntt120_vec_znx_idft_apply_tmpa`) DFT
9//!   operations.
10//! - Component-wise DFT-domain arithmetic (add, sub, negate, copy, zero).
11//!
12//! # Scalar layout
13//!
14//! `VecZnxDft<_, NTT120Ref>` stores [`Q120bScalar`] values (32 bytes each).
15//! Each `Q120bScalar` holds four `u64` CRT residues for one ring coefficient.
16//! A `bytemuck::cast_slice` converts a `&[Q120bScalar]` limb slice to
17//! `&[u64]` for use with the primitive NTT arithmetic functions.
18//!
19//! # Prime set
20//!
21//! All arithmetic is hardcoded to [`Primes30`] (the spqlios-arithmetic
22//! default, Q ≈ 2^120).  Generalisation to `Primes29` / `Primes31`
23//! is future work.
24
25use bytemuck::{cast_slice, cast_slice_mut};
26
27use crate::{
28    layouts::{
29        Backend, HostDataMut, HostDataRef, Module, VecZnxBackendRef, VecZnxBigBackendMut, VecZnxDft, VecZnxDftBackendMut,
30        VecZnxDftBackendRef, ZnxView, ZnxViewMut,
31    },
32    reference::ntt120::{
33        NttAdd, NttAddAssign, NttCopy, NttDFTExecute, NttFromZnx64, NttNegate, NttNegateAssign, NttSub, NttSubAssign,
34        NttSubNegateAssign, NttToZnx128, NttZero,
35        mat_vec::{BbbMeta, BbcMeta},
36        ntt::{NttTable, NttTableInv, intt_ref},
37        primes::{PrimeSet, Primes30},
38        types::Q120bScalar,
39    },
40};
41
42// ──────────────────────────────────────────────────────────────────────────────
43// NttModuleHandle trait + NttHandleProvider blanket impl
44// ──────────────────────────────────────────────────────────────────────────────
45
46// TODO(ntt120): Associate PrimeSet with NttModuleHandle (add associated type)
47//               to enable Primes29/Primes31 dispatch through the public API.
48
49/// Access to the precomputed NTT/iNTT tables and lazy-accumulation metadata
50/// stored inside a `Module<B>` handle.
51///
52/// Automatically implemented for any `Module<B>` whose `B::Handle` implements
53/// [`NttHandleProvider`].  Backend crates (e.g. `poulpy-cpu-ref`) implement
54/// `NttHandleProvider` for their concrete handle type; they do *not* implement
55/// this trait directly (which would violate the orphan rule).
56///
57/// <!-- DOCUMENTED EXCEPTION: Primes30 hardcoded for spqlios compatibility.
58///   Generalisation path: add `type PrimeSet: PrimeSet` as an associated type here,
59///   then parameterise NttTable/NttTableInv/BbcMeta accordingly. -->
60pub trait NttModuleHandle {
61    /// Precomputed forward NTT twiddle table (Primes30, size `n`).
62    fn get_ntt_table(&self) -> &NttTable<Primes30>;
63    /// Precomputed inverse NTT twiddle table (Primes30, size `n`).
64    fn get_intt_table(&self) -> &NttTableInv<Primes30>;
65    /// Precomputed metadata for `q120b × q120c` lazy multiply–accumulate.
66    fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
67    /// Precomputed metadata for `q120b × q120b` lazy multiply–accumulate.
68    fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
69}
70
71/// Implemented by backend `Handle` types that store NTT/iNTT tables and BBC
72/// metadata.
73///
74/// Implement this trait for your concrete handle struct (e.g. `NTT120RefHandle`)
75/// in the backend crate.  A blanket `impl NttModuleHandle for Module<B>` is
76/// provided here in `poulpy-hal`, so no orphan-rule violation occurs.
77///
78/// # Safety
79///
80/// Implementors must ensure the returned references are valid for the lifetime
81/// of `&self` and that the tables were fully initialised before first use.
82///
83/// The blanket `impl<B> NttModuleHandle for Module<B>` assumes the handle is
84/// fully initialised before `Module::new()` returns.  This invariant is
85/// established by the module defaults (or a backend override).  There is no
86/// runtime check in release builds.
87pub unsafe trait NttHandleProvider {
88    /// Returns a reference to the forward NTT twiddle table.
89    fn get_ntt_table(&self) -> &NttTable<Primes30>;
90    /// Returns a reference to the inverse NTT twiddle table.
91    fn get_intt_table(&self) -> &NttTableInv<Primes30>;
92    /// Returns a reference to the `q120b × q120c` lazy multiply–accumulate metadata.
93    fn get_bbc_meta(&self) -> &BbcMeta<Primes30>;
94    /// Returns a reference to the `q120b × q120b` lazy multiply–accumulate metadata.
95    fn get_bbb_meta(&self) -> &BbbMeta<Primes30>;
96}
97
98/// Construct NTT120 backend handles for [`Module::new`](crate::api::ModuleNew::new).
99///
100/// # Safety
101///
102/// Implementors must return a fully initialized handle for the requested `n`.
103/// The handle is boxed and stored inside the `Module`, so it must be safe to
104/// drop via [`crate::layouts::Backend::destroy`].
105pub unsafe trait NttHandleFactory: Sized {
106    /// Builds a fully initialized handle for ring dimension `n`.
107    fn create_ntt_handle(n: usize) -> Self;
108
109    /// Optional runtime capability check (default: no-op).
110    fn assert_ntt_runtime_support() {}
111}
112
113/// Blanket impl: any `Module<B>` whose handle implements `NttHandleProvider`
114/// automatically satisfies `NttModuleHandle`.
115impl<B> NttModuleHandle for Module<B>
116where
117    B: Backend,
118    B::Handle: NttHandleProvider,
119{
120    fn get_ntt_table(&self) -> &NttTable<Primes30> {
121        // SAFETY: `ptr()` returns a valid, non-null pointer to `B::Handle`
122        // that was initialised by the module defaults and is kept alive by
123        // the `Module`.
124        unsafe { (&*self.ptr()).get_ntt_table() }
125    }
126
127    fn get_intt_table(&self) -> &NttTableInv<Primes30> {
128        unsafe { (&*self.ptr()).get_intt_table() }
129    }
130
131    fn get_bbc_meta(&self) -> &BbcMeta<Primes30> {
132        unsafe { (&*self.ptr()).get_bbc_meta() }
133    }
134
135    fn get_bbb_meta(&self) -> &BbbMeta<Primes30> {
136        unsafe { (&*self.ptr()).get_bbb_meta() }
137    }
138}
139
140// ──────────────────────────────────────────────────────────────────────────────
141// Helper: cast VecZnxDft limb to &[u64]
142// ──────────────────────────────────────────────────────────────────────────────
143
144/// Returns the q120b u64 slice for limb `(col, limb)` of a VecZnxDft.
145///
146/// `at(col, limb)` returns `&[Q120bScalar]` of length `n`; we cast to
147/// `&[u64]` of length `4*n`.
148#[inline(always)]
149fn limb_u64<D: crate::layouts::HostDataRef, BE: Backend<ScalarPrep = Q120bScalar>>(
150    v: &VecZnxDft<D, BE>,
151    col: usize,
152    limb: usize,
153) -> &[u64] {
154    cast_slice(v.at(col, limb))
155}
156
157#[inline(always)]
158fn limb_u64_mut<D: crate::layouts::HostDataMut, BE: Backend<ScalarPrep = Q120bScalar>>(
159    v: &mut VecZnxDft<D, BE>,
160    col: usize,
161    limb: usize,
162) -> &mut [u64] {
163    cast_slice_mut(v.at_mut(col, limb))
164}
165
166// ──────────────────────────────────────────────────────────────────────────────
167// Forward DFT
168// ──────────────────────────────────────────────────────────────────────────────
169
170/// Forward NTT: encode `a[a_col]` into `res[res_col]`.
171///
172/// For each output limb `j`:
173/// - Input limb index `= offset + j * step` from `a[a_col]`.
174/// - Converts i64 coefficients to q120b with [`NttFromZnx64`],
175///   then applies the forward NTT in-place via [`NttDFTExecute`].
176/// - Missing input limbs (out of range) are zeroed in `res`.
177pub fn ntt120_vec_znx_dft_apply<BE>(
178    module: &impl NttModuleHandle,
179    step: usize,
180    offset: usize,
181    res: &mut VecZnxDftBackendMut<'_, BE>,
182    res_col: usize,
183    a: &VecZnxBackendRef<'_, BE>,
184    a_col: usize,
185) where
186    BE: Backend<ScalarPrep = Q120bScalar> + NttDFTExecute<NttTable<Primes30>> + NttFromZnx64 + NttZero + 'static,
187    for<'x> BE: Backend<BufRef<'x> = &'x [u8], BufMut<'x> = &'x mut [u8]>,
188{
189    let a_size = a.size();
190    let res_size = res.size();
191
192    let table = module.get_ntt_table();
193
194    let steps = a_size.div_ceil(step);
195    let min_steps = res_size.min(steps);
196
197    for j in 0..min_steps {
198        let limb = offset + j * step;
199        if limb < a_size {
200            let res_slice: &mut [u64] = limb_u64_mut(res, res_col, j);
201            BE::ntt_from_znx64(res_slice, a.at(a_col, limb));
202            BE::ntt_dft_execute(table, res_slice);
203        } else {
204            BE::ntt_zero(limb_u64_mut(res, res_col, j));
205        }
206    }
207
208    for j in min_steps..res_size {
209        BE::ntt_zero(limb_u64_mut(res, res_col, j));
210    }
211}
212
213// ──────────────────────────────────────────────────────────────────────────────
214// Inverse DFT
215// ──────────────────────────────────────────────────────────────────────────────
216
217/// Returns the scratch space (in bytes) for [`ntt120_vec_znx_idft_apply`].
218///
219/// Requires one q120b buffer of length `n` (4 u64 per coefficient).
220pub fn ntt120_vec_znx_idft_apply_tmp_bytes(n: usize) -> usize {
221    4 * n * size_of::<u64>()
222}
223
224/// Inverse NTT (non-destructive): decode `a[a_col]` into `res[res_col]`.
225///
226/// For each output limb `j`:
227/// 1. Copies `a.at(a_col, j)` into `tmp` via [`NttCopy`].
228/// 2. Applies the inverse NTT to `tmp` in place via [`NttDFTExecute`].
229/// 3. CRT-reconstructs the `i128` coefficients via [`NttToZnx128`].
230///
231/// `tmp` must hold at least `4 * n` `u64` values.
232pub fn ntt120_vec_znx_idft_apply<BE>(
233    module: &impl NttModuleHandle,
234    res: &mut VecZnxBigBackendMut<'_, BE>,
235    res_col: usize,
236    a: &VecZnxDftBackendRef<'_, BE>,
237    a_col: usize,
238    tmp: &mut [u64],
239) where
240    BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128 + NttCopy,
241    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
242    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
243{
244    let n = res.n();
245    let res_size = res.size();
246    let min_size = res_size.min(a.size());
247
248    let table = module.get_intt_table();
249
250    for j in 0..min_size {
251        let a_slice: &[u64] = limb_u64(a, a_col, j);
252        let tmp_n: &mut [u64] = &mut tmp[..4 * n];
253        BE::ntt_copy(tmp_n, a_slice);
254        BE::ntt_dft_execute(table, tmp_n);
255        BE::ntt_to_znx128(res.at_mut(res_col, j), n, tmp_n);
256    }
257
258    for j in min_size..res_size {
259        res.at_mut(res_col, j).fill(0i128);
260    }
261}
262
263/// Inverse NTT (destructive): decode `a[a_col]` into `res[res_col]`.
264///
265/// Like [`ntt120_vec_znx_idft_apply`] but applies the inverse NTT
266/// **in place** to `a`, modifying it.  Requires no scratch space.
267pub fn ntt120_vec_znx_idft_apply_tmpa<BE>(
268    module: &impl NttModuleHandle,
269    res: &mut VecZnxBigBackendMut<'_, BE>,
270    res_col: usize,
271    a: &mut VecZnxDftBackendMut<'_, BE>,
272    a_col: usize,
273) where
274    BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128> + NttDFTExecute<NttTableInv<Primes30>> + NttToZnx128,
275    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
276{
277    let n = res.n();
278    let res_size = res.size();
279    let min_size = res_size.min(a.size());
280
281    let table = module.get_intt_table();
282
283    for j in 0..min_size {
284        BE::ntt_dft_execute(table, limb_u64_mut(a, a_col, j));
285        let a_slice: &[u64] = limb_u64(a, a_col, j);
286        BE::ntt_to_znx128(res.at_mut(res_col, j), n, a_slice);
287    }
288
289    for j in min_size..res_size {
290        res.at_mut(res_col, j).fill(0i128);
291    }
292}
293
294// Kept as dormant internal helpers for the removed consume path.
295// They are intentionally retained because the in-place q120b -> big compaction
296// logic may still be useful as a future optimization, even though the current
297// public API now applies IDFT into a separately allocated VecZnxBig.
298#[allow(dead_code)]
299pub fn ntt120_vec_znx_idft_apply_consume<'a, BE>(
300    module: &impl NttModuleHandle,
301    mut a: VecZnxDftBackendMut<'a, BE>,
302) -> VecZnxBigBackendMut<'a, BE>
303where
304    BE: Backend<ScalarPrep = Q120bScalar, ScalarBig = i128>,
305    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
306{
307    let table = module.get_intt_table();
308
309    let (n, n_blocks, u64_ptr) = {
310        let n = a.n();
311        let n_blocks = a.cols() * a.size();
312        let ptr: *mut u64 = {
313            let s = a.raw_mut();
314            cast_slice_mut::<_, u64>(s).as_mut_ptr()
315        };
316        (n, n_blocks, ptr)
317    };
318
319    unsafe { compact_all_blocks_scalar(n, n_blocks, u64_ptr, table) };
320
321    a.into_big()
322}
323
324#[allow(dead_code)]
325#[inline(always)]
326fn barrett_u61(x: u64, q: u64, mu: u64) -> u64 {
327    let q_approx = ((x as u128 * mu as u128) >> 61) as u64;
328    let r = x - q_approx * q;
329    let r = if r >= q { r - q } else { r };
330    if r >= q { r - q } else { r }
331}
332
333#[allow(dead_code)]
334#[inline(always)]
335fn reduce_q120b_crt(x: u64, q: u64, mu: u64, pow32_crt: u64, pow16_crt: u64, crt: u64) -> u64 {
336    let x_hi = x >> 32;
337    let x_hi_r = if x_hi >= q { x_hi - q } else { x_hi };
338    let x_lo = x & 0xFFFF_FFFF;
339    let x_lo_hi = x_lo >> 16;
340    let x_lo_lo = x_lo & 0xFFFF;
341    let tmp = x_hi_r
342        .wrapping_mul(pow32_crt)
343        .wrapping_add(x_lo_hi.wrapping_mul(pow16_crt))
344        .wrapping_add(x_lo_lo.wrapping_mul(crt));
345    barrett_u61(tmp, q, mu)
346}
347
348#[allow(dead_code)]
349unsafe fn compact_all_blocks_scalar(n: usize, n_blocks: usize, u64_ptr: *mut u64, table: &NttTableInv<Primes30>) {
350    let q_u64: [u64; 4] = Primes30::Q.map(|qi| qi as u64);
351    let mu: [u64; 4] = q_u64.map(|qi| (1u64 << 61) / qi);
352    let crt: [u64; 4] = Primes30::CRT_CST.map(|c| c as u64);
353
354    let pow32_crt: [u64; 4] = std::array::from_fn(|k| {
355        let pow32 = ((1u128 << 32) % q_u64[k] as u128) as u64;
356        barrett_u61(pow32 * crt[k], q_u64[k], mu[k])
357    });
358    let pow16_crt: [u64; 4] = std::array::from_fn(|k| barrett_u61((1u64 << 16) * crt[k], q_u64[k], mu[k]));
359
360    let q: [u128; 4] = q_u64.map(|qi| qi as u128);
361    let total_q: u128 = q[0] * q[1] * q[2] * q[3];
362    let qm: [u128; 4] = [q[1] * q[2] * q[3], q[0] * q[2] * q[3], q[0] * q[1] * q[3], q[0] * q[1] * q[2]];
363    let half_q: u128 = total_q.div_ceil(2);
364    let total_q_mult: [u128; 4] = [0, total_q, total_q * 2, total_q * 3];
365
366    for k in 0..n_blocks {
367        let src_start = 4 * n * k;
368        let dst_start = 2 * n * k;
369
370        {
371            let blk: &mut [u64] = unsafe { std::slice::from_raw_parts_mut(u64_ptr.add(src_start), 4 * n) };
372            intt_ref::<Primes30>(table, blk);
373        }
374
375        for c in 0..n {
376            let (x0, x1, x2, x3) = unsafe {
377                (
378                    *u64_ptr.add(src_start + 4 * c),
379                    *u64_ptr.add(src_start + 4 * c + 1),
380                    *u64_ptr.add(src_start + 4 * c + 2),
381                    *u64_ptr.add(src_start + 4 * c + 3),
382                )
383            };
384
385            let t0 = reduce_q120b_crt(x0, q_u64[0], mu[0], pow32_crt[0], pow16_crt[0], crt[0]);
386            let t1 = reduce_q120b_crt(x1, q_u64[1], mu[1], pow32_crt[1], pow16_crt[1], crt[1]);
387            let t2 = reduce_q120b_crt(x2, q_u64[2], mu[2], pow32_crt[2], pow16_crt[2], crt[2]);
388            let t3 = reduce_q120b_crt(x3, q_u64[3], mu[3], pow32_crt[3], pow16_crt[3], crt[3]);
389
390            let mut v: u128 = t0 as u128 * qm[0] + t1 as u128 * qm[1] + t2 as u128 * qm[2] + t3 as u128 * qm[3];
391
392            let q_approx = (v >> 120) as usize;
393            v -= total_q_mult[q_approx];
394            if v >= total_q {
395                v -= total_q;
396            }
397
398            let val: i128 = if v >= half_q { v as i128 - total_q as i128 } else { v as i128 };
399
400            unsafe { (u64_ptr.add(dst_start + 2 * c) as *mut i128).write_unaligned(val) };
401        }
402    }
403}
404
405// ──────────────────────────────────────────────────────────────────────────────
406// DFT-domain arithmetic
407// ──────────────────────────────────────────────────────────────────────────────
408
409/// DFT-domain add: `res[res_col] = a[a_col] + b[b_col]`.
410///
411/// Uses lazy q120b addition; out-of-range limbs are copied or zeroed.
412pub fn ntt120_vec_znx_dft_add_into<BE>(
413    res: &mut VecZnxDftBackendMut<'_, BE>,
414    res_col: usize,
415    a: &VecZnxDftBackendRef<'_, BE>,
416    a_col: usize,
417    b: &VecZnxDftBackendRef<'_, BE>,
418    b_col: usize,
419) where
420    BE: Backend<ScalarPrep = Q120bScalar> + NttAdd + NttCopy + NttZero,
421    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
422    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
423{
424    let res_size = res.size();
425    let a_size = a.size();
426    let b_size = b.size();
427
428    if a_size <= b_size {
429        let sum_size = a_size.min(res_size);
430        let cpy_size = b_size.min(res_size);
431        for j in 0..sum_size {
432            BE::ntt_add(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
433        }
434        for j in sum_size..cpy_size {
435            BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(b, b_col, j));
436        }
437        for j in cpy_size..res_size {
438            BE::ntt_zero(limb_u64_mut(res, res_col, j));
439        }
440    } else {
441        let sum_size = b_size.min(res_size);
442        let cpy_size = a_size.min(res_size);
443        for j in 0..sum_size {
444            BE::ntt_add(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
445        }
446        for j in sum_size..cpy_size {
447            BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
448        }
449        for j in cpy_size..res_size {
450            BE::ntt_zero(limb_u64_mut(res, res_col, j));
451        }
452    }
453}
454
455/// DFT-domain in-place add: `res[res_col] += a[a_col]`.
456pub fn ntt120_vec_znx_dft_add_assign<BE>(
457    res: &mut VecZnxDftBackendMut<'_, BE>,
458    res_col: usize,
459    a: &VecZnxDftBackendRef<'_, BE>,
460    a_col: usize,
461) where
462    BE: Backend<ScalarPrep = Q120bScalar> + NttAddAssign,
463    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
464    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
465{
466    let sum_size = res.size().min(a.size());
467    for j in 0..sum_size {
468        BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
469    }
470}
471
472/// DFT-domain scaled in-place add: `res[res_col] += a[a_col] >> (a_scale * base2k)`.
473///
474/// `a_scale > 0` shifts `a` down by `a_scale` limbs (drops low limbs);
475/// `a_scale < 0` shifts `a` up by `|a_scale|` limbs (adds into higher limbs).
476pub fn ntt120_vec_znx_dft_add_scaled_assign<BE>(
477    res: &mut VecZnxDftBackendMut<'_, BE>,
478    res_col: usize,
479    a: &VecZnxDftBackendRef<'_, BE>,
480    a_col: usize,
481    a_scale: i64,
482) where
483    BE: Backend<ScalarPrep = Q120bScalar> + NttAddAssign,
484    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
485    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
486{
487    let res_size = res.size();
488    let a_size = a.size();
489
490    if a_scale > 0 {
491        let shift = (a_scale as usize).min(a_size);
492        let sum_size = a_size.min(res_size).saturating_sub(shift);
493        for j in 0..sum_size {
494            BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j + shift));
495        }
496    } else if a_scale < 0 {
497        let shift = (a_scale.unsigned_abs() as usize).min(res_size);
498        let sum_size = a_size.min(res_size.saturating_sub(shift));
499        for j in 0..sum_size {
500            BE::ntt_add_assign(limb_u64_mut(res, res_col, j + shift), limb_u64(a, a_col, j));
501        }
502    } else {
503        let sum_size = a_size.min(res_size);
504        for j in 0..sum_size {
505            BE::ntt_add_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
506        }
507    }
508}
509
510/// DFT-domain sub: `res[res_col] = a[a_col] - b[b_col]`.
511pub fn ntt120_vec_znx_dft_sub<BE>(
512    res: &mut VecZnxDftBackendMut<'_, BE>,
513    res_col: usize,
514    a: &VecZnxDftBackendRef<'_, BE>,
515    a_col: usize,
516    b: &VecZnxDftBackendRef<'_, BE>,
517    b_col: usize,
518) where
519    BE: Backend<ScalarPrep = Q120bScalar> + NttSub + NttNegate + NttCopy + NttZero,
520    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
521    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
522{
523    let res_size = res.size();
524    let a_size = a.size();
525    let b_size = b.size();
526
527    if a_size <= b_size {
528        let sum_size = a_size.min(res_size);
529        let cpy_size = b_size.min(res_size);
530        for j in 0..sum_size {
531            BE::ntt_sub(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
532        }
533        for j in sum_size..cpy_size {
534            BE::ntt_negate(limb_u64_mut(res, res_col, j), limb_u64(b, b_col, j));
535        }
536        for j in cpy_size..res_size {
537            BE::ntt_zero(limb_u64_mut(res, res_col, j));
538        }
539    } else {
540        let sum_size = b_size.min(res_size);
541        let cpy_size = a_size.min(res_size);
542        for j in 0..sum_size {
543            BE::ntt_sub(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j), limb_u64(b, b_col, j));
544        }
545        for j in sum_size..cpy_size {
546            BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
547        }
548        for j in cpy_size..res_size {
549            BE::ntt_zero(limb_u64_mut(res, res_col, j));
550        }
551    }
552}
553
554/// DFT-domain in-place sub: `res[res_col] -= a[a_col]`.
555pub fn ntt120_vec_znx_dft_sub_assign<BE>(
556    res: &mut VecZnxDftBackendMut<'_, BE>,
557    res_col: usize,
558    a: &VecZnxDftBackendRef<'_, BE>,
559    a_col: usize,
560) where
561    BE: Backend<ScalarPrep = Q120bScalar> + NttSubAssign,
562    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
563    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
564{
565    let sum_size = res.size().min(a.size());
566    for j in 0..sum_size {
567        BE::ntt_sub_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
568    }
569}
570
571/// DFT-domain in-place swap-sub: `res[res_col] = a[a_col] - res[res_col]`.
572///
573/// Extra `res` limbs beyond `a.size()` are negated.
574pub fn ntt120_vec_znx_dft_sub_negate_assign<BE>(
575    res: &mut VecZnxDftBackendMut<'_, BE>,
576    res_col: usize,
577    a: &VecZnxDftBackendRef<'_, BE>,
578    a_col: usize,
579) where
580    BE: Backend<ScalarPrep = Q120bScalar> + NttSubNegateAssign + NttNegateAssign,
581    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
582    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
583{
584    let res_size = res.size();
585    let sum_size = res_size.min(a.size());
586    for j in 0..sum_size {
587        BE::ntt_sub_negate_assign(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, j));
588    }
589    for j in sum_size..res_size {
590        BE::ntt_negate_assign(limb_u64_mut(res, res_col, j));
591    }
592}
593
594/// DFT-domain copy with stride: `res[res_col][j] = a[a_col][offset + j*step]`.
595///
596/// Mirrors `vec_znx_dft_copy` from the FFT64 backend.
597pub fn ntt120_vec_znx_dft_copy<BE>(
598    step: usize,
599    offset: usize,
600    res: &mut VecZnxDftBackendMut<'_, BE>,
601    res_col: usize,
602    a: &VecZnxDftBackendRef<'_, BE>,
603    a_col: usize,
604) where
605    BE: Backend<ScalarPrep = Q120bScalar> + NttCopy + NttZero,
606    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
607    for<'x> <BE as Backend>::BufRef<'x>: HostDataRef,
608{
609    #[cfg(debug_assertions)]
610    {
611        assert_eq!(res.n(), a.n())
612    }
613
614    let steps: usize = a.size().div_ceil(step);
615    let min_steps: usize = res.size().min(steps);
616
617    for j in 0..min_steps {
618        let limb = offset + j * step;
619        if limb < a.size() {
620            BE::ntt_copy(limb_u64_mut(res, res_col, j), limb_u64(a, a_col, limb));
621        } else {
622            BE::ntt_zero(limb_u64_mut(res, res_col, j));
623        }
624    }
625    for j in min_steps..res.size() {
626        BE::ntt_zero(limb_u64_mut(res, res_col, j));
627    }
628}
629
630/// Zero all limbs of `res[res_col]`.
631pub fn ntt120_vec_znx_dft_zero<BE>(res: &mut VecZnxDftBackendMut<'_, BE>, res_col: usize)
632where
633    BE: Backend<ScalarPrep = Q120bScalar> + NttZero,
634    for<'x> <BE as Backend>::BufMut<'x>: HostDataMut,
635{
636    for j in 0..res.size() {
637        BE::ntt_zero(limb_u64_mut(res, res_col, j));
638    }
639}