moq_lite/coding/
reader.rs

1use std::{cmp, fmt::Debug, io, sync::Arc};
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::{Error, coding::*};
6
7/// A reader for decoding messages from a stream.
8pub struct Reader<S: web_transport_trait::RecvStream, V> {
9	stream: S,
10	buffer: BytesMut,
11	version: V,
12}
13
14impl<S: web_transport_trait::RecvStream, V> Reader<S, V> {
15	pub fn new(stream: S, version: V) -> Self {
16		Self {
17			stream,
18			buffer: Default::default(),
19			version,
20		}
21	}
22
23	/// Decode the next message from the stream.
24	pub async fn decode<T: Decode<V> + Debug>(&mut self) -> Result<T, Error>
25	where
26		V: Clone,
27	{
28		loop {
29			let mut cursor = io::Cursor::new(&self.buffer);
30			match T::decode(&mut cursor, self.version.clone()) {
31				Ok(msg) => {
32					self.buffer.advance(cursor.position() as usize);
33					return Ok(msg);
34				}
35				Err(DecodeError::Short) => {
36					// Try to read more data
37					if self
38						.stream
39						.read_buf(&mut self.buffer)
40						.await
41						.map_err(|e| Error::Transport(Arc::new(e)))?
42						.is_none()
43					{
44						// Stream closed while we still need more data
45						return Err(Error::Decode(DecodeError::Short));
46					}
47				}
48				Err(e) => return Err(Error::Decode(e)),
49			}
50		}
51	}
52
53	/// Decode the next message unless the stream is closed.
54	pub async fn decode_maybe<T: Decode<V> + Debug>(&mut self) -> Result<Option<T>, Error>
55	where
56		V: Clone,
57	{
58		match self.closed().await {
59			Ok(()) => Ok(None),
60			Err(Error::Decode(DecodeError::ExpectedEnd)) => Ok(Some(self.decode().await?)),
61			Err(e) => Err(e),
62		}
63	}
64
65	/// Decode the next message from the stream without consuming it.
66	pub async fn decode_peek<T: Decode<V> + Debug>(&mut self) -> Result<T, Error>
67	where
68		V: Clone,
69	{
70		loop {
71			let mut cursor = io::Cursor::new(&self.buffer);
72			match T::decode(&mut cursor, self.version.clone()) {
73				Ok(msg) => return Ok(msg),
74				Err(DecodeError::Short) => {
75					// Try to read more data
76					if self
77						.stream
78						.read_buf(&mut self.buffer)
79						.await
80						.map_err(|e| Error::Transport(Arc::new(e)))?
81						.is_none()
82					{
83						// Stream closed while we still need more data
84						return Err(Error::Decode(DecodeError::Short));
85					}
86				}
87				Err(e) => return Err(Error::Decode(e)),
88			}
89		}
90	}
91
92	/// Returns a non-zero chunk of data, or None if the stream is closed
93	pub async fn read(&mut self, max: usize) -> Result<Option<Bytes>, Error> {
94		if !self.buffer.is_empty() {
95			let size = cmp::min(max, self.buffer.len());
96			let data = self.buffer.split_to(size).freeze();
97			return Ok(Some(data));
98		}
99
100		self.stream
101			.read_chunk(max)
102			.await
103			.map_err(|e| Error::Transport(Arc::new(e)))
104	}
105
106	/// Read exactly the given number of bytes from the stream.
107	pub async fn read_exact(&mut self, size: usize) -> Result<Bytes, Error> {
108		// An optimization to avoid a copy if we have enough data in the buffer
109		if self.buffer.len() >= size {
110			return Ok(self.buffer.split_to(size).freeze());
111		}
112
113		let data = BytesMut::with_capacity(size.min(u16::MAX as usize));
114		let mut buf = data.limit(size);
115
116		let size = cmp::min(buf.remaining_mut(), self.buffer.len());
117		let data = self.buffer.split_to(size);
118		buf.put(data);
119
120		while buf.has_remaining_mut() {
121			self.stream
122				.read_buf(&mut buf)
123				.await
124				.map_err(|e| Error::Transport(Arc::new(e)))?;
125		}
126
127		Ok(buf.into_inner().freeze())
128	}
129
130	/// Skip the given number of bytes from the stream.
131	pub async fn skip(&mut self, mut size: usize) -> Result<(), Error> {
132		let buffered = self.buffer.len().min(size);
133		self.buffer.advance(buffered);
134		size -= buffered;
135
136		while size > 0 {
137			let chunk = self
138				.stream
139				.read_chunk(size)
140				.await
141				.map_err(|e| Error::Transport(Arc::new(e)))?
142				.ok_or(Error::Decode(DecodeError::Short))?;
143			size -= chunk.len();
144		}
145
146		Ok(())
147	}
148
149	/// Wait until the stream is closed, erroring if there are any additional bytes.
150	pub async fn closed(&mut self) -> Result<(), Error> {
151		if self.buffer.is_empty()
152			&& self
153				.stream
154				.read_buf(&mut self.buffer)
155				.await
156				.map_err(|e| Error::Transport(Arc::new(e)))?
157				.is_none()
158		{
159			return Ok(());
160		}
161
162		Err(DecodeError::ExpectedEnd.into())
163	}
164
165	/// Abort the stream with the given error.
166	pub fn abort(&mut self, err: &Error) {
167		self.stream.stop(err.to_code());
168	}
169
170	/// Cast the reader to a different version, used during version negotiation.
171	pub fn with_version<O>(self, version: O) -> Reader<S, O> {
172		Reader {
173			stream: self.stream,
174			buffer: self.buffer,
175			version,
176		}
177	}
178}