feanor_math/algorithms/convolution/
ntt.rs

1use std::alloc::{Allocator, Global};
2
3use crate::cow::*;
4use crate::{algorithms::fft::cooley_tuckey::CooleyTuckeyFFT, lazy::LazyVec};
5use crate::homomorphism::*;
6use crate::primitive_int::StaticRing;
7use crate::ring::*;
8use crate::rings::zn::*;
9use crate::integer::*;
10use crate::seq::VectorView;
11
12use super::ConvolutionAlgorithm;
13
14///
15/// Computes the convolution over a finite field that has suitable roots of unity
16/// using a power-of-two length FFT (sometimes called Number-Theoretic Transform,
17/// NTT in this context).
18/// 
19#[stability::unstable(feature = "enable")]
20pub struct NTTConvolution<R_main, R_twiddle, H, A = Global>
21    where R_main: ?Sized + ZnRing,
22        R_twiddle: ?Sized + ZnRing,
23        H: Homomorphism<R_twiddle, R_main> + Clone,
24        A: Allocator + Clone
25{
26    hom: H,
27    fft_algos: LazyVec<CooleyTuckeyFFT<R_main, R_twiddle, H>>,
28    allocator: A
29}
30
31#[stability::unstable(feature = "enable")]
32pub struct PreparedConvolutionOperand<R, A = Global>
33    where R: ?Sized + ZnRing,
34        A: Allocator + Clone
35{
36    significant_entries: usize,
37    ntt_data: Vec<R::Element, A>
38}
39
40impl<R> NTTConvolution<R::Type, R::Type, Identity<R>>
41    where R: RingStore + Clone,
42        R::Type: ZnRing
43{
44    #[stability::unstable(feature = "enable")]
45    pub fn new(ring: R) -> Self {
46        Self::new_with(ring.into_identity(), Global)
47    }
48}
49
50impl<R_main, R_twiddle, H, A> NTTConvolution<R_main, R_twiddle, H, A>
51    where R_main: ?Sized + ZnRing,
52        R_twiddle: ?Sized + ZnRing,
53        H: Homomorphism<R_twiddle, R_main> + Clone,
54        A: Allocator + Clone
55{
56    #[stability::unstable(feature = "enable")]
57    pub fn new_with(hom: H, allocator: A) -> Self {
58        Self {
59            fft_algos: LazyVec::new(),
60            hom: hom,
61            allocator: allocator
62        }
63    }
64
65    #[stability::unstable(feature = "enable")]
66    pub fn ring(&self) -> RingRef<R_main> {
67        RingRef::new(self.hom.codomain().get_ring())
68    }
69
70    fn get_ntt_table<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R_main, R_twiddle, H> {
71        self.fft_algos.get_or_init(log2_n, || CooleyTuckeyFFT::for_zn_with_hom(self.hom.clone(), log2_n).expect("NTTConvolution was instantiated with parameters that don't support this length"))
72    }
73
74    fn get_ntt_data<'a, V>(
75        &self,
76        data: V,
77        data_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
78        significant_entries: usize,
79    ) -> MyCow<'a, Vec<R_main::Element, A>>
80        where V: VectorView<R_main::Element>
81    {
82        assert!(data.len() <= significant_entries);
83        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
84
85        let compute_result = || {
86            let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
87            result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
88            result.resize_with(1 << log2_len, || self.ring().zero());
89            self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
90            return result;
91        };
92
93        return if let Some(data_prep) = data_prep {
94            assert!(data_prep.significant_entries >= significant_entries);
95            MyCow::Borrowed(&data_prep.ntt_data)
96        } else {
97            MyCow::Owned(compute_result())
98        }
99    }
100
101    fn prepare_convolution_impl<V>(
102        &self,
103        data: V,
104        len_hint: Option<usize>
105    ) -> PreparedConvolutionOperand<R_main, A>
106        where V: VectorView<R_main::Element>
107    {
108        let significant_entries = if let Some(out_len) = len_hint {
109            assert!(data.len() <= out_len);
110            out_len
111        } else {
112            2 * data.len()
113        };
114        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
115
116        let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
117        result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
118        result.resize_with(1 << log2_len, || self.ring().zero());
119        self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
120
121        return PreparedConvolutionOperand {
122            ntt_data: result,
123            significant_entries: significant_entries
124        };
125    }
126
127    fn compute_convolution_impl<V1, V2>(
128        &self,
129        lhs: V1,
130        mut lhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
131        rhs: V2,
132        mut rhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
133        dst: &mut [R_main::Element]
134    )
135        where V1: VectorView<R_main::Element>,
136            V2: VectorView<R_main::Element>
137    {
138        if lhs.len() == 0 || rhs.len() == 0 {
139            return;
140        }
141        let len = lhs.len() + rhs.len() - 1;
142        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
143
144        if lhs_prep.is_some() && (lhs_prep.unwrap().significant_entries < len || lhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
145            lhs_prep = None;
146        }
147        if rhs_prep.is_some() && (rhs_prep.unwrap().significant_entries < len || rhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
148            rhs_prep = None;
149        }
150
151        let mut lhs_ntt = self.get_ntt_data(lhs, lhs_prep, len);
152        let mut rhs_ntt = self.get_ntt_data(rhs, rhs_prep, len);
153        if rhs_ntt.is_owned() {
154            std::mem::swap(&mut lhs_ntt, &mut rhs_ntt);
155        }
156        let mut lhs_ntt = lhs_ntt.to_mut_with(|data| {
157            let mut copied_data = Vec::with_capacity_in(data.len(), self.allocator.clone());
158            copied_data.extend(data.iter().map(|x| self.ring().clone_el(x)));
159            copied_data
160        });
161
162        for i in 0..len {
163            self.ring().mul_assign_ref(&mut lhs_ntt[i], &rhs_ntt[i]);
164        }
165
166        self.get_ntt_table(log2_len).unordered_truncated_fft_inv(&mut lhs_ntt, len);
167
168        for (i, x) in lhs_ntt.drain(..).enumerate().take(len) {
169            self.ring().add_assign(&mut dst[i], x);
170        }
171    }
172}
173
174impl<R_main, R_twiddle, H, A> ConvolutionAlgorithm<R_main> for NTTConvolution<R_main, R_twiddle, H, A>
175    where R_main: ?Sized + ZnRing,
176        R_twiddle: ?Sized + ZnRing,
177        H: Homomorphism<R_twiddle, R_main> + Clone,
178        A: Allocator + Clone
179{
180    type PreparedConvolutionOperand = PreparedConvolutionOperand<R_main, A>;
181
182    fn supports_ring<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> bool {
183        ring.get_ring() == self.ring().get_ring()
184    }
185
186    fn compute_convolution<S: RingStore<Type = R_main> + Copy, V1: VectorView<<R_main as RingBase>::Element>, V2: VectorView<<R_main as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R_main::Element], ring: S) {
187        assert!(self.supports_ring(ring));
188        self.compute_convolution_impl(
189            lhs,
190            None,
191            rhs,
192            None,
193            dst
194        )
195    }
196
197    fn prepare_convolution_operand<S, V>(&self, val: V, length_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
198        where S: RingStore<Type = R_main> + Copy, V: VectorView<R_main::Element>
199    {
200        assert!(self.supports_ring(ring));
201        self.prepare_convolution_impl(
202            val,
203            length_hint
204        )
205    }
206
207    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R_main::Element], ring: S)
208        where S: RingStore<Type = R_main> + Copy, V1: VectorView<R_main::Element>, V2: VectorView<R_main::Element>
209    {
210        assert!(self.supports_ring(ring));
211        self.compute_convolution_impl(
212            lhs,
213            lhs_prep,
214            rhs,
215            rhs_prep,
216            dst
217        )
218    }
219}
220
221#[test]
222fn test_convolution() {
223    let ring = zn_64::Zn::new(65537);
224    let convolution = NTTConvolution::new_with(ring.identity(), Global);
225    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
226}