1use std::alloc::{Allocator, Global};
2use std::cmp::min;
3
4use crate::{algorithms::fft::cooley_tuckey::CooleyTuckeyFFT, lazy::LazyVec};
5use crate::homomorphism::*;
6use crate::primitive_int::StaticRing;
7use crate::ring::*;
8use crate::algorithms::fft::*;
9use crate::rings::zn::*;
10use crate::integer::*;
11use crate::seq::VectorView;
12
13use super::{ConvolutionAlgorithm, PreparedConvolutionAlgorithm, PreparedConvolutionOperation};
14
15#[stability::unstable(feature = "enable")]
21pub struct NTTConvolution<R, A = Global>
22 where R: RingStore + Clone,
23 R::Type: ZnRing,
24 A: Allocator + Clone
25{
26 ring: R,
27 fft_algos: LazyVec<CooleyTuckeyFFT<R::Type, R::Type, Identity<R>>>,
28 allocator: A
29}
30
31#[stability::unstable(feature = "enable")]
32pub struct PreparedConvolutionOperand<R, A = Global>
33 where R: RingStore + Clone,
34 R::Type: ZnRing,
35 A: Allocator + Clone
36{
37 len: usize,
38 data: Vec<El<R>, A>
39}
40
41impl<R, A> NTTConvolution<R, A>
42 where R: RingStore + Clone,
43 R::Type: ZnRing,
44 A: Allocator + Clone
45{
46 #[stability::unstable(feature = "enable")]
47 pub fn new_with(ring: R, allocator: A) -> Self {
48 Self {
49 fft_algos: LazyVec::new(),
50 ring: ring,
51 allocator: allocator
52 }
53 }
54
55 #[stability::unstable(feature = "enable")]
56 pub fn ring(&self) -> &R {
57 &self.ring
58 }
59
60 fn add_assign_elementwise_product(lhs: &[El<R>], rhs: &[El<R>], dst: &mut [El<R>], ring: RingRef<R::Type>) {
61 assert_eq!(lhs.len(), rhs.len());
62 assert_eq!(lhs.len(), dst.len());
63 for i in 0..lhs.len() {
64 ring.add_assign(&mut dst[i], ring.mul_ref(&lhs[i], &rhs[i]));
65 }
66 }
67
68 fn compute_convolution_impl(&self, mut lhs: PreparedConvolutionOperand<R, A>, rhs: &PreparedConvolutionOperand<R, A>, out: &mut [El<R>]) {
69 let log2_n = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
70 assert_eq!(lhs.data.len(), 1 << log2_n);
71 assert_eq!(rhs.data.len(), 1 << log2_n);
72 assert!(lhs.len + rhs.len <= 1 << log2_n);
73 assert!(out.len() >= lhs.len + rhs.len);
74 for i in 0..(1 << log2_n) {
75 self.ring.mul_assign_ref(&mut lhs.data[i], &rhs.data[i]);
76 }
77 self.get_fft(log2_n).unordered_inv_fft(&mut lhs.data[..], self.ring());
78 for i in 0..min(out.len(), 1 << log2_n) {
79 self.ring.add_assign_ref(&mut out[i], &lhs.data[i]);
80 }
81 }
82
83 fn un_and_redo_fft(&self, input: &[El<R>], log2_n: usize) -> Vec<El<R>, A> {
84 let log2_in_len = ZZ.abs_log2_ceil(&(input.len() as i64)).unwrap();
85 assert_eq!(input.len(), 1 << log2_in_len);
86 assert!(log2_in_len < log2_n);
87
88 let mut tmp = Vec::with_capacity_in(input.len(), self.allocator.clone());
89 tmp.extend(input.iter().map(|x| self.ring.clone_el(x)));
90 self.get_fft(log2_in_len).unordered_inv_fft(&mut tmp[..], self.ring());
91
92 tmp.resize_with(1 << log2_n, || self.ring.zero());
93 self.get_fft(log2_n).unordered_fft(&mut tmp[..], self.ring());
94 return tmp;
95 }
96
97 fn get_fft<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R::Type, R::Type, Identity<R>> {
98 self.fft_algos.get_or_init(log2_n, || CooleyTuckeyFFT::for_zn(self.ring().clone(), log2_n).unwrap())
99 }
100
101 fn clone_prepared_operand(&self, operand: &PreparedConvolutionOperand<R, A>) -> PreparedConvolutionOperand<R, A> {
102 let mut result = Vec::with_capacity_in(operand.data.len(), self.allocator.clone());
103 result.extend(operand.data.iter().map(|x| self.ring.clone_el(x)));
104 return PreparedConvolutionOperand {
105 len: operand.len,
106 data: result
107 };
108 }
109
110 fn prepare_convolution_impl<V: VectorView<El<R>>>(&self, val: V, log2_n: usize) -> PreparedConvolutionOperand<R, A> {
111 let mut result = Vec::with_capacity_in(1 << log2_n, self.allocator.clone());
112 result.extend(val.as_iter().map(|x| self.ring.clone_el(x)));
113 result.resize_with(1 << log2_n, || self.ring.zero());
114 let fft = self.get_fft(log2_n);
115 fft.unordered_fft(&mut result[..], self.ring());
116 return PreparedConvolutionOperand {
117 len: val.len(),
118 data: result
119 };
120 }
121}
122
123impl<R, A> ConvolutionAlgorithm<R::Type> for NTTConvolution<R, A>
124 where R: RingStore + Clone,
125 R::Type: ZnRing,
126 A: Allocator + Clone
127{
128 fn supports_ring<S: RingStore<Type = R::Type> + Copy>(&self, ring: S) -> bool {
129 ring.get_ring() == self.ring.get_ring()
130 }
131
132 fn compute_convolution<S: RingStore<Type = R::Type> + Copy, V1: VectorView<<R::Type as RingBase>::Element>, V2: VectorView<<R::Type as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [El<R>], ring: S) {
133 assert!(self.supports_ring(&ring));
134 let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
135 let lhs_prep = self.prepare_convolution_impl(lhs, log2_n);
136 let rhs_prep = self.prepare_convolution_impl(rhs, log2_n);
137 self.compute_convolution_impl(lhs_prep, &rhs_prep, dst);
138 }
139
140 fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
141 where F: PreparedConvolutionOperation<Self, R::Type>
142 {
143 Ok(function.execute())
144 }
145}
146
147const ZZ: StaticRing<i64> = StaticRing::<i64>::RING;
148
149impl<R, A> PreparedConvolutionAlgorithm<R::Type> for NTTConvolution<R, A>
150 where R: RingStore + Clone,
151 R::Type: ZnRing,
152 A: Allocator + Clone
153{
154 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
155
156 fn prepare_convolution_operand<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand {
157 assert!(ring.get_ring() == self.ring.get_ring());
158 let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
159 let log2_n_out = log2_n_in + 1;
160 return self.prepare_convolution_impl(val, log2_n_out);
161 }
162
163 fn compute_convolution_lhs_prepared<S: RingStore<Type = R::Type> + Copy, V: VectorView<El<R>>>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [El<R>], ring: S) {
164 assert!(ring.is_commutative());
165 assert!(ring.get_ring() == self.ring.get_ring());
166 let log2_lhs = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
167 assert_eq!(lhs.data.len(), 1 << log2_lhs);
168 let log2_n = ZZ.abs_log2_ceil(&((lhs.len + rhs.len()) as i64)).unwrap().max(log2_lhs);
169 assert!(log2_lhs <= log2_n);
170 self.compute_convolution_prepared(lhs, &self.prepare_convolution_impl(rhs, log2_n), dst, ring);
171 }
172
173 fn compute_convolution_prepared<S: RingStore<Type = R::Type> + Copy>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [El<R>], ring: S) {
174 assert!(ring.is_commutative());
175 assert!(ring.get_ring() == self.ring.get_ring());
176 let log2_lhs = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
177 assert_eq!(1 << log2_lhs, lhs.data.len());
178 let log2_rhs = ZZ.abs_log2_ceil(&(rhs.data.len() as i64)).unwrap();
179 assert_eq!(1 << log2_rhs, rhs.data.len());
180 match log2_lhs.cmp(&log2_rhs) {
181 std::cmp::Ordering::Equal => self.compute_convolution_impl(self.clone_prepared_operand(lhs), rhs, dst),
182 std::cmp::Ordering::Greater => self.compute_convolution_impl(PreparedConvolutionOperand { data: self.un_and_redo_fft(&rhs.data, log2_lhs), len: rhs.len }, lhs, dst),
183 std::cmp::Ordering::Less => self.compute_convolution_impl(PreparedConvolutionOperand { data: self.un_and_redo_fft(&lhs.data, log2_rhs), len: lhs.len }, rhs, dst)
184 }
185 }
186
187 fn compute_convolution_inner_product_prepared<'a, S, I>(&self, values: I, dst: &mut [El<R>], ring: S)
188 where S: RingStore<Type = R::Type> + Copy,
189 I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
190 Self: 'a,
191 R: 'a,
192 PreparedConvolutionOperand<R, A>: 'a
193 {
194 assert!(ring.get_ring() == self.ring.get_ring());
195 let mut values_it = values.peekable();
196 if values_it.peek().is_none() {
197 return;
198 }
199 let expected_len = values_it.peek().unwrap().0.data.len().max(values_it.peek().unwrap().1.data.len());
200 let mut current_log2_len = ZZ.abs_log2_ceil(&(expected_len as i64)).unwrap();
201 assert_eq!(expected_len, 1 << current_log2_len);
202 let mut tmp = Vec::with_capacity_in(1 << current_log2_len, self.allocator.clone());
203 tmp.resize_with(1 << current_log2_len, || ring.zero());
204 for (lhs, rhs) in values_it {
205 assert!(dst.len() >= lhs.len + rhs.len);
206 let lhs_log2_len = ZZ.abs_log2_ceil(&(lhs.data.len() as i64)).unwrap();
207 let rhs_log2_len = ZZ.abs_log2_ceil(&(rhs.data.len() as i64)).unwrap();
208 let new_log2_len = current_log2_len.max(lhs_log2_len).max(rhs_log2_len);
209
210 if current_log2_len < new_log2_len {
211 tmp = self.un_and_redo_fft(&tmp, new_log2_len);
212 current_log2_len = new_log2_len;
213 }
214 match (lhs_log2_len < current_log2_len, rhs_log2_len < current_log2_len) {
215 (false, false) => Self::add_assign_elementwise_product(&lhs.data, &rhs.data, &mut tmp, RingRef::new(ring.get_ring())),
216 (true, false) => Self::add_assign_elementwise_product(&self.un_and_redo_fft(&lhs.data, new_log2_len), &rhs.data, &mut tmp, RingRef::new(ring.get_ring())),
217 (false, true) => Self::add_assign_elementwise_product(&lhs.data, &self.un_and_redo_fft(&rhs.data, new_log2_len), &mut tmp, RingRef::new(ring.get_ring())),
218 (true, true) => Self::add_assign_elementwise_product(&self.un_and_redo_fft(&lhs.data, new_log2_len), &self.un_and_redo_fft(&rhs.data, new_log2_len), &mut tmp, RingRef::new(ring.get_ring())),
219 }
220 }
221 self.get_fft(current_log2_len).unordered_inv_fft(&mut tmp[..], self.ring());
222 for i in 0..min(dst.len(), 1 << current_log2_len) {
223 self.ring.add_assign_ref(&mut dst[i], &tmp[i]);
224 }
225 }
226}
227
228#[test]
229fn test_convolution() {
230 let ring = zn_64::Zn::new(65537);
231 let convolution = NTTConvolution::new_with(ring, Global);
232 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
233 super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.one());
234}