plonky2/gates/
arithmetic_extension.rs1#[cfg(not(feature = "std"))]
2use alloc::{
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use core::ops::Range;
8
9use anyhow::Result;
10
11use crate::field::extension::{Extendable, FieldExtension};
12use crate::gates::gate::Gate;
13use crate::gates::util::StridedConstraintConsumer;
14use crate::hash::hash_types::RichField;
15use crate::iop::ext_target::ExtensionTarget;
16use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef};
17use crate::iop::target::Target;
18use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite};
19use crate::plonk::circuit_builder::CircuitBuilder;
20use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData};
21use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
22use crate::util::serialization::{Buffer, IoResult, Read, Write};
23
24#[derive(Debug, Clone)]
27pub struct ArithmeticExtensionGate<const D: usize> {
28 pub num_ops: usize,
30}
31
32impl<const D: usize> ArithmeticExtensionGate<D> {
33 pub const fn new_from_config(config: &CircuitConfig) -> Self {
34 Self {
35 num_ops: Self::num_ops(config),
36 }
37 }
38
39 pub(crate) const fn num_ops(config: &CircuitConfig) -> usize {
41 let wires_per_op = 4 * D;
42 config.num_routed_wires / wires_per_op
43 }
44
45 pub(crate) const fn wires_ith_multiplicand_0(i: usize) -> Range<usize> {
46 4 * D * i..4 * D * i + D
47 }
48 pub(crate) const fn wires_ith_multiplicand_1(i: usize) -> Range<usize> {
49 4 * D * i + D..4 * D * i + 2 * D
50 }
51 pub(crate) const fn wires_ith_addend(i: usize) -> Range<usize> {
52 4 * D * i + 2 * D..4 * D * i + 3 * D
53 }
54 pub(crate) const fn wires_ith_output(i: usize) -> Range<usize> {
55 4 * D * i + 3 * D..4 * D * i + 4 * D
56 }
57}
58
59impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D> {
60 fn id(&self) -> String {
61 format!("{self:?}")
62 }
63
64 fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
65 dst.write_usize(self.num_ops)
66 }
67
68 fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
69 let num_ops = src.read_usize()?;
70 Ok(Self { num_ops })
71 }
72
73 fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
74 let const_0 = vars.local_constants[0];
75 let const_1 = vars.local_constants[1];
76
77 let mut constraints = Vec::with_capacity(self.num_ops * D);
78 for i in 0..self.num_ops {
79 let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
80 let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
81 let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i));
82 let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
83 let computed_output =
84 (multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1);
85
86 constraints.extend((output - computed_output).to_basefield_array());
87 }
88
89 constraints
90 }
91
92 fn eval_unfiltered_base_one(
93 &self,
94 vars: EvaluationVarsBase<F>,
95 mut yield_constr: StridedConstraintConsumer<F>,
96 ) {
97 let const_0 = vars.local_constants[0];
98 let const_1 = vars.local_constants[1];
99
100 for i in 0..self.num_ops {
101 let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
102 let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
103 let addend = vars.get_local_ext(Self::wires_ith_addend(i));
104 let output = vars.get_local_ext(Self::wires_ith_output(i));
105 let computed_output =
106 (multiplicand_0 * multiplicand_1).scalar_mul(const_0) + addend.scalar_mul(const_1);
107
108 yield_constr.many((output - computed_output).to_basefield_array());
109 }
110 }
111
112 fn eval_unfiltered_circuit(
113 &self,
114 builder: &mut CircuitBuilder<F, D>,
115 vars: EvaluationTargets<D>,
116 ) -> Vec<ExtensionTarget<D>> {
117 let const_0 = vars.local_constants[0];
118 let const_1 = vars.local_constants[1];
119
120 let mut constraints = Vec::with_capacity(self.num_ops * D);
121 for i in 0..self.num_ops {
122 let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
123 let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
124 let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i));
125 let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
126 let computed_output = {
127 let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
128 let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul);
129 builder.scalar_mul_add_ext_algebra(const_1, addend, scaled_mul)
130 };
131
132 let diff = builder.sub_ext_algebra(output, computed_output);
133 constraints.extend(diff.to_ext_target_array());
134 }
135
136 constraints
137 }
138
139 fn generators(&self, row: usize, local_constants: &[F]) -> Vec<WitnessGeneratorRef<F, D>> {
140 (0..self.num_ops)
141 .map(|i| {
142 WitnessGeneratorRef::new(
143 ArithmeticExtensionGenerator {
144 row,
145 const_0: local_constants[0],
146 const_1: local_constants[1],
147 i,
148 }
149 .adapter(),
150 )
151 })
152 .collect()
153 }
154
155 fn num_wires(&self) -> usize {
156 self.num_ops * 4 * D
157 }
158
159 fn num_constants(&self) -> usize {
160 2
161 }
162
163 fn degree(&self) -> usize {
164 3
165 }
166
167 fn num_constraints(&self) -> usize {
168 self.num_ops * D
169 }
170}
171
172#[derive(Clone, Debug, Default)]
173pub struct ArithmeticExtensionGenerator<F: RichField + Extendable<D>, const D: usize> {
174 row: usize,
175 const_0: F,
176 const_1: F,
177 i: usize,
178}
179
180impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D>
181 for ArithmeticExtensionGenerator<F, D>
182{
183 fn id(&self) -> String {
184 "ArithmeticExtensionGenerator".to_string()
185 }
186
187 fn dependencies(&self) -> Vec<Target> {
188 ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(self.i)
189 .chain(ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(
190 self.i,
191 ))
192 .chain(ArithmeticExtensionGate::<D>::wires_ith_addend(self.i))
193 .map(|i| Target::wire(self.row, i))
194 .collect()
195 }
196
197 fn run_once(
198 &self,
199 witness: &PartitionWitness<F>,
200 out_buffer: &mut GeneratedValues<F>,
201 ) -> Result<()> {
202 let extract_extension = |range: Range<usize>| -> F::Extension {
203 let t = ExtensionTarget::from_range(self.row, range);
204 witness.get_extension_target(t)
205 };
206
207 let multiplicand_0 = extract_extension(
208 ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(self.i),
209 );
210 let multiplicand_1 = extract_extension(
211 ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(self.i),
212 );
213 let addend = extract_extension(ArithmeticExtensionGate::<D>::wires_ith_addend(self.i));
214
215 let output_target = ExtensionTarget::from_range(
216 self.row,
217 ArithmeticExtensionGate::<D>::wires_ith_output(self.i),
218 );
219
220 let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0)
221 + addend.scalar_mul(self.const_1);
222
223 out_buffer.set_extension_target(output_target, computed_output)
224 }
225
226 fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
227 dst.write_usize(self.row)?;
228 dst.write_field(self.const_0)?;
229 dst.write_field(self.const_1)?;
230 dst.write_usize(self.i)
231 }
232
233 fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
234 let row = src.read_usize()?;
235 let const_0 = src.read_field()?;
236 let const_1 = src.read_field()?;
237 let i = src.read_usize()?;
238 Ok(Self {
239 row,
240 const_0,
241 const_1,
242 i,
243 })
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use anyhow::Result;
250
251 use crate::field::goldilocks_field::GoldilocksField;
252 use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
253 use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
254 use crate::plonk::circuit_data::CircuitConfig;
255 use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
256
257 #[test]
258 fn low_degree() {
259 let gate =
260 ArithmeticExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
261 test_low_degree::<GoldilocksField, _, 4>(gate);
262 }
263
264 #[test]
265 fn eval_fns() -> Result<()> {
266 const D: usize = 2;
267 type C = PoseidonGoldilocksConfig;
268 type F = <C as GenericConfig<D>>::F;
269 let gate =
270 ArithmeticExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
271 test_eval_fns::<F, C, _, D>(gate)
272 }
273}