1use std::alloc::{Allocator, Global};
2
3use tracing::instrument;
4
5use feanor_math::algorithms::convolution::*;
6use feanor_math::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
7use feanor_math::homomorphism::Identity;
8use feanor_math::primitive_int::StaticRing;
9use feanor_math::ring::*;
10use feanor_math::integer::*;
11use feanor_math::rings::zn::*;
12use feanor_math::seq::*;
13use feanor_math::algorithms::fft::FFTAlgorithm;
14
15use super::HERingConvolution;
16
17pub struct NTTConv<R, A = Global>
21 where R: RingStore + Clone,
22 R::Type: ZnRing,
23 A: Allocator + Clone
24{
25 ring: R,
26 max_log2_n: usize,
27 fft_algos: Vec<CooleyTuckeyFFT<R::Type, R::Type, Identity<R>>>,
28 allocator: A
29}
30
31impl<R> HERingConvolution<R> for NTTConv<R>
32 where R: RingStore + Clone,
33 R::Type: ZnRing
34{
35 fn new(ring: R, max_log2_n: usize) -> Self {
36 Self::new_with(ring, max_log2_n, Global)
37 }
38
39 fn ring(&self) -> &R {
40 &self.ring
41 }
42}
43
44impl<R, A> NTTConv<R, A>
45 where R: RingStore + Clone,
46 R::Type: ZnRing,
47 A: Allocator + Clone
48{
49 pub fn new_with(ring: R, max_log2_n: usize, allocator: A) -> Self {
50 assert!(max_log2_n <= ring.integer_ring().get_ring().abs_lowest_set_bit(&ring.integer_ring().sub_ref_fst(ring.modulus(), ring.integer_ring().one())).unwrap());
51 Self {
52 fft_algos: (0..=max_log2_n).map(|log2_n| CooleyTuckeyFFT::for_zn(ring.clone(), log2_n).unwrap()).collect(),
53 ring: ring,
54 allocator: allocator,
55 max_log2_n: max_log2_n,
56 }
57 }
58
59 pub fn max_supported_output_len(&self) -> usize {
60 1 << self.max_log2_n
61 }
62
63 #[instrument(skip_all)]
64 fn compute_convolution_base(&self, mut lhs: PreparedConvolutionOperand<R, A>, rhs: &PreparedConvolutionOperand<R, A>, out: &mut [El<R>]) {
65 let log2_n = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
66 assert_eq!(lhs.data.len(), 1 << log2_n);
67 assert_eq!(rhs.data.len(), 1 << log2_n);
68 assert!(lhs.len + rhs.len <= 1 << log2_n);
69 assert!(out.len() >= lhs.len + rhs.len);
70 for i in 0..(1 << log2_n) {
71 self.ring.mul_assign_ref(&mut lhs.data[i], &rhs.data[i]);
72 }
73 self.get_fft(log2_n).unordered_inv_fft(&mut lhs.data[..], &self.ring);
74 for i in 0..(lhs.len + rhs.len) {
75 self.ring.add_assign_ref(&mut out[i], &lhs.data[i]);
76 }
77 }
78
79 fn get_fft<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R::Type, R::Type, Identity<R>> {
80 &self.fft_algos[log2_n]
81 }
82
83 fn clone_prepared_operand(&self, operand: &PreparedConvolutionOperand<R, A>) -> PreparedConvolutionOperand<R, A> {
84 let mut result = Vec::with_capacity_in(operand.data.len(), self.allocator.clone());
85 result.extend(operand.data.iter().map(|x| self.ring.clone_el(x)));
86 return PreparedConvolutionOperand {
87 len: operand.len,
88 data: result
89 };
90 }
91
92 #[instrument(skip_all)]
93 fn prepare_convolution_base<V: VectorView<El<R>>>(&self, val: V, log2_n: usize) -> PreparedConvolutionOperand<R, A> {
94 let mut result = Vec::with_capacity_in(1 << log2_n, self.allocator.clone());
95 result.extend(val.as_iter().map(|x| self.ring.clone_el(x)));
96 result.resize_with(1 << log2_n, || self.ring.zero());
97 let fft = self.get_fft(log2_n);
98 fft.unordered_fft(&mut result[..], &self.ring);
99 return PreparedConvolutionOperand {
100 len: val.len(),
101 data: result
102 };
103 }
104}
105
106impl<R, A> ConvolutionAlgorithm<R::Type> for NTTConv<R, A>
107 where R: RingStore + Clone,
108 R::Type: ZnRing ,
109 A: Allocator + Clone
110{
111 fn supports_ring<S: RingStore<Type = R::Type> + Copy>(&self, ring: S) -> bool {
112 ring.get_ring() == self.ring.get_ring()
113 }
114
115 fn compute_convolution<S: RingStore<Type = R::Type> + Copy, V1: VectorView<<R::Type as RingBase>::Element>, V2: VectorView<<R::Type as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [<R::Type as RingBase>::Element], ring: S) {
116 assert!(ring.get_ring() == self.ring.get_ring());
117 let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
118 let lhs_prep = self.prepare_convolution_base(lhs, log2_n);
119 let rhs_prep = self.prepare_convolution_base(rhs, log2_n);
120 self.compute_convolution_base(lhs_prep, &rhs_prep, dst);
121 }
122}
123
124pub struct PreparedConvolutionOperand<R, A = Global>
125 where R: RingStore + Clone,
126 R::Type: ZnRing,
127 A: Allocator + Clone
128{
129 len: usize,
130 data: Vec<El<R>, A>
131}
132
133const ZZ: StaticRing<i64> = StaticRing::<i64>::RING;
134
135impl<R, A> PreparedConvolutionAlgorithm<R::Type> for NTTConv<R, A>
136 where R: RingStore + Clone,
137 R::Type: ZnRing ,
138 A: Allocator + Clone
139{
140 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
141
142 fn prepare_convolution_operand<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand {
143 assert!(ring.get_ring() == self.ring.get_ring());
144 let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
145 let log2_n_out = log2_n_in + 1;
146 return self.prepare_convolution_base(val, log2_n_out);
147 }
148
149 fn compute_convolution_lhs_prepared<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [El<R>], ring: S) {
150 assert!(ring.get_ring() == self.ring.get_ring());
151 let log2_n = ZZ.abs_log2_ceil(&((lhs.len + rhs.len()) as i64)).unwrap();
152 if lhs.data.len() >= (1 << log2_n) {
153 let log2_n = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
154 assert!(lhs.data.len() == 1 << log2_n);
155 self.compute_convolution_base(self.prepare_convolution_base(rhs, log2_n), lhs, dst);
156 } else {
157 self.compute_convolution_prepared(lhs, &self.prepare_convolution_base(rhs, log2_n), dst, ring)
158 }
159 }
160
161 fn compute_convolution_prepared<S: RingStore<Type = R::Type> + Copy>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [El<R>], ring: S) {
162 assert!(ring.get_ring() == self.ring.get_ring());
163 let log2_lhs = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
164 assert_eq!(1 << log2_lhs, lhs.data.len());
165 let log2_rhs = ZZ.abs_log2_ceil(&(rhs.data.len() as i64)).unwrap();
166 assert_eq!(1 << log2_rhs, rhs.data.len());
167 match log2_lhs.cmp(&log2_rhs) {
168 std::cmp::Ordering::Equal => self.compute_convolution_base(self.clone_prepared_operand(lhs), rhs, dst),
169 std::cmp::Ordering::Greater => self.compute_convolution_prepared(rhs, lhs, dst, ring),
170 std::cmp::Ordering::Less => {
171 let mut lhs_new = Vec::with_capacity_in(lhs.data.len(), self.allocator.clone());
172 lhs_new.extend(lhs.data.iter().map(|x| self.ring.clone_el(x)));
173 self.get_fft(log2_lhs).unordered_inv_fft(&mut lhs_new[..], ring);
174 lhs_new.resize_with(1 << log2_rhs, || ring.zero());
175 self.get_fft(log2_rhs).unordered_fft(&mut lhs_new[..], ring);
176 self.compute_convolution_base(PreparedConvolutionOperand { data: lhs_new, len: lhs.len }, rhs, dst);
177 }
178 }
179 }
180}
181
182#[cfg(test)]
183use feanor_math::assert_el_eq;
184#[cfg(test)]
185use feanor_math::rings::zn::zn_64::*;
186#[cfg(test)]
187use feanor_math::homomorphism::Homomorphism;
188
189#[test]
190fn test_convolution() {
191 let ring = Zn::new(65537);
192 let convolutor = NTTConv::new_with(ring, 16, Global);
193
194 let check = |lhs: &[ZnEl], rhs: &[ZnEl], add: &[ZnEl]| {
195 let mut expected = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
196 STANDARD_CONVOLUTION.compute_convolution(lhs, rhs, &mut expected, &ring);
197
198 let mut actual1 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
199 convolutor.compute_convolution(lhs, rhs, &mut actual1, &ring);
200 for i in 0..(lhs.len() + rhs.len()) {
201 assert_el_eq!(&ring, &expected[i], &actual1[i]);
202 }
203
204 let lhs_prepared = convolutor.prepare_convolution_operand(lhs, &ring);
205 let rhs_prepared = convolutor.prepare_convolution_operand(rhs, &ring);
206
207 let mut actual2 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
208 convolutor.compute_convolution_lhs_prepared(&lhs_prepared, rhs, &mut actual2, &ring);
209 for i in 0..(lhs.len() + rhs.len()) {
210 assert_el_eq!(&ring, &expected[i], &actual2[i]);
211 }
212
213 let mut actual3 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
214 convolutor.compute_convolution_rhs_prepared(lhs, &rhs_prepared, &mut actual3, &ring);
215 for i in 0..(lhs.len() + rhs.len()) {
216 assert_el_eq!(&ring, &expected[i], &actual3[i]);
217 }
218
219 let mut actual4 = (0..(lhs.len() + rhs.len())).map(|i| if i < add.len() { add[i] } else { ring.zero() }).collect::<Vec<_>>();
220 convolutor.compute_convolution_prepared(&lhs_prepared, &rhs_prepared, &mut actual4, &ring);
221 for i in 0..(lhs.len() + rhs.len()) {
222 assert_el_eq!(&ring, &expected[i], &actual4[i]);
223 }
224 };
225
226 for lhs_len in [1, 2, 3, 4, 7, 8, 9] {
227 for rhs_len in [1, 5, 8, 16, 17] {
228 let lhs = (0..lhs_len).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
229 let rhs = (0..rhs_len).map(|i| ring.int_hom().map(16 * i)).collect::<Vec<_>>();
230 let add = (0..(lhs_len + rhs_len)).map(|i| ring.int_hom().map(32768 * i)).collect::<Vec<_>>();
231 check(&lhs, &rhs, &add);
232 }
233 }
234}