use crate::mir::constant::Constant;
use crate::serialization::sigma_byte_reader::SigmaByteRead;
use crate::serialization::sigma_byte_writer::SigmaByteWrite;
use crate::serialization::SigmaParsingError;
use crate::serialization::SigmaSerializable;
use crate::serialization::SigmaSerializationError;
use crate::serialization::SigmaSerializeResult;
use derive_more::From;
use ergo_chain_types::Base16EncodedBytes;
use std::convert::TryInto;
use std::{collections::HashMap, convert::TryFrom};
use thiserror::Error;
#[derive(PartialEq, Eq, Debug, Clone, Copy, From)]
pub enum RegisterId {
MandatoryRegisterId(MandatoryRegisterId),
NonMandatoryRegisterId(NonMandatoryRegisterId),
}
impl RegisterId {
pub const R0: RegisterId = RegisterId::MandatoryRegisterId(MandatoryRegisterId::R0);
pub const R1: RegisterId = RegisterId::MandatoryRegisterId(MandatoryRegisterId::R1);
pub const R2: RegisterId = RegisterId::MandatoryRegisterId(MandatoryRegisterId::R2);
pub const R3: RegisterId = RegisterId::MandatoryRegisterId(MandatoryRegisterId::R3);
}
#[derive(Error, PartialEq, Eq, Debug, Clone)]
#[error("register id {0} is out of bounds (0 - 9)")]
pub struct RegisterIdOutOfBounds(pub i8);
impl TryFrom<i8> for RegisterId {
type Error = RegisterIdOutOfBounds;
fn try_from(value: i8) -> Result<Self, Self::Error> {
if value < 0 {
return Err(RegisterIdOutOfBounds(value));
}
let v = value as usize;
if v < NonMandatoryRegisterId::START_INDEX {
Ok(RegisterId::MandatoryRegisterId(value.try_into()?))
} else if v <= NonMandatoryRegisterId::END_INDEX {
Ok(RegisterId::NonMandatoryRegisterId(value.try_into()?))
} else {
Err(RegisterIdOutOfBounds(value))
}
}
}
impl TryFrom<u8> for RegisterId {
type Error = RegisterIdOutOfBounds;
fn try_from(value: u8) -> Result<Self, Self::Error> {
RegisterId::try_from(value as i8)
}
}
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
#[cfg_attr(feature = "json", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "json", serde(into = "String", try_from = "String"))]
#[repr(u8)]
pub enum NonMandatoryRegisterId {
R4 = 4,
R5 = 5,
R6 = 6,
R7 = 7,
R8 = 8,
R9 = 9,
}
impl NonMandatoryRegisterId {
pub const START_INDEX: usize = 4;
pub const END_INDEX: usize = 9;
pub const NUM_REGS: usize = 6;
pub const REG_IDS: [NonMandatoryRegisterId; NonMandatoryRegisterId::NUM_REGS] = [
NonMandatoryRegisterId::R4,
NonMandatoryRegisterId::R5,
NonMandatoryRegisterId::R6,
NonMandatoryRegisterId::R7,
NonMandatoryRegisterId::R8,
NonMandatoryRegisterId::R9,
];
pub fn get_by_zero_index(i: usize) -> NonMandatoryRegisterId {
assert!(i < NonMandatoryRegisterId::NUM_REGS);
NonMandatoryRegisterId::REG_IDS[i]
}
}
impl From<NonMandatoryRegisterId> for String {
fn from(v: NonMandatoryRegisterId) -> Self {
format!("R{}", v as u8)
}
}
impl TryFrom<String> for NonMandatoryRegisterId {
type Error = NonMandatoryRegisterIdParsingError;
fn try_from(str: String) -> Result<Self, Self::Error> {
if str.len() == 2 && &str[..1] == "R" {
let index = str[1..2]
.parse::<usize>()
.map_err(|_| NonMandatoryRegisterIdParsingError())?;
if (NonMandatoryRegisterId::START_INDEX..=NonMandatoryRegisterId::END_INDEX)
.contains(&index)
{
Ok(NonMandatoryRegisterId::get_by_zero_index(
index - NonMandatoryRegisterId::START_INDEX,
))
} else {
Err(NonMandatoryRegisterIdParsingError())
}
} else {
Err(NonMandatoryRegisterIdParsingError())
}
}
}
impl TryFrom<i8> for NonMandatoryRegisterId {
type Error = RegisterIdOutOfBounds;
fn try_from(value: i8) -> Result<Self, Self::Error> {
let v_usize = value as usize;
if (NonMandatoryRegisterId::START_INDEX..=NonMandatoryRegisterId::END_INDEX)
.contains(&v_usize)
{
Ok(NonMandatoryRegisterId::get_by_zero_index(
v_usize - NonMandatoryRegisterId::START_INDEX,
))
} else {
Err(RegisterIdOutOfBounds(value))
}
}
}
#[derive(Error, PartialEq, Eq, Debug, Clone)]
#[error("failed to parse register id")]
pub struct NonMandatoryRegisterIdParsingError();
#[derive(PartialEq, Eq, Debug, Clone)]
#[cfg_attr(feature = "json", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "json",
serde(
into = "HashMap<NonMandatoryRegisterId, ergo_chain_types::Base16EncodedBytes>",
try_from = "HashMap<NonMandatoryRegisterId, crate::chain::json::ergo_box::ConstantHolder>"
)
)]
pub struct NonMandatoryRegisters(Vec<RegisterValue>);
impl NonMandatoryRegisters {
pub const MAX_SIZE: usize = NonMandatoryRegisterId::NUM_REGS;
pub fn empty() -> NonMandatoryRegisters {
NonMandatoryRegisters(vec![])
}
pub fn new(
regs: HashMap<NonMandatoryRegisterId, Constant>,
) -> Result<NonMandatoryRegisters, NonMandatoryRegistersError> {
NonMandatoryRegisters::try_from(
regs.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<HashMap<NonMandatoryRegisterId, RegisterValue>>(),
)
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn get(&self, reg_id: NonMandatoryRegisterId) -> Option<&RegisterValue> {
self.0.get(reg_id as usize)
}
pub fn get_constant(&self, reg_id: NonMandatoryRegisterId) -> Option<&Constant> {
self.0
.get(reg_id as usize - NonMandatoryRegisterId::START_INDEX)
.and_then(|rv| rv.as_option_constant())
}
}
#[derive(PartialEq, Eq, Debug, Clone, From)]
pub enum RegisterValue {
Parsed(Constant),
Unparseable(Vec<u8>),
}
impl RegisterValue {
pub fn as_option_constant(&self) -> Option<&Constant> {
match self {
RegisterValue::Parsed(c) => Some(c),
RegisterValue::Unparseable(_) => None,
}
}
#[allow(clippy::unwrap_used)] fn sigma_serialize_bytes(&self) -> Vec<u8> {
match self {
RegisterValue::Parsed(c) => c.sigma_serialize_bytes().unwrap(),
RegisterValue::Unparseable(bytes) => bytes.clone(),
}
}
}
impl TryFrom<Vec<RegisterValue>> for NonMandatoryRegisters {
type Error = NonMandatoryRegistersError;
fn try_from(values: Vec<RegisterValue>) -> Result<Self, Self::Error> {
if values.len() > NonMandatoryRegisters::MAX_SIZE {
Err(NonMandatoryRegistersError::InvalidSize(values.len()))
} else {
Ok(NonMandatoryRegisters(values))
}
}
}
impl TryFrom<Vec<Constant>> for NonMandatoryRegisters {
type Error = NonMandatoryRegistersError;
fn try_from(values: Vec<Constant>) -> Result<Self, Self::Error> {
NonMandatoryRegisters::try_from(
values
.into_iter()
.map(RegisterValue::Parsed)
.collect::<Vec<RegisterValue>>(),
)
}
}
impl SigmaSerializable for NonMandatoryRegisters {
fn sigma_serialize<W: SigmaByteWrite>(&self, w: &mut W) -> SigmaSerializeResult {
let regs_num = self.len();
w.put_u8(regs_num as u8)?;
for reg_value in self.0.iter() {
match reg_value {
RegisterValue::Parsed(c) => c.sigma_serialize(w)?,
RegisterValue::Unparseable(_) => {
return Err(SigmaSerializationError::NotSupported(
"unparseable register value cannot be serialized, because it cannot be parsed later"
))
}
};
}
Ok(())
}
fn sigma_parse<R: SigmaByteRead>(r: &mut R) -> Result<Self, SigmaParsingError> {
let regs_num = r.get_u8()?;
let mut additional_regs = Vec::with_capacity(regs_num as usize);
for _ in 0..regs_num {
let v = Constant::sigma_parse(r)?;
additional_regs.push(v);
}
Ok(additional_regs.try_into()?)
}
}
#[derive(Error, PartialEq, Eq, Clone, Debug)]
pub enum NonMandatoryRegistersError {
#[error("invalid non-mandatory registers size ({0})")]
InvalidSize(usize),
#[error("registers are not densely packed (register R{0} is missing)")]
NonDenselyPacked(u8),
}
impl From<NonMandatoryRegisters>
for HashMap<NonMandatoryRegisterId, ergo_chain_types::Base16EncodedBytes>
{
fn from(v: NonMandatoryRegisters) -> Self {
v.0.into_iter()
.enumerate()
.map(|(i, reg_value)| {
(
NonMandatoryRegisterId::get_by_zero_index(i),
#[allow(clippy::unwrap_used)]
Base16EncodedBytes::new(®_value.sigma_serialize_bytes()),
)
})
.collect()
}
}
impl From<NonMandatoryRegisters> for HashMap<NonMandatoryRegisterId, RegisterValue> {
fn from(v: NonMandatoryRegisters) -> Self {
v.0.into_iter()
.enumerate()
.map(|(i, reg_val)| (NonMandatoryRegisterId::get_by_zero_index(i), reg_val))
.collect()
}
}
impl TryFrom<HashMap<NonMandatoryRegisterId, RegisterValue>> for NonMandatoryRegisters {
type Error = NonMandatoryRegistersError;
fn try_from(
reg_map: HashMap<NonMandatoryRegisterId, RegisterValue>,
) -> Result<Self, Self::Error> {
let regs_num = reg_map.len();
if regs_num > NonMandatoryRegisters::MAX_SIZE {
Err(NonMandatoryRegistersError::InvalidSize(regs_num))
} else {
let mut res: Vec<RegisterValue> = vec![];
NonMandatoryRegisterId::REG_IDS
.iter()
.take(regs_num)
.try_for_each(|reg_id| match reg_map.get(reg_id) {
Some(v) => Ok(res.push(v.clone())),
None => Err(NonMandatoryRegistersError::NonDenselyPacked(*reg_id as u8)),
})?;
Ok(NonMandatoryRegisters(res))
}
}
}
#[cfg(feature = "json")]
impl TryFrom<HashMap<NonMandatoryRegisterId, crate::chain::json::ergo_box::ConstantHolder>>
for NonMandatoryRegisters
{
type Error = NonMandatoryRegistersError;
fn try_from(
value: HashMap<NonMandatoryRegisterId, crate::chain::json::ergo_box::ConstantHolder>,
) -> Result<Self, Self::Error> {
let cm: HashMap<NonMandatoryRegisterId, RegisterValue> =
value.into_iter().map(|(k, v)| (k, v.into())).collect();
NonMandatoryRegisters::try_from(cm)
}
}
impl From<NonMandatoryRegistersError> for SigmaParsingError {
fn from(error: NonMandatoryRegistersError) -> Self {
SigmaParsingError::Misc(error.to_string())
}
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum MandatoryRegisterId {
R0 = 0,
R1 = 1,
R2 = 2,
R3 = 3,
}
impl TryFrom<i8> for MandatoryRegisterId {
type Error = RegisterIdOutOfBounds;
fn try_from(value: i8) -> Result<Self, Self::Error> {
match value {
v if v == MandatoryRegisterId::R0 as i8 => Ok(MandatoryRegisterId::R0),
v if v == MandatoryRegisterId::R1 as i8 => Ok(MandatoryRegisterId::R1),
v if v == MandatoryRegisterId::R2 as i8 => Ok(MandatoryRegisterId::R2),
v if v == MandatoryRegisterId::R3 as i8 => Ok(MandatoryRegisterId::R3),
_ => Err(RegisterIdOutOfBounds(value)),
}
}
}
#[allow(clippy::unwrap_used)]
#[cfg(feature = "arbitrary")]
pub(crate) mod arbitrary {
use super::*;
use proptest::{arbitrary::Arbitrary, collection::vec, prelude::*};
#[derive(Default)]
pub struct ArbNonMandatoryRegistersParams {
pub allow_unparseable: bool,
}
impl Arbitrary for NonMandatoryRegisters {
type Parameters = ArbNonMandatoryRegistersParams;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
vec(
if params.allow_unparseable {
prop_oneof![
any::<Constant>().prop_map(RegisterValue::Parsed),
vec(any::<u8>(), 0..100).prop_map(RegisterValue::Unparseable)
]
.boxed()
} else {
any::<Constant>().prop_map(RegisterValue::Parsed).boxed()
},
0..=NonMandatoryRegisterId::NUM_REGS,
)
.prop_map(|reg_values| NonMandatoryRegisters::try_from(reg_values).unwrap())
.boxed()
}
}
}
#[allow(clippy::panic)]
#[allow(clippy::unwrap_used)]
#[allow(clippy::expect_used)]
#[cfg(test)]
mod tests {
use super::*;
use crate::serialization::sigma_serialize_roundtrip;
use proptest::prelude::*;
proptest! {
#[test]
fn hash_map_roundtrip(regs in any::<NonMandatoryRegisters>()) {
let hash_map: HashMap<NonMandatoryRegisterId, RegisterValue> = regs.clone().into();
let regs_from_map = NonMandatoryRegisters::try_from(hash_map);
prop_assert![regs_from_map.is_ok()];
prop_assert_eq![regs_from_map.unwrap(), regs];
}
#[test]
fn get(regs in any::<NonMandatoryRegisters>()) {
let hash_map: HashMap<NonMandatoryRegisterId, RegisterValue> = regs.clone().into();
hash_map.keys().try_for_each(|reg_id| {
prop_assert_eq![regs.get_constant(*reg_id), hash_map.get(reg_id).unwrap().as_option_constant()];
Ok(())
})?;
}
#[test]
fn reg_id_from_byte(reg_id_byte in 0i8..NonMandatoryRegisterId::END_INDEX as i8) {
assert!(RegisterId::try_from(reg_id_byte).is_ok());
}
#[test]
fn ser_roundtrip(regs in any::<NonMandatoryRegisters>()) {
prop_assert_eq![sigma_serialize_roundtrip(®s), regs];
}
}
#[test]
fn test_empty() {
assert!(NonMandatoryRegisters::empty().is_empty());
}
#[test]
fn test_non_densely_packed_error() {
let mut hash_map: HashMap<NonMandatoryRegisterId, RegisterValue> = HashMap::new();
let c: Constant = 1i32.into();
hash_map.insert(NonMandatoryRegisterId::R4, c.clone().into());
hash_map.insert(NonMandatoryRegisterId::R6, c.into());
assert!(NonMandatoryRegisters::try_from(hash_map).is_err());
}
}