use primitives::algebra::elliptic_curve::{Curve, Point, Scalar};
use serde::{Deserialize, Serialize};
use crate::circuit::{errors::CircuitError, Circuit, Gate, GateIndex};
#[derive(Serialize, Deserialize, Default)]
#[serde(bound(
serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
))]
#[repr(C)]
struct CompressedCircuit<C: Curve> {
pub ops: Vec<Gate<C>>,
pub output_gates: Vec<GateIndex>,
}
impl<C: Curve> From<&Circuit<C>> for CompressedCircuit<C> {
fn from(value: &Circuit<C>) -> Self {
CompressedCircuit {
ops: value.iter_gates().cloned().collect(),
output_gates: value.iter_output_indices().copied().collect(),
}
}
}
impl<C: Curve> TryFrom<CompressedCircuit<C>> for Circuit<C> {
type Error = CircuitError<C>;
fn try_from(circuit: CompressedCircuit<C>) -> Result<Self, Self::Error> {
let mut res = Self {
gates: Vec::with_capacity(circuit.ops.len()),
inputs: Vec::new(),
outputs: Vec::with_capacity(circuit.output_gates.len()),
};
for gate in circuit.ops.into_iter() {
res.add_gate(gate)?;
}
for index in circuit.output_gates.into_iter() {
res.add_output(index)?;
}
Ok(res)
}
}
mod bincode {
use primitives::algebra::elliptic_curve::Curve;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::circuit::{compressed_circuit::CompressedCircuit, Circuit};
impl<C: Curve> Serialize for Circuit<C> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let circuit_serde: CompressedCircuit<C> = self.into();
circuit_serde.serialize(serializer)
}
}
impl<'de, C: Curve> Deserialize<'de> for Circuit<C> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let circuit_serde = CompressedCircuit::<C>::deserialize(deserializer)
.map_err(serde::de::Error::custom)?;
let circuit = circuit_serde.try_into();
circuit.map_err(serde::de::Error::custom)
}
}
}
mod wincode {
use core::{
mem::{self, MaybeUninit},
ptr,
};
use ::wincode::{
containers,
io::{Reader, Writer},
len::BincodeLen,
ReadResult,
SchemaRead,
SchemaWrite,
TypeMeta,
WriteResult,
};
use super::*;
pub type BincodeLenU32 = BincodeLen<{ 2 << 32 }>;
impl<C: Curve> SchemaWrite for Circuit<C> {
type Src = Self;
const TYPE_META: TypeMeta = <CompressedCircuit<C> as SchemaWrite>::TYPE_META;
fn size_of(src: &Self::Src) -> WriteResult<usize> {
let circuit_serde: CompressedCircuit<C> = src.into();
<CompressedCircuit<C> as SchemaWrite>::size_of(&circuit_serde)
}
fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
let circuit_serde: CompressedCircuit<C> = src.into();
<CompressedCircuit<C> as SchemaWrite>::write(writer, &circuit_serde)
}
}
impl<C: Curve> ::wincode::SchemaWrite for CompressedCircuit<C> {
type Src = Self;
#[allow(clippy::arithmetic_side_effects)]
const TYPE_META: TypeMeta = if let (
TypeMeta::Static {
size: a,
zero_copy: zc_a,
},
TypeMeta::Static {
size: b,
zero_copy: zc_b,
},
) = (
<containers::Vec<Gate<C>, BincodeLenU32> as SchemaWrite>::TYPE_META,
<Vec<GateIndex> as SchemaWrite>::TYPE_META,
) {
let serialized_size = a + b;
let no_padding = serialized_size == size_of::<Self>();
TypeMeta::Static {
size: serialized_size,
zero_copy: no_padding && zc_a && zc_b,
}
} else {
TypeMeta::Dynamic
};
#[inline]
fn size_of(src: &Self::Src) -> WriteResult<usize> {
if let TypeMeta::Static { size, .. } = <Self as SchemaWrite>::TYPE_META {
return Ok(size);
}
let mut total = 0usize;
total += <containers::Vec<Gate<C>, BincodeLenU32> as SchemaWrite>::size_of(&src.ops)?;
total += <Vec<GateIndex> as SchemaWrite>::size_of(&src.output_gates)?;
Ok(total)
}
#[inline]
fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
macro_rules! write_fields {
($writer:expr) => {{
<containers::Vec<Gate<C>, BincodeLenU32> as SchemaWrite>::write(
$writer, &src.ops,
)?;
<Vec<GateIndex> as SchemaWrite>::write($writer, &src.output_gates)?;
}};
}
match <Self as SchemaWrite>::TYPE_META {
TypeMeta::Static { size, .. } => {
let writer = &mut unsafe { writer.as_trusted_for(size) }?;
write_fields!(writer);
writer.finish()?;
}
TypeMeta::Dynamic => {
write_fields!(writer);
}
}
Ok(())
}
}
impl<'de, C: Curve> SchemaRead<'de> for Circuit<C> {
type Dst = Self;
const TYPE_META: TypeMeta = <CompressedCircuit<C> as SchemaRead>::TYPE_META;
fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
let mut circuit_serde = MaybeUninit::new(CompressedCircuit::<C>::default());
<CompressedCircuit<C> as SchemaRead>::read(reader, &mut circuit_serde)?;
let circuit_serde = unsafe { circuit_serde.assume_init() };
let circuit: Self::Dst = circuit_serde.try_into().map_err(|_| {
::wincode::ReadError::Custom("Invalid cast from CircuitSerde to Circuit struct")
})?;
dst.write(circuit);
Ok(())
}
}
impl<'de, C: Curve> SchemaRead<'de> for CompressedCircuit<C> {
type Dst = Self;
#[allow(clippy::arithmetic_side_effects)]
const TYPE_META: TypeMeta = if let (
TypeMeta::Static {
size: a,
zero_copy: zc_a,
},
TypeMeta::Static {
size: b,
zero_copy: zc_b,
},
) = (
<containers::Vec<Gate<C>, BincodeLenU32> as SchemaRead<'de>>::TYPE_META,
<Vec<GateIndex> as SchemaRead<'de>>::TYPE_META,
) {
let serialized_size = a + b;
let no_padding = serialized_size == size_of::<Self>();
TypeMeta::Static {
size: serialized_size,
zero_copy: no_padding && zc_a && zc_b,
}
} else {
TypeMeta::Dynamic
};
#[inline]
fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
struct DropGuard<C: Curve> {
init_count: u8,
dst_ptr: *mut CompressedCircuit<C>,
}
impl<C: Curve> Drop for DropGuard<C> {
#[cold]
fn drop(&mut self) {
let dst_ptr = self.dst_ptr;
let init_count = self.init_count;
match init_count {
0 => {}
1u8 => unsafe {
ptr::drop_in_place(&raw mut (*dst_ptr).ops);
},
_ => unreachable!("init_count out of bounds"),
}
}
}
macro_rules! read_fields {
($reader:expr, $dst_ptr:expr, $guard:expr) => {{
let init_count = &mut $guard.init_count;
<wincode::containers::Vec<Gate<C>, BincodeLenU32> as SchemaRead<'de>>::read(
$reader,
unsafe { &mut *(&raw mut (*$dst_ptr).ops).cast::<MaybeUninit<_>>() },
)?;
*init_count += 1;
<Vec<GateIndex> as SchemaRead<'de>>::read($reader, unsafe {
&mut *(&raw mut (*$dst_ptr).output_gates).cast::<MaybeUninit<_>>()
})?;
mem::forget($guard);
}};
}
let dst_ptr = dst.as_mut_ptr();
let mut guard = DropGuard {
init_count: 0,
dst_ptr,
};
match <Self as SchemaRead<'de>>::TYPE_META {
TypeMeta::Static { size, .. } => {
read_fields!(&mut unsafe { reader.as_trusted_for(size) }?, dst_ptr, guard);
}
TypeMeta::Dynamic => {
read_fields!(reader, dst_ptr, guard);
}
}
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
use super::*;
use crate::circuit::{
tests::create_add_tree_circuit,
AlgebraicType,
FieldShareBinaryOp,
Input,
};
#[test]
fn valid_circuit() {
let mut circuit = Circuit::<C>::new();
let input_gate1 = circuit
.add_gate(Gate::Input(Input::SecretPlaintext {
inputer: 0,
algebraic_type: AlgebraicType::ScalarField,
batch_size: 1,
}))
.unwrap();
assert_eq!(input_gate1, 0);
let input_gate2 = circuit
.add_gate(Gate::Input(Input::SecretPlaintext {
inputer: 1,
algebraic_type: AlgebraicType::ScalarField,
batch_size: 1,
}))
.unwrap();
assert_eq!(input_gate2, 1);
let add_gate = circuit
.add_gate(Gate::FieldShareBinaryOp {
x: input_gate1,
y: input_gate2,
op: FieldShareBinaryOp::Add,
})
.unwrap();
assert_eq!(add_gate, 2);
circuit.add_output(add_gate).unwrap();
assert_eq!(
circuit.iter_output_indices().copied().collect::<Vec<_>>(),
vec![2]
);
}
#[test]
fn test_ser_circuit_bincode() {
let circuit = create_add_tree_circuit(18);
let serialized = ::bincode::serialize(&circuit).unwrap();
let circuit_de: Circuit<C> = ::bincode::deserialize(&serialized).unwrap();
assert_eq!(circuit, circuit_de);
}
#[test]
fn test_ser_circuit_wincode() {
let circuit = create_add_tree_circuit(18);
let serialized = ::wincode::serialize(&circuit).expect("Serialization failed");
let deserialized: Circuit<C> =
::wincode::deserialize(&serialized).expect("Deserialization failed");
assert_eq!(circuit, deserialized);
}
}