feanor_math/algorithms/convolution/
ntt.rs

1use std::alloc::{Allocator, Global};
2
3use crate::cow::*;
4use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
5use crate::lazy::LazyVec;
6use crate::homomorphism::*;
7use crate::primitive_int::StaticRing;
8use crate::ring::*;
9use crate::rings::zn::*;
10use crate::integer::*;
11use crate::seq::VectorView;
12
13use super::ConvolutionAlgorithm;
14
15///
16/// Computes the convolution over a finite field that has suitable roots of unity
17/// using a power-of-two length FFT (sometimes called Number-Theoretic Transform,
18/// NTT in this context).
19/// 
20#[stability::unstable(feature = "enable")]
21pub struct NTTConvolution<R_main, R_twiddle, H, A = Global>
22    where R_main: ?Sized + ZnRing,
23        R_twiddle: ?Sized + ZnRing,
24        H: Homomorphism<R_twiddle, R_main> + Clone,
25        A: Allocator + Clone
26{
27    hom: H,
28    fft_algos: LazyVec<CooleyTuckeyFFT<R_main, R_twiddle, H>>,
29    allocator: A
30}
31
32///
33/// A prepared convolution operand for a [`NTTConvolution`].
34/// 
35#[stability::unstable(feature = "enable")]
36pub struct PreparedConvolutionOperand<R, A = Global>
37    where R: ?Sized + ZnRing,
38        A: Allocator + Clone
39{
40    significant_entries: usize,
41    ntt_data: Vec<R::Element, A>
42}
43
44impl<R> NTTConvolution<R::Type, R::Type, Identity<R>>
45    where R: RingStore + Clone,
46        R::Type: ZnRing
47{
48    ///
49    /// Creates a new [`NTTConvolution`].
50    /// 
51    /// Note that this convolution will be able to compute convolutions whose output is
52    /// of length `<= n`, where `n` is the largest power of two such that the given ring
53    /// has a primitive `n`-th root of unity.
54    /// 
55    #[stability::unstable(feature = "enable")]
56    pub fn new(ring: R) -> Self {
57        Self::new_with_hom(ring.into_identity(), Global)
58    }
59}
60
61impl<R_main, R_twiddle, H, A> NTTConvolution<R_main, R_twiddle, H, A>
62    where R_main: ?Sized + ZnRing,
63        R_twiddle: ?Sized + ZnRing,
64        H: Homomorphism<R_twiddle, R_main> + Clone,
65        A: Allocator + Clone
66{
67    ///
68    /// Creates a new [`NTTConvolution`].
69    /// 
70    /// Note that this convolution will be able to compute convolutions whose output is
71    /// of length `<= n`, where `n` is the largest power of two such that the domain of
72    /// the given homomorphism has a primitive `n`-th root of unity.
73    /// 
74    /// Internally, twiddle factors for the underlying NTT will be stored as elements of
75    /// the domain of the given homomorphism, while the convolutions are performed over the
76    /// codomain. This can be used for more efficient NTTs, see e.g. [`zn_64::ZnFastmul`].
77    /// 
78    #[stability::unstable(feature = "enable")]
79    pub fn new_with_hom(hom: H, allocator: A) -> Self {
80        Self {
81            fft_algos: LazyVec::new(),
82            hom: hom,
83            allocator: allocator
84        }
85    }
86
87    ///
88    /// Returns the ring over which this object can compute convolutions.
89    /// 
90    #[stability::unstable(feature = "enable")]
91    pub fn ring(&self) -> RingRef<'_, R_main> {
92        RingRef::new(self.hom.codomain().get_ring())
93    }
94
95    fn get_ntt_table<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R_main, R_twiddle, H> {
96        self.fft_algos.get_or_init(log2_n, || CooleyTuckeyFFT::for_zn_with_hom(self.hom.clone(), log2_n).expect("NTTConvolution was instantiated with parameters that don't support this length"))
97    }
98
99    fn get_ntt_data<'a, V>(
100        &self,
101        data: V,
102        data_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
103        significant_entries: usize,
104    ) -> MyCow<'a, Vec<R_main::Element, A>>
105        where V: VectorView<R_main::Element>
106    {
107        assert!(data.len() <= significant_entries);
108        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
109
110        let compute_result = || {
111            let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
112            result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
113            result.resize_with(1 << log2_len, || self.ring().zero());
114            self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
115            return result;
116        };
117
118        return if let Some(data_prep) = data_prep {
119            assert!(data_prep.significant_entries >= significant_entries);
120            MyCow::Borrowed(&data_prep.ntt_data)
121        } else {
122            MyCow::Owned(compute_result())
123        }
124    }
125
126    fn compute_convolution_ntt<'a, V1, V2>(&self,
127        lhs: V1,
128        mut lhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
129        rhs: V2,
130        mut rhs_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
131        len: usize
132    ) -> MyCow<'a, Vec<R_main::Element, A>>
133        where V1: VectorView<R_main::Element>,
134            V2: VectorView<R_main::Element>
135    {
136        if lhs.len() == 0 || rhs.len() == 0 {
137            return MyCow::Owned(Vec::new_in(self.allocator.clone()));
138        }
139        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
140
141        if lhs_prep.is_some() && (lhs_prep.unwrap().significant_entries < len || lhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
142            lhs_prep = None;
143        }
144        if rhs_prep.is_some() && (rhs_prep.unwrap().significant_entries < len || rhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
145            rhs_prep = None;
146        }
147
148        let mut lhs_ntt = self.get_ntt_data(lhs, lhs_prep, len);
149        let mut rhs_ntt = self.get_ntt_data(rhs, rhs_prep, len);
150        if rhs_ntt.is_owned() {
151            std::mem::swap(&mut lhs_ntt, &mut rhs_ntt);
152        }
153        let lhs_ntt_data = lhs_ntt.to_mut_with(|data| {
154            let mut copied_data = Vec::with_capacity_in(data.len(), self.allocator.clone());
155            copied_data.extend(data.iter().map(|x| self.ring().clone_el(x)));
156            copied_data
157        });
158
159        for i in 0..len {
160            self.ring().mul_assign_ref(&mut lhs_ntt_data[i], &rhs_ntt[i]);
161        }
162        return lhs_ntt;
163    }
164
165    fn prepare_convolution_impl<V>(
166        &self,
167        data: V,
168        len_hint: Option<usize>
169    ) -> PreparedConvolutionOperand<R_main, A>
170        where V: VectorView<R_main::Element>
171    {
172        let significant_entries = if let Some(out_len) = len_hint {
173            assert!(data.len() <= out_len);
174            out_len
175        } else {
176            2 * data.len()
177        };
178        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
179
180        let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
181        result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
182        result.resize_with(1 << log2_len, || self.ring().zero());
183        self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
184
185        return PreparedConvolutionOperand {
186            ntt_data: result,
187            significant_entries: significant_entries
188        };
189    }
190
191    fn compute_convolution_impl<V1, V2>(
192        &self,
193        lhs: V1,
194        lhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
195        rhs: V2,
196        rhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
197        dst: &mut [R_main::Element]
198    )
199        where V1: VectorView<R_main::Element>,
200            V2: VectorView<R_main::Element>
201    {
202        assert!(lhs.len() + rhs.len() - 1 <= dst.len());
203        let len = lhs.len() + rhs.len() - 1;
204        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
205
206        let mut lhs_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
207        let lhs_ntt = lhs_ntt.to_mut_with(|_| unreachable!());
208        self.get_ntt_table(log2_len).unordered_truncated_fft_inv(&mut lhs_ntt[..], len);
209        for (i, x) in lhs_ntt.drain(..).enumerate().take(len) {
210            self.ring().add_assign(&mut dst[i], x);
211        }
212    }
213
214    fn compute_convolution_sum_impl<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S) 
215        where S: RingStore<Type = R_main> + Copy, 
216            I: ExactSizeIterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R_main, A>>, V2, Option<&'a PreparedConvolutionOperand<R_main, A>>)>,
217            V1: VectorView<R_main::Element>,
218            V2: VectorView<R_main::Element>,
219            Self: 'a,
220            R_main: 'a
221    {
222        let len = dst.len();
223        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
224
225        let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
226        buffer.resize_with(1 << log2_len, || ring.zero());
227
228        for (lhs, lhs_prep, rhs, rhs_prep) in values {
229            assert!(lhs.len() + rhs.len() - 1 <= len);
230
231            let res_ntt = self.compute_convolution_ntt(lhs, lhs_prep, rhs, rhs_prep, len);
232            for i in 0..len {
233                self.ring().add_assign_ref(&mut buffer[i], &res_ntt[i]);
234            }
235        }
236        self.get_ntt_table(log2_len).unordered_truncated_fft_inv(&mut buffer, len);
237        for (i, x) in buffer.drain(..).enumerate().take(len) {
238            self.ring().add_assign(&mut dst[i], x);
239        }
240    }
241}
242
243impl<R_main, R_twiddle, H, A> ConvolutionAlgorithm<R_main> for NTTConvolution<R_main, R_twiddle, H, A>
244    where R_main: ?Sized + ZnRing,
245        R_twiddle: ?Sized + ZnRing,
246        H: Homomorphism<R_twiddle, R_main> + Clone,
247        A: Allocator + Clone
248{
249    type PreparedConvolutionOperand = PreparedConvolutionOperand<R_main, A>;
250
251    fn supports_ring<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> bool {
252        ring.get_ring() == self.ring().get_ring()
253    }
254
255    fn compute_convolution<S: RingStore<Type = R_main> + Copy, V1: VectorView<<R_main as RingBase>::Element>, V2: VectorView<<R_main as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R_main::Element], ring: S) {
256        assert!(self.supports_ring(ring));
257        self.compute_convolution_impl(
258            lhs,
259            None,
260            rhs,
261            None,
262            dst
263        )
264    }
265
266    fn prepare_convolution_operand<S, V>(&self, val: V, length_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
267        where S: RingStore<Type = R_main> + Copy, V: VectorView<R_main::Element>
268    {
269        assert!(self.supports_ring(ring));
270        self.prepare_convolution_impl(
271            val,
272            length_hint
273        )
274    }
275
276    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R_main::Element], ring: S)
277        where S: RingStore<Type = R_main> + Copy, V1: VectorView<R_main::Element>, V2: VectorView<R_main::Element>
278    {
279        assert!(self.supports_ring(ring));
280        self.compute_convolution_impl(
281            lhs,
282            lhs_prep,
283            rhs,
284            rhs_prep,
285            dst
286        )
287    }
288
289    fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R_main::Element], ring: S) 
290        where S: RingStore<Type = R_main> + Copy, 
291            I: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
292            V1: VectorView<R_main::Element>,
293            V2: VectorView<R_main::Element>,
294            Self: 'a,
295            R_main: 'a
296    {
297        assert!(self.supports_ring(ring));
298        self.compute_convolution_sum_impl(
299            values, 
300            dst, 
301            ring
302        )
303    }
304}
305
306#[cfg(test)]
307use test::Bencher;
308#[cfg(test)]
309use crate::algorithms::convolution::STANDARD_CONVOLUTION;
310#[cfg(test)]
311use crate::rings::zn::zn_64::{Zn, ZnBase, ZnEl};
312
313#[test]
314fn test_convolution() {
315    let ring = zn_64::Zn::new(65537);
316    let convolution = NTTConvolution::new(ring);
317    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
318}
319
320#[cfg(test)]
321fn run_benchmark<F>(ring: Zn, bencher: &mut Bencher, mut f: F)
322    where F: for<'a> FnMut(&mut dyn ExactSizeIterator<Item = (Vec<ZnEl>, Option<&'a PreparedConvolutionOperand<ZnBase>>, Vec<ZnEl>, Option<&'a PreparedConvolutionOperand<ZnBase>>)>, &mut [ZnEl], Zn)
323{
324    let mut expected = (0..512).map(|_| ring.zero()).collect::<Vec<_>>();
325    let value = (0..256).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
326    STANDARD_CONVOLUTION.compute_convolution(
327        &value,
328        &value,
329        &mut expected,
330        ring
331    );
332
333    let mut i = 1;
334    let mut actual = Vec::with_capacity(511);
335    let hom = ring.can_hom(&StaticRing::<i64>::RING).unwrap();
336    bencher.iter(|| {
337        actual.clear();
338        actual.resize_with(511, || ring.zero());
339        f(
340            &mut (0..256).map(|j| (
341                (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
342                None,
343                (0..256).map(|k| hom.map(i * j as i64 * k)).collect::<Vec<_>>(),
344                None
345            )),
346            &mut actual,
347            ring
348        );
349        let factor = hom.map(i * i * 128 * 511 * 85);
350        for (l, r) in expected.iter().zip(actual.iter()) {
351            assert_el_eq!(ring, ring.mul_ref(l, &factor), r);
352        }
353        i += 1;
354    });
355}
356
357#[bench]
358fn bench_convolution_sum(bencher: &mut Bencher) {
359    let ring = zn_64::Zn::new(65537);
360    let convolution = NTTConvolution::new(ring);
361
362    run_benchmark(ring, bencher, |values, dst, ring| convolution.compute_convolution_sum_impl(values, dst, ring));
363}
364
365#[bench]
366fn bench_convolution_sum_default(bencher: &mut Bencher) {
367    let ring = zn_64::Zn::new(65537);
368    let convolution = NTTConvolution::new(ring);
369
370    run_benchmark(ring, bencher, |values, dst, ring| {
371        for (lhs, lhs_prep, rhs, rhs_prep) in values {
372            convolution.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
373        }
374    });
375}