nova-snark 0.71.0

High-speed recursive arguments from folding schemes
Documentation
use ff::PrimeField;
use serde::{
  de::{self, Deserializer, MapAccess, SeqAccess, Visitor},
  ser::{SerializeStruct, Serializer},
  Deserialize, Serialize,
};
use std::{fmt, marker::PhantomData};

use super::{
  hash_type::HashType,
  poseidon_inner::{Arity, PoseidonConstants},
};

impl<F, A> Serialize for PoseidonConstants<F, A>
where
  F: PrimeField + Serialize,
  A: Arity<F>,
{
  fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
  where
    S: Serializer,
  {
    let mut state = serializer.serialize_struct("PoseidonConstants", 8)?;
    state.serialize_field("mds", &self.mds_matrices)?;
    state.serialize_field("crc", &self.compressed_round_constants)?;
    state.serialize_field("psm", &self.pre_sparse_matrix)?;
    state.serialize_field("sm", &self.sparse_matrixes)?;
    state.serialize_field("s", &self.strength)?;
    state.serialize_field("rf", &self.full_rounds)?;
    state.serialize_field("rp", &self.partial_rounds)?;
    state.serialize_field("ht", &self.hash_type)?;
    state.end()
  }
}

impl<'de, F, A> Deserialize<'de> for PoseidonConstants<F, A>
where
  F: PrimeField + Deserialize<'de>,
  A: Arity<F>,
{
  fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  where
    D: Deserializer<'de>,
  {
    #[derive(Deserialize)]
    #[serde(field_identifier, rename_all = "lowercase")]
    enum Field {
      Mds,
      Crc,
      Psm,
      Sm,
      S,
      Rf,
      Rp,
      Ht,
    }

    struct PoseidonConstantsVisitor<F, A>
    where
      F: PrimeField,
      A: Arity<F>,
    {
      _f: PhantomData<F>,
      _a: PhantomData<A>,
    }

    impl<'de, F, A> Visitor<'de> for PoseidonConstantsVisitor<F, A>
    where
      F: PrimeField + Deserialize<'de>,
      A: Arity<F>,
    {
      type Value = PoseidonConstants<F, A>;

      fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter.write_str("struct PoseidonConstants")
      }

      fn visit_seq<V>(self, mut seq: V) -> Result<PoseidonConstants<F, A>, V::Error>
      where
        V: SeqAccess<'de>,
      {
        let mds_matrices = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(0, &self))?;
        let compressed_round_constants = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(1, &self))?;
        let pre_sparse_matrix = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(2, &self))?;
        let sparse_matrixes = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(3, &self))?;
        let strength = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(4, &self))?;
        let full_rounds = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(5, &self))?;
        let partial_rounds = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(6, &self))?;
        let hash_type: HashType<F, A> = seq
          .next_element()?
          .ok_or_else(|| de::Error::invalid_length(7, &self))?;

        Ok(PoseidonConstants {
          mds_matrices,
          round_constants: None,
          compressed_round_constants,
          pre_sparse_matrix,
          sparse_matrixes,
          strength,
          domain_tag: hash_type.domain_tag(),
          full_rounds,
          half_full_rounds: full_rounds / 2,
          partial_rounds,
          hash_type,
          _a: PhantomData::<A>,
        })
      }

      fn visit_map<V>(self, mut map: V) -> Result<PoseidonConstants<F, A>, V::Error>
      where
        V: MapAccess<'de>,
      {
        let mut mds_matrices = None;
        let mut compressed_round_constants = None;
        let mut pre_sparse_matrix = None;
        let mut sparse_matrixes = None;
        let mut strength = None;
        let mut full_rounds = None;
        let mut partial_rounds = None;
        let mut hash_type = None;

        while let Some(key) = map.next_key()? {
          match key {
            Field::Mds => {
              if mds_matrices.is_some() {
                return Err(de::Error::duplicate_field("mds_matrices"));
              }
              mds_matrices = Some(map.next_value()?);
            }
            Field::Crc => {
              if compressed_round_constants.is_some() {
                return Err(de::Error::duplicate_field("compressed_round_constants"));
              }
              compressed_round_constants = Some(map.next_value()?);
            }
            Field::Psm => {
              if pre_sparse_matrix.is_some() {
                return Err(de::Error::duplicate_field("pre_sparse_matrix"));
              }
              pre_sparse_matrix = Some(map.next_value()?);
            }
            Field::Sm => {
              if sparse_matrixes.is_some() {
                return Err(de::Error::duplicate_field("sparse_matrixes"));
              }
              sparse_matrixes = Some(map.next_value()?);
            }
            Field::S => {
              if strength.is_some() {
                return Err(de::Error::duplicate_field("strength"));
              }
              strength = Some(map.next_value()?);
            }
            Field::Rf => {
              if full_rounds.is_some() {
                return Err(de::Error::duplicate_field("full_rounds"));
              }
              full_rounds = Some(map.next_value()?);
            }
            Field::Rp => {
              if partial_rounds.is_some() {
                return Err(de::Error::duplicate_field("partial_rounds"));
              }
              partial_rounds = Some(map.next_value()?);
            }
            Field::Ht => {
              if hash_type.is_some() {
                return Err(de::Error::duplicate_field("hash_type"));
              }
              hash_type = Some(map.next_value()?);
            }
          }
        }

        let mds_matrices = mds_matrices.ok_or_else(|| de::Error::missing_field("mds_matrices"))?;
        let compressed_round_constants = compressed_round_constants
          .ok_or_else(|| de::Error::missing_field("compressed_round_constants"))?;
        let pre_sparse_matrix =
          pre_sparse_matrix.ok_or_else(|| de::Error::missing_field("pre_sparse_matrix"))?;
        let sparse_matrixes =
          sparse_matrixes.ok_or_else(|| de::Error::missing_field("sparse_matrixes"))?;
        let strength = strength.ok_or_else(|| de::Error::missing_field("strength"))?;
        let full_rounds = full_rounds.ok_or_else(|| de::Error::missing_field("full_rounds"))?;
        let partial_rounds =
          partial_rounds.ok_or_else(|| de::Error::missing_field("partial_rounds"))?;
        let hash_type: HashType<F, A> =
          hash_type.ok_or_else(|| de::Error::missing_field("hash_type"))?;
        Ok(PoseidonConstants {
          mds_matrices,
          round_constants: None,
          compressed_round_constants,
          pre_sparse_matrix,
          sparse_matrixes,
          strength,
          domain_tag: hash_type.domain_tag(),
          full_rounds,
          half_full_rounds: full_rounds / 2,
          partial_rounds,
          hash_type,
          _a: PhantomData::<A>,
        })
      }
    }

    const FIELDS: &[&str] = &[
      "mds_matrices",
      "compressed_round_constants",
      "pre_sparse_matrix",
      "sparse_matrixes",
      "strength",
      "full_rounds",
      "partial_rounds",
      "hash_type",
    ];
    deserializer.deserialize_struct(
      "PoseidonConstants",
      FIELDS,
      PoseidonConstantsVisitor {
        _f: PhantomData,
        _a: PhantomData,
      },
    )
  }
}