use crate::proto::convert::ProtoSchema;
use acir_field::AcirField;
use noir_protobuf::ProtoCodec;
use num_enum::{IntoPrimitive, TryFromPrimitive};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use strum_macros::EnumString;
const FORMAT_ENV_VAR: &str = "NOIR_SERIALIZATION_FORMAT";
#[derive(Debug, Clone, Copy, IntoPrimitive, TryFromPrimitive, EnumString, PartialEq, Eq)]
#[strum(serialize_all = "kebab-case")]
#[repr(u8)]
pub(crate) enum Format {
BincodeLegacy = 0,
Bincode = 1,
Msgpack = 2,
MsgpackCompact = 3,
Protobuf = 4,
}
impl Format {
pub(crate) fn from_env() -> Result<Option<Self>, String> {
let Ok(format) = std::env::var(FORMAT_ENV_VAR) else {
return Ok(None);
};
Self::from_str(&format)
.map(Some)
.map_err(|e| format!("unknown format '{format}' in {FORMAT_ENV_VAR}: {e}"))
}
}
pub(crate) fn bincode_serialize<T: Serialize>(value: &T) -> std::io::Result<Vec<u8>> {
bincode::serde::encode_to_vec(value, bincode::config::legacy()).map_err(std::io::Error::other)
}
pub(crate) fn bincode_deserialize<T: for<'a> Deserialize<'a>>(buf: &[u8]) -> std::io::Result<T> {
bincode::serde::borrow_decode_from_slice(buf, bincode::config::legacy())
.map(|(result, _)| result)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
#[allow(dead_code)]
pub(crate) fn msgpack_serialize<T: Serialize>(
value: &T,
compact: bool,
) -> std::io::Result<Vec<u8>> {
if compact {
rmp_serde::to_vec(value).map_err(std::io::Error::other)
} else {
rmp_serde::to_vec_named(value).map_err(std::io::Error::other)
}
}
#[allow(dead_code)]
pub(crate) fn msgpack_deserialize<T: for<'a> Deserialize<'a>>(buf: &[u8]) -> std::io::Result<T> {
rmp_serde::from_slice(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
#[allow(dead_code)]
pub(crate) fn proto_serialize<F, T, R>(value: &T) -> Vec<u8>
where
F: AcirField,
R: prost::Message,
ProtoSchema<F>: ProtoCodec<T, R>,
{
ProtoSchema::<F>::serialize_to_vec(value)
}
#[allow(dead_code)]
pub(crate) fn proto_deserialize<F, T, R>(buf: &[u8]) -> std::io::Result<T>
where
F: AcirField,
R: prost::Message + Default,
ProtoSchema<F>: ProtoCodec<T, R>,
{
ProtoSchema::<F>::deserialize_from_slice(buf)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
pub(crate) fn deserialize_any_format<F, T, R>(buf: &[u8]) -> std::io::Result<T>
where
T: for<'a> Deserialize<'a>,
F: AcirField,
R: prost::Message + Default,
ProtoSchema<F>: ProtoCodec<T, R>,
{
let bincode_result = bincode_deserialize(buf);
if bincode_result.is_err() && !buf.is_empty() {
if let Ok(format) = Format::try_from(buf[0]) {
match format {
Format::BincodeLegacy => {
}
Format::Bincode => {
if let Ok(value) = bincode_deserialize(&buf[1..]) {
return Ok(value);
}
}
Format::Msgpack | Format::MsgpackCompact => {
if let Ok(value) = msgpack_deserialize(&buf[1..]) {
return Ok(value);
}
}
Format::Protobuf => {
if let Ok(value) = proto_deserialize(&buf[1..]) {
return Ok(value);
}
}
}
}
}
bincode_result
}
pub(crate) fn serialize_with_format<F, T, R>(value: &T, format: Format) -> std::io::Result<Vec<u8>>
where
F: AcirField,
T: Serialize,
R: prost::Message,
ProtoSchema<F>: ProtoCodec<T, R>,
{
let mut buf = match format {
Format::BincodeLegacy => return bincode_serialize(value),
Format::Bincode => bincode_serialize(value)?,
Format::Protobuf => proto_serialize(value),
Format::Msgpack => msgpack_serialize(value, false)?,
Format::MsgpackCompact => msgpack_serialize(value, true)?,
};
let mut res = vec![format.into()];
res.append(&mut buf);
Ok(res)
}
pub(crate) fn serialize_with_format_from_env<F, T, R>(value: &T) -> std::io::Result<Vec<u8>>
where
F: AcirField,
T: Serialize,
R: prost::Message,
ProtoSchema<F>: ProtoCodec<T, R>,
{
match Format::from_env() {
Ok(Some(format)) => {
serialize_with_format(value, format)
}
Ok(None) => {
bincode_serialize(value)
}
Err(e) => Err(std::io::Error::other(e)),
}
}
#[cfg(test)]
mod tests {
use acir_field::FieldElement;
use brillig::{BitSize, HeapArray, IntegerBitSize, ValueOrArray};
use std::str::FromStr;
use crate::{
circuit::{Opcode, brillig::BrilligFunctionId},
native_types::Witness,
serialization::{Format, msgpack_deserialize, msgpack_serialize},
};
mod version1 {
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) enum Foo {
Case0 { d: u32 },
Case1 { a: u64, b: bool },
Case2 { a: i32 },
Case3 { a: bool },
Case4 { a: Box<Foo> },
Case5 { a: u32, b: Option<u32> },
}
}
mod version2 {
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) enum Foo {
Case1 {
a: u64,
b: bool,
},
Case2 {
b: String,
a: i32,
},
Case3 {
a: bool,
b: String,
},
Case5 {
a: u32,
},
Case4 {
#[serde(rename = "a")]
c: Box<Foo>,
},
Case6 {
b: i64,
},
Case7 {
c: bool,
},
}
}
#[test]
fn msgpack_serialize_backwards_compatibility() {
let cases = vec![
(version2::Foo::Case1 { b: true, a: 1 }, version1::Foo::Case1 { b: true, a: 1 }),
(version2::Foo::Case2 { b: "prefix".into(), a: 2 }, version1::Foo::Case2 { a: 2 }),
(
version2::Foo::Case3 { a: true, b: "suffix".into() },
version1::Foo::Case3 { a: true },
),
(
version2::Foo::Case4 { c: Box::new(version2::Foo::Case1 { a: 4, b: false }) },
version1::Foo::Case4 { a: Box::new(version1::Foo::Case1 { a: 4, b: false }) },
),
(version2::Foo::Case5 { a: 5 }, version1::Foo::Case5 { a: 5, b: None }),
];
for (i, (v2, v1)) in cases.into_iter().enumerate() {
let bz = msgpack_serialize(&v2, false).unwrap();
let v = msgpack_deserialize::<version1::Foo>(&bz)
.unwrap_or_else(|e| panic!("case {i} failed: {e}"));
assert_eq!(v, v1);
}
}
#[test]
fn msgpack_serialize_compact_backwards_compatibility() {
let cases = vec![
(version2::Foo::Case1 { b: true, a: 1 }, version1::Foo::Case1 { b: true, a: 1 }, None),
(
version2::Foo::Case2 { b: "prefix".into(), a: 2 },
version1::Foo::Case2 { a: 2 },
Some("wrong msgpack marker FixStr(6)"),
),
(
version2::Foo::Case3 { a: true, b: "suffix".into() },
version1::Foo::Case3 { a: true },
Some("array had incorrect length, expected 1"),
),
(
version2::Foo::Case4 { c: Box::new(version2::Foo::Case1 { a: 4, b: false }) },
version1::Foo::Case4 { a: Box::new(version1::Foo::Case1 { a: 4, b: false }) },
None,
),
(
version2::Foo::Case5 { a: 5 },
version1::Foo::Case5 { a: 5, b: None },
Some("invalid length 1, expected struct variant Foo::Case5 with 2 elements"),
),
];
for (i, (v2, v1, ex)) in cases.into_iter().enumerate() {
let bz = msgpack_serialize(&v2, true).unwrap();
let res = msgpack_deserialize::<version1::Foo>(&bz);
match (res, ex) {
(Ok(v), None) => {
assert_eq!(v, v1);
}
(Ok(_), Some(ex)) => panic!("case {i} expected to fail with {ex}"),
(Err(e), None) => panic!("case {i} expected to pass; got {e}"),
(Err(e), Some(ex)) => {
let e = e.to_string();
if !e.contains(ex) {
panic!("case {i} error expected to contain {ex}; got {e}")
}
}
}
}
}
#[test]
fn msgpack_repr_enum_of_structs() {
use rmpv::Value;
let value = ValueOrArray::HeapArray(HeapArray {
pointer: brillig::MemoryAddress::Relative(0),
size: 3,
});
let bz = msgpack_serialize(&value, false).unwrap();
let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap();
let Value::Map(fields) = msg else {
panic!("expected Map: {msg:?}");
};
assert_eq!(fields.len(), 1);
let Value::String(key) = &fields[0].0 else {
panic!("expected String key: {fields:?}");
};
assert_eq!(key.as_str(), Some("HeapArray"));
}
#[test]
fn msgpack_repr_enum_of_unit_structs() {
let value = IntegerBitSize::U1;
let bz = msgpack_serialize(&value, false).unwrap();
let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap();
assert_eq!(msg.as_str(), Some("U1"));
}
#[test]
fn msgpack_repr_enum_of_mixed() {
let value = vec![BitSize::Field, BitSize::Integer(IntegerBitSize::U64)];
let bz = msgpack_serialize(&value, false).unwrap();
let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap();
assert_eq!(format!("{msg}"), r#"["Field", {"Integer": "U64"}]"#);
}
#[test]
fn msgpack_repr_newtype() {
use rmpv::Value;
let value = Witness(1);
let bz = msgpack_serialize(&value, false).unwrap();
let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap();
assert!(matches!(msg, Value::Integer(_)));
}
#[test]
fn msgpack_optional() {
use rmpv::Value;
let value: Opcode<FieldElement> = Opcode::BrilligCall {
id: BrilligFunctionId(1),
inputs: Vec::new(),
outputs: Vec::new(),
predicate: None,
};
let bz = msgpack_serialize(&value, false).unwrap();
let msg = rmpv::decode::read_value::<&[u8]>(&mut bz.as_ref()).unwrap();
let fields = msg.as_map().expect("enum is a map");
let fields = &fields.first().expect("enum is non-empty").1;
let fields = fields.as_map().expect("fields are map");
let (k, v) = fields.last().expect("fields are not empty");
assert_eq!(k.as_str().expect("names are str"), "predicate");
assert!(matches!(v, Value::Nil));
}
#[test]
fn format_from_str() {
assert_eq!(Format::from_str("msgpack-compact").unwrap(), Format::MsgpackCompact);
}
}