1#![cfg_attr(not(test), forbid(unsafe_code))] #![cfg_attr(not(test), warn(unused_crate_dependencies, unused_extern_crates))]
9
10#[doc = include_str!("../README.md")]
11pub mod circuit;
12pub mod native_types;
13#[cfg(test)]
14mod parser;
15mod proto;
16mod serialization;
17
18pub use acir_field;
19pub use acir_field::{AcirField, FieldElement};
20pub use brillig;
21pub use circuit::black_box_functions::BlackBoxFunc;
22pub use circuit::opcodes::InvalidInputBitSize;
23
24#[cfg(test)]
25mod reflection {
26 use std::{
37 collections::BTreeMap,
38 fs::File,
39 io::Write,
40 path::{Path, PathBuf},
41 };
42
43 use acir_field::{AcirField, FieldElement};
44 use brillig::{
45 BinaryFieldOp, BinaryIntOp, BitSize, BlackBoxOp, HeapValueType, IntegerBitSize,
46 MemoryAddress, Opcode as BrilligOpcode, ValueOrArray,
47 };
48 use regex::Regex;
49 use serde::{Deserialize, Serialize};
50 use serde_generate::CustomCode;
51 use serde_reflection::{
52 ContainerFormat, Format, Named, Registry, Tracer, TracerConfig, VariantFormat,
53 };
54
55 use crate::{
56 circuit::{
57 AssertionPayload, Circuit, ExpressionOrMemory, ExpressionWidth, Opcode, OpcodeLocation,
58 Program,
59 brillig::{BrilligInputs, BrilligOutputs},
60 opcodes::{BlackBoxFuncCall, BlockType, ConstantOrWitnessEnum, FunctionInput},
61 },
62 native_types::{Witness, WitnessMap, WitnessStack},
63 };
64
65 #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Hash)]
76 struct ProgramWithoutBrillig<F: AcirField> {
77 pub functions: Vec<Circuit<F>>,
78 }
79
80 #[test]
81 fn serde_acir_cpp_codegen() {
82 let mut tracer = Tracer::new(TracerConfig::default());
83 tracer.trace_simple_type::<BlockType>().unwrap();
84 tracer.trace_simple_type::<Program<FieldElement>>().unwrap();
85 tracer.trace_simple_type::<ProgramWithoutBrillig<FieldElement>>().unwrap();
86 tracer.trace_simple_type::<Circuit<FieldElement>>().unwrap();
87 tracer.trace_simple_type::<ExpressionWidth>().unwrap();
88 tracer.trace_simple_type::<Opcode<FieldElement>>().unwrap();
89 tracer.trace_simple_type::<OpcodeLocation>().unwrap();
90 tracer.trace_simple_type::<BinaryFieldOp>().unwrap();
91 tracer.trace_simple_type::<ConstantOrWitnessEnum<FieldElement>>().unwrap();
92 tracer.trace_simple_type::<FunctionInput<FieldElement>>().unwrap();
93 tracer.trace_simple_type::<BlackBoxFuncCall<FieldElement>>().unwrap();
94 tracer.trace_simple_type::<BrilligInputs<FieldElement>>().unwrap();
95 tracer.trace_simple_type::<BrilligOutputs>().unwrap();
96 tracer.trace_simple_type::<BrilligOpcode<FieldElement>>().unwrap();
97 tracer.trace_simple_type::<BinaryIntOp>().unwrap();
98 tracer.trace_simple_type::<BlackBoxOp>().unwrap();
99 tracer.trace_simple_type::<ValueOrArray>().unwrap();
100 tracer.trace_simple_type::<HeapValueType>().unwrap();
101 tracer.trace_simple_type::<AssertionPayload<FieldElement>>().unwrap();
102 tracer.trace_simple_type::<ExpressionOrMemory<FieldElement>>().unwrap();
103 tracer.trace_simple_type::<BitSize>().unwrap();
104 tracer.trace_simple_type::<IntegerBitSize>().unwrap();
105 tracer.trace_simple_type::<MemoryAddress>().unwrap();
106
107 serde_cpp_codegen(
108 "Acir",
109 PathBuf::from("./codegen/acir.cpp").as_path(),
110 &tracer.registry().unwrap(),
111 CustomCode::default(),
112 );
113 }
114
115 #[test]
116 fn serde_witness_map_cpp_codegen() {
117 let mut tracer = Tracer::new(TracerConfig::default());
118 tracer.trace_simple_type::<Witness>().unwrap();
119 tracer.trace_simple_type::<WitnessMap<FieldElement>>().unwrap();
120 tracer.trace_simple_type::<WitnessStack<FieldElement>>().unwrap();
121
122 let namespace = "Witnesses";
123 let mut code = CustomCode::default();
124 code.insert(
127 vec![namespace.to_string(), "Witness".to_string()],
128 "bool operator<(Witness const& rhs) const { return value < rhs.value; }".to_string(),
129 );
130
131 serde_cpp_codegen(
132 namespace,
133 PathBuf::from("./codegen/witness.cpp").as_path(),
134 &tracer.registry().unwrap(),
135 code,
136 );
137 }
138
139 fn serde_cpp_codegen(namespace: &str, path: &Path, registry: &Registry, code: CustomCode) {
145 let old_hash = if path.is_file() {
146 let old_source = std::fs::read(path).expect("failed to read existing code");
147 let old_source = String::from_utf8(old_source).expect("old source not UTF-8");
148 Some(fxhash::hash64(&old_source))
149 } else {
150 None
151 };
152 let msgpack_code = MsgPackCodeGenerator::generate(namespace, registry, code);
153
154 let mut source = Vec::new();
156 let config = serde_generate::CodeGeneratorConfig::new(namespace.to_string())
157 .with_encodings(vec![serde_generate::Encoding::Bincode])
158 .with_custom_code(msgpack_code);
159 let generator = serde_generate::cpp::CodeGenerator::new(&config);
160 generator.output(&mut source, registry).expect("failed to generate C++ code");
161
162 let mut source = String::from_utf8(source).expect("not a UTF-8 string");
164 replace_throw(&mut source);
165 MsgPackCodeGenerator::add_preamble(&mut source);
166 MsgPackCodeGenerator::add_helpers(&mut source, namespace);
167 MsgPackCodeGenerator::replace_array_with_shared_ptr(&mut source);
168
169 if !should_overwrite() {
170 if let Some(old_hash) = old_hash {
171 let new_hash = fxhash::hash64(&source);
172 assert_eq!(new_hash, old_hash, "Serialization format has changed",);
173 }
174 }
175
176 write_to_file(source.as_bytes(), path);
177 }
178
179 fn should_overwrite() -> bool {
182 std::env::var("NOIR_CODEGEN_OVERWRITE")
183 .ok()
184 .map(|v| v == "1" || v == "true")
185 .unwrap_or_default()
186 }
187
188 fn write_to_file(bytes: &[u8], path: &Path) -> String {
189 let display = path.display();
190
191 let parent_dir = path.parent().unwrap();
192 if !parent_dir.is_dir() {
193 std::fs::create_dir_all(parent_dir).unwrap();
194 }
195
196 let mut file = match File::create(path) {
197 Err(why) => panic!("couldn't create {display}: {why}"),
198 Ok(file) => file,
199 };
200
201 match file.write_all(bytes) {
202 Err(why) => panic!("couldn't write to {display}: {why}"),
203 Ok(_) => display.to_string(),
204 }
205 }
206
207 fn replace_throw(source: &mut String) {
213 *source = source.replace("throw serde::deserialization_error", "throw_or_abort");
214 }
215
216 struct MsgPackCodeGenerator {
219 namespace: Vec<String>,
220 code: CustomCode,
221 }
222
223 impl MsgPackCodeGenerator {
224 fn add_preamble(source: &mut String) {
226 let inc = r#"#include "serde.hpp""#;
227 let pos = source.find(inc).expect("serde.hpp missing");
228 source.insert_str(pos + inc.len(), "\n#include \"msgpack.hpp\"");
229 }
230
231 fn add_helpers(source: &mut String, namespace: &str) {
233 let helpers = r#"
237 struct Helpers {
238 static std::map<std::string, msgpack::object const*> make_kvmap(
239 msgpack::object const& o,
240 std::string const& name
241 ) {
242 if(o.type != msgpack::type::MAP) {
243 std::cerr << o << std::endl;
244 throw_or_abort("expected MAP for " + name);
245 }
246 std::map<std::string, msgpack::object const*> kvmap;
247 for (uint32_t i = 0; i < o.via.map.size; ++i) {
248 if (o.via.map.ptr[i].key.type != msgpack::type::STR) {
249 std::cerr << o << std::endl;
250 throw_or_abort("expected STR for keys of " + name);
251 }
252 kvmap.emplace(
253 std::string(
254 o.via.map.ptr[i].key.via.str.ptr,
255 o.via.map.ptr[i].key.via.str.size),
256 &o.via.map.ptr[i].val);
257 }
258 return kvmap;
259 }
260 template<typename T>
261 static void conv_fld_from_kvmap(
262 std::map<std::string, msgpack::object const*> const& kvmap,
263 std::string const& struct_name,
264 std::string const& field_name,
265 T& field,
266 bool is_optional
267 ) {
268 auto it = kvmap.find(field_name);
269 if (it != kvmap.end()) {
270 try {
271 it->second->convert(field);
272 } catch (const msgpack::type_error&) {
273 std::cerr << *it->second << std::endl;
274 throw_or_abort("error converting into field " + struct_name + "::" + field_name);
275 }
276 } else if (!is_optional) {
277 throw_or_abort("missing field: " + struct_name + "::" + field_name);
278 }
279 }
280 };
281 "#;
282 let pos = source.find(&format!("namespace {namespace}")).expect("namespace");
284 source.insert_str(pos, &format!("namespace {namespace} {{{helpers}}}\n\n"));
285 }
286
287 fn replace_array_with_shared_ptr(source: &mut String) {
289 let re = Regex::new(r#"std::array<\s*([^,<>]+?)\s*,\s*([0-9]+)\s*>"#)
291 .expect("failed to create regex");
292
293 let fixed =
294 re.replace_all(source, "std::shared_ptr<std::array<${1}, ${2}>>").into_owned();
295
296 *source = fixed;
297 }
298
299 fn generate(namespace: &str, registry: &Registry, code: CustomCode) -> CustomCode {
300 let mut g = Self { namespace: vec![namespace.to_string()], code };
301 for (name, container) in registry {
302 g.generate_container(name, container);
303 }
304 g.code
305 }
306
307 fn add_code(&mut self, name: &str, code: &str) {
309 let mut ns = self.namespace.clone();
310 ns.push(name.to_string());
311 let c = self.code.entry(ns).or_default();
312 if !c.is_empty() && code.contains('\n') {
313 c.push('\n');
314 }
315 c.push_str(code);
316 c.push('\n');
317 }
318
319 fn generate_container(&mut self, name: &str, container: &ContainerFormat) {
320 use serde_reflection::ContainerFormat::*;
321 match container {
322 UnitStruct => {
323 self.generate_unit_struct(name);
324 }
325 NewTypeStruct(_format) => {
326 self.generate_newtype(name);
327 }
328 TupleStruct(formats) => {
329 self.generate_tuple(name, formats);
330 }
331 Struct(fields) => {
332 self.generate_struct(name, fields);
333 }
334 Enum(variants) => {
335 self.generate_enum(name, variants);
336 }
337 }
338 }
339
340 fn generate_unit_struct(&mut self, name: &str) {
342 self.msgpack_pack(name, "");
346 self.msgpack_unpack(name, "");
347 }
348
349 fn generate_struct(&mut self, name: &str, fields: &[Named<Format>]) {
351 self.msgpack_pack(name, &{
365 let mut body = format!(
366 "
367 packer.pack_map({});",
368 fields.len()
369 );
370 for field in fields {
371 let field_name = &field.name;
372 body.push_str(&format!(
373 r#"
374 packer.pack(std::make_pair("{field_name}", {field_name}));"#
375 ));
376 }
377 body
378 });
379
380 self.msgpack_unpack(name, &{
381 let mut body = format!(
385 r#"
386 auto name = "{name}";
387 auto kvmap = Helpers::make_kvmap(o, name);"#
388 );
389 for field in fields {
391 let field_name = &field.name;
392 let is_optional = matches!(field.value, Format::Option(_));
393 body.push_str(&format!(
395 r#"
396 Helpers::conv_fld_from_kvmap(kvmap, name, "{field_name}", {field_name}, {is_optional});"#
397 ));
398 }
400 body
401 });
402 }
403
404 fn generate_newtype(&mut self, name: &str) {
406 self.msgpack_pack(name, "packer.pack(value);");
407 self.msgpack_unpack(
408 name,
409 &format!(
411 r#"
412 try {{
413 o.convert(value);
414 }} catch (const msgpack::type_error&) {{
415 std::cerr << o << std::endl;
416 throw_or_abort("error converting into newtype '{name}'");
417 }}
418 "#
419 ),
420 );
422 }
423
424 fn generate_tuple(&mut self, _name: &str, _formats: &[Format]) {
426 unimplemented!("Until we have a tuple enum in our schema we don't need this.");
427 }
428
429 fn generate_enum(&mut self, name: &str, variants: &BTreeMap<u32, Named<VariantFormat>>) {
431 self.namespace.push(name.to_string());
433 for variant in variants.values() {
434 self.generate_variant(&variant.name, &variant.value);
435 }
436 self.namespace.pop();
437
438 self.msgpack_pack(name, &{
440 let cases = variants
441 .iter()
442 .map(|(i, v)| {
443 format!(
444 r#"
445 case {i}:
446 tag = "{}";
447 is_unit = {};
448 break;"#,
449 v.name,
450 matches!(v.value, VariantFormat::Unit)
451 )
452 })
453 .collect::<Vec<_>>()
454 .join("");
455
456 format!(
457 r#"
458 std::string tag;
459 bool is_unit;
460 switch (value.index()) {{
461 {cases}
462 default:
463 throw_or_abort("unknown enum '{name}' variant index: " + std::to_string(value.index()));
464 }}
465 if (is_unit) {{
466 packer.pack(tag);
467 }} else {{
468 std::visit([&packer, tag](const auto& arg) {{
469 std::map<std::string, msgpack::object> data;
470 data[tag] = msgpack::object(arg);
471 packer.pack(data);
472 }}, value);
473 }}"#
474 )
475 });
476
477 self.msgpack_unpack(name, &{
481 let mut body = format!(
483 r#"
484
485 if (o.type != msgpack::type::object_type::MAP && o.type != msgpack::type::object_type::STR) {{
486 std::cerr << o << std::endl;
487 throw_or_abort("expected MAP or STR for enum '{name}'; got type " + std::to_string(o.type));
488 }}
489 if (o.type == msgpack::type::object_type::MAP && o.via.map.size != 1) {{
490 throw_or_abort("expected 1 entry for enum '{name}'; got " + std::to_string(o.via.map.size));
491 }}
492 std::string tag;
493 try {{
494 if (o.type == msgpack::type::object_type::MAP) {{
495 o.via.map.ptr[0].key.convert(tag);
496 }} else {{
497 o.convert(tag);
498 }}
499 }} catch(const msgpack::type_error&) {{
500 std::cerr << o << std::endl;
501 throw_or_abort("error converting tag to string for enum '{name}'");
502 }}"#
503 );
504 for (i, v) in variants.iter() {
507 let variant = &v.name;
508 body.push_str(&format!(
509 r#"
510 {}if (tag == "{variant}") {{
511 {variant} v;"#,
512 if *i == 0 { "" } else { "else " }
513 ));
514
515 if !matches!(v.value, VariantFormat::Unit) {
516 body.push_str(&format!(
518 r#"
519 try {{
520 o.via.map.ptr[0].val.convert(v);
521 }} catch (const msgpack::type_error&) {{
522 std::cerr << o << std::endl;
523 throw_or_abort("error converting into enum variant '{name}::{variant}'");
524 }}
525 "#
526 ));
527 }
529 body.push_str(
531 r#"
532 value = v;
533 }"#,
534 );
535 }
536 body.push_str(&format!(
538 r#"
539 else {{
540 std::cerr << o << std::endl;
541 throw_or_abort("unknown '{name}' enum variant: " + tag);
542 }}"#
543 ));
544 body
547 });
548 }
549
550 fn generate_variant(&mut self, name: &str, variant: &VariantFormat) {
552 match variant {
553 VariantFormat::Variable(_) => {
554 unreachable!("internal construct")
555 }
556 VariantFormat::Unit => self.generate_unit_struct(name),
557 VariantFormat::NewType(_format) => self.generate_newtype(name),
558 VariantFormat::Tuple(formats) => self.generate_tuple(name, formats),
559 VariantFormat::Struct(fields) => self.generate_struct(name, fields),
560 }
561 }
562
563 #[allow(dead_code)]
568 fn msgpack_fields(&mut self, name: &str, fields: impl Iterator<Item = String>) {
569 let fields = fields.collect::<Vec<_>>().join(", ");
570 let code = format!("MSGPACK_FIELDS({fields});");
571 self.add_code(name, &code);
572 }
573
574 fn msgpack_pack(&mut self, name: &str, body: &str) {
576 let code = Self::make_fn("void msgpack_pack(auto& packer) const", body);
577 self.add_code(name, &code);
578 }
579
580 fn msgpack_unpack(&mut self, name: &str, body: &str) {
582 let code = Self::make_fn("void msgpack_unpack(msgpack::object const& o)", body);
620 self.add_code(name, &code);
621 }
622
623 fn make_fn(header: &str, body: &str) -> String {
624 let body = body.trim_end();
625 if body.is_empty() {
626 format!("{header} {{}}")
627 } else if !body.contains('\n') {
628 format!("{header} {{ {body} }}")
629 } else if body.starts_with('\n') {
630 format!("{header} {{{body}\n}}")
631 } else {
632 format!("{header} {{\n{body}\n}}")
633 }
634 }
635 }
636}