1extern crate alloc;
2#[cfg(feature = "alloc")]
3use super::{
4 cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive},
5 cosets::Coset,
6 twiddles::{get_twiddles, TwiddlesConfig},
7};
8use crate::{
9 fft::cpu::bit_reversing::in_place_bit_reverse_permute,
10 field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field},
11};
12#[cfg(feature = "alloc")]
13use alloc::vec::Vec;
14
15#[cfg(feature = "alloc")]
19pub fn evaluate_cfft(
20 coeff: Vec<FieldElement<Mersenne31Field>>,
21) -> Vec<FieldElement<Mersenne31Field>> {
22 let mut coeff = coeff;
23
24 let domain_log_2_size: u32 = coeff.len().trailing_zeros();
26 let coset = Coset::new_standard(domain_log_2_size);
27 let config = TwiddlesConfig::Evaluation;
28 let twiddles = get_twiddles(coset, config);
29
30 in_place_bit_reverse_permute::<FieldElement<Mersenne31Field>>(&mut coeff);
32 cfft(&mut coeff, twiddles);
33
34 order_cfft_result_naive(&coeff)
36}
37
38#[cfg(feature = "alloc")]
43pub fn interpolate_cfft(
44 eval: Vec<FieldElement<Mersenne31Field>>,
45) -> Vec<FieldElement<Mersenne31Field>> {
46 let mut eval = eval;
47
48 if eval.is_empty() {
49 let poly: Vec<FieldElement<Mersenne31Field>> = Vec::new();
50 return poly;
51 }
52
53 let domain_log_2_size: u32 = eval.len().trailing_zeros();
55 let coset = Coset::new_standard(domain_log_2_size);
56 let config = TwiddlesConfig::Interpolation;
57 let twiddles = get_twiddles(coset, config);
58
59 let mut eval_ordered = order_icfft_input_naive(&mut eval);
61 icfft(&mut eval_ordered, twiddles);
62
63 in_place_bit_reverse_permute::<FieldElement<Mersenne31Field>>(&mut eval_ordered);
65
66 let factor = (FieldElement::<Mersenne31Field>::from(eval.len() as u64))
70 .inv()
71 .unwrap();
72 eval_ordered.iter().map(|coef| coef * factor).collect()
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::circle::cosets::Coset;
79 type FE = FieldElement<Mersenne31Field>;
80 use alloc::vec;
81
82 fn evaluate_poly_4(coef: &[FE; 4], x: FE, y: FE) -> FE {
84 coef[0] + coef[1] * y + coef[2] * x + coef[3] * x * y
85 }
86
87 fn evaluate_poly_8(coef: &[FE; 8], x: FE, y: FE) -> FE {
89 coef[0]
90 + coef[1] * y
91 + coef[2] * x
92 + coef[3] * x * y
93 + coef[4] * (x.square().double() - FE::one())
94 + coef[5] * (x.square().double() - FE::one()) * y
95 + coef[6] * ((x.square() * x).double() - x)
96 + coef[7] * ((x.square() * x).double() - x) * y
97 }
98
99 fn evaluate_poly_16(coef: &[FE; 16], x: FE, y: FE) -> FE {
101 let mut a = x;
102 let mut v = Vec::new();
103 v.push(FE::one());
104 v.push(x);
105 for _ in 2..4 {
106 a = a.square().double() - FE::one();
107 v.push(a);
108 }
109
110 coef[0] * v[0]
111 + coef[1] * y * v[0]
112 + coef[2] * v[1]
113 + coef[3] * y * v[1]
114 + coef[4] * v[2]
115 + coef[5] * y * v[2]
116 + coef[6] * v[1] * v[2]
117 + coef[7] * y * v[1] * v[2]
118 + coef[8] * v[3]
119 + coef[9] * y * v[3]
120 + coef[10] * v[1] * v[3]
121 + coef[11] * y * v[1] * v[3]
122 + coef[12] * v[2] * v[3]
123 + coef[13] * y * v[2] * v[3]
124 + coef[14] * v[1] * v[2] * v[3]
125 + coef[15] * y * v[1] * v[2] * v[3]
126 }
127
128 #[test]
129 fn cfft_evaluation_4_points() {
131 let input = [FE::from(1), FE::from(2), FE::from(3), FE::from(4)];
133
134 let coset = Coset::new_standard(2);
136 let points = Coset::get_coset_points(&coset);
137 let mut expected_result: Vec<FE> = Vec::new();
138 for point in points {
139 let point_eval = evaluate_poly_4(&input, point.x, point.y);
140 expected_result.push(point_eval);
141 }
142
143 let input_vec = input.to_vec();
144 let result = evaluate_cfft(input_vec);
146 let slice_result: &[FE] = &result;
147
148 assert_eq!(slice_result, expected_result);
149 }
150
151 #[test]
152 fn cfft_evaluation_8_points() {
154 let input = [
156 FE::from(1),
157 FE::from(2),
158 FE::from(3),
159 FE::from(4),
160 FE::from(5),
161 FE::from(6),
162 FE::from(7),
163 FE::from(8),
164 ];
165
166 let coset = Coset::new_standard(3);
168 let points = Coset::get_coset_points(&coset);
169 let mut expected_result: Vec<FE> = Vec::new();
170 for point in points {
171 let point_eval = evaluate_poly_8(&input, point.x, point.y);
172 expected_result.push(point_eval);
173 }
174
175 let result = evaluate_cfft(input.to_vec());
177 let slice_result: &[FE] = &result;
178
179 assert_eq!(slice_result, expected_result);
180 }
181
182 #[test]
183 fn cfft_evaluation_16_points() {
185 let input = [
187 FE::from(1),
188 FE::from(2),
189 FE::from(3),
190 FE::from(4),
191 FE::from(5),
192 FE::from(6),
193 FE::from(7),
194 FE::from(8),
195 FE::from(9),
196 FE::from(10),
197 FE::from(11),
198 FE::from(12),
199 FE::from(13),
200 FE::from(14),
201 FE::from(15),
202 FE::from(16),
203 ];
204
205 let coset = Coset::new_standard(4);
207 let points = Coset::get_coset_points(&coset);
208 let mut expected_result: Vec<FE> = Vec::new();
209 for point in points {
210 let point_eval = evaluate_poly_16(&input, point.x, point.y);
211 expected_result.push(point_eval);
212 }
213
214 let result = evaluate_cfft(input.to_vec());
216 let slice_result: &[FE] = &result;
217
218 assert_eq!(slice_result, expected_result);
219 }
220
221 #[test]
222 fn evaluate_and_interpolate_8_points_is_identity() {
223 let coeff = vec![
225 FE::from(1),
226 FE::from(2),
227 FE::from(3),
228 FE::from(4),
229 FE::from(5),
230 FE::from(6),
231 FE::from(7),
232 FE::from(8),
233 ];
234 let evals = evaluate_cfft(coeff.clone());
235 let new_coeff = interpolate_cfft(evals);
236
237 assert_eq!(coeff, new_coeff);
238 }
239
240 #[test]
241 fn evaluate_and_interpolate_8_other_points() {
242 let coeff = vec![
243 FE::from(2147483650),
244 FE::from(147483647),
245 FE::from(2147483700),
246 FE::from(2147483647),
247 FE::from(3147483647),
248 FE::from(4147483647),
249 FE::from(2147483640),
250 FE::from(5147483647),
251 ];
252 let evals = evaluate_cfft(coeff.clone());
253 let new_coeff = interpolate_cfft(evals);
254
255 assert_eq!(coeff, new_coeff);
256 }
257
258 #[test]
259 fn evaluate_and_interpolate_32_points() {
260 let coeff = vec![
262 FE::from(1),
263 FE::from(2),
264 FE::from(3),
265 FE::from(4),
266 FE::from(5),
267 FE::from(6),
268 FE::from(7),
269 FE::from(8),
270 FE::from(9),
271 FE::from(10),
272 FE::from(11),
273 FE::from(12),
274 FE::from(13),
275 FE::from(14),
276 FE::from(15),
277 FE::from(16),
278 FE::from(17),
279 FE::from(18),
280 FE::from(19),
281 FE::from(20),
282 FE::from(21),
283 FE::from(22),
284 FE::from(23),
285 FE::from(24),
286 FE::from(25),
287 FE::from(26),
288 FE::from(27),
289 FE::from(28),
290 FE::from(29),
291 FE::from(30),
292 FE::from(31),
293 FE::from(32),
294 ];
295 let evals = evaluate_cfft(coeff.clone());
296 let new_coeff = interpolate_cfft(evals);
297
298 assert_eq!(coeff, new_coeff);
299 }
300
301 #[test]
302 fn evaluate_and_interpolate_2_pow_20_other_points() {
303 let coeff: Vec<FieldElement<Mersenne31Field>> =
304 (0..2_u32.pow(20)).map(|i| FE::from(&i)).collect();
305 let evals = evaluate_cfft(coeff.clone());
306 let new_coeff = interpolate_cfft(evals);
307
308 assert_eq!(coeff, new_coeff);
309 }
310}