use std::{borrow::Cow, sync::Arc};
use bytes::{Bytes, BytesMut};
use super::BoundsExceeded;
#[derive(thiserror::Error, Debug, Clone)]
#[non_exhaustive]
pub enum EncodeError {
#[error("bounds exceeded")]
BoundsExceeded,
#[error("too large")]
TooLarge,
#[error("short buffer")]
Short,
#[error("invalid state")]
InvalidState,
#[error("too many")]
TooMany,
#[error("unsupported version")]
Version,
}
impl From<BoundsExceeded> for EncodeError {
fn from(_: BoundsExceeded) -> Self {
Self::BoundsExceeded
}
}
fn check_remaining(w: &impl bytes::BufMut, needed: usize) -> Result<(), EncodeError> {
if w.remaining_mut() < needed {
return Err(EncodeError::Short);
}
Ok(())
}
pub trait Encode<V>: Sized {
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError>;
fn encode_bytes(&self, v: V) -> Result<Bytes, EncodeError> {
let mut buf = BytesMut::new();
self.encode(&mut buf, v)?;
Ok(buf.freeze())
}
}
impl<V> Encode<V> for bool {
fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
check_remaining(&*w, 1)?;
w.put_u8(*self as u8);
Ok(())
}
}
impl<V> Encode<V> for u8 {
fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
check_remaining(&*w, 1)?;
w.put_u8(*self);
Ok(())
}
}
impl<V> Encode<V> for u16 {
fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
check_remaining(&*w, 2)?;
w.put_u16(*self);
Ok(())
}
}
impl<V: Copy> Encode<V> for String
where
usize: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
self.as_str().encode(w, version)
}
}
impl<V: Copy> Encode<V> for &str
where
usize: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
self.len().encode(w, version)?;
check_remaining(&*w, self.len())?;
w.put(self.as_bytes());
Ok(())
}
}
impl<V> Encode<V> for i8 {
fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
check_remaining(&*w, 1)?;
w.put_u8(((*self as i16) + 128) as u8);
Ok(())
}
}
impl<V: Copy, T: Encode<V>> Encode<V> for &[T]
where
usize: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
self.len().encode(w, version)?;
for item in self.iter() {
item.encode(w, version)?;
}
Ok(())
}
}
impl<V: Copy> Encode<V> for Vec<u8>
where
usize: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
self.len().encode(w, version)?;
check_remaining(&*w, self.len())?;
w.put_slice(self);
Ok(())
}
}
impl<V: Copy> Encode<V> for bytes::Bytes
where
usize: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
self.len().encode(w, version)?;
check_remaining(&*w, self.len())?;
w.put_slice(self);
Ok(())
}
}
impl<T: Encode<V>, V> Encode<V> for Arc<T> {
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
(**self).encode(w, version)
}
}
impl<V: Copy> Encode<V> for Cow<'_, str>
where
usize: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
self.len().encode(w, version)?;
check_remaining(&*w, self.len())?;
w.put(self.as_bytes());
Ok(())
}
}
impl<V: Copy> Encode<V> for Option<u64>
where
u64: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
match self {
Some(value) => value.checked_add(1).ok_or(EncodeError::TooLarge)?.encode(w, version),
None => 0u64.encode(w, version),
}
}
}
impl<V: Copy> Encode<V> for std::time::Duration
where
super::VarInt: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
let ms = super::VarInt::try_from(self.as_millis())?;
ms.encode(w, version)
}
}