Skip to main content

moq_lite/model/
frame.rs

1use std::task::{Poll, ready};
2
3use bytes::{Bytes, BytesMut};
4
5use crate::{Error, Result};
6
7/// A chunk of data with an upfront size.
8///
9/// Note that this is just the header.
10/// You use [FrameProducer] and [FrameConsumer] to deal with the frame payload, potentially chunked.
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct Frame {
14	pub size: u64,
15}
16
17impl Frame {
18	/// Create a new producer for the frame.
19	pub fn produce(self) -> FrameProducer {
20		FrameProducer::new(self)
21	}
22}
23
24impl From<usize> for Frame {
25	fn from(size: usize) -> Self {
26		Self { size: size as u64 }
27	}
28}
29
30impl From<u64> for Frame {
31	fn from(size: u64) -> Self {
32		Self { size }
33	}
34}
35
36impl From<u32> for Frame {
37	fn from(size: u32) -> Self {
38		Self { size: size as u64 }
39	}
40}
41
42impl From<u16> for Frame {
43	fn from(size: u16) -> Self {
44		Self { size: size as u64 }
45	}
46}
47
48#[derive(Default, Debug)]
49struct FrameState {
50	// The chunks that have been written thus far
51	chunks: Vec<Bytes>,
52
53	// The number of bytes remaining to be written.
54	remaining: u64,
55
56	// The error that caused the frame to be aborted, if any.
57	abort: Option<Error>,
58}
59
60impl FrameState {
61	fn write_chunk(&mut self, chunk: Bytes) -> Result<()> {
62		if let Some(err) = &self.abort {
63			return Err(err.clone());
64		}
65
66		self.remaining = self.remaining.checked_sub(chunk.len() as u64).ok_or(Error::WrongSize)?;
67		self.chunks.push(chunk);
68		Ok(())
69	}
70
71	fn poll_read_chunk(&self, index: usize) -> Poll<Result<Option<Bytes>>> {
72		if let Some(chunk) = self.chunks.get(index).cloned() {
73			Poll::Ready(Ok(Some(chunk)))
74		} else if self.remaining == 0 {
75			Poll::Ready(Ok(None))
76		} else if let Some(err) = &self.abort {
77			Poll::Ready(Err(err.clone()))
78		} else {
79			Poll::Pending
80		}
81	}
82
83	fn poll_read_chunks(&self, index: usize) -> Poll<Result<Vec<Bytes>>> {
84		if index >= self.chunks.len() && self.remaining == 0 {
85			Poll::Ready(Ok(Vec::new()))
86		} else if self.remaining == 0 {
87			Poll::Ready(Ok(self.chunks[index..].to_vec()))
88		} else if let Some(err) = &self.abort {
89			Poll::Ready(Err(err.clone()))
90		} else {
91			Poll::Pending
92		}
93	}
94
95	fn poll_read_all(&self, index: usize) -> Poll<Result<Bytes>> {
96		let chunks = ready!(self.poll_read_all_chunks(index)?);
97
98		Poll::Ready(Ok(match chunks.len() {
99			0 => Bytes::new(),
100			1 => chunks[0].clone(),
101			_ => {
102				let size = chunks.iter().map(Bytes::len).sum();
103				let mut buf = BytesMut::with_capacity(size);
104				for chunk in chunks {
105					buf.extend_from_slice(chunk.as_ref());
106				}
107				buf.freeze()
108			}
109		}))
110	}
111
112	fn poll_read_all_chunks(&self, index: usize) -> Poll<Result<&[Bytes]>> {
113		if self.remaining > 0 {
114			Poll::Pending
115		} else if let Some(err) = &self.abort {
116			Poll::Ready(Err(err.clone()))
117		} else if index < self.chunks.len() {
118			Poll::Ready(Ok(&self.chunks[index..]))
119		} else {
120			Poll::Ready(Ok(&[]))
121		}
122	}
123}
124
125/// Writes a frame's payload in one or more chunks.
126///
127/// The total bytes written must exactly match [Frame::size].
128/// Call [Self::finish] after writing all bytes to verify correctness.
129pub struct FrameProducer {
130	/// The frame header containing the expected size.
131	pub info: Frame,
132
133	// Mutable stream state.
134	state: conducer::Producer<FrameState>,
135}
136
137impl FrameProducer {
138	/// Create a new frame producer for the given frame header.
139	pub fn new(info: Frame) -> Self {
140		let state = FrameState {
141			chunks: Vec::new(),
142			remaining: info.size,
143			abort: None,
144		};
145		Self {
146			info,
147			state: conducer::Producer::new(state),
148		}
149	}
150
151	/// Write a chunk of data to the frame.
152	///
153	/// Returns [Error::WrongSize] if the total bytes written would exceed [Frame::size].
154	pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
155		let chunk = chunk.into();
156		let mut state = self.modify()?;
157		state.write_chunk(chunk)
158	}
159
160	/// Write a chunk of data to the frame.
161	///
162	/// Deprecated: use [`Self::write`] instead.
163	#[deprecated(note = "use write(chunk) instead")]
164	pub fn write_chunk<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
165		self.write(chunk)
166	}
167
168	/// Verify that all bytes have been written.
169	///
170	/// Returns [Error::WrongSize] if the bytes written don't match [Frame::size].
171	pub fn finish(&mut self) -> Result<()> {
172		let state = self.modify()?;
173		if state.remaining != 0 {
174			return Err(Error::WrongSize);
175		}
176		Ok(())
177	}
178
179	/// Abort the frame with the given error.
180	pub fn abort(&mut self, err: Error) -> Result<()> {
181		let mut guard = self.modify()?;
182		guard.abort = Some(err);
183		guard.close();
184		Ok(())
185	}
186
187	/// Create a new consumer for the frame.
188	pub fn consume(&self) -> FrameConsumer {
189		FrameConsumer {
190			info: self.info.clone(),
191			state: self.state.consume(),
192			index: 0,
193		}
194	}
195
196	/// Block until there are no active consumers.
197	pub async fn unused(&self) -> Result<()> {
198		self.state
199			.unused()
200			.await
201			.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
202	}
203
204	fn modify(&mut self) -> Result<conducer::Mut<'_, FrameState>> {
205		self.state
206			.write()
207			.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
208	}
209}
210
211impl Clone for FrameProducer {
212	fn clone(&self) -> Self {
213		Self {
214			info: self.info.clone(),
215			state: self.state.clone(),
216		}
217	}
218}
219
220impl From<Frame> for FrameProducer {
221	fn from(info: Frame) -> Self {
222		FrameProducer::new(info)
223	}
224}
225
226/// Used to consume a frame's worth of data in chunks.
227#[derive(Clone)]
228pub struct FrameConsumer {
229	// Immutable stream state.
230	pub info: Frame,
231
232	// Shared state with the producer.
233	state: conducer::Consumer<FrameState>,
234
235	// The number of chunks we've read.
236	// NOTE: Cloned readers inherit this offset, but then run in parallel.
237	index: usize,
238}
239
240impl FrameConsumer {
241	// A helper to automatically apply Dropped if the state is closed without an error.
242	fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
243	where
244		F: Fn(&conducer::Ref<'_, FrameState>) -> Poll<Result<R>>,
245	{
246		Poll::Ready(match ready!(self.state.poll(waiter, f)) {
247			Ok(res) => res,
248			// We try to clone abort just in case the function forgot to check for terminal state.
249			Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
250		})
251	}
252
253	/// Poll for all remaining data without blocking.
254	pub fn poll_read_all(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Bytes>> {
255		let data = ready!(self.poll(waiter, |state| state.poll_read_all(self.index))?);
256		self.index = usize::MAX;
257		Poll::Ready(Ok(data))
258	}
259
260	/// Return all of the remaining chunks concatenated together.
261	pub async fn read_all(&mut self) -> Result<Bytes> {
262		conducer::wait(|waiter| self.poll_read_all(waiter)).await
263	}
264
265	/// Return all of the remaining chunks of the frame, without blocking.
266	pub fn poll_read_all_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
267		let chunks = ready!(self.poll(waiter, |state| {
268			// This is more complicated because we need to make a copy of the chunks while holding the lock..
269			state
270				.poll_read_all_chunks(self.index)
271				.map(|res| res.map(|chunks| chunks.to_vec()))
272		})?);
273		self.index += chunks.len();
274
275		Poll::Ready(Ok(chunks))
276	}
277
278	/// Poll for the next chunk, without blocking.
279	pub fn poll_read_chunk(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
280		let Some(chunk) = ready!(self.poll(waiter, |state| state.poll_read_chunk(self.index))?) else {
281			return Poll::Ready(Ok(None));
282		};
283		self.index += 1;
284		Poll::Ready(Ok(Some(chunk)))
285	}
286
287	/// Return the next chunk.
288	pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
289		conducer::wait(|waiter| self.poll_read_chunk(waiter)).await
290	}
291
292	/// Poll for the next chunks, without blocking.
293	pub fn poll_read_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
294		let chunks = ready!(self.poll(waiter, |state| state.poll_read_chunks(self.index))?);
295		self.index += chunks.len();
296		Poll::Ready(Ok(chunks))
297	}
298
299	/// Read all of the remaining chunks into a vector.
300	pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
301		conducer::wait(|waiter| self.poll_read_chunks(waiter)).await
302	}
303}
304
305#[cfg(test)]
306mod test {
307	use super::*;
308	use futures::FutureExt;
309
310	#[test]
311	fn single_chunk_roundtrip() {
312		let mut producer = Frame { size: 5 }.produce();
313		producer.write(Bytes::from_static(b"hello")).unwrap();
314		producer.finish().unwrap();
315
316		let mut consumer = producer.consume();
317		let data = consumer.read_all().now_or_never().unwrap().unwrap();
318		assert_eq!(data, Bytes::from_static(b"hello"));
319	}
320
321	#[test]
322	fn multi_chunk_read_all() {
323		let mut producer = Frame { size: 10 }.produce();
324		producer.write(Bytes::from_static(b"hello")).unwrap();
325		producer.write(Bytes::from_static(b"world")).unwrap();
326		producer.finish().unwrap();
327
328		let mut consumer = producer.consume();
329		let data = consumer.read_all().now_or_never().unwrap().unwrap();
330		assert_eq!(data, Bytes::from_static(b"helloworld"));
331	}
332
333	#[test]
334	fn read_chunk_sequential() {
335		let mut producer = Frame { size: 10 }.produce();
336		producer.write(Bytes::from_static(b"hello")).unwrap();
337		producer.write(Bytes::from_static(b"world")).unwrap();
338		producer.finish().unwrap();
339
340		let mut consumer = producer.consume();
341		let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
342		assert_eq!(c1, Some(Bytes::from_static(b"hello")));
343		let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
344		assert_eq!(c2, Some(Bytes::from_static(b"world")));
345		let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
346		assert_eq!(c3, None);
347	}
348
349	#[test]
350	fn read_all_chunks() {
351		let mut producer = Frame { size: 10 }.produce();
352		producer.write(Bytes::from_static(b"hello")).unwrap();
353		producer.write(Bytes::from_static(b"world")).unwrap();
354		producer.finish().unwrap();
355
356		let mut consumer = producer.consume();
357		let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
358		assert_eq!(chunks.len(), 2);
359		assert_eq!(chunks[0], Bytes::from_static(b"hello"));
360		assert_eq!(chunks[1], Bytes::from_static(b"world"));
361	}
362
363	#[test]
364	fn finish_checks_remaining() {
365		let mut producer = Frame { size: 5 }.produce();
366		producer.write(Bytes::from_static(b"hi")).unwrap();
367		let err = producer.finish().unwrap_err();
368		assert!(matches!(err, Error::WrongSize));
369	}
370
371	#[test]
372	fn write_too_many_bytes() {
373		let mut producer = Frame { size: 3 }.produce();
374		let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
375		assert!(matches!(err, Error::WrongSize));
376	}
377
378	#[test]
379	fn abort_propagates() {
380		let mut producer = Frame { size: 5 }.produce();
381		let mut consumer = producer.consume();
382		producer.abort(Error::Cancel).unwrap();
383
384		let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
385		assert!(matches!(err, Error::Cancel));
386	}
387
388	#[test]
389	fn empty_frame() {
390		let mut producer = Frame { size: 0 }.produce();
391		producer.finish().unwrap();
392
393		let mut consumer = producer.consume();
394		let data = consumer.read_all().now_or_never().unwrap().unwrap();
395		assert_eq!(data, Bytes::new());
396	}
397
398	#[tokio::test]
399	async fn pending_then_ready() {
400		let mut producer = Frame { size: 5 }.produce();
401		let mut consumer = producer.consume();
402
403		// Consumer blocks because no data yet.
404		assert!(consumer.read_all().now_or_never().is_none());
405
406		producer.write(Bytes::from_static(b"hello")).unwrap();
407		producer.finish().unwrap();
408
409		let data = consumer.read_all().now_or_never().unwrap().unwrap();
410		assert_eq!(data, Bytes::from_static(b"hello"));
411	}
412}