Skip to main content

moq_lite/coding/
writer.rs

1use std::fmt::Debug;
2
3use crate::{Error, coding::*, ietf};
4
5/// A wrapper around a [web_transport_trait::SendStream] that will reset on Drop.
6pub struct Writer<S: web_transport_trait::SendStream, V> {
7	stream: Option<S>,
8	buffer: bytes::BytesMut,
9	version: V,
10}
11
12impl<S: web_transport_trait::SendStream, V> Writer<S, V> {
13	/// Create a new writer for the given stream and version.
14	pub fn new(stream: S, version: V) -> Self {
15		Self {
16			stream: Some(stream),
17			buffer: Default::default(),
18			version,
19		}
20	}
21
22	/// Encode the given message to the stream.
23	pub async fn encode<T: Encode<V> + Debug>(&mut self, msg: &T) -> Result<(), Error>
24	where
25		V: Clone,
26	{
27		self.buffer.clear();
28		msg.encode(&mut self.buffer, self.version.clone())?;
29
30		while !self.buffer.is_empty() {
31			self.stream
32				.as_mut()
33				.unwrap()
34				.write_buf(&mut self.buffer)
35				.await
36				.map_err(Error::from_transport)?;
37		}
38
39		Ok(())
40	}
41
42	// Not public to avoid accidental partial writes.
43	async fn write<Buf: bytes::Buf + Send>(&mut self, buf: &mut Buf) -> Result<usize, Error> {
44		self.stream
45			.as_mut()
46			.unwrap()
47			.write_buf(buf)
48			.await
49			.map_err(Error::from_transport)
50	}
51
52	/// Write the entire `Buf` to the stream.
53	///
54	/// NOTE: This can avoid performing a copy when using `Bytes`.
55	pub async fn write_all<Buf: bytes::Buf + Send>(&mut self, buf: &mut Buf) -> Result<(), Error> {
56		while buf.has_remaining() {
57			self.write(buf).await?;
58		}
59		Ok(())
60	}
61
62	/// Mark the stream as finished.
63	pub fn finish(&mut self) -> Result<(), Error> {
64		self.stream.as_mut().unwrap().finish().map_err(Error::from_transport)
65	}
66
67	/// Abort the stream with the given error.
68	pub fn abort(&mut self, err: &Error) {
69		self.stream.as_mut().unwrap().reset(err.to_code());
70	}
71
72	/// Wait for the stream to be closed, or the [Self::finish] to be acknowledged by the peer.
73	pub async fn closed(&mut self) -> Result<(), Error> {
74		self.stream
75			.as_mut()
76			.unwrap()
77			.closed()
78			.await
79			.map_err(Error::from_transport)?;
80		Ok(())
81	}
82
83	/// Set the priority of the stream.
84	pub fn set_priority(&mut self, priority: u8) {
85		self.stream.as_mut().unwrap().set_priority(priority);
86	}
87
88	/// Cast the writer to a different version, used during version negotiation.
89	pub fn with_version<O>(mut self, version: O) -> Writer<S, O> {
90		Writer {
91			// We need to use an Option so Drop doesn't reset the stream.
92			stream: self.stream.take(),
93			buffer: std::mem::take(&mut self.buffer),
94			version,
95		}
96	}
97}
98
99impl<S: web_transport_trait::SendStream> Writer<S, ietf::Version> {
100	/// Encode an IETF `Message` to the stream, writing `[type_id][size][body]`.
101	pub async fn encode_message<T: ietf::Message>(&mut self, msg: &T) -> Result<(), Error> {
102		self.encode(&T::ID).await?;
103		self.encode(msg).await
104	}
105}
106
107impl<S: web_transport_trait::SendStream, V> Drop for Writer<S, V> {
108	fn drop(&mut self) {
109		if let Some(mut stream) = self.stream.take() {
110			// Unlike the Quinn default, we abort the stream on drop.
111			stream.reset(Error::Cancel.to_code());
112		}
113	}
114}