moq_lite/model/
frame.rs

1use std::future::Future;
2
3use bytes::{Bytes, BytesMut};
4use tokio::sync::watch;
5
6use crate::{Error, Result};
7
8#[derive(Clone, Debug)]
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10pub struct Frame {
11	pub size: u64,
12}
13
14impl Frame {
15	pub fn produce(self) -> FrameProducer {
16		FrameProducer::new(self)
17	}
18}
19
20impl From<usize> for Frame {
21	fn from(size: usize) -> Self {
22		Self { size: size as u64 }
23	}
24}
25
26impl From<u64> for Frame {
27	fn from(size: u64) -> Self {
28		Self { size }
29	}
30}
31
32impl From<u32> for Frame {
33	fn from(size: u32) -> Self {
34		Self { size: size as u64 }
35	}
36}
37
38impl From<u16> for Frame {
39	fn from(size: u16) -> Self {
40		Self { size: size as u64 }
41	}
42}
43
44#[derive(Default)]
45struct FrameState {
46	// The chunks that has been written thus far
47	chunks: Vec<Bytes>,
48
49	// Set when the writer or all readers are dropped.
50	closed: Option<Result<()>>,
51}
52
53/// Used to write a frame's worth of data in chunks.
54#[derive(Clone)]
55pub struct FrameProducer {
56	// Immutable stream state.
57	pub info: Frame,
58
59	// Mutable stream state.
60	state: watch::Sender<FrameState>,
61
62	// Sanity check to ensure we don't write more than the frame size.
63	written: usize,
64}
65
66impl FrameProducer {
67	pub fn new(info: Frame) -> Self {
68		Self {
69			info,
70			state: Default::default(),
71			written: 0,
72		}
73	}
74
75	pub fn write<B: Into<Bytes>>(&mut self, chunk: B) {
76		let chunk = chunk.into();
77		self.written += chunk.len();
78		assert!(self.written <= self.info.size as usize);
79
80		self.state.send_modify(|state| {
81			assert!(state.closed.is_none());
82			state.chunks.push(chunk);
83		});
84	}
85
86	pub fn finish(self) {
87		assert!(self.written == self.info.size as usize);
88		self.state.send_modify(|state| state.closed = Some(Ok(())));
89	}
90
91	pub fn abort(self, err: Error) {
92		self.state.send_modify(|state| state.closed = Some(Err(err)));
93	}
94
95	/// Create a new consumer for the frame.
96	pub fn consume(&self) -> FrameConsumer {
97		FrameConsumer {
98			info: self.info.clone(),
99			state: self.state.subscribe(),
100			index: 0,
101		}
102	}
103
104	// Returns a Future so &self is not borrowed during the future.
105	pub fn unused(&self) -> impl Future<Output = ()> {
106		let state = self.state.clone();
107		async move {
108			state.closed().await;
109		}
110	}
111}
112
113impl From<Frame> for FrameProducer {
114	fn from(info: Frame) -> Self {
115		FrameProducer::new(info)
116	}
117}
118
119/// Used to consume a frame's worth of data in chunks.
120#[derive(Clone)]
121pub struct FrameConsumer {
122	// Immutable stream state.
123	pub info: Frame,
124
125	// Modify the stream state.
126	state: watch::Receiver<FrameState>,
127
128	// The number of frames we've read.
129	// NOTE: Cloned readers inherit this offset, but then run in parallel.
130	index: usize,
131}
132
133impl FrameConsumer {
134	// Return the next chunk.
135	pub async fn read(&mut self) -> Result<Option<Bytes>> {
136		loop {
137			{
138				let state = self.state.borrow_and_update();
139
140				if let Some(chunk) = state.chunks.get(self.index).cloned() {
141					self.index += 1;
142					return Ok(Some(chunk));
143				}
144
145				match &state.closed {
146					Some(Ok(_)) => return Ok(None),
147					Some(Err(err)) => return Err(err.clone()),
148					_ => {}
149				}
150			}
151
152			if self.state.changed().await.is_err() {
153				return Err(Error::Cancel);
154			}
155		}
156	}
157
158	// Return all of the remaining chunks concatenated together.
159	pub async fn read_all(&mut self) -> Result<Bytes> {
160		// Wait until the writer is done before even attempting to read.
161		// That way this function can be cancelled without consuming half of the frame.
162		let state = match self.state.wait_for(|state| state.closed.is_some()).await {
163			Ok(state) => {
164				if let Some(Err(err)) = &state.closed {
165					return Err(err.clone());
166				}
167				state
168			}
169			Err(_) => return Err(Error::Cancel),
170		};
171
172		// Get all of the remaining chunks.
173		let chunks = &state.chunks[self.index..];
174		self.index = state.chunks.len();
175
176		// We know the final size so we can allocate the buffer upfront.
177		let size = chunks.iter().map(Bytes::len).sum();
178
179		// We know the final size so we can allocate the buffer upfront.
180		let mut buf = BytesMut::with_capacity(size);
181
182		// Copy the chunks into the buffer.
183		for chunk in chunks {
184			buf.extend_from_slice(chunk);
185		}
186
187		Ok(buf.freeze())
188	}
189}