use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::fmt::Debug;
use std::hash::Hash;
use std::io;
use std::io::{Read, Write};
use std::ops::{
Deref, Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive,
};
use amplify::num::u24;
use crate::{Error, StrictDecode, StrictEncode};
impl<T> StrictEncode for Range<T>
where
T: StrictEncode,
{
fn strict_encode<E: Write>(&self, mut e: E) -> Result<usize, Error> {
Ok(self.start.strict_encode(&mut e)?
+ self.end.strict_encode(&mut e)?)
}
}
impl<T> StrictDecode for Range<T>
where
T: StrictDecode,
{
fn strict_decode<D: Read>(mut d: D) -> Result<Self, Error> {
Ok(Range {
start: T::strict_decode(&mut d)?,
end: T::strict_decode(&mut d)?,
})
}
}
impl<T> StrictEncode for RangeInclusive<T>
where
T: StrictEncode,
{
fn strict_encode<E: Write>(&self, mut e: E) -> Result<usize, Error> {
Ok(self.start().strict_encode(&mut e)?
+ self.end().strict_encode(&mut e)?)
}
}
impl<T> StrictDecode for RangeInclusive<T>
where
T: StrictDecode,
{
fn strict_decode<D: Read>(mut d: D) -> Result<Self, Error> {
Ok(RangeInclusive::new(
T::strict_decode(&mut d)?,
T::strict_decode(&mut d)?,
))
}
}
impl<T> StrictEncode for RangeFrom<T>
where
T: StrictEncode,
{
#[inline]
fn strict_encode<E: Write>(&self, mut e: E) -> Result<usize, Error> {
self.start.strict_encode(&mut e)
}
}
impl<T> StrictDecode for RangeFrom<T>
where
T: StrictDecode,
{
#[inline]
fn strict_decode<D: Read>(mut d: D) -> Result<Self, Error> {
Ok(RangeFrom {
start: T::strict_decode(&mut d)?,
})
}
}
impl<T> StrictEncode for RangeTo<T>
where
T: StrictEncode,
{
#[inline]
fn strict_encode<E: Write>(&self, mut e: E) -> Result<usize, Error> {
self.end.strict_encode(&mut e)
}
}
impl<T> StrictDecode for RangeTo<T>
where
T: StrictDecode,
{
#[inline]
fn strict_decode<D: Read>(mut d: D) -> Result<Self, Error> {
Ok(RangeTo {
end: T::strict_decode(&mut d)?,
})
}
}
impl<T> StrictEncode for RangeToInclusive<T>
where
T: StrictEncode,
{
#[inline]
fn strict_encode<E: Write>(&self, mut e: E) -> Result<usize, Error> {
self.end.strict_encode(&mut e)
}
}
impl<T> StrictDecode for RangeToInclusive<T>
where
T: StrictDecode,
{
#[inline]
fn strict_decode<D: Read>(mut d: D) -> Result<Self, Error> {
Ok(RangeToInclusive {
end: T::strict_decode(&mut d)?,
})
}
}
impl<T> StrictEncode for Option<T>
where
T: StrictEncode,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
Ok(match self {
None => strict_encode_list!(e; 0u8),
Some(val) => strict_encode_list!(e; 1u8, val),
})
}
}
impl<T> StrictDecode for Option<T>
where
T: StrictDecode,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let len = u8::strict_decode(&mut d)?;
match len {
0 => Ok(None),
1 => Ok(Some(T::strict_decode(&mut d)?)),
invalid => Err(Error::WrongOptionalEncoding(invalid)),
}
}
}
impl<T> StrictEncode for [T]
where
T: StrictEncode,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
let len = self.len();
let mut encoded = len.strict_encode(&mut e)?;
for item in self {
encoded += item.strict_encode(&mut e)?;
}
Ok(encoded)
}
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(transparent)
)]
pub struct LargeVec<T>(Vec<T>)
where
T: StrictEncode + StrictDecode;
impl<T> Default for LargeVec<T>
where
T: StrictEncode + StrictDecode,
{
fn default() -> Self { Self(vec![]) }
}
impl<T> Deref for LargeVec<T>
where
T: StrictEncode + StrictDecode,
{
type Target = Vec<T>;
fn deref(&self) -> &Self::Target { &self.0 }
}
impl<T> TryFrom<Vec<T>> for LargeVec<T>
where
T: StrictEncode + StrictDecode,
{
type Error = Error;
fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
let len = value.len();
if len > u32::MAX as usize {
return Err(Error::ExceedMaxItems(len));
}
Ok(Self(value))
}
}
impl<'me, T> IntoIterator for &'me LargeVec<T>
where
T: StrictEncode + StrictDecode,
{
type Item = &'me T;
type IntoIter = std::slice::Iter<'me, T>;
fn into_iter(self) -> Self::IntoIter { self.0.iter() }
}
impl<T> IntoIterator for LargeVec<T>
where
T: StrictEncode + StrictDecode,
{
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
fn into_iter(self) -> Self::IntoIter { self.0.into_iter() }
}
impl<T> StrictEncode for LargeVec<T>
where
T: StrictEncode + StrictDecode,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
let len = self.0.len();
if len > u32::MAX as usize {
return Err(Error::ExceedMaxItems(len));
}
let mut count = (len as u32).strict_encode(&mut e)?;
for el in &self.0 {
count += el.strict_encode(&mut e)?;
}
Ok(count)
}
}
impl<T> StrictDecode for LargeVec<T>
where
T: StrictDecode + StrictEncode,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let len = u32::strict_decode(&mut d)?;
let mut data = Vec::<T>::with_capacity(len as usize);
for _ in 0..len {
data.push(T::strict_decode(&mut d)?);
}
Ok(Self(data))
}
}
impl<T> LargeVec<T>
where
T: StrictEncode + StrictDecode,
{
pub fn new() -> Self { Self(vec![]) }
pub fn len_u32(&self) -> u32 { self.0.len() as u32 }
pub fn iter_mut(&mut self) -> std::slice::IterMut<T> { self.0.iter_mut() }
pub fn push(&mut self, item: T) -> Result<usize, Error> {
let len = self.0.len();
if len > u32::MAX as usize {
return Err(Error::ExceedMaxItems(len));
}
self.0.push(item);
Ok(len)
}
pub fn remove(&mut self, index: usize) -> T { self.0.remove(index) }
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(transparent)
)]
pub struct MediumVec<T>(Vec<T>)
where
T: StrictEncode + StrictDecode;
impl<T> Default for MediumVec<T>
where
T: StrictEncode + StrictDecode,
{
fn default() -> Self { Self(vec![]) }
}
impl<T> Deref for MediumVec<T>
where
T: StrictEncode + StrictDecode,
{
type Target = Vec<T>;
fn deref(&self) -> &Self::Target { &self.0 }
}
impl<T> TryFrom<Vec<T>> for MediumVec<T>
where
T: StrictEncode + StrictDecode,
{
type Error = Error;
fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
let len = value.len();
if len > u32::MAX as usize {
return Err(Error::ExceedMaxItems(len));
}
Ok(Self(value))
}
}
impl<'me, T> IntoIterator for &'me MediumVec<T>
where
T: StrictEncode + StrictDecode,
{
type Item = &'me T;
type IntoIter = std::slice::Iter<'me, T>;
fn into_iter(self) -> Self::IntoIter { self.0.iter() }
}
impl<T> IntoIterator for MediumVec<T>
where
T: StrictEncode + StrictDecode,
{
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
fn into_iter(self) -> Self::IntoIter { self.0.into_iter() }
}
impl<T> StrictEncode for MediumVec<T>
where
T: StrictEncode + StrictDecode,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
let len = self.0.len();
if len > u24::MAX.as_u32() as usize {
return Err(Error::ExceedMaxItems(len));
}
let mut count = u24::try_from(len as u32)
.expect("u32 Cmp is broken")
.strict_encode(&mut e)?;
for el in &self.0 {
count += el.strict_encode(&mut e)?;
}
Ok(count)
}
}
impl<T> StrictDecode for MediumVec<T>
where
T: StrictDecode + StrictEncode,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let len = u24::strict_decode(&mut d)?.as_u32() as usize;
let mut data = Vec::<T>::with_capacity(len);
for _ in 0..len {
data.push(T::strict_decode(&mut d)?);
}
Ok(Self(data))
}
}
impl<T> MediumVec<T>
where
T: StrictEncode + StrictDecode,
{
pub fn new() -> Self { Self(vec![]) }
pub fn len_u24(&self) -> u24 {
u24::try_from(self.0.len() as u32)
.expect("MediumVec inner size guarantees are broken")
}
pub fn iter_mut(&mut self) -> std::slice::IterMut<T> { self.0.iter_mut() }
pub fn push(&mut self, item: T) -> Result<usize, Error> {
let len = self.0.len();
if len > u24::MAX.as_u32() as usize {
return Err(Error::ExceedMaxItems(len));
}
self.0.push(item);
Ok(len)
}
pub fn remove(&mut self, index: usize) -> T { self.0.remove(index) }
}
impl<T> StrictEncode for Vec<T>
where
T: StrictEncode,
{
fn strict_encode<E: io::Write>(&self, e: E) -> Result<usize, Error> {
self.as_slice().strict_encode(e)
}
}
impl<T> StrictDecode for Vec<T>
where
T: StrictDecode,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let len = usize::strict_decode(&mut d)?;
let mut data = Vec::<T>::with_capacity(len);
for _ in 0..len {
data.push(T::strict_decode(&mut d)?);
}
Ok(data)
}
}
impl<T> StrictEncode for HashSet<T>
where
T: StrictEncode + Eq + Ord + Hash + Debug,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
let len = self.len();
let mut encoded = len.strict_encode(&mut e)?;
let mut vec: Vec<&T> = self.iter().collect();
vec.sort();
for item in vec {
encoded += item.strict_encode(&mut e)?;
}
Ok(encoded)
}
}
impl<T> StrictDecode for HashSet<T>
where
T: StrictDecode + Eq + Ord + Hash + Debug,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let len = usize::strict_decode(&mut d)?;
let mut data = HashSet::<T>::with_capacity(len);
for _ in 0..len {
let val = T::strict_decode(&mut d)?;
if data.contains(&val) {
return Err(Error::RepeatedValue(format!("{:?}", val)));
} else {
data.insert(val);
}
}
Ok(data)
}
}
impl<T> StrictEncode for BTreeSet<T>
where
T: StrictEncode + Eq + Ord + Debug,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
let len = self.len();
let mut encoded = len.strict_encode(&mut e)?;
let mut vec: Vec<&T> = self.iter().collect();
vec.sort();
for item in vec {
encoded += item.strict_encode(&mut e)?;
}
Ok(encoded)
}
}
impl<T> StrictDecode for BTreeSet<T>
where
T: StrictDecode + Eq + Ord + Debug,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let len = usize::strict_decode(&mut d)?;
let mut data = BTreeSet::<T>::new();
for _ in 0..len {
let val = T::strict_decode(&mut d)?;
if let Some(max) = data.iter().max() {
if max > &val {
return Err(Error::DataIntegrityError(format!(
"encoded values are not deterministically ordered: \
value `{:?}` should go before `{:?}`",
val, max
)));
}
}
if data.contains(&val) {
return Err(Error::RepeatedValue(format!("{:?}", val)));
}
data.insert(val);
}
Ok(data)
}
}
impl<T> StrictEncode for HashMap<usize, T>
where
T: StrictEncode + Clone,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
let ordered: BTreeMap<usize, T> =
self.iter().map(|(key, val)| (*key, val.clone())).collect();
ordered.strict_encode(&mut e)
}
}
impl<T> StrictDecode for HashMap<usize, T>
where
T: StrictDecode + Clone,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let map: HashMap<usize, T> =
BTreeMap::<usize, T>::strict_decode(&mut d)?
.iter()
.map(|(key, val)| (*key, val.clone()))
.collect();
Ok(map)
}
}
impl<K, V> StrictEncode for BTreeMap<K, V>
where
K: StrictEncode + Ord + Clone,
V: StrictEncode + Clone,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
let len = self.len();
let encoded = len.strict_encode(&mut e)?;
self.iter().try_fold(encoded, |mut acc, (key, val)| {
acc += key.strict_encode(&mut e)?;
acc += val.strict_encode(&mut e)?;
Ok(acc)
})
}
}
impl<K, V> StrictDecode for BTreeMap<K, V>
where
K: StrictDecode + Ord + Clone + Debug,
V: StrictDecode + Clone,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let len = usize::strict_decode(&mut d)?;
let mut map = BTreeMap::<K, V>::new();
for _ in 0..len {
let key = K::strict_decode(&mut d)?;
let val = V::strict_decode(&mut d)?;
if let Some(max) = map.keys().max() {
if max > &key {
return Err(Error::DataIntegrityError(format!(
"encoded values are not deterministically ordered: \
value `{:?}` should go before `{:?}`",
key, max
)));
}
}
if map.contains_key(&key) {
return Err(Error::RepeatedValue(format!("{:?}", key)));
}
map.insert(key, val);
}
Ok(map)
}
}
impl<K, V> StrictEncode for (K, V)
where
K: StrictEncode + Clone,
V: StrictEncode + Clone,
{
fn strict_encode<E: io::Write>(&self, mut e: E) -> Result<usize, Error> {
Ok(self.0.strict_encode(&mut e)? + self.1.strict_encode(&mut e)?)
}
}
impl<K, V> StrictDecode for (K, V)
where
K: StrictDecode + Clone,
V: StrictDecode + Clone,
{
fn strict_decode<D: io::Read>(mut d: D) -> Result<Self, Error> {
let a = K::strict_decode(&mut d)?;
let b = V::strict_decode(&mut d)?;
Ok((a, b))
}
}
#[cfg(test)]
pub mod test {
use super::*;
use crate::strict_serialize;
#[test]
fn test_option_encode_none() {
let o1: Option<u8> = None;
let o2: Option<u64> = None;
let two_zero_bytes = &vec![0u8][..];
assert_eq!(strict_serialize(&o1).unwrap(), two_zero_bytes);
assert_eq!(strict_serialize(&o2).unwrap(), two_zero_bytes);
assert_eq!(Option::<u8>::strict_decode(two_zero_bytes).unwrap(), None);
assert_eq!(Option::<u64>::strict_decode(two_zero_bytes).unwrap(), None);
}
#[test]
fn test_option_encode_some() {
let o1: Option<u8> = Some(0);
let o2: Option<u8> = Some(13);
let o3: Option<u8> = Some(0xFF);
let o4: Option<u64> = Some(13);
let o5: Option<u64> = Some(0x1FF);
let o6: Option<u64> = Some(0xFFFFFFFFFFFFFFFF);
let o7: Option<usize> = Some(13);
let o8: Option<usize> = Some(0xFFFFFFFFFFFFFFFF);
let byte_0 = &[1u8, 0u8][..];
let byte_13 = &[1u8, 13u8][..];
let byte_255 = &[1u8, 0xFFu8][..];
let word_13 = &[1u8, 13u8, 0u8][..];
let qword_13 = &[1u8, 13u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8][..];
let qword_256 =
&[1u8, 0xFFu8, 0x01u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8][..];
let qword_max = &[
1u8, 0xFFu8, 0xFFu8, 0xFFu8, 0xFFu8, 0xFFu8, 0xFFu8, 0xFFu8, 0xFFu8,
][..];
assert_eq!(strict_serialize(&o1).unwrap(), byte_0);
assert_eq!(strict_serialize(&o2).unwrap(), byte_13);
assert_eq!(strict_serialize(&o3).unwrap(), byte_255);
assert_eq!(strict_serialize(&o4).unwrap(), qword_13);
assert_eq!(strict_serialize(&o5).unwrap(), qword_256);
assert_eq!(strict_serialize(&o6).unwrap(), qword_max);
assert_eq!(strict_serialize(&o7).unwrap(), word_13);
assert!(strict_serialize(&o8).err().is_some());
assert_eq!(Option::<u8>::strict_decode(byte_0).unwrap(), Some(0));
assert_eq!(Option::<u8>::strict_decode(byte_13).unwrap(), Some(13));
assert_eq!(Option::<u8>::strict_decode(byte_255).unwrap(), Some(0xFF));
assert_eq!(Option::<u64>::strict_decode(qword_13).unwrap(), Some(13));
assert_eq!(
Option::<u64>::strict_decode(qword_256).unwrap(),
Some(0x1FF)
);
assert_eq!(
Option::<u64>::strict_decode(qword_max).unwrap(),
Some(0xFFFFFFFFFFFFFFFF)
);
assert_eq!(Option::<usize>::strict_decode(word_13).unwrap(), Some(13));
assert_eq!(
Option::<usize>::strict_decode(qword_max).unwrap(),
Some(0xFFFF)
);
}
#[test]
fn test_option_decode_vec() {
assert!(Option::<u8>::strict_decode(&[2u8, 0u8, 0u8, 0u8][..])
.err()
.is_some());
assert!(Option::<u8>::strict_decode(&[3u8, 0u8, 0u8, 0u8][..])
.err()
.is_some());
assert!(Option::<u8>::strict_decode(&[0xFFu8, 0u8, 0u8, 0u8][..])
.err()
.is_some());
}
#[test]
fn test_vec_encode() {
let v1: Vec<u8> = vec![0, 13, 0xFF];
let v2: Vec<u8> = vec![13];
let v3: Vec<u64> = vec![0, 13, 13, 0x1FF, 0xFFFFFFFFFFFFFFFF];
let v4: Vec<u8> =
(0..0x1FFFF).map(|item| (item % 0xFF) as u8).collect();
let s1 = &[3u8, 0u8, 0u8, 13u8, 0xFFu8][..];
let s2 = &[1u8, 0u8, 13u8][..];
let s3 = &[
5u8, 0u8, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 13, 0,
0, 0, 0, 0, 0, 0, 0xFF, 1, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
][..];
assert_eq!(strict_serialize(&v1).unwrap(), s1);
assert_eq!(strict_serialize(&v2).unwrap(), s2);
assert_eq!(strict_serialize(&v3).unwrap(), s3);
assert!(strict_serialize(&v4).err().is_some());
assert_eq!(Vec::<u8>::strict_decode(s1).unwrap(), v1);
assert_eq!(Vec::<u8>::strict_decode(s2).unwrap(), v2);
assert_eq!(Vec::<u64>::strict_decode(s3).unwrap(), v3);
}
}