moq_lite/coding/
reader.rs

1use std::{cmp, fmt::Debug, io, sync::Arc};
2
3use bytes::{Buf, Bytes, BytesMut};
4
5use crate::{coding::*, Error};
6
7pub struct Reader<S: web_transport_trait::RecvStream, V> {
8	stream: S,
9	buffer: BytesMut,
10	version: V,
11}
12
13impl<S: web_transport_trait::RecvStream, V> Reader<S, V> {
14	pub fn new(stream: S, version: V) -> Self {
15		Self {
16			stream,
17			buffer: Default::default(),
18			version,
19		}
20	}
21
22	pub async fn decode<T: Decode<V> + Debug>(&mut self) -> Result<T, Error>
23	where
24		V: Clone,
25	{
26		loop {
27			let mut cursor = io::Cursor::new(&self.buffer);
28			match T::decode(&mut cursor, self.version.clone()) {
29				Ok(msg) => {
30					self.buffer.advance(cursor.position() as usize);
31					return Ok(msg);
32				}
33				Err(DecodeError::Short) => {
34					// Try to read more data
35					if self
36						.stream
37						.read_buf(&mut self.buffer)
38						.await
39						.map_err(|e| Error::Transport(Arc::new(e)))?
40						.is_none()
41					{
42						// Stream closed while we still need more data
43						return Err(Error::Decode(DecodeError::Short));
44					}
45				}
46				Err(e) => return Err(Error::Decode(e)),
47			}
48		}
49	}
50
51	// Decode optional messages at the end of a stream
52	pub async fn decode_maybe<T: Decode<V> + Debug>(&mut self) -> Result<Option<T>, Error>
53	where
54		V: Clone,
55	{
56		match self.closed().await {
57			Ok(()) => Ok(None),
58			Err(Error::Decode(DecodeError::ExpectedEnd)) => Ok(Some(self.decode().await?)),
59			Err(e) => Err(e),
60		}
61	}
62
63	pub async fn decode_peek<T: Decode<V> + Debug>(&mut self) -> Result<T, Error>
64	where
65		V: Clone,
66	{
67		loop {
68			let mut cursor = io::Cursor::new(&self.buffer);
69			match T::decode(&mut cursor, self.version.clone()) {
70				Ok(msg) => return Ok(msg),
71				Err(DecodeError::Short) => {
72					// Try to read more data
73					if self
74						.stream
75						.read_buf(&mut self.buffer)
76						.await
77						.map_err(|e| Error::Transport(Arc::new(e)))?
78						.is_none()
79					{
80						// Stream closed while we still need more data
81						return Err(Error::Decode(DecodeError::Short));
82					}
83				}
84				Err(e) => return Err(Error::Decode(e)),
85			}
86		}
87	}
88
89	// Returns a non-zero chunk of data, or None if the stream is closed
90	pub async fn read(&mut self, max: usize) -> Result<Option<Bytes>, Error> {
91		if !self.buffer.is_empty() {
92			let size = cmp::min(max, self.buffer.len());
93			let data = self.buffer.split_to(size).freeze();
94			return Ok(Some(data));
95		}
96
97		self.stream
98			.read_chunk(max)
99			.await
100			.map_err(|e| Error::Transport(Arc::new(e)))
101	}
102
103	pub async fn read_exact(&mut self, size: usize) -> Result<Bytes, Error> {
104		// An optimization to avoid a copy if we have enough data in the buffer
105		if self.buffer.len() >= size {
106			return Ok(self.buffer.split_to(size).freeze());
107		}
108
109		let data = BytesMut::with_capacity(size.min(u16::MAX as usize));
110		let mut buf = data.limit(size);
111
112		let size = cmp::min(buf.remaining_mut(), self.buffer.len());
113		let data = self.buffer.split_to(size);
114		buf.put(data);
115
116		while buf.has_remaining_mut() {
117			self.stream
118				.read_buf(&mut buf)
119				.await
120				.map_err(|e| Error::Transport(Arc::new(e)))?;
121		}
122
123		Ok(buf.into_inner().freeze())
124	}
125
126	pub async fn skip(&mut self, mut size: usize) -> Result<(), Error> {
127		let buffered = self.buffer.len().min(size);
128		self.buffer.advance(buffered);
129		size -= buffered;
130
131		while size > 0 {
132			let chunk = self
133				.stream
134				.read_chunk(size)
135				.await
136				.map_err(|e| Error::Transport(Arc::new(e)))?
137				.ok_or(Error::Decode(DecodeError::Short))?;
138			size -= chunk.len();
139		}
140
141		Ok(())
142	}
143
144	/// Wait until the stream is closed, erroring if there are any additional bytes.
145	pub async fn closed(&mut self) -> Result<(), Error> {
146		if self.buffer.is_empty()
147			&& self
148				.stream
149				.read_buf(&mut self.buffer)
150				.await
151				.map_err(|e| Error::Transport(Arc::new(e)))?
152				.is_none()
153		{
154			return Ok(());
155		}
156
157		Err(DecodeError::ExpectedEnd.into())
158	}
159
160	pub fn abort(&mut self, err: &Error) {
161		self.stream.stop(err.to_code());
162	}
163
164	pub fn with_version<O>(self, version: O) -> Reader<S, O> {
165		Reader {
166			stream: self.stream,
167			buffer: self.buffer,
168			version,
169		}
170	}
171}