Skip to main content

dusk_plonk/composer/
permutation.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4//
5// Copyright (c) DUSK NETWORK. All rights reserved.
6
7use crate::composer::{WireData, Witness};
8use crate::fft::{EvaluationDomain, Polynomial};
9use alloc::vec::Vec;
10use constants::{K1, K2, K3};
11use dusk_bls12_381::BlsScalar;
12use hashbrown::HashMap;
13use itertools::izip;
14
15pub(crate) mod constants;
16
17/// Permutation provides the necessary state information and functions
18/// to create the permutation polynomial. In the literature, Z(X) is the
19/// "accumulator", this is what this codebase calls the permutation polynomial.
20#[derive(Debug, Clone)]
21pub(crate) struct Permutation {
22    // Maps a witness to the wires that it is associated to.
23    pub(crate) witness_map: HashMap<Witness, Vec<WireData>>,
24}
25
26impl Permutation {
27    /// Creates a Permutation struct with an expected capacity of zero.
28    pub(crate) fn new() -> Permutation {
29        Permutation::with_capacity(0)
30    }
31
32    /// Creates a Permutation struct with an expected capacity of `n`.
33    pub(crate) fn with_capacity(size: usize) -> Permutation {
34        Permutation {
35            witness_map: HashMap::with_capacity(size),
36        }
37    }
38
39    /// Creates a new [`Witness`] by incrementing the index of the
40    /// `witness_map`.
41    ///
42    /// This is correct as whenever we add a new [`Witness`] into the system It
43    /// is always allocated in the `witness_map`.
44    pub(crate) fn new_witness(&mut self) -> Witness {
45        // Generate the Witness
46        let var = Witness::new(self.witness_map.keys().len());
47
48        // Allocate space for the Witness on the witness_map
49        // Each vector is initialized with a capacity of 16.
50        // This number is a best guess estimate.
51        self.witness_map.insert(var, Vec::with_capacity(16usize));
52
53        var
54    }
55
56    /// Checks that the [`Witness`]s are valid by determining if they have been
57    /// added to the system
58    fn valid_witnesses(&self, witnesses: &[Witness]) -> bool {
59        witnesses
60            .iter()
61            .all(|var| self.witness_map.contains_key(var))
62    }
63
64    /// Maps a set of [`Witness`]s (a,b,c,d) to a set of [`Wire`](WireData)s
65    /// (left, right, out, fourth) with the corresponding gate index
66    pub fn add_witnesses_to_map<T: Into<Witness>>(
67        &mut self,
68        a: T,
69        b: T,
70        c: T,
71        d: T,
72        gate_index: usize,
73    ) {
74        let left: WireData = WireData::Left(gate_index);
75        let right: WireData = WireData::Right(gate_index);
76        let output: WireData = WireData::Output(gate_index);
77        let fourth: WireData = WireData::Fourth(gate_index);
78
79        // Map each witness to the wire it is associated with
80        // This essentially tells us that:
81        self.add_witness_to_map(a.into(), left);
82        self.add_witness_to_map(b.into(), right);
83        self.add_witness_to_map(c.into(), output);
84        self.add_witness_to_map(d.into(), fourth);
85    }
86
87    pub(crate) fn add_witness_to_map<T: Into<Witness> + Copy>(
88        &mut self,
89        var: T,
90        wire_data: WireData,
91    ) {
92        assert!(self.valid_witnesses(&[var.into()]));
93
94        // Since we always allocate space for the Vec of WireData when a
95        // Witness is added to the witness_map, this should never fail
96        let vec_wire_data = self.witness_map.get_mut(&var.into()).unwrap();
97        vec_wire_data.push(wire_data);
98    }
99
100    // Performs shift by one permutation and computes sigma_1, sigma_2 and
101    // sigma_3, sigma_4 permutations from the witness maps
102    pub(super) fn compute_sigma_permutations(
103        &mut self,
104        n: usize,
105    ) -> [Vec<WireData>; 4] {
106        let sigma_1: Vec<_> = (0..n).map(WireData::Left).collect();
107        let sigma_2: Vec<_> = (0..n).map(WireData::Right).collect();
108        let sigma_3: Vec<_> = (0..n).map(WireData::Output).collect();
109        let sigma_4: Vec<_> = (0..n).map(WireData::Fourth).collect();
110
111        let mut sigmas = [sigma_1, sigma_2, sigma_3, sigma_4];
112
113        for (_, wire_data) in self.witness_map.iter() {
114            // Gets the data for each wire associated with this witness
115            for (wire_index, current_wire) in wire_data.iter().enumerate() {
116                // Fetch index of the next wire, if it is the last element
117                // We loop back around to the beginning
118                let next_index = match wire_index == wire_data.len() - 1 {
119                    true => 0,
120                    false => wire_index + 1,
121                };
122
123                // Fetch the next wire
124                let next_wire = &wire_data[next_index];
125
126                // Map current wire to next wire
127                match current_wire {
128                    WireData::Left(index) => sigmas[0][*index] = *next_wire,
129                    WireData::Right(index) => sigmas[1][*index] = *next_wire,
130                    WireData::Output(index) => sigmas[2][*index] = *next_wire,
131                    WireData::Fourth(index) => sigmas[3][*index] = *next_wire,
132                };
133            }
134        }
135
136        sigmas
137    }
138
139    fn compute_permutation_lagrange(
140        &self,
141        sigma_mapping: &[WireData],
142        domain: &EvaluationDomain,
143    ) -> Vec<BlsScalar> {
144        let roots: Vec<_> = domain.elements().collect();
145
146        let lagrange_poly: Vec<BlsScalar> = sigma_mapping
147            .iter()
148            .map(|x| match x {
149                WireData::Left(index) => {
150                    let root = &roots[*index];
151                    *root
152                }
153                WireData::Right(index) => {
154                    let root = &roots[*index];
155                    K1 * root
156                }
157                WireData::Output(index) => {
158                    let root = &roots[*index];
159                    K2 * root
160                }
161                WireData::Fourth(index) => {
162                    let root = &roots[*index];
163                    K3 * root
164                }
165            })
166            .collect();
167
168        lagrange_poly
169    }
170
171    /// Computes the sigma polynomials which are used to build the permutation
172    /// polynomial
173    pub(crate) fn compute_sigma_polynomials(
174        &mut self,
175        n: usize,
176        domain: &EvaluationDomain,
177    ) -> [Polynomial; 4] {
178        // Compute sigma mappings
179        let sigmas = self.compute_sigma_permutations(n);
180
181        assert_eq!(sigmas[0].len(), n);
182        assert_eq!(sigmas[1].len(), n);
183        assert_eq!(sigmas[2].len(), n);
184        assert_eq!(sigmas[3].len(), n);
185
186        // Define sigma permutations over disjoint cosets generated by K1/K2/K3.
187        let s_sigma_1 = self.compute_permutation_lagrange(&sigmas[0], domain);
188        let s_sigma_2 = self.compute_permutation_lagrange(&sigmas[1], domain);
189        let s_sigma_3 = self.compute_permutation_lagrange(&sigmas[2], domain);
190        let s_sigma_4 = self.compute_permutation_lagrange(&sigmas[3], domain);
191
192        let s_sigma_1_poly =
193            Polynomial::from_coefficients_vec(domain.ifft(&s_sigma_1));
194        let s_sigma_2_poly =
195            Polynomial::from_coefficients_vec(domain.ifft(&s_sigma_2));
196        let s_sigma_3_poly =
197            Polynomial::from_coefficients_vec(domain.ifft(&s_sigma_3));
198        let s_sigma_4_poly =
199            Polynomial::from_coefficients_vec(domain.ifft(&s_sigma_4));
200
201        [
202            s_sigma_1_poly,
203            s_sigma_2_poly,
204            s_sigma_3_poly,
205            s_sigma_4_poly,
206        ]
207    }
208
209    // Uses a rayon multizip to allow more code flexibility while remaining
210    // parallelizable. This can be adapted into a general product argument
211    // for any number of wires.
212    pub(crate) fn compute_permutation_vec(
213        &self,
214        domain: &EvaluationDomain,
215        wires: [&[BlsScalar]; 4],
216        beta: &BlsScalar,
217        gamma: &BlsScalar,
218        sigma_polys: [&Polynomial; 4],
219    ) -> Vec<BlsScalar> {
220        let n = domain.size();
221
222        // Constants defining cosets H, k1H, k2H, etc
223        let ks = vec![BlsScalar::one(), K1, K2, K3];
224
225        // Transpose wires and sigma values to get "rows" in the form [a_i,
226        // b_i, c_i, d_i] where each row contains the wire and sigma
227        // values for a single gate
228        let gatewise_wires = izip!(wires[0], wires[1], wires[2], wires[3])
229            .map(|(w0, w1, w2, w3)| vec![w0, w1, w2, w3]);
230
231        let gatewise_sigmas: Vec<Vec<BlsScalar>> =
232            sigma_polys.iter().map(|sigma| domain.fft(sigma)).collect();
233        let gatewise_sigmas = izip!(
234            &gatewise_sigmas[0],
235            &gatewise_sigmas[1],
236            &gatewise_sigmas[2],
237            &gatewise_sigmas[3]
238        )
239        .map(|(s0, s1, s2, s3)| vec![s0, s1, s2, s3]);
240
241        // Compute all roots
242        // Non-parallelizable?
243        let roots: Vec<BlsScalar> = domain.elements().collect();
244
245        let product_argument = izip!(roots, gatewise_sigmas, gatewise_wires)
246            // Associate each wire value in a gate with the k defining its coset
247            .map(|(gate_root, gate_sigmas, gate_wires)| {
248                (gate_root, izip!(gate_sigmas, gate_wires, &ks))
249            })
250            // Now the ith element represents gate i and will have the form:
251            //   (root_i, ((w0_i, s0_i, k0), (w1_i, s1_i, k1), ..., (wm_i, sm_i,
252            // km)))   for m different wires, which is all the
253            // information   needed for a single product coefficient
254            // for a single gate Multiply up the numerator and
255            // denominator irreducibles for each gate   and pair the
256            // results
257            .map(|(gate_root, wire_params)| {
258                (
259                    // Numerator product
260                    wire_params
261                        .clone()
262                        .map(|(_sigma, wire, k)| {
263                            wire + beta * k * gate_root + gamma
264                        })
265                        .product::<BlsScalar>(),
266                    // Denominator product
267                    wire_params
268                        .map(|(sigma, wire, _k)| wire + beta * sigma + gamma)
269                        .product::<BlsScalar>(),
270                )
271            })
272            // Divide each pair to get the single scalar representing each gate
273            .map(|(n, d)| n * d.invert().unwrap())
274            // Collect into vector intermediary since rayon does not support
275            // `scan`
276            .collect::<Vec<BlsScalar>>();
277
278        let mut z = Vec::with_capacity(n);
279
280        // First element is one
281        let mut state = BlsScalar::one();
282        z.push(state);
283
284        // Accumulate by successively multiplying the scalars
285        // Non-parallelizable?
286        for s in product_argument {
287            state *= s;
288            z.push(state);
289        }
290
291        // Remove the last(n+1'th) element
292        z.remove(n);
293
294        assert_eq!(n, z.len());
295
296        z
297    }
298}
299
300#[cfg(feature = "std")]
301#[cfg(test)]
302mod test {
303    use super::*;
304    use crate::fft::Polynomial;
305    use dusk_bls12_381::BlsScalar;
306    use ff::Field;
307    use rand_core::OsRng;
308
309    #[allow(dead_code)]
310    fn compute_fast_permutation_poly(
311        domain: &EvaluationDomain,
312        a: &[BlsScalar],
313        b: &[BlsScalar],
314        c: &[BlsScalar],
315        d: &[BlsScalar],
316        beta: &BlsScalar,
317        gamma: &BlsScalar,
318        (s_sigma_1_poly, s_sigma_2_poly, s_sigma_3_poly, s_sigma_4_poly): (
319            &Polynomial,
320            &Polynomial,
321            &Polynomial,
322            &Polynomial,
323        ),
324    ) -> Vec<BlsScalar> {
325        let n = domain.size();
326
327        // Compute beta * roots
328        let common_roots: Vec<BlsScalar> =
329            domain.elements().map(|root| root * beta).collect();
330
331        let s_sigma_1_mapping = domain.fft(s_sigma_1_poly);
332        let s_sigma_2_mapping = domain.fft(s_sigma_2_poly);
333        let s_sigma_3_mapping = domain.fft(s_sigma_3_poly);
334        let s_sigma_4_mapping = domain.fft(s_sigma_4_poly);
335
336        // Compute beta * sigma polynomials
337        let beta_s_sigma_1: Vec<_> =
338            s_sigma_1_mapping.iter().map(|sigma| sigma * beta).collect();
339        let beta_s_sigma_2: Vec<_> =
340            s_sigma_2_mapping.iter().map(|sigma| sigma * beta).collect();
341        let beta_s_sigma_3: Vec<_> =
342            s_sigma_3_mapping.iter().map(|sigma| sigma * beta).collect();
343        let beta_s_sigma_4: Vec<_> =
344            s_sigma_4_mapping.iter().map(|sigma| sigma * beta).collect();
345
346        // Compute beta * roots * K1
347        let beta_roots_k1: Vec<_> =
348            common_roots.iter().map(|x| x * K1).collect();
349
350        // Compute beta * roots * K2
351        let beta_roots_k2: Vec<_> =
352            common_roots.iter().map(|x| x * K2).collect();
353
354        // Compute beta * roots * K3
355        let beta_roots_k3: Vec<_> =
356            common_roots.iter().map(|x| x * K3).collect();
357
358        // Compute left_wire + gamma
359        let a_gamma: Vec<_> = a.iter().map(|a| a + gamma).collect();
360
361        // Compute right_wire + gamma
362        let b_gamma: Vec<_> = b.iter().map(|b| b + gamma).collect();
363
364        // Compute out_wire + gamma
365        let c_gamma: Vec<_> = c.iter().map(|c| c + gamma).collect();
366
367        // Compute fourth_wire + gamma
368        let d_gamma: Vec<_> = d.iter().map(|d| d + gamma).collect();
369
370        // Compute 6 accumulator components
371        // Parallelizable
372        let accumulator_components_without_l1: Vec<_> = izip!(
373            a_gamma,
374            b_gamma,
375            c_gamma,
376            d_gamma,
377            common_roots,
378            beta_roots_k1,
379            beta_roots_k2,
380            beta_roots_k3,
381            beta_s_sigma_1,
382            beta_s_sigma_2,
383            beta_s_sigma_3,
384            beta_s_sigma_4,
385        )
386        .map(
387            |(
388                a_gamma,
389                b_gamma,
390                c_gamma,
391                d_gamma,
392                beta_root,
393                beta_root_k1,
394                beta_root_k2,
395                beta_root_k3,
396                beta_s_sigma_1,
397                beta_s_sigma_2,
398                beta_s_sigma_3,
399                beta_s_sigma_4,
400            )| {
401                // w_j + beta * root^j-1 + gamma
402                let ac1 = a_gamma + beta_root;
403
404                // w_{n+j} + beta * K1 * root^j-1 + gamma
405                let ac2 = b_gamma + beta_root_k1;
406
407                // w_{2n+j} + beta * K2 * root^j-1 + gamma
408                let ac3 = c_gamma + beta_root_k2;
409
410                // w_{3n+j} + beta * K3 * root^j-1 + gamma
411                let ac4 = d_gamma + beta_root_k3;
412
413                // 1 / w_j + beta * sigma(j) + gamma
414                let ac5 = (a_gamma + beta_s_sigma_1).invert().unwrap();
415
416                // 1 / w_{n+j} + beta * sigma(n+j) + gamma
417                let ac6 = (b_gamma + beta_s_sigma_2).invert().unwrap();
418
419                // 1 / w_{2n+j} + beta * sigma(2n+j) + gamma
420                let ac7 = (c_gamma + beta_s_sigma_3).invert().unwrap();
421
422                // 1 / w_{3n+j} + beta * sigma(3n+j) + gamma
423                let ac8 = (d_gamma + beta_s_sigma_4).invert().unwrap();
424
425                [ac1, ac2, ac3, ac4, ac5, ac6, ac7, ac8]
426            },
427        )
428        .collect();
429
430        // Prepend ones to the beginning of each accumulator to signify L_1(x)
431        let accumulator_components = core::iter::once([BlsScalar::one(); 8])
432            .chain(accumulator_components_without_l1);
433
434        // Multiply each component of the accumulators
435        // A simplified example is the following:
436        // A1 = [1,2,3,4]
437        // result = [1, 1*2, 1*2*3, 1*2*3*4]
438        // Non Parallelizable
439        let mut prev = [BlsScalar::one(); 8];
440
441        let product_accumulated_components: Vec<_> = accumulator_components
442            .map(|current_component| {
443                current_component
444                    .iter()
445                    .zip(prev.iter_mut())
446                    .for_each(|(curr, prev)| *prev *= curr);
447                prev
448            })
449            .collect();
450
451        // Right now we basically have 6 accumulators of the form:
452        // A1 = [a1, a1 * a2, a1*a2*a3,...]
453        // A2 = [b1, b1 * b2, b1*b2*b3,...]
454        // A3 = [c1, c1 * c2, c1*c2*c3,...]
455        // ... and so on
456        // We want:
457        // [a1*b1*c1, a1 * a2 *b1 * b2 * c1 * c2,...]
458        // Parallelizable
459        let mut z: Vec<_> = product_accumulated_components
460            .iter()
461            .map(move |current_component| current_component.iter().product())
462            .collect();
463        // Remove the last(n+1'th) element
464        z.remove(n);
465
466        assert_eq!(n, z.len());
467
468        z
469    }
470
471    fn compute_slow_permutation_poly<I>(
472        domain: &EvaluationDomain,
473        a: I,
474        b: I,
475        c: I,
476        d: I,
477        beta: &BlsScalar,
478        gamma: &BlsScalar,
479        (s_sigma_1_poly, s_sigma_2_poly, s_sigma_3_poly, s_sigma_4_poly): (
480            &Polynomial,
481            &Polynomial,
482            &Polynomial,
483            &Polynomial,
484        ),
485    ) -> (Vec<BlsScalar>, Vec<BlsScalar>, Vec<BlsScalar>)
486    where
487        I: Iterator<Item = BlsScalar>,
488    {
489        let n = domain.size();
490
491        let s_sigma_1_mapping = domain.fft(s_sigma_1_poly);
492        let s_sigma_2_mapping = domain.fft(s_sigma_2_poly);
493        let s_sigma_3_mapping = domain.fft(s_sigma_3_poly);
494        let s_sigma_4_mapping = domain.fft(s_sigma_4_poly);
495
496        // Compute beta * sigma polynomials
497        let beta_s_sigma_1_iter =
498            s_sigma_1_mapping.iter().map(|sigma| *sigma * beta);
499        let beta_s_sigma_2_iter =
500            s_sigma_2_mapping.iter().map(|sigma| *sigma * beta);
501        let beta_s_sigma_3_iter =
502            s_sigma_3_mapping.iter().map(|sigma| *sigma * beta);
503        let beta_s_sigma_4_iter =
504            s_sigma_4_mapping.iter().map(|sigma| *sigma * beta);
505
506        // Compute beta * roots
507        let beta_roots_iter = domain.elements().map(|root| root * beta);
508
509        // Compute beta * roots * K1
510        let beta_roots_k1_iter = domain.elements().map(|root| K1 * beta * root);
511
512        // Compute beta * roots * K2
513        let beta_roots_k2_iter = domain.elements().map(|root| K2 * beta * root);
514
515        // Compute beta * roots * K3
516        let beta_roots_k3_iter = domain.elements().map(|root| K3 * beta * root);
517
518        // Compute left_wire + gamma
519        let a_gamma: Vec<_> = a.map(|w| w + gamma).collect();
520
521        // Compute right_wire + gamma
522        let b_gamma: Vec<_> = b.map(|w| w + gamma).collect();
523
524        // Compute out_wire + gamma
525        let c_gamma: Vec<_> = c.map(|w| w + gamma).collect();
526
527        // Compute fourth_wire + gamma
528        let d_gamma: Vec<_> = d.map(|w| w + gamma).collect();
529
530        let mut numerator_partial_components: Vec<BlsScalar> =
531            Vec::with_capacity(n);
532        let mut denominator_partial_components: Vec<BlsScalar> =
533            Vec::with_capacity(n);
534
535        let mut numerator_coefficients: Vec<BlsScalar> = Vec::with_capacity(n);
536        let mut denominator_coefficients: Vec<BlsScalar> =
537            Vec::with_capacity(n);
538
539        // First element in both of them is one
540        numerator_coefficients.push(BlsScalar::one());
541        denominator_coefficients.push(BlsScalar::one());
542
543        // Compute numerator coefficients
544        for (
545            a_gamma,
546            b_gamma,
547            c_gamma,
548            d_gamma,
549            beta_root,
550            beta_root_k1,
551            beta_root_k2,
552            beta_root_k3,
553        ) in izip!(
554            a_gamma.iter(),
555            b_gamma.iter(),
556            c_gamma.iter(),
557            d_gamma.iter(),
558            beta_roots_iter,
559            beta_roots_k1_iter,
560            beta_roots_k2_iter,
561            beta_roots_k3_iter,
562        ) {
563            // (a + beta * root + gamma)
564            let prod_a = beta_root + a_gamma;
565
566            // (b + beta * root * k_1 + gamma)
567            let prod_b = beta_root_k1 + b_gamma;
568
569            // (c + beta * root * k_2 + gamma)
570            let prod_c = beta_root_k2 + c_gamma;
571
572            // (d + beta * root * k_3 + gamma)
573            let prod_d = beta_root_k3 + d_gamma;
574
575            let mut prod = prod_a * prod_b * prod_c * prod_d;
576
577            numerator_partial_components.push(prod);
578
579            prod *= numerator_coefficients.last().unwrap();
580
581            numerator_coefficients.push(prod);
582        }
583
584        // Compute denominator coefficients
585        for (
586            a_gamma,
587            b_gamma,
588            c_gamma,
589            d_gamma,
590            beta_s_sigma_1,
591            beta_s_sigma_2,
592            beta_s_sigma_3,
593            beta_s_sigma_4,
594        ) in izip!(
595            a_gamma,
596            b_gamma,
597            c_gamma,
598            d_gamma,
599            beta_s_sigma_1_iter,
600            beta_s_sigma_2_iter,
601            beta_s_sigma_3_iter,
602            beta_s_sigma_4_iter,
603        ) {
604            // (a + beta * s_sigma_1 + gamma)
605            let prod_a = beta_s_sigma_1 + a_gamma;
606
607            // (b + beta * s_sigma_2 + gamma)
608            let prod_b = beta_s_sigma_2 + b_gamma;
609
610            // (c + beta * s_sigma_3 + gamma)
611            let prod_c = beta_s_sigma_3 + c_gamma;
612
613            // (d + beta * s_sigma_4 + gamma)
614            let prod_d = beta_s_sigma_4 + d_gamma;
615
616            let mut prod = prod_a * prod_b * prod_c * prod_d;
617
618            denominator_partial_components.push(prod);
619
620            let last_element = denominator_coefficients.last().unwrap();
621
622            prod *= last_element;
623
624            denominator_coefficients.push(prod);
625        }
626
627        assert_eq!(denominator_coefficients.len(), n + 1);
628        assert_eq!(numerator_coefficients.len(), n + 1);
629
630        // Check that n+1'th elements are equal (taken from proof)
631        let a = numerator_coefficients.pop().unwrap();
632        assert_ne!(a, BlsScalar::zero());
633        let b = denominator_coefficients.pop().unwrap();
634        assert_ne!(b, BlsScalar::zero());
635        assert_eq!(a * b.invert().unwrap(), BlsScalar::one());
636
637        // Combine numerator and denominator
638
639        let mut z_coefficients: Vec<BlsScalar> = Vec::with_capacity(n);
640        for (numerator, denominator) in numerator_coefficients
641            .iter()
642            .zip(denominator_coefficients.iter())
643        {
644            z_coefficients.push(*numerator * denominator.invert().unwrap());
645        }
646        assert_eq!(z_coefficients.len(), n);
647
648        (
649            z_coefficients,
650            numerator_partial_components,
651            denominator_partial_components,
652        )
653    }
654
655    #[test]
656    fn test_permutation_format() {
657        let mut perm: Permutation = Permutation::new();
658
659        let num_witnesses = 10u8;
660        for i in 0..num_witnesses {
661            let var = perm.new_witness();
662            assert_eq!(var.index(), i as usize);
663            assert_eq!(perm.witness_map.len(), (i as usize) + 1);
664        }
665
666        let var_one = perm.new_witness();
667        let var_two = perm.new_witness();
668        let var_three = perm.new_witness();
669
670        let gate_size = 100;
671        for i in 0..gate_size {
672            perm.add_witnesses_to_map(var_one, var_one, var_two, var_three, i);
673        }
674
675        // Check all gate_indices are valid
676        for (_, wire_data) in perm.witness_map.iter() {
677            for wire in wire_data.iter() {
678                match wire {
679                    WireData::Left(index)
680                    | WireData::Right(index)
681                    | WireData::Output(index)
682                    | WireData::Fourth(index) => assert!(*index < gate_size),
683                };
684            }
685        }
686    }
687
688    #[test]
689    fn test_permutation_compute_sigmas_only_left_wires() {
690        let mut perm = Permutation::new();
691
692        let var_zero = perm.new_witness();
693        let var_two = perm.new_witness();
694        let var_three = perm.new_witness();
695        let var_four = perm.new_witness();
696        let var_five = perm.new_witness();
697        let var_six = perm.new_witness();
698        let var_seven = perm.new_witness();
699        let var_eight = perm.new_witness();
700        let var_nine = perm.new_witness();
701
702        let num_wire_mappings = 4;
703
704        // Add four wire mappings
705        perm.add_witnesses_to_map(var_zero, var_zero, var_five, var_nine, 0);
706        perm.add_witnesses_to_map(var_zero, var_two, var_six, var_nine, 1);
707        perm.add_witnesses_to_map(var_zero, var_three, var_seven, var_nine, 2);
708        perm.add_witnesses_to_map(var_zero, var_four, var_eight, var_nine, 3);
709
710        /*
711        var_zero = {L0, R0, L1, L2, L3}
712        var_two = {R1}
713        var_three = {R2}
714        var_four = {R3}
715        var_five = {O0}
716        var_six = {O1}
717        var_seven = {O2}
718        var_eight = {O3}
719        var_nine = {F0, F1, F2, F3}
720        s_sigma_1 = {R0, L2, L3, L0}
721        s_sigma_2 = {L1, R1, R2, R3}
722        s_sigma_3 = {O0, O1, O2, O3}
723        s_sigma_4 = {F1, F2, F3, F0}
724        */
725        let sigmas = perm.compute_sigma_permutations(num_wire_mappings);
726        let s_sigma_1 = &sigmas[0];
727        let s_sigma_2 = &sigmas[1];
728        let s_sigma_3 = &sigmas[2];
729        let s_sigma_4 = &sigmas[3];
730
731        // Check the left sigma polynomial
732        assert_eq!(s_sigma_1[0], WireData::Right(0));
733        assert_eq!(s_sigma_1[1], WireData::Left(2));
734        assert_eq!(s_sigma_1[2], WireData::Left(3));
735        assert_eq!(s_sigma_1[3], WireData::Left(0));
736
737        // Check the right sigma polynomial
738        assert_eq!(s_sigma_2[0], WireData::Left(1));
739        assert_eq!(s_sigma_2[1], WireData::Right(1));
740        assert_eq!(s_sigma_2[2], WireData::Right(2));
741        assert_eq!(s_sigma_2[3], WireData::Right(3));
742
743        // Check the output sigma polynomial
744        assert_eq!(s_sigma_3[0], WireData::Output(0));
745        assert_eq!(s_sigma_3[1], WireData::Output(1));
746        assert_eq!(s_sigma_3[2], WireData::Output(2));
747        assert_eq!(s_sigma_3[3], WireData::Output(3));
748
749        // Check the output sigma polynomial
750        assert_eq!(s_sigma_4[0], WireData::Fourth(1));
751        assert_eq!(s_sigma_4[1], WireData::Fourth(2));
752        assert_eq!(s_sigma_4[2], WireData::Fourth(3));
753        assert_eq!(s_sigma_4[3], WireData::Fourth(0));
754
755        let domain = EvaluationDomain::new(num_wire_mappings).unwrap();
756        let w = domain.group_gen;
757        let w_squared = w.pow(&[2, 0, 0, 0]);
758        let w_cubed = w.pow(&[3, 0, 0, 0]);
759
760        // Check the left sigmas have been encoded properly
761        // s_sigma_1 = {R0, L2, L3, L0}
762        // Should turn into {1 * K1, w^2, w^3, 1}
763        let encoded_s_sigma_1 =
764            perm.compute_permutation_lagrange(s_sigma_1, &domain);
765        assert_eq!(encoded_s_sigma_1[0], BlsScalar::one() * K1);
766        assert_eq!(encoded_s_sigma_1[1], w_squared);
767        assert_eq!(encoded_s_sigma_1[2], w_cubed);
768        assert_eq!(encoded_s_sigma_1[3], BlsScalar::one());
769
770        // Check the right sigmas have been encoded properly
771        // s_sigma_2 = {L1, R1, R2, R3}
772        // Should turn into {w, w * K1, w^2 * K1, w^3 * K1}
773        let encoded_s_sigma_2 =
774            perm.compute_permutation_lagrange(s_sigma_2, &domain);
775        assert_eq!(encoded_s_sigma_2[0], w);
776        assert_eq!(encoded_s_sigma_2[1], w * K1);
777        assert_eq!(encoded_s_sigma_2[2], w_squared * K1);
778        assert_eq!(encoded_s_sigma_2[3], w_cubed * K1);
779
780        // Check the output sigmas have been encoded properly
781        // s_sigma_3 = {O0, O1, O2, O3}
782        // Should turn into {1 * K2, w * K2, w^2 * K2, w^3 * K2}
783        let encoded_s_sigma_3 =
784            perm.compute_permutation_lagrange(s_sigma_3, &domain);
785        assert_eq!(encoded_s_sigma_3[0], BlsScalar::one() * K2);
786        assert_eq!(encoded_s_sigma_3[1], w * K2);
787        assert_eq!(encoded_s_sigma_3[2], w_squared * K2);
788        assert_eq!(encoded_s_sigma_3[3], w_cubed * K2);
789
790        // Check the fourth sigmas have been encoded properly
791        // s_sigma_4 = {F1, F2, F3, F0}
792        // Should turn into {w * K3, w^2 * K3, w^3 * K3, 1 * K3}
793        let encoded_s_sigma_4 =
794            perm.compute_permutation_lagrange(s_sigma_4, &domain);
795        assert_eq!(encoded_s_sigma_4[0], w * K3);
796        assert_eq!(encoded_s_sigma_4[1], w_squared * K3);
797        assert_eq!(encoded_s_sigma_4[2], w_cubed * K3);
798        assert_eq!(encoded_s_sigma_4[3], K3);
799
800        let a = vec![
801            BlsScalar::from(2),
802            BlsScalar::from(2),
803            BlsScalar::from(2),
804            BlsScalar::from(2),
805        ];
806        let b = vec![
807            BlsScalar::from(2),
808            BlsScalar::one(),
809            BlsScalar::one(),
810            BlsScalar::one(),
811        ];
812        let c = vec![
813            BlsScalar::one(),
814            BlsScalar::one(),
815            BlsScalar::one(),
816            BlsScalar::one(),
817        ];
818        let d = vec![
819            BlsScalar::one(),
820            BlsScalar::one(),
821            BlsScalar::one(),
822            BlsScalar::one(),
823        ];
824
825        test_correct_permutation_poly(
826            num_wire_mappings,
827            perm,
828            &domain,
829            a,
830            b,
831            c,
832            d,
833        );
834    }
835
836    #[test]
837    fn test_permutation_compute_sigmas() {
838        let mut perm: Permutation = Permutation::new();
839
840        let var_one = perm.new_witness();
841        let var_two = perm.new_witness();
842        let var_three = perm.new_witness();
843        let var_four = perm.new_witness();
844
845        let num_wire_mappings = 4;
846
847        // Add four wire mappings
848        perm.add_witnesses_to_map(var_one, var_one, var_two, var_four, 0);
849        perm.add_witnesses_to_map(var_two, var_one, var_two, var_four, 1);
850        perm.add_witnesses_to_map(var_three, var_three, var_one, var_four, 2);
851        perm.add_witnesses_to_map(var_two, var_one, var_three, var_four, 3);
852
853        /*
854        Below is a sketch of the map created by adding the specific witnesses into the map
855        var_one : {L0,R0, R1, O2, R3 }
856        var_two : {O0, L1, O1, L3}
857        var_three : {L2, R2, O3}
858        var_four : {F0, F1, F2, F3}
859        s_sigma_1 : {0,1,2,3} -> {R0,O1,R2,O0}
860        s_sigma_2 : {0,1,2,3} -> {R1, O2, O3, L0}
861        s_sigma_3 : {0,1,2,3} -> {L1, L3, R3, L2}
862        s_sigma_4 : {0,1,2,3} -> {F1, F2, F3, F0}
863        */
864        let sigmas = perm.compute_sigma_permutations(num_wire_mappings);
865        let s_sigma_1 = &sigmas[0];
866        let s_sigma_2 = &sigmas[1];
867        let s_sigma_3 = &sigmas[2];
868        let s_sigma_4 = &sigmas[3];
869
870        // Check the left sigma polynomial
871        assert_eq!(s_sigma_1[0], WireData::Right(0));
872        assert_eq!(s_sigma_1[1], WireData::Output(1));
873        assert_eq!(s_sigma_1[2], WireData::Right(2));
874        assert_eq!(s_sigma_1[3], WireData::Output(0));
875
876        // Check the right sigma polynomial
877        assert_eq!(s_sigma_2[0], WireData::Right(1));
878        assert_eq!(s_sigma_2[1], WireData::Output(2));
879        assert_eq!(s_sigma_2[2], WireData::Output(3));
880        assert_eq!(s_sigma_2[3], WireData::Left(0));
881
882        // Check the output sigma polynomial
883        assert_eq!(s_sigma_3[0], WireData::Left(1));
884        assert_eq!(s_sigma_3[1], WireData::Left(3));
885        assert_eq!(s_sigma_3[2], WireData::Right(3));
886        assert_eq!(s_sigma_3[3], WireData::Left(2));
887
888        // Check the fourth sigma polynomial
889        assert_eq!(s_sigma_4[0], WireData::Fourth(1));
890        assert_eq!(s_sigma_4[1], WireData::Fourth(2));
891        assert_eq!(s_sigma_4[2], WireData::Fourth(3));
892        assert_eq!(s_sigma_4[3], WireData::Fourth(0));
893
894        /*
895        Check that the unique encodings of the sigma polynomials have been computed properly
896        s_sigma_1 : {R0,O1,R2,O0}
897            When encoded using w, K1,K2,K3 we have {1 * K1, w * K2, w^2 * K1, 1 * K2}
898        s_sigma_2 : {R1, O2, O3, L0}
899            When encoded using w, K1,K2,K3 we have {w * K1, w^2 * K2, w^3 * K2, 1}
900        s_sigma_3 : {L1, L3, R3, L2}
901            When encoded using w, K1, K2,K3 we have {w, w^3 , w^3 * K1, w^2}
902        s_sigma_4 : {0,1,2,3} -> {F1, F2, F3, F0}
903            When encoded using w, K1, K2,K3 we have {w * K3, w^2 * K3, w^3 * K3, 1 * K3}
904        */
905        let domain = EvaluationDomain::new(num_wire_mappings).unwrap();
906        let w = domain.group_gen;
907        let w_squared = w.pow(&[2, 0, 0, 0]);
908        let w_cubed = w.pow(&[3, 0, 0, 0]);
909        // check the left sigmas have been encoded properly
910        let encoded_s_sigma_1 =
911            perm.compute_permutation_lagrange(s_sigma_1, &domain);
912        assert_eq!(encoded_s_sigma_1[0], K1);
913        assert_eq!(encoded_s_sigma_1[1], w * K2);
914        assert_eq!(encoded_s_sigma_1[2], w_squared * K1);
915        assert_eq!(encoded_s_sigma_1[3], BlsScalar::one() * K2);
916
917        // check the right sigmas have been encoded properly
918        let encoded_s_sigma_2 =
919            perm.compute_permutation_lagrange(s_sigma_2, &domain);
920        assert_eq!(encoded_s_sigma_2[0], w * K1);
921        assert_eq!(encoded_s_sigma_2[1], w_squared * K2);
922        assert_eq!(encoded_s_sigma_2[2], w_cubed * K2);
923        assert_eq!(encoded_s_sigma_2[3], BlsScalar::one());
924
925        // check the output sigmas have been encoded properly
926        let encoded_s_sigma_3 =
927            perm.compute_permutation_lagrange(s_sigma_3, &domain);
928        assert_eq!(encoded_s_sigma_3[0], w);
929        assert_eq!(encoded_s_sigma_3[1], w_cubed);
930        assert_eq!(encoded_s_sigma_3[2], w_cubed * K1);
931        assert_eq!(encoded_s_sigma_3[3], w_squared);
932
933        // check the fourth sigmas have been encoded properly
934        let encoded_s_sigma_4 =
935            perm.compute_permutation_lagrange(s_sigma_4, &domain);
936        assert_eq!(encoded_s_sigma_4[0], w * K3);
937        assert_eq!(encoded_s_sigma_4[1], w_squared * K3);
938        assert_eq!(encoded_s_sigma_4[2], w_cubed * K3);
939        assert_eq!(encoded_s_sigma_4[3], K3);
940    }
941
942    #[test]
943    fn test_basic_slow_permutation_poly() {
944        let num_wire_mappings = 2;
945        let mut perm = Permutation::new();
946        let domain = EvaluationDomain::new(num_wire_mappings).unwrap();
947
948        let var_one = perm.new_witness();
949        let var_two = perm.new_witness();
950        let var_three = perm.new_witness();
951        let var_four = perm.new_witness();
952
953        perm.add_witnesses_to_map(var_one, var_two, var_three, var_four, 0);
954        perm.add_witnesses_to_map(var_three, var_two, var_one, var_four, 1);
955
956        let a: Vec<_> = vec![BlsScalar::one(), BlsScalar::from(3)];
957        let b: Vec<_> = vec![BlsScalar::from(2), BlsScalar::from(2)];
958        let c: Vec<_> = vec![BlsScalar::from(3), BlsScalar::one()];
959        let d: Vec<_> = vec![BlsScalar::one(), BlsScalar::one()];
960
961        test_correct_permutation_poly(
962            num_wire_mappings,
963            perm,
964            &domain,
965            a,
966            b,
967            c,
968            d,
969        );
970    }
971
972    // shifts the polynomials by one root of unity
973    fn shift_poly_by_one(z_coefficients: Vec<BlsScalar>) -> Vec<BlsScalar> {
974        let mut shifted_z_coefficients = z_coefficients;
975        shifted_z_coefficients.push(shifted_z_coefficients[0]);
976        shifted_z_coefficients.remove(0);
977        shifted_z_coefficients
978    }
979
980    fn test_correct_permutation_poly(
981        n: usize,
982        mut perm: Permutation,
983        domain: &EvaluationDomain,
984        a: Vec<BlsScalar>,
985        b: Vec<BlsScalar>,
986        c: Vec<BlsScalar>,
987        d: Vec<BlsScalar>,
988    ) {
989        // 0. Generate beta and gamma challenges
990        //
991        let beta = BlsScalar::random(&mut OsRng);
992        let gamma = BlsScalar::random(&mut OsRng);
993        assert_ne!(gamma, beta);
994
995        // 1. Compute the permutation polynomial using both methods
996        let [
997            s_sigma_1_poly,
998            s_sigma_2_poly,
999            s_sigma_3_poly,
1000            s_sigma_4_poly,
1001        ] = perm.compute_sigma_polynomials(n, domain);
1002        let (z_vec, numerator_components, denominator_components) =
1003            compute_slow_permutation_poly(
1004                domain,
1005                a.clone().into_iter(),
1006                b.clone().into_iter(),
1007                c.clone().into_iter(),
1008                d.clone().into_iter(),
1009                &beta,
1010                &gamma,
1011                (
1012                    &s_sigma_1_poly,
1013                    &s_sigma_2_poly,
1014                    &s_sigma_3_poly,
1015                    &s_sigma_4_poly,
1016                ),
1017            );
1018
1019        let fast_z_vec = compute_fast_permutation_poly(
1020            domain,
1021            &a,
1022            &b,
1023            &c,
1024            &d,
1025            &beta,
1026            &gamma,
1027            (
1028                &s_sigma_1_poly,
1029                &s_sigma_2_poly,
1030                &s_sigma_3_poly,
1031                &s_sigma_4_poly,
1032            ),
1033        );
1034        assert_eq!(fast_z_vec, z_vec);
1035
1036        // 2. First we perform basic tests on the permutation vector
1037        //
1038        // Check that the vector has length `n` and that the first element is
1039        // `1`
1040        assert_eq!(z_vec.len(), n);
1041        assert_eq!(&z_vec[0], &BlsScalar::one());
1042        //
1043        // Check that the \prod{f_i} / \prod{g_i} = 1
1044        // Where f_i and g_i are the numerator and denominator components in the
1045        // permutation polynomial
1046        let (mut a_0, mut b_0) = (BlsScalar::one(), BlsScalar::one());
1047        for n in numerator_components.iter() {
1048            a_0 *= n;
1049        }
1050        for n in denominator_components.iter() {
1051            b_0 *= n;
1052        }
1053        assert_eq!(a_0 * b_0.invert().unwrap(), BlsScalar::one());
1054
1055        // 3. Now we perform the two checks that need to be done on the
1056        // permutation polynomial (z)
1057        let z_poly = Polynomial::from_coefficients_vec(domain.ifft(&z_vec));
1058        //
1059        // Check that z(w^{n+1}) == z(1) == 1
1060        // This is the first check in the protocol
1061        assert_eq!(z_poly.evaluate(&BlsScalar::one()), BlsScalar::one());
1062        let n_plus_one = domain.elements().last().unwrap() * domain.group_gen;
1063        assert_eq!(z_poly.evaluate(&n_plus_one), BlsScalar::one());
1064        //
1065        // Check that when z is unblinded, it has the correct degree
1066        assert_eq!(z_poly.degree(), n - 1);
1067        //
1068        // Check relationship between z(X) and z(Xw)
1069        // This is the second check in the protocol
1070        let roots: Vec<_> = domain.elements().collect();
1071
1072        for i in 1..roots.len() {
1073            let current_root = roots[i];
1074            let next_root = current_root * domain.group_gen;
1075
1076            let current_identity_perm_product = &numerator_components[i];
1077            assert_ne!(current_identity_perm_product, &BlsScalar::zero());
1078
1079            let current_copy_perm_product = &denominator_components[i];
1080            assert_ne!(current_copy_perm_product, &BlsScalar::zero());
1081
1082            assert_ne!(
1083                current_copy_perm_product,
1084                current_identity_perm_product
1085            );
1086
1087            let z_eval = z_poly.evaluate(&current_root);
1088            assert_ne!(z_eval, BlsScalar::zero());
1089
1090            let z_eval_shifted = z_poly.evaluate(&next_root);
1091            assert_ne!(z_eval_shifted, BlsScalar::zero());
1092
1093            // Z(Xw) * copy_perm
1094            let lhs = z_eval_shifted * current_copy_perm_product;
1095            // Z(X) * iden_perm
1096            let rhs = z_eval * current_identity_perm_product;
1097            assert_eq!(
1098                lhs, rhs,
1099                "check failed at index: {}\'n lhs is : {:?} \n rhs is :{:?}",
1100                i, lhs, rhs
1101            );
1102        }
1103
1104        // Test that the shifted polynomial is correct
1105        let shifted_z = shift_poly_by_one(fast_z_vec);
1106        let shifted_z_poly =
1107            Polynomial::from_coefficients_vec(domain.ifft(&shifted_z));
1108        for element in domain.elements() {
1109            let z_eval = z_poly.evaluate(&(element * domain.group_gen));
1110            let shifted_z_eval = shifted_z_poly.evaluate(&element);
1111
1112            assert_eq!(z_eval, shifted_z_eval)
1113        }
1114    }
1115}