he_ring/ntt/
ntt_convolution.rs

1use std::alloc::{Allocator, Global};
2
3use tracing::instrument;
4
5use feanor_math::algorithms::convolution::*;
6use feanor_math::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
7use feanor_math::homomorphism::Identity;
8use feanor_math::primitive_int::StaticRing;
9use feanor_math::ring::*;
10use feanor_math::integer::*;
11use feanor_math::rings::zn::*;
12use feanor_math::seq::*;
13use feanor_math::algorithms::fft::FFTAlgorithm;
14
15use super::HERingConvolution;
16
17///
18/// A [`ConvolutionAlgorithm`] based on NTTs.
19/// 
20pub struct NTTConv<R, A = Global>
21    where R: RingStore + Clone,
22        R::Type: ZnRing,
23        A: Allocator + Clone
24{
25    ring: R,
26    max_log2_n: usize,
27    fft_algos: Vec<CooleyTuckeyFFT<R::Type, R::Type, Identity<R>>>,
28    allocator: A
29}
30
31impl<R> HERingConvolution<R> for NTTConv<R>
32    where R: RingStore + Clone,
33        R::Type: ZnRing
34{
35    fn new(ring: R, max_log2_n: usize) -> Self {
36        Self::new_with(ring, max_log2_n, Global)
37    }
38
39    fn ring(&self) -> &R {
40        &self.ring
41    }
42}
43
44impl<R, A> NTTConv<R, A>
45    where R: RingStore + Clone,
46        R::Type: ZnRing,
47        A: Allocator + Clone
48{
49    pub fn new_with(ring: R, max_log2_n: usize, allocator: A) -> Self {
50        assert!(max_log2_n <= ring.integer_ring().get_ring().abs_lowest_set_bit(&ring.integer_ring().sub_ref_fst(ring.modulus(), ring.integer_ring().one())).unwrap());
51        Self {
52            fft_algos: (0..=max_log2_n).map(|log2_n| CooleyTuckeyFFT::for_zn(ring.clone(), log2_n).unwrap()).collect(),
53            ring: ring,
54            allocator: allocator,
55            max_log2_n: max_log2_n,
56        }
57    }
58
59    pub fn max_supported_output_len(&self) -> usize {
60        1 << self.max_log2_n
61    }
62
63    #[instrument(skip_all)]
64    fn compute_convolution_base(&self, mut lhs: PreparedConvolutionOperand<R, A>, rhs: &PreparedConvolutionOperand<R, A>, out: &mut [El<R>]) {
65        let log2_n = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
66        assert_eq!(lhs.data.len(), 1 << log2_n);
67        assert_eq!(rhs.data.len(), 1 << log2_n);
68        assert!(lhs.len + rhs.len <= 1 << log2_n);
69        assert!(out.len() >= lhs.len + rhs.len);
70        for i in 0..(1 << log2_n) {
71            self.ring.mul_assign_ref(&mut lhs.data[i], &rhs.data[i]);
72        }
73        self.get_fft(log2_n).unordered_inv_fft(&mut lhs.data[..], &self.ring);
74        for i in 0..(lhs.len + rhs.len) {
75            self.ring.add_assign_ref(&mut out[i], &lhs.data[i]);
76        }
77    }
78
79    fn get_fft<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R::Type, R::Type, Identity<R>> {
80        &self.fft_algos[log2_n]
81    }
82
83    fn clone_prepared_operand(&self, operand: &PreparedConvolutionOperand<R, A>) -> PreparedConvolutionOperand<R, A> {
84        let mut result = Vec::with_capacity_in(operand.data.len(), self.allocator.clone());
85        result.extend(operand.data.iter().map(|x| self.ring.clone_el(x)));
86        return PreparedConvolutionOperand {
87            len: operand.len,
88            data: result
89        };
90    }
91    
92    #[instrument(skip_all)]
93    fn prepare_convolution_base<V: VectorView<El<R>>>(&self, val: V, log2_n: usize) -> PreparedConvolutionOperand<R, A> {
94        let mut result = Vec::with_capacity_in(1 << log2_n, self.allocator.clone());
95        result.extend(val.as_iter().map(|x| self.ring.clone_el(x)));
96        result.resize_with(1 << log2_n, || self.ring.zero());
97        let fft = self.get_fft(log2_n);
98        fft.unordered_fft(&mut result[..], &self.ring);
99        return PreparedConvolutionOperand {
100            len: val.len(),
101            data: result
102        };
103    }
104}
105
106impl<R, A> ConvolutionAlgorithm<R::Type> for NTTConv<R, A>
107    where R: RingStore + Clone,
108        R::Type: ZnRing ,
109        A: Allocator + Clone
110{
111    fn supports_ring<S: RingStore<Type = R::Type> + Copy>(&self, ring: S) -> bool {
112        ring.get_ring() == self.ring.get_ring()
113    }
114
115    fn compute_convolution<S: RingStore<Type = R::Type> + Copy, V1: VectorView<<R::Type as RingBase>::Element>, V2: VectorView<<R::Type as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R::Type as RingBase>::Element], ring: S) {
116        assert!(ring.get_ring() == self.ring.get_ring());
117        let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
118        let lhs_prep = self.prepare_convolution_base(lhs, log2_n);
119        let rhs_prep = self.prepare_convolution_base(rhs, log2_n);
120        self.compute_convolution_base(lhs_prep, &rhs_prep, dst);
121    }
122}
123
124pub struct PreparedConvolutionOperand<R, A = Global>
125    where R: RingStore + Clone,
126        R::Type: ZnRing,
127        A: Allocator + Clone
128{
129    len: usize,
130    data: Vec<El<R>, A>
131}
132
133const ZZ: StaticRing<i64> = StaticRing::<i64>::RING;
134
135impl<R, A> PreparedConvolutionAlgorithm<R::Type> for NTTConv<R, A>
136    where R: RingStore + Clone,
137        R::Type: ZnRing ,
138        A: Allocator + Clone
139{
140    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
141
142    fn prepare_convolution_operand<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand {
143        assert!(ring.get_ring() == self.ring.get_ring());
144        let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
145        let log2_n_out = log2_n_in + 1;
146        return self.prepare_convolution_base(val, log2_n_out);
147    }
148
149    fn compute_convolution_lhs_prepared<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [El<R>], ring: S) {
150        assert!(ring.get_ring() == self.ring.get_ring());
151        let log2_n = ZZ.abs_log2_ceil(&((lhs.len + rhs.len()) as i64)).unwrap();
152        if lhs.data.len() >= (1 << log2_n) {
153            let log2_n = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
154            assert!(lhs.data.len() == 1 << log2_n);
155            self.compute_convolution_base(self.prepare_convolution_base(rhs, log2_n), lhs, dst);
156        } else {
157            self.compute_convolution_prepared(lhs, &self.prepare_convolution_base(rhs, log2_n), dst, ring)
158        }
159    }
160
161    fn compute_convolution_prepared<S: RingStore<Type = R::Type> + Copy>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [El<R>], ring: S) {
162        assert!(ring.get_ring() == self.ring.get_ring());
163        let log2_lhs = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
164        assert_eq!(1 << log2_lhs, lhs.data.len());
165        let log2_rhs = ZZ.abs_log2_ceil(&(rhs.data.len() as i64)).unwrap();
166        assert_eq!(1 << log2_rhs, rhs.data.len());
167        match log2_lhs.cmp(&log2_rhs) {
168            std::cmp::Ordering::Equal => self.compute_convolution_base(self.clone_prepared_operand(lhs), rhs, dst),
169            std::cmp::Ordering::Greater => self.compute_convolution_prepared(rhs, lhs, dst, ring),
170            std::cmp::Ordering::Less => {
171                let mut lhs_new = Vec::with_capacity_in(lhs.data.len(), self.allocator.clone());
172                lhs_new.extend(lhs.data.iter().map(|x| self.ring.clone_el(x)));
173                self.get_fft(log2_lhs).unordered_inv_fft(&mut lhs_new[..], ring);
174                lhs_new.resize_with(1 << log2_rhs, || ring.zero());
175                self.get_fft(log2_rhs).unordered_fft(&mut lhs_new[..], ring);
176                self.compute_convolution_base(PreparedConvolutionOperand { data: lhs_new, len: lhs.len }, rhs, dst);
177            }
178        }
179    }
180}
181
182#[cfg(test)]
183use feanor_math::assert_el_eq;
184#[cfg(test)]
185use feanor_math::rings::zn::zn_64::*;
186#[cfg(test)]
187use feanor_math::homomorphism::Homomorphism;
188
189#[test]
190fn test_convolution() {
191    let ring = Zn::new(65537);
192    let convolutor = NTTConv::new_with(ring, 16, Global);
193
194    let check = |lhs: &[ZnEl], rhs: &[ZnEl], add: &[ZnEl]| {
195        let mut expected = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
196        STANDARD_CONVOLUTION.compute_convolution(lhs, rhs, &mut expected, &ring);
197
198        let mut actual1 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
199        convolutor.compute_convolution(lhs, rhs, &mut actual1, &ring);
200        for i in 0..(lhs.len() + rhs.len()) {
201            assert_el_eq!(&ring, &expected[i], &actual1[i]);
202        }
203        
204        let lhs_prepared = convolutor.prepare_convolution_operand(lhs, &ring);
205        let rhs_prepared = convolutor.prepare_convolution_operand(rhs, &ring);
206
207        let mut actual2 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
208        convolutor.compute_convolution_lhs_prepared(&lhs_prepared, rhs, &mut actual2, &ring);
209        for i in 0..(lhs.len() + rhs.len()) {
210            assert_el_eq!(&ring, &expected[i], &actual2[i]);
211        }
212        
213        let mut actual3 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
214        convolutor.compute_convolution_rhs_prepared(lhs, &rhs_prepared, &mut actual3, &ring);
215        for i in 0..(lhs.len() + rhs.len()) {
216            assert_el_eq!(&ring, &expected[i], &actual3[i]);
217        }
218        
219        let mut actual4 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
220        convolutor.compute_convolution_prepared(&lhs_prepared, &rhs_prepared, &mut actual4, &ring);
221        for i in 0..(lhs.len() + rhs.len()) {
222            assert_el_eq!(&ring, &expected[i], &actual4[i]);
223        }
224    };
225
226    for lhs_len in [1, 2, 3, 4, 7, 8, 9] {
227        for rhs_len in [1, 5, 8, 16, 17] {
228            let lhs = (0..lhs_len).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
229            let rhs = (0..rhs_len).map(|i| ring.int_hom().map(16 * i)).collect::<Vec<_>>();
230            let add = (0..(lhs_len + rhs_len)).map(|i| ring.int_hom().map(32768 * i)).collect::<Vec<_>>();
231            check(&lhs, &rhs, &add);
232        }
233    }
234}