feanor_math/algorithms/convolution/
ntt.rs

1use std::alloc::{Allocator, Global};
2use std::cmp::min;
3
4use crate::{algorithms::fft::cooley_tuckey::CooleyTuckeyFFT, lazy::LazyVec};
5use crate::homomorphism::*;
6use crate::primitive_int::StaticRing;
7use crate::ring::*;
8use crate::algorithms::fft::*;
9use crate::rings::zn::*;
10use crate::integer::*;
11use crate::seq::VectorView;
12
13use super::{ConvolutionAlgorithm, PreparedConvolutionAlgorithm, PreparedConvolutionOperation};
14
15///
16/// Computes the convolution over a finite field that has suitable roots of unity
17/// using a power-of-two length FFT (sometimes called Number-Theoretic Transform,
18/// NTT in this context).
19/// 
20#[stability::unstable(feature = "enable")]
21pub struct NTTConvolution<R, A = Global>
22    where R: RingStore + Clone,
23        R::Type: ZnRing,
24        A: Allocator + Clone
25{
26    ring: R,
27    fft_algos: LazyVec<CooleyTuckeyFFT<R::Type, R::Type, Identity<R>>>,
28    allocator: A
29}
30
31#[stability::unstable(feature = "enable")]
32pub struct PreparedConvolutionOperand<R, A = Global>
33    where R: RingStore + Clone,
34        R::Type: ZnRing,
35        A: Allocator + Clone
36{
37    len: usize,
38    data: Vec<El<R>, A>
39}
40
41impl<R, A> NTTConvolution<R, A>
42    where R: RingStore + Clone,
43        R::Type: ZnRing,
44        A: Allocator + Clone
45{
46    #[stability::unstable(feature = "enable")]
47    pub fn new_with(ring: R, allocator: A) -> Self {
48        Self {
49            fft_algos: LazyVec::new(),
50            ring: ring,
51            allocator: allocator
52        }
53    }
54
55    #[stability::unstable(feature = "enable")]
56    pub fn ring(&self) -> &R {
57        &self.ring
58    }
59
60    fn add_assign_elementwise_product(lhs: &[El<R>], rhs: &[El<R>], dst: &mut [El<R>], ring: RingRef<R::Type>) {
61        assert_eq!(lhs.len(), rhs.len());
62        assert_eq!(lhs.len(), dst.len());
63        for i in 0..lhs.len() {
64            ring.add_assign(&mut dst[i], ring.mul_ref(&lhs[i], &rhs[i]));
65        }
66    }
67
68    fn compute_convolution_impl(&self, mut lhs: PreparedConvolutionOperand<R, A>, rhs: &PreparedConvolutionOperand<R, A>, out: &mut [El<R>]) {
69        let log2_n = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
70        assert_eq!(lhs.data.len(), 1 << log2_n);
71        assert_eq!(rhs.data.len(), 1 << log2_n);
72        assert!(lhs.len + rhs.len <= 1 << log2_n);
73        assert!(out.len() >= lhs.len + rhs.len);
74        for i in 0..(1 << log2_n) {
75            self.ring.mul_assign_ref(&mut lhs.data[i], &rhs.data[i]);
76        }
77        self.get_fft(log2_n).unordered_inv_fft(&mut lhs.data[..], self.ring());
78        for i in 0..min(out.len(), 1 << log2_n) {
79            self.ring.add_assign_ref(&mut out[i], &lhs.data[i]);
80        }
81    }
82
83    fn un_and_redo_fft(&self, input: &[El<R>], log2_n: usize) -> Vec<El<R>, A> {
84        let log2_in_len = ZZ.abs_log2_ceil(&(input.len() as i64)).unwrap();
85        assert_eq!(input.len(), 1 << log2_in_len);
86        assert!(log2_in_len < log2_n);
87
88        let mut tmp = Vec::with_capacity_in(input.len(), self.allocator.clone());
89        tmp.extend(input.iter().map(|x| self.ring.clone_el(x)));
90        self.get_fft(log2_in_len).unordered_inv_fft(&mut tmp[..], self.ring());
91
92        tmp.resize_with(1 << log2_n, || self.ring.zero());
93        self.get_fft(log2_n).unordered_fft(&mut tmp[..], self.ring());
94        return tmp;
95    }
96
97    fn get_fft<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R::Type, R::Type, Identity<R>> {
98        self.fft_algos.get_or_init(log2_n, || CooleyTuckeyFFT::for_zn(self.ring().clone(), log2_n).unwrap())
99    }
100
101    fn clone_prepared_operand(&self, operand: &PreparedConvolutionOperand<R, A>) -> PreparedConvolutionOperand<R, A> {
102        let mut result = Vec::with_capacity_in(operand.data.len(), self.allocator.clone());
103        result.extend(operand.data.iter().map(|x| self.ring.clone_el(x)));
104        return PreparedConvolutionOperand {
105            len: operand.len,
106            data: result
107        };
108    }
109    
110    fn prepare_convolution_impl<V: VectorView<El<R>>>(&self, val: V, log2_n: usize) -> PreparedConvolutionOperand<R, A> {
111        let mut result = Vec::with_capacity_in(1 << log2_n, self.allocator.clone());
112        result.extend(val.as_iter().map(|x| self.ring.clone_el(x)));
113        result.resize_with(1 << log2_n, || self.ring.zero());
114        let fft = self.get_fft(log2_n);
115        fft.unordered_fft(&mut result[..], self.ring());
116        return PreparedConvolutionOperand {
117            len: val.len(),
118            data: result
119        };
120    }
121}
122
123impl<R, A> ConvolutionAlgorithm<R::Type> for NTTConvolution<R, A>
124    where R: RingStore + Clone,
125        R::Type: ZnRing,
126        A: Allocator + Clone
127{
128    fn supports_ring<S: RingStore<Type = R::Type> + Copy>(&self, ring: S) -> bool {
129        ring.get_ring() == self.ring.get_ring()
130    }
131
132    fn compute_convolution<S: RingStore<Type = R::Type> + Copy, V1: VectorView<<R::Type as RingBase>::Element>, V2: VectorView<<R::Type as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [El<R>], ring: S) {
133        assert!(self.supports_ring(&ring));
134        let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
135        let lhs_prep = self.prepare_convolution_impl(lhs, log2_n);
136        let rhs_prep = self.prepare_convolution_impl(rhs, log2_n);
137        self.compute_convolution_impl(lhs_prep, &rhs_prep, dst);
138    }
139
140    fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
141        where F: PreparedConvolutionOperation<Self, R::Type>
142    {
143        Ok(function.execute())
144    }
145}
146
147const ZZ: StaticRing<i64> = StaticRing::<i64>::RING;
148
149impl<R, A> PreparedConvolutionAlgorithm<R::Type> for NTTConvolution<R, A>
150    where R: RingStore + Clone,
151        R::Type: ZnRing,
152        A: Allocator + Clone
153{
154    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
155
156    fn prepare_convolution_operand<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand {
157        assert!(ring.get_ring() == self.ring.get_ring());
158        let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
159        let log2_n_out = log2_n_in + 1;
160        return self.prepare_convolution_impl(val, log2_n_out);
161    }
162
163    fn compute_convolution_lhs_prepared<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [El<R>], ring: S) {
164        assert!(ring.is_commutative());
165        assert!(ring.get_ring() == self.ring.get_ring());
166        let log2_lhs = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
167        assert_eq!(lhs.data.len(), 1 << log2_lhs);
168        let log2_n = ZZ.abs_log2_ceil(&((lhs.len + rhs.len()) as i64)).unwrap().max(log2_lhs);
169        assert!(log2_lhs <= log2_n);
170        self.compute_convolution_prepared(lhs, &self.prepare_convolution_impl(rhs, log2_n), dst, ring);
171    }
172
173    fn compute_convolution_prepared<S: RingStore<Type = R::Type> + Copy>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [El<R>], ring: S) {
174        assert!(ring.is_commutative());
175        assert!(ring.get_ring() == self.ring.get_ring());
176        let log2_lhs = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
177        assert_eq!(1 << log2_lhs, lhs.data.len());
178        let log2_rhs = ZZ.abs_log2_ceil(&(rhs.data.len() as i64)).unwrap();
179        assert_eq!(1 << log2_rhs, rhs.data.len());
180        match log2_lhs.cmp(&log2_rhs) {
181            std::cmp::Ordering::Equal => self.compute_convolution_impl(self.clone_prepared_operand(lhs), rhs, dst),
182            std::cmp::Ordering::Greater => self.compute_convolution_impl(PreparedConvolutionOperand { data: self.un_and_redo_fft(&rhs.data, log2_lhs), len: rhs.len }, lhs, dst),
183            std::cmp::Ordering::Less => self.compute_convolution_impl(PreparedConvolutionOperand { data: self.un_and_redo_fft(&lhs.data, log2_rhs), len: lhs.len }, rhs, dst)
184        }
185    }
186
187    fn compute_convolution_inner_product_prepared<'a, S, I>(&self, values: I, dst: &mut [El<R>], ring: S)
188        where S: RingStore<Type = R::Type> + Copy, 
189            I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
190            Self: 'a,
191            R: 'a,
192            PreparedConvolutionOperand<R, A>: 'a
193    {
194        assert!(ring.get_ring() == self.ring.get_ring());
195        let mut values_it = values.peekable();
196        if values_it.peek().is_none() {
197            return;
198        }
199        let expected_len = values_it.peek().unwrap().0.data.len().max(values_it.peek().unwrap().1.data.len());
200        let mut current_log2_len = ZZ.abs_log2_ceil(&(expected_len as i64)).unwrap();
201        assert_eq!(expected_len, 1 << current_log2_len);
202        let mut tmp = Vec::with_capacity_in(1 << current_log2_len, self.allocator.clone());
203        tmp.resize_with(1 << current_log2_len, || ring.zero());
204        for (lhs, rhs) in values_it {
205            assert!(dst.len() >= lhs.len + rhs.len);
206            let lhs_log2_len = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
207            let rhs_log2_len = ZZ.abs_log2_ceil(&(rhs.data.len() as i64)).unwrap();
208            let new_log2_len = current_log2_len.max(lhs_log2_len).max(rhs_log2_len);
209            
210            if current_log2_len < new_log2_len {
211                tmp = self.un_and_redo_fft(&tmp, new_log2_len);
212                current_log2_len = new_log2_len;
213            }
214            match (lhs_log2_len < current_log2_len, rhs_log2_len < current_log2_len) {
215                (false, false) => Self::add_assign_elementwise_product(&lhs.data, &rhs.data, &mut tmp, RingRef::new(ring.get_ring())),
216                (true, false) => Self::add_assign_elementwise_product(&self.un_and_redo_fft(&lhs.data, new_log2_len), &rhs.data, &mut tmp, RingRef::new(ring.get_ring())),
217                (false, true) => Self::add_assign_elementwise_product(&lhs.data, &self.un_and_redo_fft(&rhs.data, new_log2_len), &mut tmp, RingRef::new(ring.get_ring())),
218                (true, true) => Self::add_assign_elementwise_product(&self.un_and_redo_fft(&lhs.data, new_log2_len), &self.un_and_redo_fft(&rhs.data, new_log2_len), &mut tmp, RingRef::new(ring.get_ring())),
219            }
220        }
221        self.get_fft(current_log2_len).unordered_inv_fft(&mut tmp[..], self.ring());
222        for i in 0..min(dst.len(), 1 << current_log2_len) {
223            self.ring.add_assign_ref(&mut dst[i], &tmp[i]);
224        }
225    }
226}
227
228#[test]
229fn test_convolution() {
230    let ring = zn_64::Zn::new(65537);
231    let convolution = NTTConvolution::new_with(ring, Global);
232    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
233    super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.one());
234}