feanor_math/algorithms/convolution/
ntt.rs1use std::alloc::{Allocator, Global};
2
3use crate::cow::*;
4use crate::{algorithms::fft::cooley_tuckey::CooleyTuckeyFFT, lazy::LazyVec};
5use crate::homomorphism::*;
6use crate::primitive_int::StaticRing;
7use crate::ring::*;
8use crate::rings::zn::*;
9use crate::integer::*;
10use crate::seq::VectorView;
11
12use super::ConvolutionAlgorithm;
13
14#[stability::unstable(feature = "enable")]
20pub struct NTTConvolution<R_main, R_twiddle, H, A = Global>
21 where R_main: ?Sized + ZnRing,
22 R_twiddle: ?Sized + ZnRing,
23 H: Homomorphism<R_twiddle, R_main> + Clone,
24 A: Allocator + Clone
25{
26 hom: H,
27 fft_algos: LazyVec<CooleyTuckeyFFT<R_main, R_twiddle, H>>,
28 allocator: A
29}
30
31#[stability::unstable(feature = "enable")]
32pub struct PreparedConvolutionOperand<R, A = Global>
33 where R: ?Sized + ZnRing,
34 A: Allocator + Clone
35{
36 significant_entries: usize,
37 ntt_data: Vec<R::Element, A>
38}
39
40impl<R> NTTConvolution<R::Type, R::Type, Identity<R>>
41 where R: RingStore + Clone,
42 R::Type: ZnRing
43{
44 #[stability::unstable(feature = "enable")]
45 pub fn new(ring: R) -> Self {
46 Self::new_with(ring.into_identity(), Global)
47 }
48}
49
50impl<R_main, R_twiddle, H, A> NTTConvolution<R_main, R_twiddle, H, A>
51 where R_main: ?Sized + ZnRing,
52 R_twiddle: ?Sized + ZnRing,
53 H: Homomorphism<R_twiddle, R_main> + Clone,
54 A: Allocator + Clone
55{
56 #[stability::unstable(feature = "enable")]
57 pub fn new_with(hom: H, allocator: A) -> Self {
58 Self {
59 fft_algos: LazyVec::new(),
60 hom: hom,
61 allocator: allocator
62 }
63 }
64
65 #[stability::unstable(feature = "enable")]
66 pub fn ring(&self) -> RingRef<R_main> {
67 RingRef::new(self.hom.codomain().get_ring())
68 }
69
70 fn get_ntt_table<'a>(&'a self, log2_n: usize) -> &'a CooleyTuckeyFFT<R_main, R_twiddle, H> {
71 self.fft_algos.get_or_init(log2_n, || CooleyTuckeyFFT::for_zn_with_hom(self.hom.clone(), log2_n).expect("NTTConvolution was instantiated with parameters that don't support this length"))
72 }
73
74 fn get_ntt_data<'a, V>(
75 &self,
76 data: V,
77 data_prep: Option<&'a PreparedConvolutionOperand<R_main, A>>,
78 significant_entries: usize,
79 ) -> MyCow<'a, Vec<R_main::Element, A>>
80 where V: VectorView<R_main::Element>
81 {
82 assert!(data.len() <= significant_entries);
83 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
84
85 let compute_result = || {
86 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
87 result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
88 result.resize_with(1 << log2_len, || self.ring().zero());
89 self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
90 return result;
91 };
92
93 return if let Some(data_prep) = data_prep {
94 assert!(data_prep.significant_entries >= significant_entries);
95 MyCow::Borrowed(&data_prep.ntt_data)
96 } else {
97 MyCow::Owned(compute_result())
98 }
99 }
100
101 fn prepare_convolution_impl<V>(
102 &self,
103 data: V,
104 len_hint: Option<usize>
105 ) -> PreparedConvolutionOperand<R_main, A>
106 where V: VectorView<R_main::Element>
107 {
108 let significant_entries = if let Some(out_len) = len_hint {
109 assert!(data.len() <= out_len);
110 out_len
111 } else {
112 2 * data.len()
113 };
114 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&significant_entries.try_into().unwrap()).unwrap();
115
116 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
117 result.extend(data.as_iter().map(|x| self.ring().clone_el(x)));
118 result.resize_with(1 << log2_len, || self.ring().zero());
119 self.get_ntt_table(log2_len).unordered_truncated_fft(&mut result, significant_entries);
120
121 return PreparedConvolutionOperand {
122 ntt_data: result,
123 significant_entries: significant_entries
124 };
125 }
126
127 fn compute_convolution_impl<V1, V2>(
128 &self,
129 lhs: V1,
130 mut lhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
131 rhs: V2,
132 mut rhs_prep: Option<&PreparedConvolutionOperand<R_main, A>>,
133 dst: &mut [R_main::Element]
134 )
135 where V1: VectorView<R_main::Element>,
136 V2: VectorView<R_main::Element>
137 {
138 if lhs.len() == 0 || rhs.len() == 0 {
139 return;
140 }
141 let len = lhs.len() + rhs.len() - 1;
142 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
143
144 if lhs_prep.is_some() && (lhs_prep.unwrap().significant_entries < len || lhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
145 lhs_prep = None;
146 }
147 if rhs_prep.is_some() && (rhs_prep.unwrap().significant_entries < len || rhs_prep.unwrap().ntt_data.len() != 1 << log2_len) {
148 rhs_prep = None;
149 }
150
151 let mut lhs_ntt = self.get_ntt_data(lhs, lhs_prep, len);
152 let mut rhs_ntt = self.get_ntt_data(rhs, rhs_prep, len);
153 if rhs_ntt.is_owned() {
154 std::mem::swap(&mut lhs_ntt, &mut rhs_ntt);
155 }
156 let mut lhs_ntt = lhs_ntt.to_mut_with(|data| {
157 let mut copied_data = Vec::with_capacity_in(data.len(), self.allocator.clone());
158 copied_data.extend(data.iter().map(|x| self.ring().clone_el(x)));
159 copied_data
160 });
161
162 for i in 0..len {
163 self.ring().mul_assign_ref(&mut lhs_ntt[i], &rhs_ntt[i]);
164 }
165
166 self.get_ntt_table(log2_len).unordered_truncated_fft_inv(&mut lhs_ntt, len);
167
168 for (i, x) in lhs_ntt.drain(..).enumerate().take(len) {
169 self.ring().add_assign(&mut dst[i], x);
170 }
171 }
172}
173
174impl<R_main, R_twiddle, H, A> ConvolutionAlgorithm<R_main> for NTTConvolution<R_main, R_twiddle, H, A>
175 where R_main: ?Sized + ZnRing,
176 R_twiddle: ?Sized + ZnRing,
177 H: Homomorphism<R_twiddle, R_main> + Clone,
178 A: Allocator + Clone
179{
180 type PreparedConvolutionOperand = PreparedConvolutionOperand<R_main, A>;
181
182 fn supports_ring<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> bool {
183 ring.get_ring() == self.ring().get_ring()
184 }
185
186 fn compute_convolution<S: RingStore<Type = R_main> + Copy, V1: VectorView<<R_main as RingBase>::Element>, V2: VectorView<<R_main as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R_main::Element], ring: S) {
187 assert!(self.supports_ring(ring));
188 self.compute_convolution_impl(
189 lhs,
190 None,
191 rhs,
192 None,
193 dst
194 )
195 }
196
197 fn prepare_convolution_operand<S, V>(&self, val: V, length_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
198 where S: RingStore<Type = R_main> + Copy, V: VectorView<R_main::Element>
199 {
200 assert!(self.supports_ring(ring));
201 self.prepare_convolution_impl(
202 val,
203 length_hint
204 )
205 }
206
207 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R_main::Element], ring: S)
208 where S: RingStore<Type = R_main> + Copy, V1: VectorView<R_main::Element>, V2: VectorView<R_main::Element>
209 {
210 assert!(self.supports_ring(ring));
211 self.compute_convolution_impl(
212 lhs,
213 lhs_prep,
214 rhs,
215 rhs_prep,
216 dst
217 )
218 }
219}
220
221#[test]
222fn test_convolution() {
223 let ring = zn_64::Zn::new(65537);
224 let convolution = NTTConvolution::new_with(ring.identity(), Global);
225 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
226}