acir/circuit/
mod.rs

1pub mod black_box_functions;
2pub mod brillig;
3pub mod directives;
4pub mod opcodes;
5
6use crate::native_types::{Expression, Witness};
7use acir_field::FieldElement;
8pub use opcodes::Opcode;
9use thiserror::Error;
10
11use std::{io::prelude::*, num::ParseIntError, str::FromStr};
12
13use base64::Engine;
14use flate2::Compression;
15use serde::{de::Error as DeserializationError, Deserialize, Deserializer, Serialize, Serializer};
16
17use std::collections::BTreeSet;
18
19use self::{brillig::BrilligBytecode, opcodes::BlockId};
20
21/// Specifies the maximum width of the expressions which will be constrained.
22///
23/// Unbounded Expressions are useful if you are eventually going to pass the ACIR
24/// into a proving system which supports R1CS.
25///
26/// Bounded Expressions are useful if you are eventually going to pass the ACIR
27/// into a proving system which supports PLONK, where arithmetic expressions have a
28/// finite fan-in.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
30pub enum ExpressionWidth {
31    #[default]
32    Unbounded,
33    Bounded {
34        width: usize,
35    },
36}
37
38/// A program represented by multiple ACIR circuits. The execution trace of these
39/// circuits is dictated by construction of the [crate::native_types::WitnessStack].
40#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
41pub struct Program {
42    pub functions: Vec<Circuit>,
43    pub unconstrained_functions: Vec<BrilligBytecode>,
44}
45
46#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
47pub struct Circuit {
48    // current_witness_index is the highest witness index in the circuit. The next witness to be added to this circuit
49    // will take on this value. (The value is cached here as an optimization.)
50    pub current_witness_index: u32,
51    pub opcodes: Vec<Opcode>,
52    pub expression_width: ExpressionWidth,
53
54    /// The set of private inputs to the circuit.
55    pub private_parameters: BTreeSet<Witness>,
56    // ACIR distinguishes between the public inputs which are provided externally or calculated within the circuit and returned.
57    // The elements of these sets may not be mutually exclusive, i.e. a parameter may be returned from the circuit.
58    // All public inputs (parameters and return values) must be provided to the verifier at verification time.
59    /// The set of public inputs provided by the prover.
60    pub public_parameters: PublicInputs,
61    /// The set of public inputs calculated within the circuit.
62    pub return_values: PublicInputs,
63    /// Maps opcode locations to failed assertion payloads.
64    /// The data in the payload is embedded in the circuit to provide useful feedback to users
65    /// when a constraint in the circuit is not satisfied.
66    ///
67    // Note: This should be a BTreeMap, but serde-reflect is creating invalid
68    // c++ code at the moment when it is, due to OpcodeLocation needing a comparison
69    // implementation which is never generated.
70    pub assert_messages: Vec<(OpcodeLocation, AssertionPayload)>,
71
72    /// States whether the backend should use a SNARK recursion friendly prover.
73    /// If implemented by a backend, this means that proofs generated with this circuit
74    /// will be friendly for recursively verifying inside of another SNARK.
75    pub recursive: bool,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
79pub enum ExpressionOrMemory {
80    Expression(Expression),
81    Memory(BlockId),
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
85pub enum AssertionPayload {
86    StaticString(String),
87    Dynamic(/* error_selector */ u64, Vec<ExpressionOrMemory>),
88}
89
90#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
91pub struct ErrorSelector(u64);
92
93impl ErrorSelector {
94    pub fn new(integer: u64) -> Self {
95        ErrorSelector(integer)
96    }
97
98    pub fn as_u64(&self) -> u64 {
99        self.0
100    }
101}
102
103impl Serialize for ErrorSelector {
104    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105    where
106        S: serde::Serializer,
107    {
108        self.0.to_string().serialize(serializer)
109    }
110}
111
112impl<'de> Deserialize<'de> for ErrorSelector {
113    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
114    where
115        D: serde::Deserializer<'de>,
116    {
117        let s: String = Deserialize::deserialize(deserializer)?;
118        let as_u64 = s.parse().map_err(serde::de::Error::custom)?;
119        Ok(ErrorSelector(as_u64))
120    }
121}
122
123/// This selector indicates that the payload is a string.
124/// This is used to parse any error with a string payload directly,
125/// to avoid users having to parse the error externally to the ACVM.
126/// Only non-string errors need to be parsed externally to the ACVM using the circuit ABI.
127pub const STRING_ERROR_SELECTOR: ErrorSelector = ErrorSelector(0);
128
129#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
130pub struct RawAssertionPayload {
131    pub selector: ErrorSelector,
132    pub data: Vec<FieldElement>,
133}
134
135#[derive(Clone, PartialEq, Eq, Debug)]
136pub enum ResolvedAssertionPayload {
137    String(String),
138    Raw(RawAssertionPayload),
139}
140
141#[derive(Debug, Copy, Clone)]
142/// The opcode location for a call to a separate ACIR circuit
143/// This includes the function index of the caller within a [program][Program]
144/// and the index in the callers ACIR to the specific call opcode.
145/// This is only resolved and set during circuit execution.
146pub struct ResolvedOpcodeLocation {
147    pub acir_function_index: usize,
148    pub opcode_location: OpcodeLocation,
149}
150
151#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
152/// Opcodes are locatable so that callers can
153/// map opcodes to debug information related to their context.
154pub enum OpcodeLocation {
155    Acir(usize),
156    Brillig { acir_index: usize, brillig_index: usize },
157}
158
159impl std::fmt::Display for OpcodeLocation {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            OpcodeLocation::Acir(index) => write!(f, "{index}"),
163            OpcodeLocation::Brillig { acir_index, brillig_index } => {
164                write!(f, "{acir_index}.{brillig_index}")
165            }
166        }
167    }
168}
169
170#[derive(Error, Debug)]
171pub enum OpcodeLocationFromStrError {
172    #[error("Invalid opcode location string: {0}")]
173    InvalidOpcodeLocationString(String),
174}
175
176/// The implementation of display and FromStr allows serializing and deserializing a OpcodeLocation to a string.
177/// This is useful when used as key in a map that has to be serialized to JSON/TOML, for example when mapping an opcode to its metadata.
178impl FromStr for OpcodeLocation {
179    type Err = OpcodeLocationFromStrError;
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        let parts: Vec<_> = s.split('.').collect();
182
183        if parts.is_empty() || parts.len() > 2 {
184            return Err(OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()));
185        }
186
187        fn parse_components(parts: Vec<&str>) -> Result<OpcodeLocation, ParseIntError> {
188            match parts.len() {
189                1 => {
190                    let index = parts[0].parse()?;
191                    Ok(OpcodeLocation::Acir(index))
192                }
193                2 => {
194                    let acir_index = parts[0].parse()?;
195                    let brillig_index = parts[1].parse()?;
196                    Ok(OpcodeLocation::Brillig { acir_index, brillig_index })
197                }
198                _ => unreachable!("`OpcodeLocation` has too many components"),
199            }
200        }
201
202        parse_components(parts)
203            .map_err(|_| OpcodeLocationFromStrError::InvalidOpcodeLocationString(s.to_string()))
204    }
205}
206
207impl Circuit {
208    pub fn num_vars(&self) -> u32 {
209        self.current_witness_index + 1
210    }
211
212    /// Returns all witnesses which are required to execute the circuit successfully.
213    pub fn circuit_arguments(&self) -> BTreeSet<Witness> {
214        self.private_parameters.union(&self.public_parameters.0).cloned().collect()
215    }
216
217    /// Returns all public inputs. This includes those provided as parameters to the circuit and those
218    /// computed as return values.
219    pub fn public_inputs(&self) -> PublicInputs {
220        let public_inputs =
221            self.public_parameters.0.union(&self.return_values.0).cloned().collect();
222        PublicInputs(public_inputs)
223    }
224}
225
226impl Program {
227    fn write<W: std::io::Write>(&self, writer: W) -> std::io::Result<()> {
228        let buf = bincode::serialize(self).unwrap();
229        let mut encoder = flate2::write::GzEncoder::new(writer, Compression::default());
230        encoder.write_all(&buf)?;
231        encoder.finish()?;
232        Ok(())
233    }
234
235    fn read<R: std::io::Read>(reader: R) -> std::io::Result<Self> {
236        let mut gz_decoder = flate2::read::GzDecoder::new(reader);
237        let mut buf_d = Vec::new();
238        gz_decoder.read_to_end(&mut buf_d)?;
239        bincode::deserialize(&buf_d)
240            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
241    }
242
243    pub fn serialize_program(program: &Program) -> Vec<u8> {
244        let mut program_bytes: Vec<u8> = Vec::new();
245        program.write(&mut program_bytes).expect("expected circuit to be serializable");
246        program_bytes
247    }
248
249    pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
250        Program::read(serialized_circuit)
251    }
252
253    // Serialize and base64 encode program
254    pub fn serialize_program_base64<S>(program: &Program, s: S) -> Result<S::Ok, S::Error>
255    where
256        S: Serializer,
257    {
258        let program_bytes = Program::serialize_program(program);
259        let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes);
260        s.serialize_str(&encoded_b64)
261    }
262
263    // Deserialize and base64 decode program
264    pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Program, D::Error>
265    where
266        D: Deserializer<'de>,
267    {
268        let bytecode_b64: String = serde::Deserialize::deserialize(deserializer)?;
269        let program_bytes = base64::engine::general_purpose::STANDARD
270            .decode(bytecode_b64)
271            .map_err(D::Error::custom)?;
272        let circuit = Self::deserialize_program(&program_bytes).map_err(D::Error::custom)?;
273        Ok(circuit)
274    }
275}
276
277impl std::fmt::Display for Circuit {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        writeln!(f, "current witness index : {}", self.current_witness_index)?;
280
281        let write_witness_indices =
282            |f: &mut std::fmt::Formatter<'_>, indices: &[u32]| -> Result<(), std::fmt::Error> {
283                write!(f, "[")?;
284                for (index, witness_index) in indices.iter().enumerate() {
285                    write!(f, "{witness_index}")?;
286                    if index != indices.len() - 1 {
287                        write!(f, ", ")?;
288                    }
289                }
290                writeln!(f, "]")
291            };
292
293        write!(f, "private parameters indices : ")?;
294        write_witness_indices(
295            f,
296            &self
297                .private_parameters
298                .iter()
299                .map(|witness| witness.witness_index())
300                .collect::<Vec<_>>(),
301        )?;
302
303        write!(f, "public parameters indices : ")?;
304        write_witness_indices(f, &self.public_parameters.indices())?;
305
306        write!(f, "return value indices : ")?;
307        write_witness_indices(f, &self.return_values.indices())?;
308
309        for opcode in &self.opcodes {
310            writeln!(f, "{opcode}")?;
311        }
312        Ok(())
313    }
314}
315
316impl std::fmt::Debug for Circuit {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        std::fmt::Display::fmt(self, f)
319    }
320}
321
322impl std::fmt::Display for Program {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        for (func_index, function) in self.functions.iter().enumerate() {
325            writeln!(f, "func {}", func_index)?;
326            writeln!(f, "{}", function)?;
327        }
328        for (func_index, function) in self.unconstrained_functions.iter().enumerate() {
329            writeln!(f, "unconstrained func {}", func_index)?;
330            writeln!(f, "{:?}", function.bytecode)?;
331        }
332        Ok(())
333    }
334}
335
336impl std::fmt::Debug for Program {
337    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338        std::fmt::Display::fmt(self, f)
339    }
340}
341
342#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
343pub struct PublicInputs(pub BTreeSet<Witness>);
344
345impl PublicInputs {
346    /// Returns the witness index of each public input
347    pub fn indices(&self) -> Vec<u32> {
348        self.0.iter().map(|witness| witness.witness_index()).collect()
349    }
350
351    pub fn contains(&self, index: usize) -> bool {
352        self.0.contains(&Witness(index as u32))
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use std::collections::BTreeSet;
359
360    use super::{
361        opcodes::{BlackBoxFuncCall, FunctionInput},
362        Circuit, Compression, Opcode, PublicInputs,
363    };
364    use crate::{
365        circuit::{ExpressionWidth, Program},
366        native_types::Witness,
367    };
368    use acir_field::FieldElement;
369
370    fn and_opcode() -> Opcode {
371        Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND {
372            lhs: FunctionInput { witness: Witness(1), num_bits: 4 },
373            rhs: FunctionInput { witness: Witness(2), num_bits: 4 },
374            output: Witness(3),
375        })
376    }
377    fn range_opcode() -> Opcode {
378        Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
379            input: FunctionInput { witness: Witness(1), num_bits: 8 },
380        })
381    }
382    fn keccakf1600_opcode() -> Opcode {
383        let inputs: Box<[FunctionInput; 25]> = Box::new(std::array::from_fn(|i| FunctionInput {
384            witness: Witness(i as u32 + 1),
385            num_bits: 8,
386        }));
387        let outputs: Box<[Witness; 25]> = Box::new(std::array::from_fn(|i| Witness(i as u32 + 26)));
388
389        Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs })
390    }
391    fn schnorr_verify_opcode() -> Opcode {
392        let public_key_x =
393            FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() };
394        let public_key_y =
395            FunctionInput { witness: Witness(2), num_bits: FieldElement::max_num_bits() };
396        let signature: Box<[FunctionInput; 64]> = Box::new(std::array::from_fn(|i| {
397            FunctionInput { witness: Witness(i as u32 + 3), num_bits: 8 }
398        }));
399        let message: Vec<FunctionInput> = vec![FunctionInput { witness: Witness(67), num_bits: 8 }];
400        let output = Witness(68);
401
402        Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SchnorrVerify {
403            public_key_x,
404            public_key_y,
405            signature,
406            message,
407            output,
408        })
409    }
410
411    #[test]
412    fn serialization_roundtrip() {
413        let circuit = Circuit {
414            current_witness_index: 5,
415            expression_width: ExpressionWidth::Unbounded,
416            opcodes: vec![and_opcode(), range_opcode(), schnorr_verify_opcode()],
417            private_parameters: BTreeSet::new(),
418            public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])),
419            return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])),
420            assert_messages: Default::default(),
421            recursive: false,
422        };
423        let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
424
425        fn read_write(program: Program) -> (Program, Program) {
426            let bytes = Program::serialize_program(&program);
427            let got_program = Program::deserialize_program(&bytes).unwrap();
428            (program, got_program)
429        }
430
431        let (circ, got_circ) = read_write(program);
432        assert_eq!(circ, got_circ);
433    }
434
435    #[test]
436    fn test_serialize() {
437        let circuit = Circuit {
438            current_witness_index: 0,
439            expression_width: ExpressionWidth::Unbounded,
440            opcodes: vec![
441                Opcode::AssertZero(crate::native_types::Expression {
442                    mul_terms: vec![],
443                    linear_combinations: vec![],
444                    q_c: FieldElement::from(8u128),
445                }),
446                range_opcode(),
447                and_opcode(),
448                keccakf1600_opcode(),
449                schnorr_verify_opcode(),
450            ],
451            private_parameters: BTreeSet::new(),
452            public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
453            return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
454            assert_messages: Default::default(),
455            recursive: false,
456        };
457        let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };
458
459        let json = serde_json::to_string_pretty(&program).unwrap();
460
461        let deserialized = serde_json::from_str(&json).unwrap();
462        assert_eq!(program, deserialized);
463    }
464
465    #[test]
466    fn does_not_panic_on_invalid_circuit() {
467        use std::io::Write;
468
469        let bad_circuit = "I'm not an ACIR circuit".as_bytes();
470
471        // We expect to load circuits as compressed artifacts so we compress the junk circuit.
472        let mut zipped_bad_circuit = Vec::new();
473        let mut encoder =
474            flate2::write::GzEncoder::new(&mut zipped_bad_circuit, Compression::default());
475        encoder.write_all(bad_circuit).unwrap();
476        encoder.finish().unwrap();
477
478        let deserialization_result = Program::deserialize_program(&zipped_bad_circuit);
479        assert!(deserialization_result.is_err());
480    }
481}