triton_vm/
arithmetic_domain.rs

1use std::ops::Mul;
2use std::ops::MulAssign;
3
4use num_traits::ConstOne;
5use num_traits::One;
6use num_traits::Zero;
7use rayon::prelude::*;
8use twenty_first::math::traits::FiniteField;
9use twenty_first::math::traits::PrimitiveRootOfUnity;
10use twenty_first::prelude::*;
11
12use crate::error::ArithmeticDomainError;
13
14type Result<T> = std::result::Result<T, ArithmeticDomainError>;
15
16#[derive(Debug, Copy, Clone, Eq, PartialEq)]
17pub struct ArithmeticDomain {
18    pub offset: BFieldElement,
19    pub generator: BFieldElement,
20    pub length: usize,
21}
22
23impl ArithmeticDomain {
24    /// Create a new domain with the given length.
25    /// No offset is applied, but can be added through
26    /// [`with_offset()`](Self::with_offset).
27    ///
28    /// # Errors
29    ///
30    /// Errors if the domain length is not a power of 2.
31    pub fn of_length(length: usize) -> Result<Self> {
32        let domain = Self {
33            offset: BFieldElement::ONE,
34            generator: Self::generator_for_length(length as u64)?,
35            length,
36        };
37        Ok(domain)
38    }
39
40    /// Set the offset of the domain.
41    #[must_use]
42    pub fn with_offset(mut self, offset: BFieldElement) -> Self {
43        self.offset = offset;
44        self
45    }
46
47    /// Derive a generator for a domain of the given length.
48    ///
49    /// # Errors
50    ///
51    /// Errors if the domain length is not a power of 2.
52    pub fn generator_for_length(domain_length: u64) -> Result<BFieldElement> {
53        let error = ArithmeticDomainError::PrimitiveRootNotSupported(domain_length);
54        BFieldElement::primitive_root_of_unity(domain_length).ok_or(error)
55    }
56
57    pub fn evaluate<FF>(&self, polynomial: &Polynomial<FF>) -> Vec<FF>
58    where
59        FF: FiniteField
60            + MulAssign<BFieldElement>
61            + Mul<BFieldElement, Output = FF>
62            + From<BFieldElement>
63            + 'static,
64    {
65        let (offset, length) = (self.offset, self.length);
66        let evaluate_from = |chunk| Polynomial::from(chunk).fast_coset_evaluate(offset, length);
67
68        // avoid `enumerate` to directly get index of the right type
69        let mut indexed_chunks = (0..).zip(polynomial.coefficients().chunks(length));
70
71        // only allocate a bunch of zeros if there are no chunks
72        let mut values = indexed_chunks.next().map_or_else(
73            || vec![FF::ZERO; length],
74            |(_, first_chunk)| evaluate_from(first_chunk),
75        );
76        for (chunk_index, chunk) in indexed_chunks {
77            let coefficient_index = chunk_index * u64::try_from(length).unwrap();
78            let scaled_offset = offset.mod_pow(coefficient_index);
79            values
80                .par_iter_mut()
81                .zip(evaluate_from(chunk))
82                .for_each(|(value, evaluation)| *value += evaluation * scaled_offset);
83        }
84
85        values
86    }
87
88    /// # Panics
89    ///
90    /// Panics if the length of the argument does not match the length of
91    /// `self`.
92    pub fn interpolate<FF>(&self, values: &[FF]) -> Polynomial<'static, FF>
93    where
94        FF: FiniteField + MulAssign<BFieldElement> + Mul<BFieldElement, Output = FF>,
95    {
96        debug_assert_eq!(self.length, values.len()); // required by `fast_coset_interpolate`
97
98        Polynomial::fast_coset_interpolate(self.offset, values)
99    }
100
101    pub fn low_degree_extension<FF>(&self, codeword: &[FF], target_domain: Self) -> Vec<FF>
102    where
103        FF: FiniteField
104            + MulAssign<BFieldElement>
105            + Mul<BFieldElement, Output = FF>
106            + From<BFieldElement>
107            + 'static,
108    {
109        target_domain.evaluate(&self.interpolate(codeword))
110    }
111
112    /// Compute the `n`th element of the domain.
113    pub fn domain_value(&self, n: u32) -> BFieldElement {
114        self.generator.mod_pow_u32(n) * self.offset
115    }
116
117    pub fn domain_values(&self) -> Vec<BFieldElement> {
118        let mut accumulator = bfe!(1);
119        let mut domain_values = Vec::with_capacity(self.length);
120
121        for _ in 0..self.length {
122            domain_values.push(accumulator * self.offset);
123            accumulator *= self.generator;
124        }
125        assert!(
126            accumulator.is_one(),
127            "length must be the order of the generator"
128        );
129        domain_values
130    }
131
132    /// A polynomial that evaluates to 0 on (and only on)
133    /// a [domain value][Self::domain_values].
134    pub fn zerofier(&self) -> Polynomial<'_, BFieldElement> {
135        if self.offset.is_zero() {
136            return Polynomial::x_to_the(1);
137        }
138
139        Polynomial::x_to_the(self.length)
140            - Polynomial::from_constant(self.offset.mod_pow(self.length as u64))
141    }
142
143    /// [`Self::zerofier`] times the argument.
144    /// More performant than polynomial multiplication.
145    /// See [`Self::zerofier`] for details.
146    pub fn mul_zerofier_with<FF>(&self, polynomial: Polynomial<FF>) -> Polynomial<'static, FF>
147    where
148        FF: FiniteField + Mul<BFieldElement, Output = FF>,
149    {
150        // use knowledge of zerofier's shape for faster multiplication
151        polynomial.clone().shift_coefficients(self.length)
152            - polynomial.scalar_mul(self.offset.mod_pow(self.length as u64))
153    }
154
155    pub(crate) fn halve(&self) -> Result<Self> {
156        if self.length < 2 {
157            return Err(ArithmeticDomainError::TooSmallForHalving(self.length));
158        }
159        let domain = Self {
160            offset: self.offset.square(),
161            generator: self.generator.square(),
162            length: self.length / 2,
163        };
164        Ok(domain)
165    }
166}
167
168#[cfg(test)]
169#[cfg_attr(coverage_nightly, coverage(off))]
170mod tests {
171    use assert2::let_assert;
172    use itertools::Itertools;
173    use proptest::collection::vec;
174    use proptest::prelude::*;
175    use proptest_arbitrary_interop::arb;
176    use test_strategy::proptest;
177
178    use crate::shared_tests::arbitrary_polynomial;
179    use crate::shared_tests::arbitrary_polynomial_of_degree;
180
181    use super::*;
182
183    prop_compose! {
184        fn arbitrary_domain()(
185            length in (0_usize..17).prop_map(|x| 1 << x),
186        )(
187            domain in arbitrary_domain_of_length(length),
188        ) -> ArithmeticDomain {
189            domain
190        }
191    }
192
193    prop_compose! {
194        fn arbitrary_halveable_domain()(
195            length in (2_usize..17).prop_map(|x| 1 << x),
196        )(
197            domain in arbitrary_domain_of_length(length),
198        ) -> ArithmeticDomain {
199            domain
200        }
201    }
202
203    prop_compose! {
204        fn arbitrary_domain_of_length(length: usize)(
205            offset in arb(),
206        ) -> ArithmeticDomain {
207            ArithmeticDomain::of_length(length).unwrap().with_offset(offset)
208        }
209    }
210
211    #[proptest]
212    fn evaluate_empty_polynomial(
213        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
214        #[strategy(arbitrary_polynomial_of_degree(-1))] poly: Polynomial<'static, XFieldElement>,
215    ) {
216        domain.evaluate(&poly);
217    }
218
219    #[proptest]
220    fn evaluate_constant_polynomial(
221        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
222        #[strategy(arbitrary_polynomial_of_degree(0))] poly: Polynomial<'static, XFieldElement>,
223    ) {
224        domain.evaluate(&poly);
225    }
226
227    #[proptest]
228    fn evaluate_linear_polynomial(
229        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
230        #[strategy(arbitrary_polynomial_of_degree(1))] poly: Polynomial<'static, XFieldElement>,
231    ) {
232        domain.evaluate(&poly);
233    }
234
235    #[proptest]
236    fn evaluate_polynomial(
237        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
238        #[strategy(arbitrary_polynomial())] polynomial: Polynomial<'static, XFieldElement>,
239    ) {
240        domain.evaluate(&polynomial);
241    }
242
243    #[test]
244    fn domain_values() {
245        let poly = Polynomial::<BFieldElement>::x_to_the(3);
246        let x_cubed_coefficients = poly.clone().into_coefficients();
247
248        for order in [4, 8, 32] {
249            let generator = BFieldElement::primitive_root_of_unity(order).unwrap();
250            let offset = BFieldElement::generator();
251            let b_domain = ArithmeticDomain::of_length(order as usize)
252                .unwrap()
253                .with_offset(offset);
254
255            let expected_b_values = (0..order)
256                .map(|i| offset * generator.mod_pow(i))
257                .collect_vec();
258            let actual_b_values_1 = b_domain.domain_values();
259            let actual_b_values_2 = (0..order as u32)
260                .map(|i| b_domain.domain_value(i))
261                .collect_vec();
262            assert_eq!(expected_b_values, actual_b_values_1);
263            assert_eq!(expected_b_values, actual_b_values_2);
264
265            let values = b_domain.evaluate(&poly);
266            assert_ne!(values, x_cubed_coefficients);
267
268            let interpolant = b_domain.interpolate(&values);
269            assert_eq!(poly, interpolant);
270
271            // Verify that batch-evaluated values match a manual evaluation
272            for i in 0..order {
273                let indeterminate = b_domain.domain_value(i as u32);
274                let evaluation: BFieldElement = poly.evaluate(indeterminate);
275                assert_eq!(evaluation, values[i as usize]);
276            }
277        }
278    }
279
280    #[test]
281    fn low_degree_extension() {
282        let short_domain_len = 32;
283        let long_domain_len = 128;
284        let unit_distance = long_domain_len / short_domain_len;
285
286        let short_domain = ArithmeticDomain::of_length(short_domain_len).unwrap();
287        let long_domain = ArithmeticDomain::of_length(long_domain_len).unwrap();
288
289        let polynomial = Polynomial::new(bfe_vec![1, 2, 3, 4]);
290        let short_codeword = short_domain.evaluate(&polynomial);
291        let long_codeword = short_domain.low_degree_extension(&short_codeword, long_domain);
292
293        assert_eq!(long_codeword.len(), long_domain_len);
294
295        let long_codeword_sub_view = long_codeword
296            .into_iter()
297            .step_by(unit_distance)
298            .collect_vec();
299        assert_eq!(short_codeword, long_codeword_sub_view);
300    }
301
302    #[proptest]
303    fn halving_domain_squares_all_points(
304        #[strategy(arbitrary_halveable_domain())] domain: ArithmeticDomain,
305    ) {
306        let half_domain = domain.halve()?;
307        prop_assert_eq!(domain.length / 2, half_domain.length);
308
309        let domain_points = domain.domain_values();
310        let half_domain_points = half_domain.domain_values();
311
312        for (domain_point, halved_domain_point) in domain_points
313            .into_iter()
314            .zip(half_domain_points.into_iter())
315        {
316            prop_assert_eq!(domain_point.square(), halved_domain_point);
317        }
318    }
319
320    #[test]
321    fn too_small_domains_cannot_be_halved() {
322        for i in [0, 1] {
323            let domain = ArithmeticDomain::of_length(i).unwrap();
324            let_assert!(Err(err) = domain.halve());
325            assert!(ArithmeticDomainError::TooSmallForHalving(i) == err);
326        }
327    }
328
329    #[proptest]
330    fn can_evaluate_polynomial_larger_than_domain(
331        #[strategy(1_usize..10)] _log_domain_length: usize,
332        #[strategy(1_usize..5)] _expansion_factor: usize,
333        #[strategy(Just(1 << #_log_domain_length))] domain_length: usize,
334        #[strategy(vec(arb(),#domain_length*#_expansion_factor))] coefficients: Vec<BFieldElement>,
335        #[strategy(arb())] offset: BFieldElement,
336    ) {
337        let domain = ArithmeticDomain::of_length(domain_length)
338            .unwrap()
339            .with_offset(offset);
340        let polynomial = Polynomial::new(coefficients);
341
342        let values0 = domain.evaluate(&polynomial);
343        let values1 = polynomial.batch_evaluate(&domain.domain_values());
344        assert_eq!(values0, values1);
345    }
346
347    #[proptest]
348    fn zerofier_is_actually_zerofier(#[strategy(arbitrary_domain())] domain: ArithmeticDomain) {
349        let actual_zerofier = Polynomial::zerofier(&domain.domain_values());
350        prop_assert_eq!(actual_zerofier, domain.zerofier());
351    }
352
353    #[proptest]
354    fn multiplication_with_zerofier_is_identical_to_method_mul_with_zerofier(
355        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
356        #[strategy(arbitrary_polynomial())] polynomial: Polynomial<'static, XFieldElement>,
357    ) {
358        let mul = domain.zerofier() * polynomial.clone();
359        let mul_with = domain.mul_zerofier_with(polynomial);
360        prop_assert_eq!(mul, mul_with);
361    }
362}