use crate::{Error, checked::CheckedSum, writer::Writer};
use core::str;
#[cfg(feature = "alloc")]
use alloc::{string::String, vec::Vec};
#[cfg(feature = "bytes")]
use bytes::{Bytes, BytesMut};
pub trait Encode {
fn encoded_len(&self) -> Result<usize, Error>;
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error>;
fn encoded_len_prefixed(&self) -> Result<usize, Error> {
[4, self.encoded_len()?].checked_sum()
}
fn encode_prefixed(&self, writer: &mut impl Writer) -> Result<(), Error> {
self.encoded_len()?.encode(writer)?;
self.encode(writer)
}
#[cfg(feature = "alloc")]
fn encode_vec(&self) -> Result<Vec<u8>, Error> {
let mut ret = Vec::with_capacity(self.encoded_len()?);
self.encode(&mut ret)?;
Ok(ret)
}
#[cfg(feature = "bytes")]
fn encode_bytes(&self) -> Result<BytesMut, Error> {
let mut ret = BytesMut::with_capacity(self.encoded_len()?);
self.encode(&mut ret)?;
Ok(ret)
}
}
impl Encode for u8 {
fn encoded_len(&self) -> Result<usize, Error> {
Ok(1)
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
writer.write(&[*self])
}
}
impl Encode for bool {
fn encoded_len(&self) -> Result<usize, Error> {
Ok(1)
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
if *self {
1u8.encode(writer)
} else {
0u8.encode(writer)
}
}
}
impl Encode for u32 {
fn encoded_len(&self) -> Result<usize, Error> {
Ok(4)
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
writer.write(&self.to_be_bytes())
}
}
impl Encode for u64 {
fn encoded_len(&self) -> Result<usize, Error> {
Ok(8)
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
writer.write(&self.to_be_bytes())
}
}
impl Encode for usize {
fn encoded_len(&self) -> Result<usize, Error> {
Ok(4)
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
u32::try_from(*self)?.encode(writer)
}
}
impl Encode for [u8] {
fn encoded_len(&self) -> Result<usize, Error> {
[4, self.len()].checked_sum()
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
self.len().encode(writer)?;
writer.write(self)
}
}
impl<const N: usize> Encode for [u8; N] {
fn encoded_len(&self) -> Result<usize, Error> {
Ok(N)
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
writer.write(self)
}
}
macro_rules! impl_by_delegation {
(
$(
$(#[$attr:meta])*
impl $( ($($generics:tt)+) )? Encode for $type:ty where $self:ident -> $delegate:expr;
)+
) => {
$(
$(#[$attr])*
impl $(< $($generics)* >)? Encode for $type {
fn encoded_len(&$self) -> Result<usize, Error> {
$delegate.encoded_len()
}
fn encode(&$self, writer: &mut impl Writer) -> Result<(), Error> {
$delegate.encode(writer)
}
}
)+
};
}
impl_by_delegation!(
impl Encode for str where self -> self.as_bytes();
#[cfg(feature = "alloc")]
impl Encode for Vec<u8> where self -> self.as_slice();
#[cfg(feature = "alloc")]
impl Encode for String where self -> self.as_bytes();
#[cfg(feature = "bytes")]
impl Encode for Bytes where self -> self.as_ref();
impl Encode for &str where self -> **self;
impl Encode for &[u8] where self -> **self;
#[cfg(feature = "alloc")]
impl Encode for &Vec<u8> where self -> **self;
#[cfg(feature = "alloc")]
impl Encode for &String where self -> **self;
#[cfg(feature = "bytes")]
impl Encode for &Bytes where self -> **self;
);
pub trait Rfc4251String: Encode {}
impl Rfc4251String for str {}
impl Rfc4251String for [u8] {}
#[cfg(feature = "alloc")]
impl Rfc4251String for String {}
#[cfg(feature = "alloc")]
impl Rfc4251String for Vec<u8> {}
#[cfg(feature = "bytes")]
impl Rfc4251String for Bytes {}
impl<'a, T> Rfc4251String for &'a T
where
T: Rfc4251String + ?Sized,
&'a T: Encode,
{
}
impl<T: Rfc4251String> Encode for [T] {
fn encoded_len(&self) -> Result<usize, Error> {
self.iter().try_fold(4usize, |acc, string| {
acc.checked_add(string.encoded_len()?).ok_or(Error::Length)
})
}
fn encode(&self, writer: &mut impl Writer) -> Result<(), Error> {
self.encoded_len()?
.checked_sub(4)
.ok_or(Error::Length)?
.encode(writer)?;
self.iter().try_fold((), |(), entry| entry.encode(writer))
}
}
impl_by_delegation!(
#[cfg(feature = "alloc")]
impl (T: Rfc4251String) Encode for Vec<T> where self -> self.as_slice();
);