Skip to main content

moq_lite/model/
frame.rs

1use std::task::Poll;
2
3use bytes::{Bytes, BytesMut};
4
5use crate::{Error, Result};
6
7use super::state::{Consumer, Producer};
8use super::waiter::waiter_fn;
9
10/// A chunk of data with an upfront size.
11///
12/// Note that this is just the header.
13/// You use [FrameProducer] and [FrameConsumer] to deal with the frame payload, potentially chunked.
14#[derive(Clone, Debug)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct Frame {
17	pub size: u64,
18}
19
20impl Frame {
21	/// Create a new producer for the frame.
22	pub fn produce(self) -> FrameProducer {
23		FrameProducer::new(self)
24	}
25}
26
27impl From<usize> for Frame {
28	fn from(size: usize) -> Self {
29		Self { size: size as u64 }
30	}
31}
32
33impl From<u64> for Frame {
34	fn from(size: u64) -> Self {
35		Self { size }
36	}
37}
38
39impl From<u32> for Frame {
40	fn from(size: u32) -> Self {
41		Self { size: size as u64 }
42	}
43}
44
45impl From<u16> for Frame {
46	fn from(size: u16) -> Self {
47		Self { size: size as u64 }
48	}
49}
50
51#[derive(Default, Debug)]
52struct FrameState {
53	// The chunks that have been written thus far
54	chunks: Vec<Bytes>,
55
56	// The number of bytes remaining to be written.
57	remaining: u64,
58}
59
60impl FrameState {
61	fn write_chunk(&mut self, chunk: Bytes) -> Result<()> {
62		self.remaining = self.remaining.checked_sub(chunk.len() as u64).ok_or(Error::WrongSize)?;
63		self.chunks.push(chunk);
64		Ok(())
65	}
66
67	fn poll_read_chunk(&self, index: usize) -> Poll<Option<Bytes>> {
68		if let Some(chunk) = self.chunks.get(index).cloned() {
69			Poll::Ready(Some(chunk))
70		} else if self.remaining == 0 {
71			Poll::Ready(None)
72		} else {
73			Poll::Pending
74		}
75	}
76
77	fn poll_read_chunks(&self, index: usize) -> Poll<Vec<Bytes>> {
78		if index >= self.chunks.len() && self.remaining == 0 {
79			return Poll::Ready(Vec::new());
80		}
81		if self.remaining == 0 {
82			Poll::Ready(self.chunks[index..].to_vec())
83		} else {
84			Poll::Pending
85		}
86	}
87
88	fn poll_read_all(&self, index: usize) -> Poll<Bytes> {
89		if self.remaining > 0 {
90			return Poll::Pending;
91		}
92
93		if index >= self.chunks.len() {
94			return Poll::Ready(Bytes::new());
95		}
96
97		let chunks = &self.chunks[index..];
98		let size = chunks.iter().map(Bytes::len).sum();
99		let mut buf = BytesMut::with_capacity(size);
100		for chunk in chunks {
101			buf.extend_from_slice(chunk);
102		}
103		Poll::Ready(buf.freeze())
104	}
105}
106
107/// Writes a frame's payload in one or more chunks.
108///
109/// The total bytes written must exactly match [Frame::size].
110/// Call [Self::finish] after writing all bytes to verify correctness.
111pub struct FrameProducer {
112	/// The frame header containing the expected size.
113	pub info: Frame,
114
115	// Mutable stream state.
116	state: Producer<FrameState>,
117}
118
119impl FrameProducer {
120	/// Create a new frame producer for the given frame header.
121	pub fn new(info: Frame) -> Self {
122		let state = FrameState {
123			chunks: Vec::new(),
124			remaining: info.size,
125		};
126		Self {
127			info,
128			state: Producer::new(state),
129		}
130	}
131
132	/// Write a chunk of data to the frame.
133	///
134	/// Returns [Error::WrongSize] if the total bytes written would exceed [Frame::size].
135	pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
136		let chunk = chunk.into();
137		let mut state = self.state.modify()?;
138		state.write_chunk(chunk)
139	}
140
141	/// Write a chunk of data to the frame.
142	///
143	/// Deprecated: use [`Self::write`] instead.
144	#[deprecated(note = "use write(chunk) instead")]
145	pub fn write_chunk<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
146		self.write(chunk)
147	}
148
149	/// Verify that all bytes have been written.
150	///
151	/// Returns [Error::WrongSize] if the bytes written don't match [Frame::size].
152	pub fn finish(&mut self) -> Result<()> {
153		let state = self.state.modify()?;
154		if state.remaining != 0 {
155			return Err(Error::WrongSize);
156		}
157		Ok(())
158	}
159
160	/// Abort the frame with the given error.
161	pub fn abort(&mut self, err: Error) -> Result<()> {
162		self.state.abort(err)
163	}
164
165	/// Create a new consumer for the frame.
166	pub fn consume(&self) -> FrameConsumer {
167		FrameConsumer {
168			info: self.info.clone(),
169			state: self.state.consume(),
170			index: 0,
171		}
172	}
173
174	/// Block until there are no active consumers.
175	pub async fn unused(&self) -> Result<()> {
176		self.state.unused().await
177	}
178}
179
180impl Clone for FrameProducer {
181	fn clone(&self) -> Self {
182		Self {
183			info: self.info.clone(),
184			state: self.state.clone(),
185		}
186	}
187}
188
189impl From<Frame> for FrameProducer {
190	fn from(info: Frame) -> Self {
191		FrameProducer::new(info)
192	}
193}
194
195/// Used to consume a frame's worth of data in chunks.
196#[derive(Clone)]
197pub struct FrameConsumer {
198	// Immutable stream state.
199	pub info: Frame,
200
201	// Shared state with the producer.
202	state: Consumer<FrameState>,
203
204	// The number of chunks we've read.
205	// NOTE: Cloned readers inherit this offset, but then run in parallel.
206	index: usize,
207}
208
209impl FrameConsumer {
210	/// Return the next chunk.
211	pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
212		let index = self.index;
213		let res = waiter_fn(|waiter| self.state.poll(waiter, |state| state.poll_read_chunk(index))).await?;
214		if res.is_some() {
215			self.index += 1;
216		}
217		Ok(res)
218	}
219
220	/// Read all of the remaining chunks into a vector.
221	/// Cancel-safe: returns all or nothing.
222	pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
223		let index = self.index;
224		let chunks = waiter_fn(|waiter| self.state.poll(waiter, |state| state.poll_read_chunks(index))).await?;
225		self.index += chunks.len();
226		Ok(chunks)
227	}
228
229	/// Return all of the remaining chunks concatenated together.
230	/// Cancel-safe: returns all or nothing.
231	pub async fn read_all(&mut self) -> Result<Bytes> {
232		let index = self.index;
233		let data = waiter_fn(|waiter| self.state.poll(waiter, |state| state.poll_read_all(index))).await?;
234		self.index = usize::MAX; // consumed everything
235		Ok(data)
236	}
237}