../../.cargo/katex-header.html

plonky2/gates/
arithmetic_extension.rs

1#[cfg(not(feature = "std"))]
2use alloc::{
3    format,
4    string::{String, ToString},
5    vec::Vec,
6};
7use core::ops::Range;
8
9use anyhow::Result;
10
11use crate::field::extension::{Extendable, FieldExtension};
12use crate::gates::gate::Gate;
13use crate::gates::util::StridedConstraintConsumer;
14use crate::hash::hash_types::RichField;
15use crate::iop::ext_target::ExtensionTarget;
16use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef};
17use crate::iop::target::Target;
18use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite};
19use crate::plonk::circuit_builder::CircuitBuilder;
20use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData};
21use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
22use crate::util::serialization::{Buffer, IoResult, Read, Write};
23
24/// A gate which can perform a weighted multiply-add, i.e. `result = c0.x.y + c1.z`. If the config
25/// has enough routed wires, it can support several such operations in one gate.
26#[derive(Debug, Clone)]
27pub struct ArithmeticExtensionGate<const D: usize> {
28    /// Number of arithmetic operations performed by an arithmetic gate.
29    pub num_ops: usize,
30}
31
32impl<const D: usize> ArithmeticExtensionGate<D> {
33    pub const fn new_from_config(config: &CircuitConfig) -> Self {
34        Self {
35            num_ops: Self::num_ops(config),
36        }
37    }
38
39    /// Determine the maximum number of operations that can fit in one gate for the given config.
40    pub(crate) const fn num_ops(config: &CircuitConfig) -> usize {
41        let wires_per_op = 4 * D;
42        config.num_routed_wires / wires_per_op
43    }
44
45    pub(crate) const fn wires_ith_multiplicand_0(i: usize) -> Range<usize> {
46        4 * D * i..4 * D * i + D
47    }
48    pub(crate) const fn wires_ith_multiplicand_1(i: usize) -> Range<usize> {
49        4 * D * i + D..4 * D * i + 2 * D
50    }
51    pub(crate) const fn wires_ith_addend(i: usize) -> Range<usize> {
52        4 * D * i + 2 * D..4 * D * i + 3 * D
53    }
54    pub(crate) const fn wires_ith_output(i: usize) -> Range<usize> {
55        4 * D * i + 3 * D..4 * D * i + 4 * D
56    }
57}
58
59impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D> {
60    fn id(&self) -> String {
61        format!("{self:?}")
62    }
63
64    fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
65        dst.write_usize(self.num_ops)
66    }
67
68    fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
69        let num_ops = src.read_usize()?;
70        Ok(Self { num_ops })
71    }
72
73    fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
74        let const_0 = vars.local_constants[0];
75        let const_1 = vars.local_constants[1];
76
77        let mut constraints = Vec::with_capacity(self.num_ops * D);
78        for i in 0..self.num_ops {
79            let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
80            let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
81            let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i));
82            let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
83            let computed_output =
84                (multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1);
85
86            constraints.extend((output - computed_output).to_basefield_array());
87        }
88
89        constraints
90    }
91
92    fn eval_unfiltered_base_one(
93        &self,
94        vars: EvaluationVarsBase<F>,
95        mut yield_constr: StridedConstraintConsumer<F>,
96    ) {
97        let const_0 = vars.local_constants[0];
98        let const_1 = vars.local_constants[1];
99
100        for i in 0..self.num_ops {
101            let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
102            let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
103            let addend = vars.get_local_ext(Self::wires_ith_addend(i));
104            let output = vars.get_local_ext(Self::wires_ith_output(i));
105            let computed_output =
106                (multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1);
107
108            yield_constr.many((output - computed_output).to_basefield_array());
109        }
110    }
111
112    fn eval_unfiltered_circuit(
113        &self,
114        builder: &mut CircuitBuilder<F, D>,
115        vars: EvaluationTargets<D>,
116    ) -> Vec<ExtensionTarget<D>> {
117        let const_0 = vars.local_constants[0];
118        let const_1 = vars.local_constants[1];
119
120        let mut constraints = Vec::with_capacity(self.num_ops * D);
121        for i in 0..self.num_ops {
122            let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
123            let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
124            let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i));
125            let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
126            let computed_output = {
127                let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
128                let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul);
129                builder.scalar_mul_add_ext_algebra(const_1, addend, scaled_mul)
130            };
131
132            let diff = builder.sub_ext_algebra(output, computed_output);
133            constraints.extend(diff.to_ext_target_array());
134        }
135
136        constraints
137    }
138
139    fn generators(&self, row: usize, local_constants: &[F]) -> Vec<WitnessGeneratorRef<F, D>> {
140        (0..self.num_ops)
141            .map(|i| {
142                WitnessGeneratorRef::new(
143                    ArithmeticExtensionGenerator {
144                        row,
145                        const_0: local_constants[0],
146                        const_1: local_constants[1],
147                        i,
148                    }
149                    .adapter(),
150                )
151            })
152            .collect()
153    }
154
155    fn num_wires(&self) -> usize {
156        self.num_ops * 4 * D
157    }
158
159    fn num_constants(&self) -> usize {
160        2
161    }
162
163    fn degree(&self) -> usize {
164        3
165    }
166
167    fn num_constraints(&self) -> usize {
168        self.num_ops * D
169    }
170}
171
172#[derive(Clone, Debug, Default)]
173pub struct ArithmeticExtensionGenerator<F: RichField + Extendable<D>, const D: usize> {
174    row: usize,
175    const_0: F,
176    const_1: F,
177    i: usize,
178}
179
180impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D>
181    for ArithmeticExtensionGenerator<F, D>
182{
183    fn id(&self) -> String {
184        "ArithmeticExtensionGenerator".to_string()
185    }
186
187    fn dependencies(&self) -> Vec<Target> {
188        ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(self.i)
189            .chain(ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(
190                self.i,
191            ))
192            .chain(ArithmeticExtensionGate::<D>::wires_ith_addend(self.i))
193            .map(|i| Target::wire(self.row, i))
194            .collect()
195    }
196
197    fn run_once(
198        &self,
199        witness: &PartitionWitness<F>,
200        out_buffer: &mut GeneratedValues<F>,
201    ) -> Result<()> {
202        let extract_extension = |range: Range<usize>| -> F::Extension {
203            let t = ExtensionTarget::from_range(self.row, range);
204            witness.get_extension_target(t)
205        };
206
207        let multiplicand_0 = extract_extension(
208            ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(self.i),
209        );
210        let multiplicand_1 = extract_extension(
211            ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(self.i),
212        );
213        let addend = extract_extension(ArithmeticExtensionGate::<D>::wires_ith_addend(self.i));
214
215        let output_target = ExtensionTarget::from_range(
216            self.row,
217            ArithmeticExtensionGate::<D>::wires_ith_output(self.i),
218        );
219
220        let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0)
221            + addend.scalar_mul(self.const_1);
222
223        out_buffer.set_extension_target(output_target, computed_output)
224    }
225
226    fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
227        dst.write_usize(self.row)?;
228        dst.write_field(self.const_0)?;
229        dst.write_field(self.const_1)?;
230        dst.write_usize(self.i)
231    }
232
233    fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
234        let row = src.read_usize()?;
235        let const_0 = src.read_field()?;
236        let const_1 = src.read_field()?;
237        let i = src.read_usize()?;
238        Ok(Self {
239            row,
240            const_0,
241            const_1,
242            i,
243        })
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use anyhow::Result;
250
251    use crate::field::goldilocks_field::GoldilocksField;
252    use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
253    use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
254    use crate::plonk::circuit_data::CircuitConfig;
255    use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
256
257    #[test]
258    fn low_degree() {
259        let gate =
260            ArithmeticExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
261        test_low_degree::<GoldilocksField, _, 4>(gate);
262    }
263
264    #[test]
265    fn eval_fns() -> Result<()> {
266        const D: usize = 2;
267        type C = PoseidonGoldilocksConfig;
268        type F = <C as GenericConfig<D>>::F;
269        let gate =
270            ArithmeticExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
271        test_eval_fns::<F, C, _, D>(gate)
272    }
273}