stwo-cairo-serialize 1.2.0

Serialization utilities for Stwo Cairo proofs
Documentation
use std::array;

use starknet_ff::FieldElement;
use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::SecureField;
use stwo::core::fri::{FriConfig, FriLayerProof, FriProof};
use stwo::core::pcs::quotients::CommitmentSchemeProof;
use stwo::core::pcs::{PcsConfig, TreeVec};
use stwo::core::poly::line::LinePoly;
use stwo::core::proof::StarkProof;
use stwo::core::vcs::blake2_hash::Blake2sHash;
use stwo::core::vcs_lifted::verifier::MerkleDecommitmentLifted;
use stwo::core::vcs_lifted::MerkleHasherLifted;
use stwo::core::ColumnVec;
pub use stwo_cairo_serialize_derive::CairoDeserialize;

/// Deserializes types from a format serialized by corresponding `CairoSerialize` implementations.
pub trait CairoDeserialize: Sized {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self;
}

impl CairoDeserialize for u32 {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let field_elem = data.next().unwrap();
        let bytes = field_elem.to_bytes_be();
        if bytes[0..28] != [0; 28] {
            panic!("Invalid value for u32");
        }

        u32::from_be_bytes(bytes[28..32].try_into().unwrap())
    }
}

impl CairoDeserialize for u64 {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let field_elem = data.next().unwrap();
        let bytes = field_elem.to_bytes_be();
        if bytes[0..24] != [0; 24] {
            panic!("Invalid value for u64");
        }

        u64::from_be_bytes(bytes[24..32].try_into().unwrap())
    }
}

impl CairoDeserialize for usize {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        <u64 as CairoDeserialize>::deserialize(data) as usize
    }
}

impl CairoDeserialize for BaseField {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        BaseField::from(u32::deserialize(data))
    }
}

impl CairoDeserialize for SecureField {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let mut m31_values = [BaseField::from(0); 4];
        #[allow(clippy::needless_range_loop)]
        for i in 0..4 {
            m31_values[i] = BaseField::deserialize(data);
        }

        SecureField::from_m31_array(m31_values)
    }
}

impl<H: MerkleHasherLifted> CairoDeserialize for MerkleDecommitmentLifted<H>
where
    H::Hash: CairoDeserialize,
{
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let hash_witness = Vec::<H::Hash>::deserialize(data);
        MerkleDecommitmentLifted { hash_witness }
    }
}

impl CairoDeserialize for LinePoly {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let coeffs = Vec::<SecureField>::deserialize(data);
        let log_size: u32 = u32::deserialize(data);

        let expected_len = 1usize << log_size;
        if coeffs.len() != expected_len {
            panic!("Invalid length for LinePoly");
        }

        LinePoly::new(coeffs)
    }
}

impl<H: MerkleHasherLifted> CairoDeserialize for FriLayerProof<H>
where
    H::Hash: CairoDeserialize,
{
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let fri_witness = Vec::deserialize(data);
        let decommitment = MerkleDecommitmentLifted::deserialize(data);
        let commitment = H::Hash::deserialize(data);
        FriLayerProof {
            fri_witness,
            decommitment,
            commitment,
        }
    }
}

impl<H: MerkleHasherLifted> CairoDeserialize for FriProof<H>
where
    H::Hash: CairoDeserialize,
{
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let first_layer = FriLayerProof::deserialize(data);
        let inner_layers = Vec::deserialize(data);
        let last_layer_poly = LinePoly::deserialize(data);
        FriProof {
            first_layer,
            inner_layers,
            last_layer_poly,
        }
    }
}

impl CairoDeserialize for FieldElement {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        data.next().copied().unwrap()
    }
}

impl CairoDeserialize for FriConfig {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let log_blowup_factor = u32::deserialize(data);
        let log_last_layer_degree_bound = u32::deserialize(data);
        let n_queries = usize::deserialize(data);
        FriConfig {
            log_blowup_factor,
            log_last_layer_degree_bound,
            n_queries,
            fold_step: 1,
        }
    }
}

impl CairoDeserialize for PcsConfig {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let pow_bits = u32::deserialize(data);
        let fri_config = FriConfig::deserialize(data);

        // The cairo1 does not support fixed lifting log size in the PCS config.
        let lifting_log_size = None;
        PcsConfig {
            pow_bits,
            fri_config,
            lifting_log_size,
        }
    }
}

impl<H: MerkleHasherLifted> CairoDeserialize for CommitmentSchemeProof<H>
where
    H::Hash: CairoDeserialize,
{
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let config = PcsConfig::deserialize(data);
        let commitments = TreeVec(Vec::<H::Hash>::deserialize(data));
        let sampled_values = TreeVec(Vec::<ColumnVec<Vec<SecureField>>>::deserialize(data));
        let decommitments = TreeVec(Vec::<MerkleDecommitmentLifted<H>>::deserialize(data));
        let queried_values = TreeVec(Vec::<Vec<Vec<BaseField>>>::deserialize(data));
        let proof_of_work: u64 = u64::deserialize(data);
        let fri_proof = FriProof::deserialize(data);
        CommitmentSchemeProof {
            config,
            commitments,
            sampled_values,
            decommitments,
            queried_values,
            proof_of_work,
            fri_proof,
        }
    }
}

impl<H: MerkleHasherLifted> CairoDeserialize for StarkProof<H>
where
    H::Hash: CairoDeserialize,
{
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let commitment_scheme_proof = CommitmentSchemeProof::deserialize(data);
        StarkProof(commitment_scheme_proof)
    }
}

impl<T: CairoDeserialize> CairoDeserialize for Option<T> {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let discriminant = data.next().unwrap();
        if *discriminant == FieldElement::ZERO {
            let value = T::deserialize(data);
            Some(value)
        } else if *discriminant == FieldElement::ONE {
            None
        } else {
            panic!("Invalid discriminant for Option<T>");
        }
    }
}

impl<T: CairoDeserialize, const N: usize> CairoDeserialize for [T; N] {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        array::from_fn(|_| T::deserialize(data))
    }
}

impl<T: CairoDeserialize> CairoDeserialize for Vec<T> {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let len: usize = usize::deserialize(data);

        (0..len).map(|_| T::deserialize(data)).collect()
    }
}

impl<T0: CairoDeserialize, T1: CairoDeserialize> CairoDeserialize for (T0, T1) {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let v0 = T0::deserialize(data);
        let v1 = T1::deserialize(data);
        (v0, v1)
    }
}

impl<T0: CairoDeserialize, T1: CairoDeserialize, T2: CairoDeserialize> CairoDeserialize
    for (T0, T1, T2)
{
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let v0 = T0::deserialize(data);
        let v1 = T1::deserialize(data);
        let v2 = T2::deserialize(data);
        (v0, v1, v2)
    }
}

impl CairoDeserialize for Blake2sHash {
    fn deserialize<'a>(data: &mut impl Iterator<Item = &'a FieldElement>) -> Self {
        let mut bytes = [0u8; 32];
        for byte_chunk in bytes.chunks_exact_mut(4) {
            let byte_chunk: &mut [u8; 4] = byte_chunk.try_into().unwrap();
            let v: u32 = u32::deserialize(data);
            *byte_chunk = v.to_le_bytes();
        }
        Blake2sHash(bytes)
    }
}