Skip to main content

moq_lite/coding/
encode.rs

1use std::{borrow::Cow, sync::Arc};
2
3use bytes::{Bytes, BytesMut};
4
5use super::BoundsExceeded;
6
7/// An error that occurs during encoding.
8#[derive(thiserror::Error, Debug, Clone)]
9#[non_exhaustive]
10pub enum EncodeError {
11	#[error("bounds exceeded")]
12	BoundsExceeded,
13	#[error("too large")]
14	TooLarge,
15	#[error("short buffer")]
16	Short,
17	#[error("invalid state")]
18	InvalidState,
19	#[error("too many")]
20	TooMany,
21	#[error("unsupported version")]
22	Version,
23}
24
25impl From<BoundsExceeded> for EncodeError {
26	fn from(_: BoundsExceeded) -> Self {
27		Self::BoundsExceeded
28	}
29}
30
31/// Check that the writer has enough remaining capacity.
32fn check_remaining(w: &impl bytes::BufMut, needed: usize) -> Result<(), EncodeError> {
33	if w.remaining_mut() < needed {
34		return Err(EncodeError::Short);
35	}
36	Ok(())
37}
38
39/// Write the value to the buffer using the given version.
40pub trait Encode<V>: Sized {
41	/// Encode the value to the given writer.
42	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError>;
43
44	/// Encode the value into a [Bytes] buffer.
45	///
46	/// NOTE: This will allocate.
47	fn encode_bytes(&self, v: V) -> Result<Bytes, EncodeError> {
48		let mut buf = BytesMut::new();
49		self.encode(&mut buf, v)?;
50		Ok(buf.freeze())
51	}
52}
53
54impl<V> Encode<V> for bool {
55	fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
56		check_remaining(&*w, 1)?;
57		w.put_u8(*self as u8);
58		Ok(())
59	}
60}
61
62impl<V> Encode<V> for u8 {
63	fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
64		check_remaining(&*w, 1)?;
65		w.put_u8(*self);
66		Ok(())
67	}
68}
69
70impl<V> Encode<V> for u16 {
71	fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
72		check_remaining(&*w, 2)?;
73		w.put_u16(*self);
74		Ok(())
75	}
76}
77
78impl<V: Copy> Encode<V> for String
79where
80	usize: Encode<V>,
81{
82	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
83		self.as_str().encode(w, version)
84	}
85}
86
87impl<V: Copy> Encode<V> for &str
88where
89	usize: Encode<V>,
90{
91	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
92		self.len().encode(w, version)?;
93		check_remaining(&*w, self.len())?;
94		w.put(self.as_bytes());
95		Ok(())
96	}
97}
98
99impl<V> Encode<V> for i8 {
100	fn encode<W: bytes::BufMut>(&self, w: &mut W, _: V) -> Result<(), EncodeError> {
101		// This is not the usual way of encoding negative numbers.
102		// i8 doesn't exist in the draft, but we use it instead of u8 for priority.
103		// A default of 0 is more ergonomic for the user than a default of 128.
104		check_remaining(&*w, 1)?;
105		w.put_u8(((*self as i16) + 128) as u8);
106		Ok(())
107	}
108}
109
110impl<V: Copy, T: Encode<V>> Encode<V> for &[T]
111where
112	usize: Encode<V>,
113{
114	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
115		self.len().encode(w, version)?;
116		for item in self.iter() {
117			item.encode(w, version)?;
118		}
119		Ok(())
120	}
121}
122
123impl<V: Copy> Encode<V> for Vec<u8>
124where
125	usize: Encode<V>,
126{
127	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
128		self.len().encode(w, version)?;
129		check_remaining(&*w, self.len())?;
130		w.put_slice(self);
131		Ok(())
132	}
133}
134
135impl<V: Copy> Encode<V> for bytes::Bytes
136where
137	usize: Encode<V>,
138{
139	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
140		self.len().encode(w, version)?;
141		check_remaining(&*w, self.len())?;
142		w.put_slice(self);
143		Ok(())
144	}
145}
146
147impl<T: Encode<V>, V> Encode<V> for Arc<T> {
148	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
149		(**self).encode(w, version)
150	}
151}
152
153impl<V: Copy> Encode<V> for Cow<'_, str>
154where
155	usize: Encode<V>,
156{
157	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
158		self.len().encode(w, version)?;
159		check_remaining(&*w, self.len())?;
160		w.put(self.as_bytes());
161		Ok(())
162	}
163}
164
165impl<V: Copy> Encode<V> for Option<u64>
166where
167	u64: Encode<V>,
168{
169	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
170		match self {
171			Some(value) => value.checked_add(1).ok_or(EncodeError::TooLarge)?.encode(w, version),
172			None => 0u64.encode(w, version),
173		}
174	}
175}
176
177impl<V: Copy> Encode<V> for std::time::Duration
178where
179	super::VarInt: Encode<V>,
180{
181	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
182		let ms = super::VarInt::try_from(self.as_millis())?;
183		ms.encode(w, version)
184	}
185}