use ergotree_ir::mir::constant::Constant;
use ergotree_ir::serialization::sigma_byte_reader::SigmaByteRead;
use ergotree_ir::serialization::sigma_byte_writer::SigmaByteWrite;
use ergotree_ir::serialization::SigmaParsingError;
use ergotree_ir::serialization::SigmaSerializable;
use ergotree_ir::serialization::SigmaSerializeResult;
use indexmap::IndexMap;
use std::convert::TryFrom;
use std::fmt;
use thiserror::Error;
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ContextExtension {
pub values: IndexMap<u8, Constant>,
}
impl ContextExtension {
pub fn empty() -> Self {
Self {
values: IndexMap::new(),
}
}
}
impl fmt::Display for ContextExtension {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.values.iter()).finish()
}
}
impl SigmaSerializable for ContextExtension {
fn sigma_serialize<W: SigmaByteWrite>(&self, w: &mut W) -> SigmaSerializeResult {
w.put_u8(self.values.len() as u8)?;
let values: Vec<(&u8, &Constant)> = self.values.iter().collect();
values.iter().try_for_each(|(idx, c)| {
w.put_u8(**idx)?;
c.sigma_serialize(w)
})?;
Ok(())
}
fn sigma_parse<R: SigmaByteRead>(r: &mut R) -> Result<Self, SigmaParsingError> {
let values_count = r.get_u8()?;
let mut values: IndexMap<u8, Constant> = IndexMap::with_capacity(values_count as usize);
for _ in 0..values_count {
let idx = r.get_u8()?;
values.insert(idx, Constant::sigma_parse(r)?);
}
Ok(ContextExtension { values })
}
}
#[derive(Error, Eq, PartialEq, Debug, Clone)]
#[error("Error parsing constant: {0}")]
pub struct ConstantParsingError(pub String);
impl TryFrom<IndexMap<String, String>> for ContextExtension {
type Error = ConstantParsingError;
fn try_from(values_str: IndexMap<String, String>) -> Result<Self, Self::Error> {
let values = values_str.iter().try_fold(
IndexMap::with_capacity(values_str.len()),
|mut acc, pair| {
let idx: u8 = pair.0.parse().map_err(|_| {
ConstantParsingError(format!("cannot parse index from {0:?}", pair.0))
})?;
let constant_bytes = base16::decode(pair.1).map_err(|_| {
ConstantParsingError(format!(
"cannot decode base16 constant bytes from {0:?}",
pair.1
))
})?;
acc.insert(
idx,
Constant::sigma_parse_bytes(&constant_bytes).map_err(|_| {
ConstantParsingError(format!(
"cannot deserialize constant bytes from {0:?}",
pair.1
))
})?,
);
Ok(acc)
},
)?;
Ok(ContextExtension { values })
}
}
#[cfg(feature = "arbitrary")]
mod arbitrary {
use super::*;
use proptest::{arbitrary::Arbitrary, collection::vec, prelude::*};
impl Arbitrary for ContextExtension {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
vec(any::<Constant>(), 0..10)
.prop_map(|constants| {
let pairs = constants
.into_iter()
.enumerate()
.map(|(idx, c)| (idx as u8, c))
.collect();
Self { values: pairs }
})
.boxed()
}
}
}
#[cfg(test)]
#[cfg(feature = "arbitrary")]
#[allow(clippy::panic)]
mod tests {
use super::*;
use ergotree_ir::serialization::sigma_serialize_roundtrip;
use proptest::prelude::*;
proptest! {
#[test]
fn ser_roundtrip(v in any::<ContextExtension>()) {
prop_assert_eq![sigma_serialize_roundtrip(&v), v];
}
}
}