#[cfg(feature = "json")]
use crate::chain::json::ergo_box::ConstantHolder;
use crate::serialization::sigma_byte_reader::SigmaByteRead;
use crate::serialization::sigma_byte_writer::SigmaByteWrite;
use crate::serialization::SigmaSerializable;
use crate::{ast::constant::Constant, serialization::SerializationError};
#[cfg(feature = "json")]
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
use std::{collections::HashMap, convert::TryFrom};
use thiserror::Error;
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum RegisterId {
MandatoryRegisterId(MandatoryRegisterId),
NonMandatoryRegisterId(NonMandatoryRegisterId),
}
impl RegisterId {
pub const R0: RegisterId = RegisterId::MandatoryRegisterId(MandatoryRegisterId::R0);
}
#[derive(Error, PartialEq, Eq, Debug, Clone)]
#[error("register id {0} is out of bounds (0 - 9)")]
pub struct RegisterIdOutOfBounds(pub i8);
impl TryFrom<i8> for RegisterId {
type Error = RegisterIdOutOfBounds;
fn try_from(value: i8) -> Result<Self, Self::Error> {
if value < 0 {
return Err(RegisterIdOutOfBounds(value));
}
let v = value as usize;
if v < NonMandatoryRegisterId::START_INDEX {
Ok(RegisterId::MandatoryRegisterId(value.try_into()?))
} else if v <= NonMandatoryRegisterId::END_INDEX {
Ok(RegisterId::NonMandatoryRegisterId(value.try_into()?))
} else {
Err(RegisterIdOutOfBounds(value))
}
}
}
impl SigmaSerializable for RegisterId {
fn sigma_serialize<W: SigmaByteWrite>(&self, w: &mut W) -> Result<(), std::io::Error> {
let byte = match self {
RegisterId::MandatoryRegisterId(id) => *id as i8,
RegisterId::NonMandatoryRegisterId(id) => *id as i8,
};
w.put_i8(byte)
}
fn sigma_parse<R: SigmaByteRead>(r: &mut R) -> Result<Self, SerializationError> {
let reg_id = r.get_i8()?;
RegisterId::try_from(reg_id).map_err(|_| {
SerializationError::ValueOutOfBounds(format!("Register id out of bounds: {}", reg_id))
})
}
}
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "json", serde(into = "String", try_from = "String"))]
#[repr(u8)]
pub enum NonMandatoryRegisterId {
R4 = 4,
R5 = 5,
R6 = 6,
R7 = 7,
R8 = 8,
R9 = 9,
}
impl NonMandatoryRegisterId {
pub const START_INDEX: usize = 4;
pub const END_INDEX: usize = 9;
pub const NUM_REGS: usize = 6;
pub const REG_IDS: [NonMandatoryRegisterId; NonMandatoryRegisterId::NUM_REGS] = [
NonMandatoryRegisterId::R4,
NonMandatoryRegisterId::R5,
NonMandatoryRegisterId::R6,
NonMandatoryRegisterId::R7,
NonMandatoryRegisterId::R8,
NonMandatoryRegisterId::R9,
];
pub fn get_by_zero_index(i: usize) -> NonMandatoryRegisterId {
assert!(i < NonMandatoryRegisterId::NUM_REGS);
NonMandatoryRegisterId::REG_IDS[i]
}
}
impl Into<String> for NonMandatoryRegisterId {
fn into(self) -> String {
format!("R{}", self as u8)
}
}
impl TryFrom<String> for NonMandatoryRegisterId {
type Error = NonMandatoryRegisterIdParsingError;
fn try_from(str: String) -> Result<Self, Self::Error> {
if str.len() == 2 && &str[..1] == "R" {
let index = (&str[1..2])
.parse::<usize>()
.map_err(|_| NonMandatoryRegisterIdParsingError())?;
if (NonMandatoryRegisterId::START_INDEX..=NonMandatoryRegisterId::END_INDEX)
.contains(&index)
{
Ok(NonMandatoryRegisterId::get_by_zero_index(
index - NonMandatoryRegisterId::START_INDEX,
))
} else {
Err(NonMandatoryRegisterIdParsingError())
}
} else {
Err(NonMandatoryRegisterIdParsingError())
}
}
}
impl TryFrom<i8> for NonMandatoryRegisterId {
type Error = RegisterIdOutOfBounds;
fn try_from(value: i8) -> Result<Self, Self::Error> {
let v_usize = value as usize;
if (NonMandatoryRegisterId::START_INDEX..=NonMandatoryRegisterId::END_INDEX)
.contains(&v_usize)
{
Ok(NonMandatoryRegisterId::get_by_zero_index(
v_usize - NonMandatoryRegisterId::START_INDEX,
))
} else {
Err(RegisterIdOutOfBounds(value))
}
}
}
#[derive(Error, PartialEq, Eq, Debug, Clone)]
#[error("failed to parse register id")]
pub struct NonMandatoryRegisterIdParsingError();
#[derive(PartialEq, Eq, Debug, Clone)]
#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
#[cfg_attr(
feature = "json",
serde(
into = "HashMap<NonMandatoryRegisterId, Constant>",
try_from = "HashMap<NonMandatoryRegisterId, crate::chain::json::ergo_box::ConstantHolder>"
)
)]
pub struct NonMandatoryRegisters(Vec<Constant>);
impl NonMandatoryRegisters {
pub const MAX_SIZE: usize = NonMandatoryRegisterId::NUM_REGS;
pub fn empty() -> NonMandatoryRegisters {
NonMandatoryRegisters(vec![])
}
pub fn new(
regs: HashMap<NonMandatoryRegisterId, Constant>,
) -> Result<NonMandatoryRegisters, NonMandatoryRegistersError> {
NonMandatoryRegisters::try_from(regs)
}
pub fn from_ordered_values(
values: Vec<Constant>,
) -> Result<NonMandatoryRegisters, NonMandatoryRegistersError> {
if values.len() > NonMandatoryRegisters::MAX_SIZE {
Err(NonMandatoryRegistersError::InvalidSize(values.len()))
} else {
Ok(NonMandatoryRegisters(values))
}
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn get(&self, reg_id: NonMandatoryRegisterId) -> Option<&Constant> {
self.0
.get(reg_id as usize - NonMandatoryRegisterId::START_INDEX)
}
pub fn get_ordered_values(&self) -> &Vec<Constant> {
&self.0
}
}
#[derive(Error, PartialEq, Eq, Clone, Debug)]
pub enum NonMandatoryRegistersError {
#[error("invalid non-mandatory registers size ({0})")]
InvalidSize(usize),
#[error("registers are not densely packed (register R{0} is missing)")]
NonDenselyPacked(u8),
}
impl Into<HashMap<NonMandatoryRegisterId, Constant>> for NonMandatoryRegisters {
fn into(self) -> HashMap<NonMandatoryRegisterId, Constant> {
self.0
.into_iter()
.enumerate()
.map(|(i, c)| (NonMandatoryRegisterId::get_by_zero_index(i), c))
.collect()
}
}
impl TryFrom<HashMap<NonMandatoryRegisterId, Constant>> for NonMandatoryRegisters {
type Error = NonMandatoryRegistersError;
fn try_from(reg_map: HashMap<NonMandatoryRegisterId, Constant>) -> Result<Self, Self::Error> {
let regs_num = reg_map.len();
if regs_num > NonMandatoryRegisters::MAX_SIZE {
Err(NonMandatoryRegistersError::InvalidSize(regs_num))
} else {
let mut res: Vec<Constant> = vec![];
NonMandatoryRegisterId::REG_IDS
.iter()
.take(regs_num)
.try_for_each(|reg_id| match reg_map.get(reg_id) {
Some(v) => Ok(res.push(v.clone())),
None => Err(NonMandatoryRegistersError::NonDenselyPacked(*reg_id as u8)),
})?;
Ok(NonMandatoryRegisters(res))
}
}
}
#[cfg(feature = "json")]
impl TryFrom<HashMap<NonMandatoryRegisterId, ConstantHolder>> for NonMandatoryRegisters {
type Error = NonMandatoryRegistersError;
fn try_from(
value: HashMap<NonMandatoryRegisterId, ConstantHolder>,
) -> Result<Self, Self::Error> {
let cm: HashMap<NonMandatoryRegisterId, Constant> =
value.into_iter().map(|(k, v)| (k, v.into())).collect();
NonMandatoryRegisters::try_from(cm)
}
}
impl From<NonMandatoryRegistersError> for SerializationError {
fn from(error: NonMandatoryRegistersError) -> Self {
SerializationError::Misc(error.to_string())
}
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum MandatoryRegisterId {
R0 = 0,
R1 = 1,
R2 = 2,
R3 = 3,
}
impl TryFrom<i8> for MandatoryRegisterId {
type Error = RegisterIdOutOfBounds;
fn try_from(value: i8) -> Result<Self, Self::Error> {
match value {
v if v == MandatoryRegisterId::R0 as i8 => Ok(MandatoryRegisterId::R0),
v if v == MandatoryRegisterId::R1 as i8 => Ok(MandatoryRegisterId::R1),
v if v == MandatoryRegisterId::R2 as i8 => Ok(MandatoryRegisterId::R2),
v if v == MandatoryRegisterId::R3 as i8 => Ok(MandatoryRegisterId::R3),
_ => Err(RegisterIdOutOfBounds(value)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::{arbitrary::Arbitrary, collection::vec, prelude::*};
impl Arbitrary for NonMandatoryRegisters {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
vec(any::<Constant>(), 0..=NonMandatoryRegisterId::NUM_REGS)
.prop_map(|constants| {
NonMandatoryRegisters::from_ordered_values(constants)
.expect("error building registers")
})
.boxed()
}
}
proptest! {
#[test]
fn hash_map_roundtrip(regs in any::<NonMandatoryRegisters>()) {
let hash_map: HashMap<NonMandatoryRegisterId, Constant> = regs.clone().into();
let regs_from_map = NonMandatoryRegisters::try_from(hash_map);
prop_assert![regs_from_map.is_ok()];
prop_assert_eq![regs_from_map.unwrap(), regs];
}
#[test]
fn get(regs in any::<NonMandatoryRegisters>()) {
let hash_map: HashMap<NonMandatoryRegisterId, Constant> = regs.clone().into();
hash_map.keys().try_for_each(|reg_id| {
prop_assert_eq![regs.get(*reg_id), hash_map.get(reg_id)];
Ok(())
})?;
}
}
#[test]
fn test_empty() {
assert!(NonMandatoryRegisters::empty().is_empty());
}
#[test]
fn test_non_densely_packed_error() {
let mut hash_map: HashMap<NonMandatoryRegisterId, Constant> = HashMap::new();
hash_map.insert(NonMandatoryRegisterId::R4, 1i32.into());
hash_map.insert(NonMandatoryRegisterId::R6, 1i32.into());
assert!(NonMandatoryRegisters::try_from(hash_map).is_err());
}
}