Skip to main content

moq_lite/model/
group.rs

1//! A group is a stream of frames, split into a [GroupProducer] and [GroupConsumer] handle.
2//!
3//! A [GroupProducer] writes an ordered stream of frames.
4//! Frames can be written all at once, or in chunks.
5//!
6//! A [GroupConsumer] reads an ordered stream of frames.
7//! The reader can be cloned, in which case each reader receives a copy of each frame. (fanout)
8//!
9//! The stream is closed with [Error] when all writers or readers are dropped.
10use std::task::{Poll, ready};
11
12use bytes::Bytes;
13
14use crate::{Error, Result};
15
16use super::{Frame, FrameConsumer, FrameProducer};
17
18/// A group contains a sequence number because they can arrive out of order.
19///
20/// You can use [crate::TrackProducer::append_group] if you just want to +1 the sequence number.
21#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub struct Group {
24	pub sequence: u64,
25}
26
27impl Group {
28	pub fn produce(self) -> GroupProducer {
29		GroupProducer::new(self)
30	}
31}
32
33impl From<usize> for Group {
34	fn from(sequence: usize) -> Self {
35		Self {
36			sequence: sequence as u64,
37		}
38	}
39}
40
41impl From<u64> for Group {
42	fn from(sequence: u64) -> Self {
43		Self { sequence }
44	}
45}
46
47impl From<u32> for Group {
48	fn from(sequence: u32) -> Self {
49		Self {
50			sequence: sequence as u64,
51		}
52	}
53}
54
55impl From<u16> for Group {
56	fn from(sequence: u16) -> Self {
57		Self {
58			sequence: sequence as u64,
59		}
60	}
61}
62
63#[derive(Default)]
64struct GroupState {
65	// The frames that have been written thus far.
66	// We store producers so consumers can be created on-demand.
67	frames: Vec<FrameProducer>,
68
69	// Whether the group has been finalized (no more frames).
70	fin: bool,
71
72	// The error that caused the group to be aborted, if any.
73	abort: Option<Error>,
74}
75
76impl GroupState {
77	fn poll_get_frame(&self, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
78		if let Some(frame) = self.frames.get(index) {
79			Poll::Ready(Ok(Some(frame.consume())))
80		} else if self.fin {
81			Poll::Ready(Ok(None))
82		} else if let Some(err) = &self.abort {
83			Poll::Ready(Err(err.clone()))
84		} else {
85			Poll::Pending
86		}
87	}
88
89	fn poll_finished(&self) -> Poll<Result<u64>> {
90		if self.fin {
91			Poll::Ready(Ok(self.frames.len() as u64))
92		} else if let Some(err) = &self.abort {
93			Poll::Ready(Err(err.clone()))
94		} else {
95			Poll::Pending
96		}
97	}
98}
99
100fn modify(state: &conducer::Producer<GroupState>) -> Result<conducer::Mut<'_, GroupState>> {
101	state.write().map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
102}
103
104/// Writes frames to a group in order.
105///
106/// Each group is delivered independently over a QUIC stream.
107/// Use [Self::write_frame] for simple single-buffer frames,
108/// or [Self::create_frame] for multi-chunk streaming writes.
109pub struct GroupProducer {
110	// Mutable stream state.
111	state: conducer::Producer<GroupState>,
112
113	/// The group header containing the sequence number.
114	pub info: Group,
115}
116
117impl GroupProducer {
118	/// Create a new group producer.
119	pub fn new(info: Group) -> Self {
120		Self {
121			info,
122			state: conducer::Producer::default(),
123		}
124	}
125
126	/// A helper method to write a frame from a single byte buffer.
127	///
128	/// If you want to write multiple chunks, use [Self::create_frame] to get a frame producer.
129	/// But an upfront size is required.
130	pub fn write_frame<B: Into<Bytes>>(&mut self, frame: B) -> Result<()> {
131		let data = frame.into();
132		let frame = Frame {
133			size: data.len() as u64,
134		};
135		let mut frame = self.create_frame(frame)?;
136		frame.write(data)?;
137		frame.finish()?;
138		Ok(())
139	}
140
141	/// Create a frame with an upfront size
142	pub fn create_frame(&mut self, info: Frame) -> Result<FrameProducer> {
143		let frame = info.produce();
144		self.append_frame(frame.clone())?;
145		Ok(frame)
146	}
147
148	/// Append a frame producer to the group.
149	pub fn append_frame(&mut self, frame: FrameProducer) -> Result<()> {
150		let mut state = modify(&self.state)?;
151		if state.fin {
152			return Err(Error::Closed);
153		}
154		state.frames.push(frame);
155		Ok(())
156	}
157
158	/// Return the number of frames written so far.
159	pub fn frame_count(&self) -> usize {
160		self.state.read().frames.len()
161	}
162
163	/// Mark the group as complete; no more frames will be written.
164	pub fn finish(&mut self) -> Result<()> {
165		let mut state = modify(&self.state)?;
166		state.fin = true;
167		Ok(())
168	}
169
170	/// Abort the group with the given error.
171	///
172	/// No updates can be made after this point.
173	pub fn abort(&mut self, err: Error) -> Result<()> {
174		let mut guard = modify(&self.state)?;
175
176		// Abort all frames still in progress.
177		for frame in guard.frames.iter_mut() {
178			// Ignore errors, we don't care if the frame was already closed.
179			frame.abort(err.clone()).ok();
180		}
181
182		guard.abort = Some(err);
183		guard.close();
184		Ok(())
185	}
186
187	/// Create a new consumer for the group.
188	pub fn consume(&self) -> GroupConsumer {
189		GroupConsumer {
190			info: self.info.clone(),
191			state: self.state.consume(),
192			index: 0,
193		}
194	}
195
196	/// Block until the group is closed or aborted.
197	pub async fn closed(&self) -> Error {
198		self.state.closed().await;
199		self.state.read().abort.clone().unwrap_or(Error::Dropped)
200	}
201
202	/// Block until there are no active consumers.
203	pub async fn unused(&self) -> Result<()> {
204		self.state
205			.unused()
206			.await
207			.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
208	}
209}
210
211impl Clone for GroupProducer {
212	fn clone(&self) -> Self {
213		Self {
214			info: self.info.clone(),
215			state: self.state.clone(),
216		}
217	}
218}
219
220impl From<Group> for GroupProducer {
221	fn from(info: Group) -> Self {
222		GroupProducer::new(info)
223	}
224}
225
226/// Consume a group, frame-by-frame.
227#[derive(Clone)]
228pub struct GroupConsumer {
229	// Shared state with the producer.
230	state: conducer::Consumer<GroupState>,
231
232	// Immutable stream state.
233	pub info: Group,
234
235	// The number of frames we've read.
236	// NOTE: Cloned readers inherit this offset, but then run in parallel.
237	index: usize,
238}
239
240impl GroupConsumer {
241	// A helper to automatically apply Dropped if the state is closed without an error.
242	fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
243	where
244		F: Fn(&conducer::Ref<'_, GroupState>) -> Poll<Result<R>>,
245	{
246		Poll::Ready(match ready!(self.state.poll(waiter, f)) {
247			Ok(res) => res,
248			// We try to clone abort just in case the function forgot to check for terminal state.
249			Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
250		})
251	}
252
253	/// Block until the frame at the given index is available.
254	///
255	/// Returns None if the group is finished and the index is out of range.
256	pub async fn get_frame(&self, index: usize) -> Result<Option<FrameConsumer>> {
257		conducer::wait(|waiter| self.poll_get_frame(waiter, index)).await
258	}
259
260	/// Poll for the frame at the given index, without blocking.
261	///
262	/// Returns None if the group is finished and the index is out of range.
263	pub fn poll_get_frame(&self, waiter: &conducer::Waiter, index: usize) -> Poll<Result<Option<FrameConsumer>>> {
264		self.poll(waiter, |state| state.poll_get_frame(index))
265	}
266
267	/// Return a consumer for the next frame for chunked reading.
268	pub async fn next_frame(&mut self) -> Result<Option<FrameConsumer>> {
269		conducer::wait(|waiter| self.poll_next_frame(waiter)).await
270	}
271
272	/// Poll for the next frame, without blocking.
273	///
274	/// Returns None if the group is finished and the index is out of range.
275	pub fn poll_next_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<FrameConsumer>>> {
276		let Some(frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
277			return Poll::Ready(Ok(None));
278		};
279
280		self.index += 1;
281		Poll::Ready(Ok(Some(frame)))
282	}
283
284	/// Read the next frame's data all at once, without blocking.
285	pub fn poll_read_frame(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
286		let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
287			return Poll::Ready(Ok(None));
288		};
289
290		let data = ready!(frame.poll_read_all(waiter))?;
291		self.index += 1;
292
293		Poll::Ready(Ok(Some(data)))
294	}
295
296	/// Read the next frame's data all at once.
297	pub async fn read_frame(&mut self) -> Result<Option<Bytes>> {
298		conducer::wait(|waiter| self.poll_read_frame(waiter)).await
299	}
300
301	/// Read all of the chunks of the next frame, without blocking.
302	pub fn poll_read_frame_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Vec<Bytes>>>> {
303		let Some(mut frame) = ready!(self.poll(waiter, |state| state.poll_get_frame(self.index))?) else {
304			return Poll::Ready(Ok(None));
305		};
306
307		let data = ready!(frame.poll_read_all_chunks(waiter))?;
308		self.index += 1;
309
310		Poll::Ready(Ok(Some(data)))
311	}
312
313	/// Read all of the chunks of the next frame.
314	pub async fn read_frame_chunks(&mut self) -> Result<Option<Vec<Bytes>>> {
315		conducer::wait(|waiter| self.poll_read_frame_chunks(waiter)).await
316	}
317
318	/// Poll for the final number of frames in the group.
319	pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
320		self.poll(waiter, |state| state.poll_finished())
321	}
322
323	/// Block until the group is finished, returning the number of frames in the group.
324	pub async fn finished(&mut self) -> Result<u64> {
325		conducer::wait(|waiter| self.poll_finished(waiter)).await
326	}
327}
328
329#[cfg(test)]
330mod test {
331	use super::*;
332	use futures::FutureExt;
333
334	#[test]
335	fn basic_frame_reading() {
336		let mut producer = Group { sequence: 0 }.produce();
337		producer.write_frame(Bytes::from_static(b"frame0")).unwrap();
338		producer.write_frame(Bytes::from_static(b"frame1")).unwrap();
339		producer.finish().unwrap();
340
341		let mut consumer = producer.consume();
342		let f0 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
343		assert_eq!(f0.info.size, 6);
344		let f1 = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
345		assert_eq!(f1.info.size, 6);
346		let end = consumer.next_frame().now_or_never().unwrap().unwrap();
347		assert!(end.is_none());
348	}
349
350	#[test]
351	fn read_frame_all_at_once() {
352		let mut producer = Group { sequence: 0 }.produce();
353		producer.write_frame(Bytes::from_static(b"hello")).unwrap();
354		producer.finish().unwrap();
355
356		let mut consumer = producer.consume();
357		let data = consumer.read_frame().now_or_never().unwrap().unwrap().unwrap();
358		assert_eq!(data, Bytes::from_static(b"hello"));
359	}
360
361	#[test]
362	fn read_frame_chunks() {
363		let mut producer = Group { sequence: 0 }.produce();
364		let mut frame = producer.create_frame(Frame { size: 10 }).unwrap();
365		frame.write(Bytes::from_static(b"hello")).unwrap();
366		frame.write(Bytes::from_static(b"world")).unwrap();
367		frame.finish().unwrap();
368		producer.finish().unwrap();
369
370		let mut consumer = producer.consume();
371		let chunks = consumer.read_frame_chunks().now_or_never().unwrap().unwrap().unwrap();
372		assert_eq!(chunks.len(), 2);
373		assert_eq!(chunks[0], Bytes::from_static(b"hello"));
374		assert_eq!(chunks[1], Bytes::from_static(b"world"));
375	}
376
377	#[test]
378	fn get_frame_by_index() {
379		let mut producer = Group { sequence: 0 }.produce();
380		producer.write_frame(Bytes::from_static(b"a")).unwrap();
381		producer.write_frame(Bytes::from_static(b"bb")).unwrap();
382		producer.finish().unwrap();
383
384		let consumer = producer.consume();
385		let f0 = consumer.get_frame(0).now_or_never().unwrap().unwrap().unwrap();
386		assert_eq!(f0.info.size, 1);
387		let f1 = consumer.get_frame(1).now_or_never().unwrap().unwrap().unwrap();
388		assert_eq!(f1.info.size, 2);
389		let f2 = consumer.get_frame(2).now_or_never().unwrap().unwrap();
390		assert!(f2.is_none());
391	}
392
393	#[test]
394	fn group_finish_returns_none() {
395		let mut producer = Group { sequence: 0 }.produce();
396		producer.finish().unwrap();
397
398		let mut consumer = producer.consume();
399		let end = consumer.next_frame().now_or_never().unwrap().unwrap();
400		assert!(end.is_none());
401	}
402
403	#[test]
404	fn abort_propagates() {
405		let mut producer = Group { sequence: 0 }.produce();
406		let mut consumer = producer.consume();
407		producer.abort(crate::Error::Cancel).unwrap();
408
409		let result = consumer.next_frame().now_or_never().unwrap();
410		assert!(matches!(result, Err(crate::Error::Cancel)));
411	}
412
413	#[tokio::test]
414	async fn pending_then_ready() {
415		let mut producer = Group { sequence: 0 }.produce();
416		let mut consumer = producer.consume();
417
418		// Consumer blocks because no frames yet.
419		assert!(consumer.next_frame().now_or_never().is_none());
420
421		producer.write_frame(Bytes::from_static(b"data")).unwrap();
422		producer.finish().unwrap();
423
424		let frame = consumer.next_frame().now_or_never().unwrap().unwrap().unwrap();
425		assert_eq!(frame.info.size, 4);
426	}
427
428	#[test]
429	fn clone_consumer_independent() {
430		let mut producer = Group { sequence: 0 }.produce();
431		producer.write_frame(Bytes::from_static(b"a")).unwrap();
432
433		let mut c1 = producer.consume();
434		// Read one frame from c1
435		let _ = c1.next_frame().now_or_never().unwrap().unwrap().unwrap();
436
437		// Clone c1 — inherits index (past first frame)
438		let mut c2 = c1.clone();
439
440		producer.write_frame(Bytes::from_static(b"b")).unwrap();
441		producer.finish().unwrap();
442
443		// c2 should get the second frame (inherited index)
444		let f = c2.next_frame().now_or_never().unwrap().unwrap().unwrap();
445		assert_eq!(f.info.size, 1); // "b"
446
447		let end = c2.next_frame().now_or_never().unwrap().unwrap();
448		assert!(end.is_none());
449	}
450}