moq_lite/model/
frame.rs

1use std::future::Future;
2
3use bytes::{Bytes, BytesMut};
4use tokio::sync::watch;
5
6use crate::{Error, Produce, Result};
7
8#[derive(Clone, Debug)]
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10pub struct Frame {
11	pub size: u64,
12}
13
14impl Frame {
15	pub fn produce(self) -> Produce<FrameProducer, FrameConsumer> {
16		let producer = FrameProducer::new(self);
17		let consumer = producer.consume();
18		Produce { producer, consumer }
19	}
20}
21
22impl From<usize> for Frame {
23	fn from(size: usize) -> Self {
24		Self { size: size as u64 }
25	}
26}
27
28impl From<u64> for Frame {
29	fn from(size: u64) -> Self {
30		Self { size }
31	}
32}
33
34impl From<u32> for Frame {
35	fn from(size: u32) -> Self {
36		Self { size: size as u64 }
37	}
38}
39
40impl From<u16> for Frame {
41	fn from(size: u16) -> Self {
42		Self { size: size as u64 }
43	}
44}
45
46#[derive(Default)]
47struct FrameState {
48	// The chunks that has been written thus far
49	chunks: Vec<Bytes>,
50
51	// Set when the writer or all readers are dropped.
52	closed: Option<Result<()>>,
53}
54
55/// Used to write a frame's worth of data in chunks.
56#[derive(Clone)]
57pub struct FrameProducer {
58	// Immutable stream state.
59	pub info: Frame,
60
61	// Mutable stream state.
62	state: watch::Sender<FrameState>,
63
64	// Sanity check to ensure we don't write more than the frame size.
65	written: usize,
66}
67
68impl FrameProducer {
69	fn new(info: Frame) -> Self {
70		Self {
71			info,
72			state: Default::default(),
73			written: 0,
74		}
75	}
76
77	pub fn write_chunk<B: Into<Bytes>>(&mut self, chunk: B) {
78		let chunk = chunk.into();
79		self.written += chunk.len();
80		assert!(self.written <= self.info.size as usize);
81
82		self.state.send_modify(|state| {
83			assert!(state.closed.is_none());
84			state.chunks.push(chunk);
85		});
86	}
87
88	pub fn close(self) {
89		assert!(self.written == self.info.size as usize);
90		self.state.send_modify(|state| state.closed = Some(Ok(())));
91	}
92
93	pub fn abort(self, err: Error) {
94		self.state.send_modify(|state| state.closed = Some(Err(err)));
95	}
96
97	/// Create a new consumer for the frame.
98	pub fn consume(&self) -> FrameConsumer {
99		FrameConsumer {
100			info: self.info.clone(),
101			state: self.state.subscribe(),
102			index: 0,
103		}
104	}
105
106	// Returns a Future so &self is not borrowed during the future.
107	pub fn unused(&self) -> impl Future<Output = ()> {
108		let state = self.state.clone();
109		async move {
110			state.closed().await;
111		}
112	}
113}
114
115impl From<Frame> for FrameProducer {
116	fn from(info: Frame) -> Self {
117		FrameProducer::new(info)
118	}
119}
120
121/// Used to consume a frame's worth of data in chunks.
122#[derive(Clone)]
123pub struct FrameConsumer {
124	// Immutable stream state.
125	pub info: Frame,
126
127	// Modify the stream state.
128	state: watch::Receiver<FrameState>,
129
130	// The number of frames we've read.
131	// NOTE: Cloned readers inherit this offset, but then run in parallel.
132	index: usize,
133}
134
135impl FrameConsumer {
136	// Return the next chunk.
137	pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
138		loop {
139			{
140				let state = self.state.borrow_and_update();
141
142				if let Some(chunk) = state.chunks.get(self.index).cloned() {
143					self.index += 1;
144					return Ok(Some(chunk));
145				}
146
147				match &state.closed {
148					Some(Ok(_)) => return Ok(None),
149					Some(Err(err)) => return Err(err.clone()),
150					_ => {}
151				}
152			}
153
154			if self.state.changed().await.is_err() {
155				return Err(Error::Cancel);
156			}
157		}
158	}
159
160	// Return all of the remaining chunks concatenated together.
161	pub async fn read_all(&mut self) -> Result<Bytes> {
162		// Wait until the writer is done before even attempting to read.
163		// That way this function can be cancelled without consuming half of the frame.
164		let state = match self.state.wait_for(|state| state.closed.is_some()).await {
165			Ok(state) => {
166				if let Some(Err(err)) = &state.closed {
167					return Err(err.clone());
168				}
169				state
170			}
171			Err(_) => return Err(Error::Cancel),
172		};
173
174		// Get all of the remaining chunks.
175		let chunks = &state.chunks[self.index..];
176		self.index = state.chunks.len();
177
178		// We know the final size so we can allocate the buffer upfront.
179		let size = chunks.iter().map(Bytes::len).sum();
180
181		// We know the final size so we can allocate the buffer upfront.
182		let mut buf = BytesMut::with_capacity(size);
183
184		// Copy the chunks into the buffer.
185		for chunk in chunks {
186			buf.extend_from_slice(chunk);
187		}
188
189		Ok(buf.freeze())
190	}
191}