Skip to main content

moq_lite/model/
frame.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Poll, ready};
4
5use bytes::buf::UninitSlice;
6use bytes::{BufMut, Bytes};
7
8use crate::{Error, Result};
9
10/// A chunk of data with an upfront size.
11///
12/// Note that this is just the header.
13/// You use [FrameProducer] and [FrameConsumer] to deal with the frame payload, potentially chunked.
14#[derive(Clone, Debug)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct Frame {
17	/// Total payload size in bytes. Declared up front so consumers can preallocate.
18	pub size: u64,
19}
20
21impl Frame {
22	/// Create a new producer for the frame.
23	pub fn produce(self) -> FrameProducer {
24		FrameProducer::new(self)
25	}
26}
27
28impl From<usize> for Frame {
29	fn from(size: usize) -> Self {
30		Self { size: size as u64 }
31	}
32}
33
34impl From<u64> for Frame {
35	fn from(size: u64) -> Self {
36		Self { size }
37	}
38}
39
40impl From<u32> for Frame {
41	fn from(size: u32) -> Self {
42		Self { size: size as u64 }
43	}
44}
45
46impl From<u16> for Frame {
47	fn from(size: u16) -> Self {
48		Self { size: size as u64 }
49	}
50}
51
52/// Single-allocation buffer shared between a [FrameProducer] and many [FrameConsumer]s.
53///
54/// Internally an [Arc] over a thin pointer + length owning a heap allocation. The
55/// data pointer is stable for the life of any clone, so [Bytes] views taken via
56/// [Bytes::from_owner] remain valid. [Clone] is cheap (one atomic increment).
57///
58/// The producer writes through the raw pointer (sole writer); `written` provides
59/// happens-before for cross-thread reads. Implements [AsRef]<[u8]> directly so it
60/// can be passed to [Bytes::from_owner] without an extra wrapper newtype.
61#[derive(Clone)]
62struct FrameBuf(Arc<FrameBufInner>);
63
64struct FrameBufInner {
65	// Owned heap allocation of `capacity` bytes (zero-initialized).
66	data: *mut u8,
67	capacity: usize,
68	written: AtomicUsize,
69}
70
71// Safety: `data` is owned (Box-allocated, freed in Drop); the producer is the
72// sole writer; consumers only read bytes `< written`, which was set via Release
73// after the corresponding writes completed (Acquire pairs on the consumer side).
74unsafe impl Send for FrameBufInner {}
75unsafe impl Sync for FrameBufInner {}
76
77impl Drop for FrameBufInner {
78	fn drop(&mut self) {
79		// Safety: data was obtained from `Box::into_raw` of a `Box<[u8]>` of
80		// length `capacity` and is not aliased at drop (Arc refcount hit 0).
81		unsafe {
82			let slice = std::ptr::slice_from_raw_parts_mut(self.data, self.capacity);
83			drop(Box::from_raw(slice));
84		}
85	}
86}
87
88impl FrameBuf {
89	fn new(size: usize) -> Self {
90		let boxed: Box<[u8]> = vec![0u8; size].into_boxed_slice();
91		let capacity = boxed.len();
92		let data = Box::into_raw(boxed) as *mut u8;
93		Self(Arc::new(FrameBufInner {
94			data,
95			capacity,
96			written: AtomicUsize::new(0),
97		}))
98	}
99
100	fn capacity(&self) -> usize {
101		self.0.capacity
102	}
103
104	fn written(&self, ord: Ordering) -> usize {
105		self.0.written.load(ord)
106	}
107
108	/// Safety: caller must be the sole producer (FrameProducer-as-BufMut invariant).
109	unsafe fn data_ptr(&self) -> *mut u8 {
110		self.0.data
111	}
112
113	/// Safety: caller must be the sole producer; `new_written` must be `<= capacity`.
114	unsafe fn store_written(&self, new_written: usize) {
115		// Release pairs with consumers' Acquire load to publish prior writes.
116		self.0.written.store(new_written, Ordering::Release);
117	}
118}
119
120impl AsRef<[u8]> for FrameBuf {
121	fn as_ref(&self) -> &[u8] {
122		// Snapshot the initialized region (bytes the producer has written so far).
123		// Acquire pairs with the producer's Release on `written`.
124		let written = self.0.written.load(Ordering::Acquire);
125		// Safety: data..data+written is initialized (zero-init at alloc + producer
126		// writes up to `written`). The Arc keeps the allocation alive while any
127		// reference to the slice lives.
128		unsafe { std::slice::from_raw_parts(self.0.data, written) }
129	}
130}
131
132#[derive(Default, Debug)]
133struct FrameState {
134	// Whether the producer signaled a clean finish (written == capacity).
135	fin: bool,
136	// The error that aborted the frame, if any.
137	abort: Option<Error>,
138}
139
140/// Writes a frame's payload in one or more chunks.
141///
142/// The total bytes written must exactly match [Frame::size].
143/// Call [Self::finish] after writing all bytes to verify correctness.
144///
145/// Implements [BufMut] so the receive path can write directly into the
146/// pre-allocated buffer (e.g. via `tokio::io::AsyncReadExt::read_buf`).
147pub struct FrameProducer {
148	info: Frame,
149	state: conducer::Producer<FrameState>,
150	buf: FrameBuf,
151}
152
153impl std::ops::Deref for FrameProducer {
154	type Target = Frame;
155
156	fn deref(&self) -> &Self::Target {
157		&self.info
158	}
159}
160
161impl FrameProducer {
162	/// Create a new frame producer for the given frame header.
163	pub fn new(info: Frame) -> Self {
164		let buf = FrameBuf::new(info.size as usize);
165		Self {
166			info,
167			state: conducer::Producer::new(FrameState::default()),
168			buf,
169		}
170	}
171
172	/// Write a chunk of data to the frame.
173	///
174	/// Returns [Error::WrongSize] if the chunk would exceed the remaining bytes.
175	pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
176		let chunk = chunk.into();
177		if chunk.len() > self.remaining_mut() {
178			return Err(Error::WrongSize);
179		}
180		// Surface aborts before writing.
181		self.bail_if_aborted()?;
182		self.put_slice(&chunk);
183		Ok(())
184	}
185
186	/// Verify that all bytes have been written.
187	///
188	/// Returns [Error::WrongSize] if the bytes written don't match [Frame::size].
189	pub fn finish(&mut self) -> Result<()> {
190		let written = self.buf.written(Ordering::Acquire);
191		if written != self.buf.capacity() {
192			return Err(Error::WrongSize);
193		}
194		// Mark fin (idempotent if `advance_mut` already set it on the last byte).
195		let mut state = self.modify()?;
196		state.fin = true;
197		Ok(())
198	}
199
200	/// Abort the frame with the given error.
201	pub fn abort(&mut self, err: Error) -> Result<()> {
202		let mut guard = self.modify()?;
203		guard.abort = Some(err);
204		guard.close();
205		Ok(())
206	}
207
208	/// Create a new consumer for the frame.
209	pub fn consume(&self) -> FrameConsumer {
210		FrameConsumer {
211			info: self.info.clone(),
212			state: self.state.consume(),
213			buf: self.buf.clone(),
214			read_idx: 0,
215		}
216	}
217
218	/// Block until there are no active consumers.
219	pub async fn unused(&self) -> Result<()> {
220		self.state
221			.unused()
222			.await
223			.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
224	}
225
226	fn modify(&mut self) -> Result<conducer::Mut<'_, FrameState>> {
227		self.state
228			.write()
229			.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
230	}
231
232	fn bail_if_aborted(&self) -> Result<()> {
233		let state = self.state.read();
234		if let Some(err) = &state.abort {
235			return Err(err.clone());
236		}
237		Ok(())
238	}
239}
240
241// Safety: `chunk_mut` returns a slice into the producer-private region of the
242// buffer (`[written..capacity]`). Sole-writer invariant: even though
243// `FrameProducer` is `Clone`, the API exposes BufMut only via `&mut self`,
244// and existing callers never share a single producer between concurrent writers
245// (group.rs clones a handle for `abort` / `consume` only). The defensive
246// `assert!` in `advance_mut` panics loudly if that invariant is ever violated.
247unsafe impl BufMut for FrameProducer {
248	fn remaining_mut(&self) -> usize {
249		self.buf.capacity() - self.buf.written(Ordering::Acquire)
250	}
251
252	fn chunk_mut(&mut self) -> &mut UninitSlice {
253		let written = self.buf.written(Ordering::Acquire);
254		let cap = self.buf.capacity();
255		// Safety: writes to `[written..cap]` are unaliased — consumers only ever
256		// read `[..written]`, and we hold `&mut self`. The slice's lifetime is
257		// tied to `&mut self` by the function signature.
258		unsafe {
259			let ptr = self.buf.data_ptr().add(written);
260			UninitSlice::from_raw_parts_mut(ptr, cap - written)
261		}
262	}
263
264	unsafe fn advance_mut(&mut self, cnt: usize) {
265		let cap = self.buf.capacity();
266		let prev = self.buf.written(Ordering::Relaxed);
267		assert!(
268			prev + cnt <= cap,
269			"advance_mut past frame.size: prev={prev} cnt={cnt} cap={cap}"
270		);
271		// Safety: sole-writer invariant + bounds-checked above.
272		unsafe { self.buf.store_written(prev + cnt) };
273
274		// Briefly take the conducer write lock to wake waiters; drop of `Mut`
275		// triggers conducer's notify. Also flip `fin` if we just filled the buffer.
276		if let Ok(mut state) = self.state.write() {
277			if prev + cnt == cap {
278				state.fin = true;
279			}
280		}
281	}
282}
283
284impl Clone for FrameProducer {
285	fn clone(&self) -> Self {
286		Self {
287			info: self.info.clone(),
288			state: self.state.clone(),
289			buf: self.buf.clone(),
290		}
291	}
292}
293
294impl From<Frame> for FrameProducer {
295	fn from(info: Frame) -> Self {
296		FrameProducer::new(info)
297	}
298}
299
300/// Used to consume a frame's worth of data, streaming as bytes arrive.
301#[derive(Clone)]
302pub struct FrameConsumer {
303	info: Frame,
304	state: conducer::Consumer<FrameState>,
305	buf: FrameBuf,
306	// Byte offset into the buffer; cloned consumers inherit this offset and
307	// read independently from there.
308	read_idx: usize,
309}
310
311impl std::ops::Deref for FrameConsumer {
312	type Target = Frame;
313
314	fn deref(&self) -> &Self::Target {
315		&self.info
316	}
317}
318
319impl FrameConsumer {
320	// A helper to automatically apply Dropped if the state is closed without an error.
321	fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
322	where
323		F: Fn(&conducer::Ref<'_, FrameState>) -> Poll<Result<R>>,
324	{
325		Poll::Ready(match ready!(self.state.poll(waiter, f)) {
326			Ok(res) => res,
327			Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
328		})
329	}
330
331	fn snapshot(&self, read_idx: usize) -> Option<Bytes> {
332		// Acquire pairs with the producer's Release on `written`, making the
333		// bytes in `[..written]` visible to this thread.
334		let written = self.buf.written(Ordering::Acquire);
335		if written > read_idx {
336			Some(Bytes::from_owner(self.buf.clone()).slice(read_idx..written))
337		} else {
338			None
339		}
340	}
341
342	/// Poll for all remaining data without blocking.
343	///
344	/// Waits until the frame is finished (written == size); then returns the
345	/// remaining bytes from `read_idx` to the end as a single zero-copy slice.
346	pub fn poll_read_all(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Bytes>> {
347		let read_idx = self.read_idx;
348		let res = ready!(self.poll(waiter, |state| {
349			if state.fin {
350				return Poll::Ready(Ok(()));
351			}
352			if let Some(err) = &state.abort {
353				return Poll::Ready(Err(err.clone()));
354			}
355			Poll::Pending
356		}));
357		match res {
358			Ok(()) => {
359				// Frame is finished: written == capacity.
360				let bytes = self
361					.snapshot(read_idx)
362					.unwrap_or_else(|| Bytes::from_owner(self.buf.clone()).slice(read_idx..read_idx));
363				self.read_idx = self.buf.capacity();
364				Poll::Ready(Ok(bytes))
365			}
366			Err(e) => Poll::Ready(Err(e)),
367		}
368	}
369
370	/// Return all of the remaining bytes, blocking until the frame is finished.
371	pub async fn read_all(&mut self) -> Result<Bytes> {
372		conducer::wait(|waiter| self.poll_read_all(waiter)).await
373	}
374
375	/// Poll for all remaining bytes (split into a single-element vec for backwards
376	/// compatibility with the previous chunk-based API).
377	pub fn poll_read_all_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
378		let bytes = ready!(self.poll_read_all(waiter)?);
379		Poll::Ready(Ok(if bytes.is_empty() { Vec::new() } else { vec![bytes] }))
380	}
381
382	/// Poll for the next chunk of bytes since the last read.
383	///
384	/// Returns whatever bytes have been written since the consumer's `read_idx` —
385	/// could span multiple producer writes. Returns `None` once the frame is
386	/// finished and all bytes have been consumed.
387	pub fn poll_read_chunk(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
388		let read_idx = self.read_idx;
389		let res = ready!(self.poll(waiter, |state| {
390			let written = self.buf.written(Ordering::Acquire);
391			if written > read_idx {
392				return Poll::Ready(Ok(Some(written)));
393			}
394			if state.fin {
395				return Poll::Ready(Ok(None));
396			}
397			if let Some(err) = &state.abort {
398				return Poll::Ready(Err(err.clone()));
399			}
400			Poll::Pending
401		}));
402		match res {
403			Ok(Some(written)) => {
404				let bytes = Bytes::from_owner(self.buf.clone()).slice(read_idx..written);
405				self.read_idx = written;
406				Poll::Ready(Ok(Some(bytes)))
407			}
408			Ok(None) => Poll::Ready(Ok(None)),
409			Err(e) => Poll::Ready(Err(e)),
410		}
411	}
412
413	/// Return the next chunk of bytes since the last read.
414	pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
415		conducer::wait(|waiter| self.poll_read_chunk(waiter)).await
416	}
417
418	/// Poll for the next chunk; for backwards compatibility, wraps
419	/// [Self::poll_read_chunk] in a vec (single element if any data is available).
420	pub fn poll_read_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
421		match ready!(self.poll_read_chunk(waiter)?) {
422			Some(b) => Poll::Ready(Ok(vec![b])),
423			None => Poll::Ready(Ok(Vec::new())),
424		}
425	}
426
427	/// Read the next chunk into a vector (single element if available, empty on eof).
428	pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
429		conducer::wait(|waiter| self.poll_read_chunks(waiter)).await
430	}
431}
432
433#[cfg(test)]
434mod test {
435	use super::*;
436	use futures::FutureExt;
437
438	#[test]
439	fn single_chunk_roundtrip() {
440		let mut producer = Frame { size: 5 }.produce();
441		producer.write(Bytes::from_static(b"hello")).unwrap();
442		producer.finish().unwrap();
443
444		let mut consumer = producer.consume();
445		let data = consumer.read_all().now_or_never().unwrap().unwrap();
446		assert_eq!(data, Bytes::from_static(b"hello"));
447	}
448
449	#[test]
450	fn multi_chunk_read_all() {
451		let mut producer = Frame { size: 10 }.produce();
452		producer.write(Bytes::from_static(b"hello")).unwrap();
453		producer.write(Bytes::from_static(b"world")).unwrap();
454		producer.finish().unwrap();
455
456		let mut consumer = producer.consume();
457		let data = consumer.read_all().now_or_never().unwrap().unwrap();
458		assert_eq!(data, Bytes::from_static(b"helloworld"));
459	}
460
461	#[test]
462	fn read_chunk_sequential() {
463		let mut producer = Frame { size: 10 }.produce();
464		producer.write(Bytes::from_static(b"hello")).unwrap();
465		// Each read_chunk returns whatever is new since the last call,
466		// which may span multiple writes.
467		let mut consumer = producer.consume();
468		let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
469		assert_eq!(c1, Some(Bytes::from_static(b"hello")));
470
471		producer.write(Bytes::from_static(b"world")).unwrap();
472		producer.finish().unwrap();
473
474		let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
475		assert_eq!(c2, Some(Bytes::from_static(b"world")));
476		let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
477		assert_eq!(c3, None);
478	}
479
480	#[test]
481	fn read_all_chunks() {
482		let mut producer = Frame { size: 10 }.produce();
483		producer.write(Bytes::from_static(b"hello")).unwrap();
484		producer.write(Bytes::from_static(b"world")).unwrap();
485		producer.finish().unwrap();
486
487		let mut consumer = producer.consume();
488		let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
489		assert_eq!(chunks.len(), 1);
490		assert_eq!(chunks[0], Bytes::from_static(b"helloworld"));
491	}
492
493	#[test]
494	fn finish_checks_remaining() {
495		let mut producer = Frame { size: 5 }.produce();
496		producer.write(Bytes::from_static(b"hi")).unwrap();
497		let err = producer.finish().unwrap_err();
498		assert!(matches!(err, Error::WrongSize));
499	}
500
501	#[test]
502	fn write_too_many_bytes() {
503		let mut producer = Frame { size: 3 }.produce();
504		let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
505		assert!(matches!(err, Error::WrongSize));
506	}
507
508	#[test]
509	fn abort_propagates() {
510		let mut producer = Frame { size: 5 }.produce();
511		let mut consumer = producer.consume();
512		producer.abort(Error::Cancel).unwrap();
513
514		let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
515		assert!(matches!(err, Error::Cancel));
516	}
517
518	#[test]
519	fn empty_frame() {
520		let mut producer = Frame { size: 0 }.produce();
521		producer.finish().unwrap();
522
523		let mut consumer = producer.consume();
524		let data = consumer.read_all().now_or_never().unwrap().unwrap();
525		assert_eq!(data, Bytes::new());
526	}
527
528	#[tokio::test]
529	async fn pending_then_ready() {
530		let mut producer = Frame { size: 5 }.produce();
531		let mut consumer = producer.consume();
532
533		// Consumer blocks because no data yet.
534		assert!(consumer.read_all().now_or_never().is_none());
535
536		producer.write(Bytes::from_static(b"hello")).unwrap();
537		producer.finish().unwrap();
538
539		let data = consumer.read_all().now_or_never().unwrap().unwrap();
540		assert_eq!(data, Bytes::from_static(b"hello"));
541	}
542
543	#[test]
544	fn buf_mut_roundtrip() {
545		// Exercise the BufMut path that the receive loop uses via `read_buf`.
546		let mut producer = Frame { size: 12 }.produce();
547		assert_eq!(producer.remaining_mut(), 12);
548		producer.put_slice(b"hello");
549		assert_eq!(producer.remaining_mut(), 7);
550		producer.put_slice(b" world!");
551		assert_eq!(producer.remaining_mut(), 0);
552		producer.finish().unwrap();
553
554		let mut consumer = producer.consume();
555		let data = consumer.read_all().now_or_never().unwrap().unwrap();
556		assert_eq!(data, Bytes::from_static(b"hello world!"));
557	}
558
559	#[test]
560	#[should_panic(expected = "advance_mut past frame.size")]
561	fn buf_mut_advance_past_capacity_panics() {
562		let mut producer = Frame { size: 4 }.produce();
563		// Safety violation on purpose: cnt > remaining_mut().
564		unsafe { producer.advance_mut(5) };
565	}
566
567	#[test]
568	fn read_chunk_streams_partial_writes() {
569		let mut producer = Frame { size: 6 }.produce();
570		let mut consumer = producer.consume();
571
572		producer.write(Bytes::from_static(b"foo")).unwrap();
573		let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
574		assert_eq!(c1, Some(Bytes::from_static(b"foo")));
575
576		// No new data → pending.
577		assert!(consumer.read_chunk().now_or_never().is_none());
578
579		producer.write(Bytes::from_static(b"bar")).unwrap();
580		producer.finish().unwrap();
581		let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
582		assert_eq!(c2, Some(Bytes::from_static(b"bar")));
583		let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
584		assert_eq!(c3, None);
585	}
586
587	#[test]
588	fn cloned_consumer_independent_cursor() {
589		let mut producer = Frame { size: 10 }.produce();
590		let mut c1 = producer.consume();
591		producer.write(Bytes::from_static(b"hello")).unwrap();
592
593		// c1 reads the first 5 bytes, then we clone — c2 inherits c1's cursor.
594		let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
595		assert_eq!(chunk, Some(Bytes::from_static(b"hello")));
596		let mut c2 = c1.clone();
597
598		producer.write(Bytes::from_static(b"world")).unwrap();
599		producer.finish().unwrap();
600
601		// Both consumers now see "world" as their next chunk.
602		let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
603		assert_eq!(chunk, Some(Bytes::from_static(b"world")));
604		let chunk = c2.read_chunk().now_or_never().unwrap().unwrap();
605		assert_eq!(chunk, Some(Bytes::from_static(b"world")));
606	}
607}