1pub use crate::domain::utils::Elements;
10use crate::domain::{
11 DomainCoeff, EvaluationDomain, MixedRadixEvaluationDomain, Radix2EvaluationDomain,
12};
13use ark_ff::{FftField, Field};
14use ark_serialize::{
15 CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
16};
17use ark_std::{
18 io::{Read, Write},
19 vec::*,
20};
21
22#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
50pub enum GeneralEvaluationDomain<F: Field> {
51 Radix2(Radix2EvaluationDomain<F>),
53 MixedRadix(MixedRadixEvaluationDomain<F>),
55}
56
57macro_rules! map {
58 ($self:expr, $f1:ident $(, $x:expr)*) => {
59 match $self {
60 Self::Radix2(domain) => EvaluationDomain::$f1(domain, $($x)*),
61 Self::MixedRadix(domain) => EvaluationDomain::$f1(domain, $($x)*),
62 }
63 }
64}
65
66impl<F: FftField> CanonicalSerialize for GeneralEvaluationDomain<F> {
67 fn serialize_with_mode<W: Write>(
68 &self,
69 mut writer: W,
70 compress: Compress,
71 ) -> Result<(), SerializationError> {
72 let variant = match self {
73 Self::Radix2(_) => 0u8,
74 Self::MixedRadix(_) => 1u8,
75 };
76 variant.serialize_with_mode(&mut writer, compress)?;
77
78 match self {
79 Self::Radix2(domain) => domain.serialize_with_mode(&mut writer, compress),
80 Self::MixedRadix(domain) => domain.serialize_with_mode(&mut writer, compress),
81 }
82 }
83
84 fn serialized_size(&self, compress: Compress) -> usize {
85 let type_id = match self {
86 Self::Radix2(_) => 0u8,
87 Self::MixedRadix(_) => 1u8,
88 };
89
90 type_id.serialized_size(compress)
91 + match self {
92 Self::Radix2(domain) => domain.serialized_size(compress),
93 Self::MixedRadix(domain) => domain.serialized_size(compress),
94 }
95 }
96}
97
98impl<F: FftField> Valid for GeneralEvaluationDomain<F> {
99 fn check(&self) -> Result<(), SerializationError> {
100 Ok(())
101 }
102}
103
104impl<F: FftField> CanonicalDeserialize for GeneralEvaluationDomain<F> {
105 fn deserialize_with_mode<R: Read>(
106 mut reader: R,
107 compress: Compress,
108 validate: Validate,
109 ) -> Result<Self, SerializationError> {
110 match u8::deserialize_with_mode(&mut reader, compress, validate)? {
111 0 => Radix2EvaluationDomain::deserialize_with_mode(&mut reader, compress, validate)
112 .map(Self::Radix2),
113 1 => MixedRadixEvaluationDomain::deserialize_with_mode(&mut reader, compress, validate)
114 .map(Self::MixedRadix),
115 _ => Err(SerializationError::InvalidData),
116 }
117 }
118}
119
120impl<F: FftField> EvaluationDomain<F> for GeneralEvaluationDomain<F> {
121 type Elements = GeneralElements<F>;
122
123 fn new(num_coeffs: usize) -> Option<Self> {
130 Radix2EvaluationDomain::new(num_coeffs)
131 .map(Self::Radix2)
132 .or_else(|| {
133 F::SMALL_SUBGROUP_BASE
134 .is_some()
135 .then(|| MixedRadixEvaluationDomain::new(num_coeffs).map(Self::MixedRadix))
136 .flatten()
137 })
138 }
139
140 fn get_coset(&self, offset: F) -> Option<Self> {
141 Some(match self {
142 Self::Radix2(domain) => Self::Radix2(domain.get_coset(offset)?),
143 Self::MixedRadix(domain) => Self::MixedRadix(domain.get_coset(offset)?),
144 })
145 }
146
147 fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
148 Radix2EvaluationDomain::<F>::compute_size_of_domain(num_coeffs).or_else(|| {
149 F::SMALL_SUBGROUP_BASE
150 .is_some()
151 .then(|| MixedRadixEvaluationDomain::<F>::compute_size_of_domain(num_coeffs))
152 .flatten()
153 })
154 }
155
156 #[inline]
157 fn size(&self) -> usize {
158 map!(self, size)
159 }
160
161 #[inline]
162 fn log_size_of_group(&self) -> u64 {
163 map!(self, log_size_of_group)
164 }
165
166 #[inline]
167 fn size_inv(&self) -> F {
168 map!(self, size_inv)
169 }
170
171 #[inline]
172 fn group_gen(&self) -> F {
173 map!(self, group_gen)
174 }
175
176 #[inline]
177 fn group_gen_inv(&self) -> F {
178 map!(self, group_gen_inv)
179 }
180
181 #[inline]
182 fn coset_offset(&self) -> F {
183 map!(self, coset_offset)
184 }
185
186 #[inline]
187 fn coset_offset_inv(&self) -> F {
188 map!(self, coset_offset_inv)
189 }
190
191 fn coset_offset_pow_size(&self) -> F {
192 map!(self, coset_offset_pow_size)
193 }
194
195 #[inline]
196 fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
197 map!(self, fft_in_place, coeffs)
198 }
199
200 #[inline]
201 fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
202 map!(self, ifft_in_place, evals)
203 }
204
205 #[inline]
206 fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
207 map!(self, evaluate_all_lagrange_coefficients, tau)
208 }
209
210 #[inline]
211 fn vanishing_polynomial(&self) -> crate::univariate::SparsePolynomial<F> {
212 map!(self, vanishing_polynomial)
213 }
214
215 #[inline]
216 fn evaluate_vanishing_polynomial(&self, tau: F) -> F {
217 map!(self, evaluate_vanishing_polynomial, tau)
218 }
219
220 fn elements(&self) -> GeneralElements<F> {
222 GeneralElements(map!(self, elements))
223 }
224}
225
226pub struct GeneralElements<F: FftField>(Elements<F>);
228
229impl<F: FftField> Iterator for GeneralElements<F> {
230 type Item = F;
231
232 #[inline]
233 fn next(&mut self) -> Option<F> {
234 self.0.next()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use crate::{polynomial::Polynomial, EvaluationDomain, GeneralEvaluationDomain};
241 use ark_ff::Zero;
242 use ark_std::{rand::Rng, test_rng};
243 use ark_test_curves::{bls12_381::Fr, bn384_small_two_adicity::Fr as BNFr};
244
245 #[test]
246 fn vanishing_polynomial_evaluation() {
247 let rng = &mut test_rng();
248 for coeffs in 0..10 {
249 let domain = GeneralEvaluationDomain::<Fr>::new(coeffs).unwrap();
250 let z = domain.vanishing_polynomial();
251 for _ in 0..100 {
252 let point = rng.gen();
253 assert_eq!(
254 z.evaluate(&point),
255 domain.evaluate_vanishing_polynomial(point)
256 )
257 }
258 }
259
260 for coeffs in 15..17 {
261 let domain = GeneralEvaluationDomain::<BNFr>::new(coeffs).unwrap();
262 let z = domain.vanishing_polynomial();
263 for _ in 0..100 {
264 let point = rng.gen();
265 assert_eq!(
266 z.evaluate(&point),
267 domain.evaluate_vanishing_polynomial(point)
268 )
269 }
270 }
271 }
272
273 #[test]
274 fn vanishing_polynomial_vanishes_on_domain() {
275 for coeffs in 0..1000 {
276 let domain = GeneralEvaluationDomain::<Fr>::new(coeffs).unwrap();
277 let z = domain.vanishing_polynomial();
278 for point in domain.elements() {
279 assert!(z.evaluate(&point).is_zero())
280 }
281 }
282 }
283
284 #[test]
285 fn size_of_elements() {
286 for coeffs in 1..10 {
287 let size = 1 << coeffs;
288 let domain = GeneralEvaluationDomain::<Fr>::new(size).unwrap();
289 let domain_size = domain.size();
290 assert_eq!(domain_size, domain.elements().count());
291 }
292 }
293}