use std::marker::PhantomData;
use educe::Educe;
use crate::EncodeError;
use crate::EncodeResult;
use crate::Writeable;
use crate::WriteableOnce;
pub trait Writer {
fn write_all(&mut self, b: &[u8]);
fn write_u8(&mut self, x: u8) {
self.write_all(&[x]);
}
fn write_u16(&mut self, x: u16) {
self.write_all(&x.to_be_bytes());
}
fn write_u32(&mut self, x: u32) {
self.write_all(&x.to_be_bytes());
}
fn write_u64(&mut self, x: u64) {
self.write_all(&x.to_be_bytes());
}
fn write_u128(&mut self, x: u128) {
self.write_all(&x.to_be_bytes());
}
fn write_zeros(&mut self, n: usize) {
let v = vec![0_u8; n];
self.write_all(&v[..]);
}
fn write<E: Writeable + ?Sized>(&mut self, e: &E) -> EncodeResult<()> {
e.write_onto(self)
}
fn write_and_consume<E: WriteableOnce>(&mut self, e: E) -> EncodeResult<()> {
e.write_into(self)
}
fn write_nested_u8len(&mut self) -> NestedWriter<'_, Self, u8> {
write_nested_generic(self)
}
fn write_nested_u16len(&mut self) -> NestedWriter<'_, Self, u16> {
write_nested_generic(self)
}
fn write_nested_u32len(&mut self) -> NestedWriter<'_, Self, u32> {
write_nested_generic(self)
}
}
#[derive(Educe)]
#[educe(Deref, DerefMut)]
pub struct NestedWriter<'w, W, L>
where
W: ?Sized,
{
length_type: PhantomData<*mut L>,
outer: &'w mut W,
#[educe(Deref, DerefMut)]
inner: Vec<u8>,
}
fn write_nested_generic<W, L>(w: &mut W) -> NestedWriter<W, L>
where
W: Writer + ?Sized,
L: Default + Copy + Sized + Writeable + TryFrom<usize>,
{
NestedWriter {
length_type: PhantomData,
outer: w,
inner: vec![],
}
}
impl<'w, W, L> NestedWriter<'w, W, L>
where
W: Writer + ?Sized,
L: Default + Copy + Sized + Writeable + TryFrom<usize> + std::ops::Not<Output = L>,
{
pub fn finish(self) -> Result<(), EncodeError> {
let length = self.inner.len();
let length: L = length.try_into().map_err(|_| EncodeError::BadLengthValue)?;
self.outer.write(&length)?;
self.outer.write(&self.inner)?;
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn write_ints() {
let mut b = bytes::BytesMut::new();
b.write_u8(1);
b.write_u16(2);
b.write_u32(3);
b.write_u64(4);
b.write_u128(5);
assert_eq!(
&b[..],
&[
1, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 5
]
);
}
#[test]
fn write_slice() {
let mut v = Vec::new();
v.write_u16(0x5468);
v.write(&b"ey're good dogs, Bront"[..]).unwrap();
assert_eq!(&v[..], &b"They're good dogs, Bront"[..]);
}
#[test]
fn writeable() -> EncodeResult<()> {
struct Sequence(u8);
impl Writeable for Sequence {
fn write_onto<B: Writer + ?Sized>(&self, b: &mut B) -> EncodeResult<()> {
for i in 0..self.0 {
b.write_u8(i);
}
Ok(())
}
}
let mut v = Vec::new();
v.write(&Sequence(6))?;
assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5]);
v.write_and_consume(Sequence(3))?;
assert_eq!(&v[..], &[0, 1, 2, 3, 4, 5, 0, 1, 2]);
Ok(())
}
#[test]
fn nested() {
let mut v: Vec<u8> = b"abc".to_vec();
let mut w = v.write_nested_u8len();
w.write_u8(b'x');
w.finish().unwrap();
let mut w = v.write_nested_u16len();
w.write_u8(b'y');
w.finish().unwrap();
let mut w = v.write_nested_u32len();
w.write_u8(b'z');
w.finish().unwrap();
assert_eq!(&v, b"abc\x01x\0\x01y\0\0\0\x01z");
let mut w = v.write_nested_u8len();
w.write_zeros(256);
assert!(matches!(
w.finish().err().unwrap(),
EncodeError::BadLengthValue
));
}
}