Skip to main content

feanor_math/algorithms/convolution/
ntt.rs

1use std::alloc::{Allocator, Global};
2
3use super::ConvolutionAlgorithm;
4use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
5use crate::cow::*;
6use crate::homomorphism::*;
7use crate::integer::*;
8use crate::lazy::LazyVec;
9use crate::primitive_int::StaticRing;
10use crate::ring::*;
11use crate::rings::zn::*;
12use crate::seq::VectorView;
13
14/// Computes the convolution over a finite field that has suitable roots of unity
15/// using a power-of-two length FFT (sometimes called Number-Theoretic Transform,
16/// NTT in this context).
17#[stability::unstable(feature = "enable")]
18pub struct NTTConvolution<R_main, R_twiddle, H, A = Global>
19where
20    R_main: ?Sized + ZnRing,
21    R_twiddle: ?Sized + ZnRing,
22    H: Homomorphism<R_twiddle, R_main> + Clone,
23    A: Allocator + Clone,
24{
25    hom: H,
26    fft_algos: LazyVec<CooleyTuckeyFFT<R_main, R_twiddle, H>>,
27    allocator: A,
28}
29
30/// A prepared convolution operand for a [`NTTConvolution`].
31#[stability::unstable(feature = "enable")]
32pub struct PreparedConvolutionOperand<R, A = Global>
33where
34    R: ?Sized + ZnRing,
35    A: Allocator + Clone,
36{
37    significant_entries: usize,
38    ntt_data: Vec<R::Element, A>,
39}
40
41impl<R> NTTConvolution<R::Type, R::Type, Identity<R>>
42where
43    R: RingStore + Clone,
44    R::Type: ZnRing,
45{
46    /// Creates a new [`NTTConvolution`].
47    ///
48    /// Note that this convolution will be able to compute convolutions whose output is
49    /// of length `<= n`, where `n` is the largest power of two such that the given ring
50    /// has a primitive `n`-th root of unity.
51    #[stability::unstable(feature = "enable")]
52    pub fn new(ring: R) -> Self { Self::new_with_hom(ring.into_identity(), Global) }
53}
54
55impl<R_main, R_twiddle, H, A> NTTConvolution<R_main, R_twiddle, H, A>
56where
57    R_main: ?Sized + ZnRing,
58    R_twiddle: ?Sized + ZnRing,
59    H: Homomorphism<R_twiddle, R_main> + Clone,
60    A: Allocator + Clone,
61{
62    /// Creates a new [`NTTConvolution`].
63    ///
64    /// Note that this convolution will be able to compute convolutions whose output is
65    /// of length `<= n`, where `n` is the largest power of two such that the domain of
66    /// the given homomorphism has a primitive `n`-th root of unity.
67    ///
68    /// Internally, twiddle factors for the underlying NTT will be stored as elements of
69    /// the domain of the given homomorphism, while the convolutions are performed over the
70    /// codomain. This can be used for more efficient NTTs, see e.g. [`zn_64::ZnFastmul`].
71    #[stability::unstable(feature = "enable")]
72    pub fn new_with_hom(hom: H, allocator: A) -> Self {
73        Self {
74            fft_algos: LazyVec::new(),
75            hom,
76            allocator,
77        }
78    }
79
80    /// Returns the ring over which this object can compute convolutions.
81    #[stability::unstable(feature = "enable")]
82    pub fn ring(&self) -> RingRef<'_, R_main> { RingRef::new(self.hom.codomain().get_ring()) }
83
84    fn get_ntt_table(&self, log2_n: usize) -> &CooleyTuckeyFFT<R_main, R_twiddle, H> {
85        self.fft_algos.get_or_init(log2_n, || {
86            CooleyTuckeyFFT::for_zn_with_hom(self.hom.clone(), log2_n)
87                .expect("NTTConvolution was instantiated with parameters that don't support this length")
88        })
89    }
90
91    fn get_ntt_data<'a, V>(
92        &self,
93        data: V,
94        data_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
95        significant_entries: usize,
96    ) -> MyCow<'a, Vec<R_main::Element, A>>
97    where
98        V: VectorView<R_main::Element>,
99    {
100        assert!(data.len() <= significant_entries);
101        let log2_len = StaticRing::<i64>::RING
102            .abs_log2_ceil(&significant_entries.try_into().unwrap())
103            .unwrap();
104
105        let compute_result = || {
106            let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
107            result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
108            result.resize_with(1 << log2_len, || self.ring().zero());
109            self.get_ntt_table(log2_len)
110                .unordered_truncated_fft(&mut result, significant_entries);
111            return result;
112        };
113
114        return if let Some(data_prep) = data_prep {
115            assert!(data_prep.significant_entries >= significant_entries);
116            MyCow::Borrowed(&data_prep.ntt_data)
117        } else {
118            MyCow::Owned(compute_result())
119        };
120    }
121
122    fn compute_convolution_ntt<'a, V1, V2>(
123        &self,
124        lhs: V1,
125        mut lhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
126        rhs: V2,
127        mut rhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
128        len: usize,
129    ) -> MyCow<'a, Vec<R_main::Element, A>>
130    where
131        V1: VectorView<R_main::Element>,
132        V2: VectorView<R_main::Element>,
133    {
134        if lhs.len() == 0 || rhs.len() == 0 {
135            return MyCow::Owned(Vec::new_in(self.allocator.clone()));
136        }
137        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
138
139        if lhs_prep.is_some()
140            && (lhs_prep.unwrap().significant_entries < len || lhs_prep.unwrap().ntt_data.len() != 1 << log2_len)
141        {
142            lhs_prep = None;
143        }
144        if rhs_prep.is_some()
145            && (rhs_prep.unwrap().significant_entries < len || rhs_prep.unwrap().ntt_data.len() != 1 << log2_len)
146        {
147            rhs_prep = None;
148        }
149
150        let mut lhs_ntt = self.get_ntt_data(lhs, lhs_prep, len);
151        let mut rhs_ntt = self.get_ntt_data(rhs, rhs_prep, len);
152        if rhs_ntt.is_owned() {
153            std::mem::swap(&mut lhs_ntt, &mut rhs_ntt);
154        }
155        let lhs_ntt_data = lhs_ntt.to_mut_with(|data| {
156            let mut copied_data = Vec::with_capacity_in(data.len(), self.allocator.clone());
157            copied_data.extend(data.iter().map(|x| self.ring().clone_el(x)));
158            copied_data
159        });
160
161        for i in 0..len {
162            self.ring().mul_assign_ref(&mut lhs_ntt_data[i], &rhs_ntt[i]);
163        }
164        return lhs_ntt;
165    }
166
167    fn prepare_convolution_impl<V>(&self, data: V, len_hint: Option<usize>) -> PreparedConvolutionOperand<R_main, A>
168    where
169        V: VectorView<R_main::Element>,
170    {
171        let significant_entries = if let Some(out_len) = len_hint {
172            assert!(data.len() <= out_len);
173            out_len
174        } else {
175            2 * data.len()
176        };
177        let log2_len = StaticRing::<i64>::RING
178            .abs_log2_ceil(&significant_entries.try_into().unwrap())
179            .unwrap();
180
181        let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
182        result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
183        result.resize_with(1 << log2_len, || self.ring().zero());
184        self.get_ntt_table(log2_len)
185            .unordered_truncated_fft(&mut result, significant_entries);
186
187        return PreparedConvolutionOperand {
188            ntt_data: result,
189            significant_entries,
190        };
191    }
192
193    fn compute_convolution_impl<V1, V2>(
194        &self,
195        lhs: V1,
196        lhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
197        rhs: V2,
198        rhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
199        dst: &mut [R_main::Element],
200    ) where
201        V1: VectorView<R_main::Element>,
202        V2: VectorView<R_main::Element>,
203    {
204        assert!(lhs.len() + rhs.len() - 1 <= dst.len());
205        let len = lhs.len() + rhs.len() - 1;
206        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
207
208        let mut lhs_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
209        let lhs_ntt = lhs_ntt.to_mut_with(|_| unreachable!());
210        self.get_ntt_table(log2_len)
211            .unordered_truncated_fft_inv(&mut lhs_ntt[..], len);
212        for (i, x) in lhs_ntt.drain(..).enumerate().take(len) {
213            self.ring().add_assign(&mut dst[i], x);
214        }
215    }
216
217    fn compute_convolution_sum_impl<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S)
218    where
219        S: RingStore<Type = R_main> + Copy,
220        I: ExactSizeIterator<
221            Item = (
222                V1,
223                Option<&'a PreparedConvolutionOperand<R_main, A>>,
224                V2,
225                Option<&'a PreparedConvolutionOperand<R_main, A>>,
226            ),
227        >,
228        V1: VectorView<R_main::Element>,
229        V2: VectorView<R_main::Element>,
230        Self: 'a,
231        R_main: 'a,
232    {
233        let len = dst.len();
234        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
235
236        let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
237        buffer.resize_with(1 << log2_len, || ring.zero());
238
239        for (lhs, lhs_prep, rhs, rhs_prep) in values {
240            assert!(lhs.len() + rhs.len() - 1 <= len);
241
242            let res_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
243            for i in 0..len {
244                self.ring().add_assign_ref(&mut buffer[i], &res_ntt[i]);
245            }
246        }
247        self.get_ntt_table(log2_len)
248            .unordered_truncated_fft_inv(&mut buffer, len);
249        for (i, x) in buffer.drain(..).enumerate().take(len) {
250            self.ring().add_assign(&mut dst[i], x);
251        }
252    }
253}
254
255impl<R_main, R_twiddle, H, A> ConvolutionAlgorithm<R_main> for NTTConvolution<R_main, R_twiddle, H, A>
256where
257    R_main: ?Sized + ZnRing,
258    R_twiddle: ?Sized + ZnRing,
259    H: Homomorphism<R_twiddle, R_main> + Clone,
260    A: Allocator + Clone,
261{
262    type PreparedConvolutionOperand = PreparedConvolutionOperand<R_main, A>;
263
264    fn supports_ring<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> bool {
265        ring.get_ring() == self.ring().get_ring()
266    }
267
268    fn compute_convolution<
269        S: RingStore<Type = R_main> + Copy,
270        V1: VectorView<<R_main as RingBase>::Element>,
271        V2: VectorView<<R_main as RingBase>::Element>,
272    >(
273        &self,
274        lhs: V1,
275        rhs: V2,
276        dst: &mut [R_main::Element],
277        ring: S,
278    ) {
279        assert!(self.supports_ring(ring));
280        self.compute_convolution_impl(lhs, None, rhs, None, dst)
281    }
282
283    fn prepare_convolution_operand<S, V>(
284        &self,
285        val: V,
286        length_hint: Option<usize>,
287        ring: S,
288    ) -> Self::PreparedConvolutionOperand
289    where
290        S: RingStore<Type = R_main> + Copy,
291        V: VectorView<R_main::Element>,
292    {
293        assert!(self.supports_ring(ring));
294        self.prepare_convolution_impl(val, length_hint)
295    }
296
297    fn compute_convolution_prepared<S, V1, V2>(
298        &self,
299        lhs: V1,
300        lhs_prep: Option<&Self::PreparedConvolutionOperand>,
301        rhs: V2,
302        rhs_prep: Option<&Self::PreparedConvolutionOperand>,
303        dst: &mut [R_main::Element],
304        ring: S,
305    ) where
306        S: RingStore<Type = R_main> + Copy,
307        V1: VectorView<R_main::Element>,
308        V2: VectorView<R_main::Element>,
309    {
310        assert!(self.supports_ring(ring));
311        self.compute_convolution_impl(lhs, lhs_prep, rhs, rhs_prep, dst)
312    }
313
314    fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S)
315    where
316        S: RingStore<Type = R_main> + Copy,
317        I: ExactSizeIterator<
318            Item = (
319                V1,
320                Option<&'a Self::PreparedConvolutionOperand>,
321                V2,
322                Option<&'a Self::PreparedConvolutionOperand>,
323            ),
324        >,
325        V1: VectorView<R_main::Element>,
326        V2: VectorView<R_main::Element>,
327        Self: 'a,
328        R_main: 'a,
329    {
330        assert!(self.supports_ring(ring));
331        self.compute_convolution_sum_impl(values, dst, ring)
332    }
333}
334
335#[cfg(test)]
336use test::Bencher;
337
338#[cfg(test)]
339use crate::algorithms::convolution::STANDARD_CONVOLUTION;
340#[cfg(test)]
341use crate::rings::zn::zn_64::{Zn, ZnBase, ZnEl};
342
343#[test]
344fn test_convolution() {
345    let ring = zn_64::Zn::new(65537);
346    let convolution = NTTConvolution::new(ring);
347    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
348}
349
350#[cfg(test)]
351fn run_benchmark<F>(ring: Zn, bencher: &mut Bencher, mut f: F)
352where
353    F: for<'a> FnMut(
354        &mut dyn ExactSizeIterator<
355            Item = (
356                Vec<ZnEl>,
357                Option<&'a PreparedConvolutionOperand<ZnBase>>,
358                Vec<ZnEl>,
359                Option<&'a PreparedConvolutionOperand<ZnBase>>,
360            ),
361        >,
362        &mut [ZnEl],
363        Zn,
364    ),
365{
366    let mut expected = (0..512).map(|_| ring.zero()).collect::<Vec<_>>();
367    let value = (0..256).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
368    STANDARD_CONVOLUTION.compute_convolution(&value, &value, &mut expected, ring);
369
370    let mut i = 1;
371    let mut actual = Vec::with_capacity(511);
372    let hom = ring.can_hom(&StaticRing::<i64>::RING).unwrap();
373    bencher.iter(|| {
374        actual.clear();
375        actual.resize_with(511, || ring.zero());
376        f(
377            &mut (0..256).map(|j| {
378                (
379                    (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
380                    None,
381                    (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
382                    None,
383                )
384            }),
385            &mut actual,
386            ring,
387        );
388        let factor = hom.map(i * i * 128 * 511 * 85);
389        for (l, r) in expected.iter().zip(actual.iter()) {
390            assert_el_eq!(ring, ring.mul_ref(l, &factor), r);
391        }
392        i += 1;
393    });
394}
395
396#[bench]
397fn bench_convolution_sum(bencher: &mut Bencher) {
398    let ring = zn_64::Zn::new(65537);
399    let convolution = NTTConvolution::new(ring);
400
401    run_benchmark(ring, bencher, |values, dst, ring| {
402        convolution.compute_convolution_sum_impl(values, dst, ring)
403    });
404}
405
406#[bench]
407fn bench_convolution_sum_default(bencher: &mut Bencher) {
408    let ring = zn_64::Zn::new(65537);
409    let convolution = NTTConvolution::new(ring);
410
411    run_benchmark(ring, bencher, |values, dst, ring| {
412        for (lhs, lhs_prep, rhs, rhs_prep) in values {
413            convolution.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
414        }
415    });
416}