1use std::sync::Arc;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Poll, ready};
4
5use bytes::buf::UninitSlice;
6use bytes::{BufMut, Bytes};
7
8use crate::{Error, Result};
9
10pub(crate) const MAX_FRAME_SIZE: u64 = 16 * 1024 * 1024;
20
21#[derive(Clone, Debug)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub struct Frame {
28 pub size: u64,
30}
31
32impl Frame {
33 pub fn produce(self) -> FrameProducer {
35 FrameProducer::new(self)
36 }
37}
38
39impl From<usize> for Frame {
40 fn from(size: usize) -> Self {
41 Self { size: size as u64 }
42 }
43}
44
45impl From<u64> for Frame {
46 fn from(size: u64) -> Self {
47 Self { size }
48 }
49}
50
51impl From<u32> for Frame {
52 fn from(size: u32) -> Self {
53 Self { size: size as u64 }
54 }
55}
56
57impl From<u16> for Frame {
58 fn from(size: u16) -> Self {
59 Self { size: size as u64 }
60 }
61}
62
63#[derive(Clone)]
73struct FrameBuf(Arc<FrameBufInner>);
74
75struct FrameBufInner {
76 data: *mut u8,
78 capacity: usize,
79 written: AtomicUsize,
80}
81
82unsafe impl Send for FrameBufInner {}
86unsafe impl Sync for FrameBufInner {}
87
88impl Drop for FrameBufInner {
89 fn drop(&mut self) {
90 unsafe {
93 let slice = std::ptr::slice_from_raw_parts_mut(self.data, self.capacity);
94 drop(Box::from_raw(slice));
95 }
96 }
97}
98
99impl FrameBuf {
100 fn new(size: usize) -> Self {
101 let boxed: Box<[u8]> = vec![0u8; size].into_boxed_slice();
102 let capacity = boxed.len();
103 let data = Box::into_raw(boxed) as *mut u8;
104 Self(Arc::new(FrameBufInner {
105 data,
106 capacity,
107 written: AtomicUsize::new(0),
108 }))
109 }
110
111 fn capacity(&self) -> usize {
112 self.0.capacity
113 }
114
115 fn written(&self, ord: Ordering) -> usize {
116 self.0.written.load(ord)
117 }
118
119 unsafe fn data_ptr(&self) -> *mut u8 {
121 self.0.data
122 }
123
124 unsafe fn store_written(&self, new_written: usize) {
126 self.0.written.store(new_written, Ordering::Release);
128 }
129}
130
131impl AsRef<[u8]> for FrameBuf {
132 fn as_ref(&self) -> &[u8] {
133 let written = self.0.written.load(Ordering::Acquire);
136 unsafe { std::slice::from_raw_parts(self.0.data, written) }
140 }
141}
142
143#[derive(Default, Debug)]
144struct FrameState {
145 fin: bool,
147 abort: Option<Error>,
149}
150
151pub struct FrameProducer {
159 info: Frame,
160 state: kio::Producer<FrameState>,
161 buf: FrameBuf,
162}
163
164impl std::ops::Deref for FrameProducer {
165 type Target = Frame;
166
167 fn deref(&self) -> &Self::Target {
168 &self.info
169 }
170}
171
172impl FrameProducer {
173 pub fn new(info: Frame) -> Self {
175 let buf = FrameBuf::new(info.size as usize);
176 Self {
177 info,
178 state: kio::Producer::new(FrameState::default()),
179 buf,
180 }
181 }
182
183 pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
187 let chunk = chunk.into();
188 if chunk.len() > self.remaining_mut() {
189 return Err(Error::WrongSize);
190 }
191 self.bail_if_aborted()?;
193 self.put_slice(&chunk);
194 Ok(())
195 }
196
197 pub fn finish(&mut self) -> Result<()> {
201 let written = self.buf.written(Ordering::Acquire);
202 if written != self.buf.capacity() {
203 return Err(Error::WrongSize);
204 }
205 let mut state = self.modify()?;
207 state.fin = true;
208 Ok(())
209 }
210
211 pub fn abort(&mut self, err: Error) -> Result<()> {
213 let mut guard = self.modify()?;
214 guard.abort = Some(err);
215 guard.close();
216 Ok(())
217 }
218
219 pub fn consume(&self) -> FrameConsumer {
221 FrameConsumer {
222 info: self.info.clone(),
223 state: self.state.consume(),
224 buf: self.buf.clone(),
225 read_idx: 0,
226 }
227 }
228
229 pub async fn unused(&self) -> Result<()> {
231 self.state
232 .unused()
233 .await
234 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
235 }
236
237 fn modify(&mut self) -> Result<kio::Mut<'_, FrameState>> {
238 self.state
239 .write()
240 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
241 }
242
243 fn bail_if_aborted(&self) -> Result<()> {
244 let state = self.state.read();
245 if let Some(err) = &state.abort {
246 return Err(err.clone());
247 }
248 Ok(())
249 }
250}
251
252unsafe impl BufMut for FrameProducer {
259 fn remaining_mut(&self) -> usize {
260 self.buf.capacity() - self.buf.written(Ordering::Acquire)
261 }
262
263 fn chunk_mut(&mut self) -> &mut UninitSlice {
264 let written = self.buf.written(Ordering::Acquire);
265 let cap = self.buf.capacity();
266 unsafe {
270 let ptr = self.buf.data_ptr().add(written);
271 UninitSlice::from_raw_parts_mut(ptr, cap - written)
272 }
273 }
274
275 unsafe fn advance_mut(&mut self, cnt: usize) {
276 let cap = self.buf.capacity();
277 let prev = self.buf.written(Ordering::Relaxed);
278 assert!(
279 prev + cnt <= cap,
280 "advance_mut past frame.size: prev={prev} cnt={cnt} cap={cap}"
281 );
282 unsafe { self.buf.store_written(prev + cnt) };
284
285 if let Ok(mut state) = self.state.write() {
288 if prev + cnt == cap {
289 state.fin = true;
290 }
291 }
292 }
293}
294
295impl Clone for FrameProducer {
296 fn clone(&self) -> Self {
297 Self {
298 info: self.info.clone(),
299 state: self.state.clone(),
300 buf: self.buf.clone(),
301 }
302 }
303}
304
305impl From<Frame> for FrameProducer {
306 fn from(info: Frame) -> Self {
307 FrameProducer::new(info)
308 }
309}
310
311#[derive(Clone)]
313pub struct FrameConsumer {
314 info: Frame,
315 state: kio::Consumer<FrameState>,
316 buf: FrameBuf,
317 read_idx: usize,
320}
321
322impl std::ops::Deref for FrameConsumer {
323 type Target = Frame;
324
325 fn deref(&self) -> &Self::Target {
326 &self.info
327 }
328}
329
330impl FrameConsumer {
331 fn poll<F, R>(&self, waiter: &kio::Waiter, f: F) -> Poll<Result<R>>
333 where
334 F: Fn(&kio::Ref<'_, FrameState>) -> Poll<Result<R>>,
335 {
336 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
337 Ok(res) => res,
338 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
339 })
340 }
341
342 fn snapshot(&self, read_idx: usize) -> Option<Bytes> {
343 let written = self.buf.written(Ordering::Acquire);
346 if written > read_idx {
347 Some(Bytes::from_owner(self.buf.clone()).slice(read_idx..written))
348 } else {
349 None
350 }
351 }
352
353 pub fn poll_read_all(&mut self, waiter: &kio::Waiter) -> Poll<Result<Bytes>> {
358 let read_idx = self.read_idx;
359 let res = ready!(self.poll(waiter, |state| {
360 if state.fin {
361 return Poll::Ready(Ok(()));
362 }
363 if let Some(err) = &state.abort {
364 return Poll::Ready(Err(err.clone()));
365 }
366 Poll::Pending
367 }));
368 match res {
369 Ok(()) => {
370 let bytes = self
372 .snapshot(read_idx)
373 .unwrap_or_else(|| Bytes::from_owner(self.buf.clone()).slice(read_idx..read_idx));
374 self.read_idx = self.buf.capacity();
375 Poll::Ready(Ok(bytes))
376 }
377 Err(e) => Poll::Ready(Err(e)),
378 }
379 }
380
381 pub async fn read_all(&mut self) -> Result<Bytes> {
383 kio::wait(|waiter| self.poll_read_all(waiter)).await
384 }
385
386 pub fn poll_read_all_chunks(&mut self, waiter: &kio::Waiter) -> Poll<Result<Vec<Bytes>>> {
389 let bytes = ready!(self.poll_read_all(waiter)?);
390 Poll::Ready(Ok(if bytes.is_empty() { Vec::new() } else { vec![bytes] }))
391 }
392
393 pub fn poll_read_chunk(&mut self, waiter: &kio::Waiter) -> Poll<Result<Option<Bytes>>> {
399 let read_idx = self.read_idx;
400 let res = ready!(self.poll(waiter, |state| {
401 let written = self.buf.written(Ordering::Acquire);
402 if written > read_idx {
403 return Poll::Ready(Ok(Some(written)));
404 }
405 if state.fin {
406 return Poll::Ready(Ok(None));
407 }
408 if let Some(err) = &state.abort {
409 return Poll::Ready(Err(err.clone()));
410 }
411 Poll::Pending
412 }));
413 match res {
414 Ok(Some(written)) => {
415 let bytes = Bytes::from_owner(self.buf.clone()).slice(read_idx..written);
416 self.read_idx = written;
417 Poll::Ready(Ok(Some(bytes)))
418 }
419 Ok(None) => Poll::Ready(Ok(None)),
420 Err(e) => Poll::Ready(Err(e)),
421 }
422 }
423
424 pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
426 kio::wait(|waiter| self.poll_read_chunk(waiter)).await
427 }
428
429 pub fn poll_read_chunks(&mut self, waiter: &kio::Waiter) -> Poll<Result<Vec<Bytes>>> {
432 match ready!(self.poll_read_chunk(waiter)?) {
433 Some(b) => Poll::Ready(Ok(vec![b])),
434 None => Poll::Ready(Ok(Vec::new())),
435 }
436 }
437
438 pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
440 kio::wait(|waiter| self.poll_read_chunks(waiter)).await
441 }
442}
443
444#[cfg(test)]
445mod test {
446 use super::*;
447 use futures::FutureExt;
448
449 #[test]
450 fn single_chunk_roundtrip() {
451 let mut producer = Frame { size: 5 }.produce();
452 producer.write(Bytes::from_static(b"hello")).unwrap();
453 producer.finish().unwrap();
454
455 let mut consumer = producer.consume();
456 let data = consumer.read_all().now_or_never().unwrap().unwrap();
457 assert_eq!(data, Bytes::from_static(b"hello"));
458 }
459
460 #[test]
461 fn multi_chunk_read_all() {
462 let mut producer = Frame { size: 10 }.produce();
463 producer.write(Bytes::from_static(b"hello")).unwrap();
464 producer.write(Bytes::from_static(b"world")).unwrap();
465 producer.finish().unwrap();
466
467 let mut consumer = producer.consume();
468 let data = consumer.read_all().now_or_never().unwrap().unwrap();
469 assert_eq!(data, Bytes::from_static(b"helloworld"));
470 }
471
472 #[test]
473 fn read_chunk_sequential() {
474 let mut producer = Frame { size: 10 }.produce();
475 producer.write(Bytes::from_static(b"hello")).unwrap();
476 let mut consumer = producer.consume();
479 let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
480 assert_eq!(c1, Some(Bytes::from_static(b"hello")));
481
482 producer.write(Bytes::from_static(b"world")).unwrap();
483 producer.finish().unwrap();
484
485 let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
486 assert_eq!(c2, Some(Bytes::from_static(b"world")));
487 let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
488 assert_eq!(c3, None);
489 }
490
491 #[test]
492 fn read_all_chunks() {
493 let mut producer = Frame { size: 10 }.produce();
494 producer.write(Bytes::from_static(b"hello")).unwrap();
495 producer.write(Bytes::from_static(b"world")).unwrap();
496 producer.finish().unwrap();
497
498 let mut consumer = producer.consume();
499 let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
500 assert_eq!(chunks.len(), 1);
501 assert_eq!(chunks[0], Bytes::from_static(b"helloworld"));
502 }
503
504 #[test]
505 fn finish_checks_remaining() {
506 let mut producer = Frame { size: 5 }.produce();
507 producer.write(Bytes::from_static(b"hi")).unwrap();
508 let err = producer.finish().unwrap_err();
509 assert!(matches!(err, Error::WrongSize));
510 }
511
512 #[test]
513 fn write_too_many_bytes() {
514 let mut producer = Frame { size: 3 }.produce();
515 let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
516 assert!(matches!(err, Error::WrongSize));
517 }
518
519 #[test]
520 fn abort_propagates() {
521 let mut producer = Frame { size: 5 }.produce();
522 let mut consumer = producer.consume();
523 producer.abort(Error::Cancel).unwrap();
524
525 let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
526 assert!(matches!(err, Error::Cancel));
527 }
528
529 #[test]
530 fn empty_frame() {
531 let mut producer = Frame { size: 0 }.produce();
532 producer.finish().unwrap();
533
534 let mut consumer = producer.consume();
535 let data = consumer.read_all().now_or_never().unwrap().unwrap();
536 assert_eq!(data, Bytes::new());
537 }
538
539 #[tokio::test]
540 async fn pending_then_ready() {
541 let mut producer = Frame { size: 5 }.produce();
542 let mut consumer = producer.consume();
543
544 assert!(consumer.read_all().now_or_never().is_none());
546
547 producer.write(Bytes::from_static(b"hello")).unwrap();
548 producer.finish().unwrap();
549
550 let data = consumer.read_all().now_or_never().unwrap().unwrap();
551 assert_eq!(data, Bytes::from_static(b"hello"));
552 }
553
554 #[test]
555 fn buf_mut_roundtrip() {
556 let mut producer = Frame { size: 12 }.produce();
558 assert_eq!(producer.remaining_mut(), 12);
559 producer.put_slice(b"hello");
560 assert_eq!(producer.remaining_mut(), 7);
561 producer.put_slice(b" world!");
562 assert_eq!(producer.remaining_mut(), 0);
563 producer.finish().unwrap();
564
565 let mut consumer = producer.consume();
566 let data = consumer.read_all().now_or_never().unwrap().unwrap();
567 assert_eq!(data, Bytes::from_static(b"hello world!"));
568 }
569
570 #[test]
571 #[should_panic(expected = "advance_mut past frame.size")]
572 fn buf_mut_advance_past_capacity_panics() {
573 let mut producer = Frame { size: 4 }.produce();
574 unsafe { producer.advance_mut(5) };
576 }
577
578 #[test]
579 fn read_chunk_streams_partial_writes() {
580 let mut producer = Frame { size: 6 }.produce();
581 let mut consumer = producer.consume();
582
583 producer.write(Bytes::from_static(b"foo")).unwrap();
584 let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
585 assert_eq!(c1, Some(Bytes::from_static(b"foo")));
586
587 assert!(consumer.read_chunk().now_or_never().is_none());
589
590 producer.write(Bytes::from_static(b"bar")).unwrap();
591 producer.finish().unwrap();
592 let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
593 assert_eq!(c2, Some(Bytes::from_static(b"bar")));
594 let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
595 assert_eq!(c3, None);
596 }
597
598 #[test]
599 fn cloned_consumer_independent_cursor() {
600 let mut producer = Frame { size: 10 }.produce();
601 let mut c1 = producer.consume();
602 producer.write(Bytes::from_static(b"hello")).unwrap();
603
604 let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
606 assert_eq!(chunk, Some(Bytes::from_static(b"hello")));
607 let mut c2 = c1.clone();
608
609 producer.write(Bytes::from_static(b"world")).unwrap();
610 producer.finish().unwrap();
611
612 let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
614 assert_eq!(chunk, Some(Bytes::from_static(b"world")));
615 let chunk = c2.read_chunk().now_or_never().unwrap().unwrap();
616 assert_eq!(chunk, Some(Bytes::from_static(b"world")));
617 }
618}