arcium-core-utils 0.4.1

Arcium core utils
Documentation
/// A serialization optimized representation of a circuit.
use primitives::algebra::elliptic_curve::{Curve, Point, Scalar};
use serde::{Deserialize, Serialize};

use crate::circuit::{errors::CircuitError, Circuit, Gate, GateIndex};

/// Serialization/deserialization optimized representation of a circuit.
#[derive(Serialize, Deserialize, Default)]
#[serde(bound(
    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
))]
#[repr(C)]
struct CompressedCircuit<C: Curve> {
    /// The circuit operations.
    pub ops: Vec<Gate<C>>,
    /// The output gates in order of definition
    pub output_gates: Vec<GateIndex>,
}

impl<C: Curve> From<&Circuit<C>> for CompressedCircuit<C> {
    fn from(value: &Circuit<C>) -> Self {
        CompressedCircuit {
            ops: value.iter_gates().cloned().collect(),
            output_gates: value.iter_output_indices().copied().collect(),
        }
    }
}

impl<C: Curve> TryFrom<CompressedCircuit<C>> for Circuit<C> {
    type Error = CircuitError<C>;

    fn try_from(circuit: CompressedCircuit<C>) -> Result<Self, Self::Error> {
        let mut res = Self {
            gates: Vec::with_capacity(circuit.ops.len()),
            inputs: Vec::new(),
            outputs: Vec::with_capacity(circuit.output_gates.len()),
        };

        for gate in circuit.ops.into_iter() {
            res.add_gate(gate)?;
        }

        for index in circuit.output_gates.into_iter() {
            res.add_output(index)?;
        }

        Ok(res)
    }
}

mod bincode {
    use primitives::algebra::elliptic_curve::Curve;
    use serde::{Deserialize, Deserializer, Serialize, Serializer};

    use crate::circuit::{compressed_circuit::CompressedCircuit, Circuit};

    impl<C: Curve> Serialize for Circuit<C> {
        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
        where
            S: Serializer,
        {
            let circuit_serde: CompressedCircuit<C> = self.into();
            circuit_serde.serialize(serializer)
        }
    }

    impl<'de, C: Curve> Deserialize<'de> for Circuit<C> {
        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
        where
            D: Deserializer<'de>,
        {
            let circuit_serde = CompressedCircuit::<C>::deserialize(deserializer)
                .map_err(serde::de::Error::custom)?;
            let circuit = circuit_serde.try_into();
            circuit.map_err(serde::de::Error::custom)
        }
    }
}

mod wincode {
    use core::{
        mem::{self, MaybeUninit},
        ptr,
    };

    use ::wincode::{
        containers,
        io::{Reader, Writer},
        len::BincodeLen,
        ReadResult,
        SchemaRead,
        SchemaWrite,
        TypeMeta,
        WriteResult,
    };

    use super::*;
    pub type BincodeLenU32 = BincodeLen<{ 2 << 32 }>;

    // TODO: optimize so that we dont need to cast into CircuitSerde twice
    impl<C: Curve> SchemaWrite for Circuit<C> {
        type Src = Self;

        const TYPE_META: TypeMeta = <CompressedCircuit<C> as SchemaWrite>::TYPE_META;

        fn size_of(src: &Self::Src) -> WriteResult<usize> {
            let circuit_serde: CompressedCircuit<C> = src.into();
            <CompressedCircuit<C> as SchemaWrite>::size_of(&circuit_serde)
        }

        fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
            let circuit_serde: CompressedCircuit<C> = src.into();
            <CompressedCircuit<C> as SchemaWrite>::write(writer, &circuit_serde)
        }
    }

    impl<C: Curve> ::wincode::SchemaWrite for CompressedCircuit<C> {
        type Src = Self;
        #[allow(clippy::arithmetic_side_effects)]
        const TYPE_META: TypeMeta = if let (
            TypeMeta::Static {
                size: a,
                zero_copy: zc_a,
            },
            TypeMeta::Static {
                size: b,
                zero_copy: zc_b,
            },
        ) = (
            <containers::Vec<Gate<C>, BincodeLenU32> as SchemaWrite>::TYPE_META,
            <Vec<GateIndex> as SchemaWrite>::TYPE_META,
        ) {
            let serialized_size = a + b;
            let no_padding = serialized_size == size_of::<Self>();
            TypeMeta::Static {
                size: serialized_size,
                zero_copy: no_padding && zc_a && zc_b,
            }
        } else {
            TypeMeta::Dynamic
        };
        #[inline]
        fn size_of(src: &Self::Src) -> WriteResult<usize> {
            if let TypeMeta::Static { size, .. } = <Self as SchemaWrite>::TYPE_META {
                return Ok(size);
            }
            let mut total = 0usize;
            total += <containers::Vec<Gate<C>, BincodeLenU32> as SchemaWrite>::size_of(&src.ops)?;
            total += <Vec<GateIndex> as SchemaWrite>::size_of(&src.output_gates)?;
            Ok(total)
        }
        #[inline]
        fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
            // Macro to reduce duplication in field writing
            macro_rules! write_fields {
                ($writer:expr) => {{
                    <containers::Vec<Gate<C>, BincodeLenU32> as SchemaWrite>::write(
                        $writer, &src.ops,
                    )?;
                    <Vec<GateIndex> as SchemaWrite>::write($writer, &src.output_gates)?;
                }};
            }

            match <Self as SchemaWrite>::TYPE_META {
                TypeMeta::Static { size, .. } => {
                    let writer = &mut unsafe { writer.as_trusted_for(size) }?;
                    write_fields!(writer);
                    writer.finish()?;
                }
                TypeMeta::Dynamic => {
                    write_fields!(writer);
                }
            }
            Ok(())
        }
    }

    impl<'de, C: Curve> SchemaRead<'de> for Circuit<C> {
        type Dst = Self;
        const TYPE_META: TypeMeta = <CompressedCircuit<C> as SchemaRead>::TYPE_META;

        fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
            let mut circuit_serde = MaybeUninit::new(CompressedCircuit::<C>::default());
            <CompressedCircuit<C> as SchemaRead>::read(reader, &mut circuit_serde)?;
            let circuit_serde = unsafe { circuit_serde.assume_init() };
            let circuit: Self::Dst = circuit_serde.try_into().map_err(|_| {
                ::wincode::ReadError::Custom("Invalid cast from CircuitSerde to Circuit struct")
            })?;
            dst.write(circuit);
            Ok(())
        }
    }

    impl<'de, C: Curve> SchemaRead<'de> for CompressedCircuit<C> {
        type Dst = Self;
        #[allow(clippy::arithmetic_side_effects)]
        const TYPE_META: TypeMeta = if let (
            TypeMeta::Static {
                size: a,
                zero_copy: zc_a,
            },
            TypeMeta::Static {
                size: b,
                zero_copy: zc_b,
            },
        ) = (
            <containers::Vec<Gate<C>, BincodeLenU32> as SchemaRead<'de>>::TYPE_META,
            <Vec<GateIndex> as SchemaRead<'de>>::TYPE_META,
        ) {
            let serialized_size = a + b;
            let no_padding = serialized_size == size_of::<Self>();
            TypeMeta::Static {
                size: serialized_size,
                zero_copy: no_padding && zc_a && zc_b,
            }
        } else {
            TypeMeta::Dynamic
        };
        #[inline]
        fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
            struct DropGuard<C: Curve> {
                init_count: u8,
                dst_ptr: *mut CompressedCircuit<C>,
            }
            impl<C: Curve> Drop for DropGuard<C> {
                #[cold]
                fn drop(&mut self) {
                    let dst_ptr = self.dst_ptr;
                    let init_count = self.init_count;
                    match init_count {
                        0 => {}
                        1u8 => unsafe {
                            ptr::drop_in_place(&raw mut (*dst_ptr).ops);
                        },
                        _ => unreachable!("init_count out of bounds"),
                    }
                }
            }
            // Macro to reduce duplication in field reading
            macro_rules! read_fields {
                ($reader:expr, $dst_ptr:expr, $guard:expr) => {{
                    let init_count = &mut $guard.init_count;
                    <wincode::containers::Vec<Gate<C>, BincodeLenU32> as SchemaRead<'de>>::read(
                        $reader,
                        unsafe { &mut *(&raw mut (*$dst_ptr).ops).cast::<MaybeUninit<_>>() },
                    )?;
                    *init_count += 1;
                    <Vec<GateIndex> as SchemaRead<'de>>::read($reader, unsafe {
                        &mut *(&raw mut (*$dst_ptr).output_gates).cast::<MaybeUninit<_>>()
                    })?;
                    mem::forget($guard);
                }};
            }

            let dst_ptr = dst.as_mut_ptr();
            let mut guard = DropGuard {
                init_count: 0,
                dst_ptr,
            };

            match <Self as SchemaRead<'de>>::TYPE_META {
                TypeMeta::Static { size, .. } => {
                    read_fields!(&mut unsafe { reader.as_trusted_for(size) }?, dst_ptr, guard);
                }
                TypeMeta::Dynamic => {
                    read_fields!(reader, dst_ptr, guard);
                }
            }
            Ok(())
        }
    }
}

#[cfg(test)]
mod tests {
    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;

    use super::*;
    use crate::circuit::{
        tests::create_add_tree_circuit,
        AlgebraicType,
        FieldShareBinaryOp,
        Input,
    };

    #[test]
    fn valid_circuit() {
        let mut circuit = Circuit::<C>::new();
        let input_gate1 = circuit
            .add_gate(Gate::Input(Input::SecretPlaintext {
                inputer: 0,
                algebraic_type: AlgebraicType::ScalarField,
                batch_size: 1,
            }))
            .unwrap();
        assert_eq!(input_gate1, 0);
        let input_gate2 = circuit
            .add_gate(Gate::Input(Input::SecretPlaintext {
                inputer: 1,
                algebraic_type: AlgebraicType::ScalarField,
                batch_size: 1,
            }))
            .unwrap();
        assert_eq!(input_gate2, 1);
        let add_gate = circuit
            .add_gate(Gate::FieldShareBinaryOp {
                x: input_gate1,
                y: input_gate2,
                op: FieldShareBinaryOp::Add,
            })
            .unwrap();
        assert_eq!(add_gate, 2);
        circuit.add_output(add_gate).unwrap();
        assert_eq!(
            circuit.iter_output_indices().copied().collect::<Vec<_>>(),
            vec![2]
        );
    }

    #[test]
    fn test_ser_circuit_bincode() {
        let circuit = create_add_tree_circuit(18);
        let serialized = ::bincode::serialize(&circuit).unwrap();
        let circuit_de: Circuit<C> = ::bincode::deserialize(&serialized).unwrap();

        assert_eq!(circuit, circuit_de);
    }

    #[test]
    fn test_ser_circuit_wincode() {
        let circuit = create_add_tree_circuit(18);

        // Wincode roundtrip
        let serialized = ::wincode::serialize(&circuit).expect("Serialization failed");
        let deserialized: Circuit<C> =
            ::wincode::deserialize(&serialized).expect("Deserialization failed");

        assert_eq!(circuit, deserialized);
    }
}