use chrono::{DateTime, Duration, Utc};
use super::{Error, Result};
use std::io::{Read, Write};
pub trait SiaEncodable {
fn encoded_length(&self) -> usize;
fn encode<W: Write>(&self, w: &mut W) -> Result<()>;
}
pub trait SiaDecodable: Sized {
fn decode<R: Read>(r: &mut R) -> Result<Self>;
}
impl SiaEncodable for u8 {
fn encoded_length(&self) -> usize {
1
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
w.write_all(&[*self])?;
Ok(())
}
}
impl SiaDecodable for u8 {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let mut buf = [0; 1];
r.read_exact(&mut buf)?;
Ok(buf[0])
}
}
impl SiaEncodable for bool {
fn encoded_length(&self) -> usize {
1
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
(*self as u8).encode(w)
}
}
impl SiaDecodable for bool {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let v = u8::decode(r)?;
match v {
0 => Ok(false),
1 => Ok(true),
_ => Err(Error::InvalidValue("requires 0 or 1".into())),
}
}
}
impl SiaEncodable for DateTime<Utc> {
fn encoded_length(&self) -> usize {
8
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
self.timestamp().encode(w)
}
}
impl SiaDecodable for DateTime<Utc> {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let timestamp = i64::decode(r)?;
DateTime::from_timestamp_secs(timestamp)
.ok_or_else(|| Error::InvalidValue(format!("invalid timestamp: {timestamp}")))
}
}
impl SiaEncodable for Duration {
fn encoded_length(&self) -> usize {
8
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
self.num_nanoseconds()
.ok_or_else(|| Error::InvalidValue("duration too large".into()))?
.encode(w)
}
}
impl SiaDecodable for Duration {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let ns = u64::decode(r)?;
if ns > i64::MAX as u64 {
return Err(Error::InvalidValue(format!(
"duration {ns} must be less than {}",
i64::MAX
)));
}
Ok(Duration::nanoseconds(ns as i64))
}
}
impl<T: SiaEncodable> SiaEncodable for [T] {
fn encoded_length(&self) -> usize {
let mut len = 0;
len += self.len().encoded_length();
for item in self {
len += item.encoded_length();
}
len
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
self.len().encode(w)?;
for item in self {
item.encode(w)?;
}
Ok(())
}
}
impl<T: SiaEncodable> SiaEncodable for Option<T> {
fn encoded_length(&self) -> usize {
1 + match self {
Some(v) => v.encoded_length(),
None => 0,
}
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
match self {
Some(v) => {
true.encode(w)?;
v.encode(w)
}
None => false.encode(w),
}
}
}
impl<T: SiaDecodable> SiaDecodable for Option<T> {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
match bool::decode(r)? {
true => Ok(Some(T::decode(r)?)),
false => Ok(None),
}
}
}
macro_rules! impl_sia_numeric {
($($t:ty),*) => {
$(
impl SiaEncodable for $t {
fn encoded_length(&self) -> usize {
8
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
w.write_all(&(*self as u64).to_le_bytes())?;
Ok(())
}
}
impl SiaDecodable for $t {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let mut buf = [0u8; 8];
r.read_exact(&mut buf)?;
Ok(u64::from_le_bytes(buf) as Self)
}
}
)*
}
}
impl_sia_numeric!(u16, u32, usize, i16, i32, i64, u64);
impl<T> SiaEncodable for Vec<T>
where
T: SiaEncodable,
{
fn encoded_length(&self) -> usize {
let mut len = 0;
len += self.len().encoded_length();
for item in self {
len += item.encoded_length();
}
len
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
self.len().encode(w)?;
for item in self {
item.encode(w)?;
}
Ok(())
}
}
impl<T> SiaDecodable for Vec<T>
where
T: SiaDecodable,
{
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let mut vec = Vec::new();
for _ in 0..usize::decode(r)? {
vec.push(T::decode(r)?);
}
Ok(vec)
}
}
impl SiaEncodable for String {
fn encoded_length(&self) -> usize {
self.as_bytes().encoded_length()
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
self.as_bytes().encode(w)
}
}
impl SiaDecodable for String {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let buf = Vec::<u8>::decode(r)?;
String::from_utf8(buf).map_err(|e| Error::InvalidValue(e.to_string()))
}
}
impl SiaEncodable for bytes::Bytes {
fn encoded_length(&self) -> usize {
8 + self.len()
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
(self.len() as u64).encode(w)?;
w.write_all(self)?;
Ok(())
}
}
impl SiaDecodable for bytes::Bytes {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let len = u64::decode(r)? as usize;
let mut buf = vec![0u8; len];
r.read_exact(&mut buf)?;
Ok(bytes::Bytes::from(buf))
}
}
impl<const N: usize> SiaEncodable for [u8; N] {
fn encoded_length(&self) -> usize {
N
}
fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
w.write_all(self)?;
Ok(())
}
}
impl<const N: usize> SiaDecodable for [u8; N] {
fn decode<R: Read>(r: &mut R) -> Result<Self> {
let mut arr = [0u8; N];
r.read_exact(&mut arr)?;
Ok(arr)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_roundtrip<T: SiaEncodable + SiaDecodable + std::fmt::Debug + PartialEq>(
value: T,
expected_bytes: Vec<u8>,
) {
let mut encoded_bytes = Vec::new();
value
.encode(&mut encoded_bytes)
.unwrap_or_else(|e| panic!("failed to encode: {e:?}"));
assert_eq!(
encoded_bytes, expected_bytes,
"encoding mismatch for {value:?}"
);
let mut bytes = &expected_bytes[..];
let decoded = T::decode(&mut bytes).unwrap_or_else(|e| panic!("failed to decode: {e:?}"));
assert_eq!(decoded, value, "decoding mismatch for {value:?}");
assert_eq!(bytes.len(), 0, "leftover bytes for {value:?}");
}
#[test]
fn test_numerics() {
test_roundtrip(1u8, vec![1]);
test_roundtrip(2u16, vec![2, 0, 0, 0, 0, 0, 0, 0]);
test_roundtrip(3u32, vec![3, 0, 0, 0, 0, 0, 0, 0]);
test_roundtrip(4u64, vec![4, 0, 0, 0, 0, 0, 0, 0]);
test_roundtrip(5usize, vec![5, 0, 0, 0, 0, 0, 0, 0]);
test_roundtrip(-1i16, vec![255, 255, 255, 255, 255, 255, 255, 255]);
test_roundtrip(-2i32, vec![254, 255, 255, 255, 255, 255, 255, 255]);
test_roundtrip(-3i64, vec![253, 255, 255, 255, 255, 255, 255, 255]);
}
#[test]
fn test_strings() {
test_roundtrip(
"hello".to_string(),
vec![
5, 0, 0, 0, 0, 0, 0, 0, 104, 101, 108, 108, 111, ],
);
test_roundtrip(
"".to_string(),
vec![0, 0, 0, 0, 0, 0, 0, 0], );
}
#[test]
fn test_fixed_arrays() {
test_roundtrip([1u8, 2u8, 3u8], vec![1, 2, 3]);
test_roundtrip([0u8; 4], vec![0, 0, 0, 0]);
}
#[test]
fn test_vectors() {
test_roundtrip(
vec![1u8, 2u8, 3u8],
vec![
3, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, ],
);
test_roundtrip(
vec![100u64, 200u64],
vec![
2, 0, 0, 0, 0, 0, 0, 0, 100, 0, 0, 0, 0, 0, 0, 0, 200, 0, 0, 0, 0, 0, 0, 0, ],
);
test_roundtrip(
vec!["a".to_string(), "bc".to_string()],
vec![
2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 97, 2, 0, 0, 0, 0, 0, 0, 0, 98, 99, ],
);
}
#[test]
fn test_nested() {
test_roundtrip(
vec![vec![1u8, 2u8], vec![3u8, 4u8]],
vec![
2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 3, 4, ],
);
}
}