Skip to main content

moq_net/model/
frame.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Poll, ready};
4
5use bytes::buf::UninitSlice;
6use bytes::{BufMut, Bytes};
7
8use crate::{Error, Result};
9
10/// Maximum payload size accepted for a single frame on the wire.
11///
12/// The receive path preallocates a buffer from the declared frame size, so an
13/// untrusted peer could otherwise request a multi-gigabyte allocation with a
14/// single varint. Subscribers reject frames whose declared size exceeds this.
15///
16// TODO enforce this in [Frame::produce] / [FrameProducer::new] so the limit is
17// guaranteed for every caller, not just the wire decode paths. Blocked on
18// making the constructor fallible (returning [Result]), which is an API break.
19pub(crate) const MAX_FRAME_SIZE: u64 = 16 * 1024 * 1024;
20
21/// A chunk of data with an upfront size.
22///
23/// Note that this is just the header.
24/// You use [FrameProducer] and [FrameConsumer] to deal with the frame payload, potentially chunked.
25#[derive(Clone, Debug)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub struct Frame {
28	/// Total payload size in bytes. Declared up front so consumers can preallocate.
29	pub size: u64,
30}
31
32impl Frame {
33	/// Create a new producer for the frame.
34	pub fn produce(self) -> FrameProducer {
35		FrameProducer::new(self)
36	}
37}
38
39impl From<usize> for Frame {
40	fn from(size: usize) -> Self {
41		Self { size: size as u64 }
42	}
43}
44
45impl From<u64> for Frame {
46	fn from(size: u64) -> Self {
47		Self { size }
48	}
49}
50
51impl From<u32> for Frame {
52	fn from(size: u32) -> Self {
53		Self { size: size as u64 }
54	}
55}
56
57impl From<u16> for Frame {
58	fn from(size: u16) -> Self {
59		Self { size: size as u64 }
60	}
61}
62
63/// Single-allocation buffer shared between a [FrameProducer] and many [FrameConsumer]s.
64///
65/// Internally an [Arc] over a thin pointer + length owning a heap allocation. The
66/// data pointer is stable for the life of any clone, so [Bytes] views taken via
67/// [Bytes::from_owner] remain valid. [Clone] is cheap (one atomic increment).
68///
69/// The producer writes through the raw pointer (sole writer); `written` provides
70/// happens-before for cross-thread reads. Implements [AsRef]<[u8]> directly so it
71/// can be passed to [Bytes::from_owner] without an extra wrapper newtype.
72#[derive(Clone)]
73struct FrameBuf(Arc<FrameBufInner>);
74
75struct FrameBufInner {
76	// Owned heap allocation of `capacity` bytes (zero-initialized).
77	data: *mut u8,
78	capacity: usize,
79	written: AtomicUsize,
80}
81
82// Safety: `data` is owned (Box-allocated, freed in Drop); the producer is the
83// sole writer; consumers only read bytes `< written`, which was set via Release
84// after the corresponding writes completed (Acquire pairs on the consumer side).
85unsafe impl Send for FrameBufInner {}
86unsafe impl Sync for FrameBufInner {}
87
88impl Drop for FrameBufInner {
89	fn drop(&mut self) {
90		// Safety: data was obtained from `Box::into_raw` of a `Box<[u8]>` of
91		// length `capacity` and is not aliased at drop (Arc refcount hit 0).
92		unsafe {
93			let slice = std::ptr::slice_from_raw_parts_mut(self.data, self.capacity);
94			drop(Box::from_raw(slice));
95		}
96	}
97}
98
99impl FrameBuf {
100	fn new(size: usize) -> Self {
101		let boxed: Box<[u8]> = vec![0u8; size].into_boxed_slice();
102		let capacity = boxed.len();
103		let data = Box::into_raw(boxed) as *mut u8;
104		Self(Arc::new(FrameBufInner {
105			data,
106			capacity,
107			written: AtomicUsize::new(0),
108		}))
109	}
110
111	fn capacity(&self) -> usize {
112		self.0.capacity
113	}
114
115	fn written(&self, ord: Ordering) -> usize {
116		self.0.written.load(ord)
117	}
118
119	/// Safety: caller must be the sole producer (FrameProducer-as-BufMut invariant).
120	unsafe fn data_ptr(&self) -> *mut u8 {
121		self.0.data
122	}
123
124	/// Safety: caller must be the sole producer; `new_written` must be `<= capacity`.
125	unsafe fn store_written(&self, new_written: usize) {
126		// Release pairs with consumers' Acquire load to publish prior writes.
127		self.0.written.store(new_written, Ordering::Release);
128	}
129}
130
131impl AsRef<[u8]> for FrameBuf {
132	fn as_ref(&self) -> &[u8] {
133		// Snapshot the initialized region (bytes the producer has written so far).
134		// Acquire pairs with the producer's Release on `written`.
135		let written = self.0.written.load(Ordering::Acquire);
136		// Safety: data..data+written is initialized (zero-init at alloc + producer
137		// writes up to `written`). The Arc keeps the allocation alive while any
138		// reference to the slice lives.
139		unsafe { std::slice::from_raw_parts(self.0.data, written) }
140	}
141}
142
143#[derive(Default, Debug)]
144struct FrameState {
145	// Whether the producer signaled a clean finish (written == capacity).
146	fin: bool,
147	// The error that aborted the frame, if any.
148	abort: Option<Error>,
149}
150
151/// Writes a frame's payload in one or more chunks.
152///
153/// The total bytes written must exactly match [Frame::size].
154/// Call [Self::finish] after writing all bytes to verify correctness.
155///
156/// Implements [BufMut] so the receive path can write directly into the
157/// pre-allocated buffer (e.g. via `tokio::io::AsyncReadExt::read_buf`).
158pub struct FrameProducer {
159	info: Frame,
160	state: kio::Producer<FrameState>,
161	buf: FrameBuf,
162}
163
164impl std::ops::Deref for FrameProducer {
165	type Target = Frame;
166
167	fn deref(&self) -> &Self::Target {
168		&self.info
169	}
170}
171
172impl FrameProducer {
173	/// Create a new frame producer for the given frame header.
174	pub fn new(info: Frame) -> Self {
175		let buf = FrameBuf::new(info.size as usize);
176		Self {
177			info,
178			state: kio::Producer::new(FrameState::default()),
179			buf,
180		}
181	}
182
183	/// Write a chunk of data to the frame.
184	///
185	/// Returns [Error::WrongSize] if the chunk would exceed the remaining bytes.
186	pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
187		let chunk = chunk.into();
188		if chunk.len() > self.remaining_mut() {
189			return Err(Error::WrongSize);
190		}
191		// Surface aborts before writing.
192		self.bail_if_aborted()?;
193		self.put_slice(&chunk);
194		Ok(())
195	}
196
197	/// Verify that all bytes have been written.
198	///
199	/// Returns [Error::WrongSize] if the bytes written don't match [Frame::size].
200	pub fn finish(&mut self) -> Result<()> {
201		let written = self.buf.written(Ordering::Acquire);
202		if written != self.buf.capacity() {
203			return Err(Error::WrongSize);
204		}
205		// Mark fin (idempotent if `advance_mut` already set it on the last byte).
206		let mut state = self.modify()?;
207		state.fin = true;
208		Ok(())
209	}
210
211	/// Abort the frame with the given error.
212	pub fn abort(&mut self, err: Error) -> Result<()> {
213		let mut guard = self.modify()?;
214		guard.abort = Some(err);
215		guard.close();
216		Ok(())
217	}
218
219	/// Create a new consumer for the frame.
220	pub fn consume(&self) -> FrameConsumer {
221		FrameConsumer {
222			info: self.info.clone(),
223			state: self.state.consume(),
224			buf: self.buf.clone(),
225			read_idx: 0,
226		}
227	}
228
229	/// Block until there are no active consumers.
230	pub async fn unused(&self) -> Result<()> {
231		self.state
232			.unused()
233			.await
234			.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
235	}
236
237	fn modify(&mut self) -> Result<kio::Mut<'_, FrameState>> {
238		self.state
239			.write()
240			.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
241	}
242
243	fn bail_if_aborted(&self) -> Result<()> {
244		let state = self.state.read();
245		if let Some(err) = &state.abort {
246			return Err(err.clone());
247		}
248		Ok(())
249	}
250}
251
252// Safety: `chunk_mut` returns a slice into the producer-private region of the
253// buffer (`[written..capacity]`). Sole-writer invariant: even though
254// `FrameProducer` is `Clone`, the API exposes BufMut only via `&mut self`,
255// and existing callers never share a single producer between concurrent writers
256// (group.rs clones a handle for `abort` / `consume` only). The defensive
257// `assert!` in `advance_mut` panics loudly if that invariant is ever violated.
258unsafe impl BufMut for FrameProducer {
259	fn remaining_mut(&self) -> usize {
260		self.buf.capacity() - self.buf.written(Ordering::Acquire)
261	}
262
263	fn chunk_mut(&mut self) -> &mut UninitSlice {
264		let written = self.buf.written(Ordering::Acquire);
265		let cap = self.buf.capacity();
266		// Safety: writes to `[written..cap]` are unaliased — consumers only ever
267		// read `[..written]`, and we hold `&mut self`. The slice's lifetime is
268		// tied to `&mut self` by the function signature.
269		unsafe {
270			let ptr = self.buf.data_ptr().add(written);
271			UninitSlice::from_raw_parts_mut(ptr, cap - written)
272		}
273	}
274
275	unsafe fn advance_mut(&mut self, cnt: usize) {
276		let cap = self.buf.capacity();
277		let prev = self.buf.written(Ordering::Relaxed);
278		assert!(
279			prev + cnt <= cap,
280			"advance_mut past frame.size: prev={prev} cnt={cnt} cap={cap}"
281		);
282		// Safety: sole-writer invariant + bounds-checked above.
283		unsafe { self.buf.store_written(prev + cnt) };
284
285		// Briefly take the kio write lock to wake waiters; drop of `Mut`
286		// triggers kio's notify. Also flip `fin` if we just filled the buffer.
287		if let Ok(mut state) = self.state.write() {
288			if prev + cnt == cap {
289				state.fin = true;
290			}
291		}
292	}
293}
294
295impl Clone for FrameProducer {
296	fn clone(&self) -> Self {
297		Self {
298			info: self.info.clone(),
299			state: self.state.clone(),
300			buf: self.buf.clone(),
301		}
302	}
303}
304
305impl From<Frame> for FrameProducer {
306	fn from(info: Frame) -> Self {
307		FrameProducer::new(info)
308	}
309}
310
311/// Used to consume a frame's worth of data, streaming as bytes arrive.
312#[derive(Clone)]
313pub struct FrameConsumer {
314	info: Frame,
315	state: kio::Consumer<FrameState>,
316	buf: FrameBuf,
317	// Byte offset into the buffer; cloned consumers inherit this offset and
318	// read independently from there.
319	read_idx: usize,
320}
321
322impl std::ops::Deref for FrameConsumer {
323	type Target = Frame;
324
325	fn deref(&self) -> &Self::Target {
326		&self.info
327	}
328}
329
330impl FrameConsumer {
331	// A helper to automatically apply Dropped if the state is closed without an error.
332	fn poll<F, R>(&self, waiter: &kio::Waiter, f: F) -> Poll<Result<R>>
333	where
334		F: Fn(&kio::Ref<'_, FrameState>) -> Poll<Result<R>>,
335	{
336		Poll::Ready(match ready!(self.state.poll(waiter, f)) {
337			Ok(res) => res,
338			Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
339		})
340	}
341
342	fn snapshot(&self, read_idx: usize) -> Option<Bytes> {
343		// Acquire pairs with the producer's Release on `written`, making the
344		// bytes in `[..written]` visible to this thread.
345		let written = self.buf.written(Ordering::Acquire);
346		if written > read_idx {
347			Some(Bytes::from_owner(self.buf.clone()).slice(read_idx..written))
348		} else {
349			None
350		}
351	}
352
353	/// Poll for all remaining data without blocking.
354	///
355	/// Waits until the frame is finished (written == size); then returns the
356	/// remaining bytes from `read_idx` to the end as a single zero-copy slice.
357	pub fn poll_read_all(&mut self, waiter: &kio::Waiter) -> Poll<Result<Bytes>> {
358		let read_idx = self.read_idx;
359		let res = ready!(self.poll(waiter, |state| {
360			if state.fin {
361				return Poll::Ready(Ok(()));
362			}
363			if let Some(err) = &state.abort {
364				return Poll::Ready(Err(err.clone()));
365			}
366			Poll::Pending
367		}));
368		match res {
369			Ok(()) => {
370				// Frame is finished: written == capacity.
371				let bytes = self
372					.snapshot(read_idx)
373					.unwrap_or_else(|| Bytes::from_owner(self.buf.clone()).slice(read_idx..read_idx));
374				self.read_idx = self.buf.capacity();
375				Poll::Ready(Ok(bytes))
376			}
377			Err(e) => Poll::Ready(Err(e)),
378		}
379	}
380
381	/// Return all of the remaining bytes, blocking until the frame is finished.
382	pub async fn read_all(&mut self) -> Result<Bytes> {
383		kio::wait(|waiter| self.poll_read_all(waiter)).await
384	}
385
386	/// Poll for all remaining bytes (split into a single-element vec for backwards
387	/// compatibility with the previous chunk-based API).
388	pub fn poll_read_all_chunks(&mut self, waiter: &kio::Waiter) -> Poll<Result<Vec<Bytes>>> {
389		let bytes = ready!(self.poll_read_all(waiter)?);
390		Poll::Ready(Ok(if bytes.is_empty() { Vec::new() } else { vec![bytes] }))
391	}
392
393	/// Poll for the next chunk of bytes since the last read.
394	///
395	/// Returns whatever bytes have been written since the consumer's `read_idx` —
396	/// could span multiple producer writes. Returns `None` once the frame is
397	/// finished and all bytes have been consumed.
398	pub fn poll_read_chunk(&mut self, waiter: &kio::Waiter) -> Poll<Result<Option<Bytes>>> {
399		let read_idx = self.read_idx;
400		let res = ready!(self.poll(waiter, |state| {
401			let written = self.buf.written(Ordering::Acquire);
402			if written > read_idx {
403				return Poll::Ready(Ok(Some(written)));
404			}
405			if state.fin {
406				return Poll::Ready(Ok(None));
407			}
408			if let Some(err) = &state.abort {
409				return Poll::Ready(Err(err.clone()));
410			}
411			Poll::Pending
412		}));
413		match res {
414			Ok(Some(written)) => {
415				let bytes = Bytes::from_owner(self.buf.clone()).slice(read_idx..written);
416				self.read_idx = written;
417				Poll::Ready(Ok(Some(bytes)))
418			}
419			Ok(None) => Poll::Ready(Ok(None)),
420			Err(e) => Poll::Ready(Err(e)),
421		}
422	}
423
424	/// Return the next chunk of bytes since the last read.
425	pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
426		kio::wait(|waiter| self.poll_read_chunk(waiter)).await
427	}
428
429	/// Poll for the next chunk; for backwards compatibility, wraps
430	/// [Self::poll_read_chunk] in a vec (single element if any data is available).
431	pub fn poll_read_chunks(&mut self, waiter: &kio::Waiter) -> Poll<Result<Vec<Bytes>>> {
432		match ready!(self.poll_read_chunk(waiter)?) {
433			Some(b) => Poll::Ready(Ok(vec![b])),
434			None => Poll::Ready(Ok(Vec::new())),
435		}
436	}
437
438	/// Read the next chunk into a vector (single element if available, empty on eof).
439	pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
440		kio::wait(|waiter| self.poll_read_chunks(waiter)).await
441	}
442}
443
444#[cfg(test)]
445mod test {
446	use super::*;
447	use futures::FutureExt;
448
449	#[test]
450	fn single_chunk_roundtrip() {
451		let mut producer = Frame { size: 5 }.produce();
452		producer.write(Bytes::from_static(b"hello")).unwrap();
453		producer.finish().unwrap();
454
455		let mut consumer = producer.consume();
456		let data = consumer.read_all().now_or_never().unwrap().unwrap();
457		assert_eq!(data, Bytes::from_static(b"hello"));
458	}
459
460	#[test]
461	fn multi_chunk_read_all() {
462		let mut producer = Frame { size: 10 }.produce();
463		producer.write(Bytes::from_static(b"hello")).unwrap();
464		producer.write(Bytes::from_static(b"world")).unwrap();
465		producer.finish().unwrap();
466
467		let mut consumer = producer.consume();
468		let data = consumer.read_all().now_or_never().unwrap().unwrap();
469		assert_eq!(data, Bytes::from_static(b"helloworld"));
470	}
471
472	#[test]
473	fn read_chunk_sequential() {
474		let mut producer = Frame { size: 10 }.produce();
475		producer.write(Bytes::from_static(b"hello")).unwrap();
476		// Each read_chunk returns whatever is new since the last call,
477		// which may span multiple writes.
478		let mut consumer = producer.consume();
479		let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
480		assert_eq!(c1, Some(Bytes::from_static(b"hello")));
481
482		producer.write(Bytes::from_static(b"world")).unwrap();
483		producer.finish().unwrap();
484
485		let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
486		assert_eq!(c2, Some(Bytes::from_static(b"world")));
487		let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
488		assert_eq!(c3, None);
489	}
490
491	#[test]
492	fn read_all_chunks() {
493		let mut producer = Frame { size: 10 }.produce();
494		producer.write(Bytes::from_static(b"hello")).unwrap();
495		producer.write(Bytes::from_static(b"world")).unwrap();
496		producer.finish().unwrap();
497
498		let mut consumer = producer.consume();
499		let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
500		assert_eq!(chunks.len(), 1);
501		assert_eq!(chunks[0], Bytes::from_static(b"helloworld"));
502	}
503
504	#[test]
505	fn finish_checks_remaining() {
506		let mut producer = Frame { size: 5 }.produce();
507		producer.write(Bytes::from_static(b"hi")).unwrap();
508		let err = producer.finish().unwrap_err();
509		assert!(matches!(err, Error::WrongSize));
510	}
511
512	#[test]
513	fn write_too_many_bytes() {
514		let mut producer = Frame { size: 3 }.produce();
515		let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
516		assert!(matches!(err, Error::WrongSize));
517	}
518
519	#[test]
520	fn abort_propagates() {
521		let mut producer = Frame { size: 5 }.produce();
522		let mut consumer = producer.consume();
523		producer.abort(Error::Cancel).unwrap();
524
525		let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
526		assert!(matches!(err, Error::Cancel));
527	}
528
529	#[test]
530	fn empty_frame() {
531		let mut producer = Frame { size: 0 }.produce();
532		producer.finish().unwrap();
533
534		let mut consumer = producer.consume();
535		let data = consumer.read_all().now_or_never().unwrap().unwrap();
536		assert_eq!(data, Bytes::new());
537	}
538
539	#[tokio::test]
540	async fn pending_then_ready() {
541		let mut producer = Frame { size: 5 }.produce();
542		let mut consumer = producer.consume();
543
544		// Consumer blocks because no data yet.
545		assert!(consumer.read_all().now_or_never().is_none());
546
547		producer.write(Bytes::from_static(b"hello")).unwrap();
548		producer.finish().unwrap();
549
550		let data = consumer.read_all().now_or_never().unwrap().unwrap();
551		assert_eq!(data, Bytes::from_static(b"hello"));
552	}
553
554	#[test]
555	fn buf_mut_roundtrip() {
556		// Exercise the BufMut path that the receive loop uses via `read_buf`.
557		let mut producer = Frame { size: 12 }.produce();
558		assert_eq!(producer.remaining_mut(), 12);
559		producer.put_slice(b"hello");
560		assert_eq!(producer.remaining_mut(), 7);
561		producer.put_slice(b" world!");
562		assert_eq!(producer.remaining_mut(), 0);
563		producer.finish().unwrap();
564
565		let mut consumer = producer.consume();
566		let data = consumer.read_all().now_or_never().unwrap().unwrap();
567		assert_eq!(data, Bytes::from_static(b"hello world!"));
568	}
569
570	#[test]
571	#[should_panic(expected = "advance_mut past frame.size")]
572	fn buf_mut_advance_past_capacity_panics() {
573		let mut producer = Frame { size: 4 }.produce();
574		// Safety violation on purpose: cnt > remaining_mut().
575		unsafe { producer.advance_mut(5) };
576	}
577
578	#[test]
579	fn read_chunk_streams_partial_writes() {
580		let mut producer = Frame { size: 6 }.produce();
581		let mut consumer = producer.consume();
582
583		producer.write(Bytes::from_static(b"foo")).unwrap();
584		let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
585		assert_eq!(c1, Some(Bytes::from_static(b"foo")));
586
587		// No new data → pending.
588		assert!(consumer.read_chunk().now_or_never().is_none());
589
590		producer.write(Bytes::from_static(b"bar")).unwrap();
591		producer.finish().unwrap();
592		let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
593		assert_eq!(c2, Some(Bytes::from_static(b"bar")));
594		let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
595		assert_eq!(c3, None);
596	}
597
598	#[test]
599	fn cloned_consumer_independent_cursor() {
600		let mut producer = Frame { size: 10 }.produce();
601		let mut c1 = producer.consume();
602		producer.write(Bytes::from_static(b"hello")).unwrap();
603
604		// c1 reads the first 5 bytes, then we clone — c2 inherits c1's cursor.
605		let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
606		assert_eq!(chunk, Some(Bytes::from_static(b"hello")));
607		let mut c2 = c1.clone();
608
609		producer.write(Bytes::from_static(b"world")).unwrap();
610		producer.finish().unwrap();
611
612		// Both consumers now see "world" as their next chunk.
613		let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
614		assert_eq!(chunk, Some(Bytes::from_static(b"world")));
615		let chunk = c2.read_chunk().now_or_never().unwrap().unwrap();
616		assert_eq!(chunk, Some(Bytes::from_static(b"world")));
617	}
618}