1use std::sync::Arc;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Poll, ready};
4
5use bytes::buf::UninitSlice;
6use bytes::{BufMut, Bytes};
7
8use crate::{Error, Result};
9
10pub(crate) const MAX_FRAME_SIZE: u64 = 32 * 1024 * 1024;
26
27#[derive(Clone, Debug)]
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33pub struct Frame {
34 pub size: u64,
36}
37
38impl Frame {
39 pub fn produce(self) -> FrameProducer {
41 FrameProducer::new(self)
42 }
43}
44
45impl From<usize> for Frame {
46 fn from(size: usize) -> Self {
47 Self { size: size as u64 }
48 }
49}
50
51impl From<u64> for Frame {
52 fn from(size: u64) -> Self {
53 Self { size }
54 }
55}
56
57impl From<u32> for Frame {
58 fn from(size: u32) -> Self {
59 Self { size: size as u64 }
60 }
61}
62
63impl From<u16> for Frame {
64 fn from(size: u16) -> Self {
65 Self { size: size as u64 }
66 }
67}
68
69#[derive(Clone)]
79struct FrameBuf(Arc<FrameBufInner>);
80
81struct FrameBufInner {
82 data: *mut u8,
84 capacity: usize,
85 written: AtomicUsize,
86}
87
88unsafe impl Send for FrameBufInner {}
92unsafe impl Sync for FrameBufInner {}
93
94impl Drop for FrameBufInner {
95 fn drop(&mut self) {
96 unsafe {
99 let slice = std::ptr::slice_from_raw_parts_mut(self.data, self.capacity);
100 drop(Box::from_raw(slice));
101 }
102 }
103}
104
105impl FrameBuf {
106 fn new(size: usize) -> Self {
107 let boxed: Box<[u8]> = vec![0u8; size].into_boxed_slice();
108 let capacity = boxed.len();
109 let data = Box::into_raw(boxed) as *mut u8;
110 Self(Arc::new(FrameBufInner {
111 data,
112 capacity,
113 written: AtomicUsize::new(0),
114 }))
115 }
116
117 fn capacity(&self) -> usize {
118 self.0.capacity
119 }
120
121 fn written(&self, ord: Ordering) -> usize {
122 self.0.written.load(ord)
123 }
124
125 unsafe fn data_ptr(&self) -> *mut u8 {
127 self.0.data
128 }
129
130 unsafe fn store_written(&self, new_written: usize) {
132 self.0.written.store(new_written, Ordering::Release);
134 }
135}
136
137impl AsRef<[u8]> for FrameBuf {
138 fn as_ref(&self) -> &[u8] {
139 let written = self.0.written.load(Ordering::Acquire);
142 unsafe { std::slice::from_raw_parts(self.0.data, written) }
146 }
147}
148
149#[derive(Default, Debug)]
150struct FrameState {
151 fin: bool,
153 abort: Option<Error>,
155}
156
157pub struct FrameProducer {
165 info: Frame,
166 state: kio::Producer<FrameState>,
167 buf: FrameBuf,
168}
169
170impl std::ops::Deref for FrameProducer {
171 type Target = Frame;
172
173 fn deref(&self) -> &Self::Target {
174 &self.info
175 }
176}
177
178impl FrameProducer {
179 pub fn new(info: Frame) -> Self {
181 let buf = FrameBuf::new(info.size as usize);
182 Self {
183 info,
184 state: kio::Producer::new(FrameState::default()),
185 buf,
186 }
187 }
188
189 pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
193 let chunk = chunk.into();
194 if chunk.len() > self.remaining_mut() {
195 return Err(Error::WrongSize);
196 }
197 self.bail_if_aborted()?;
199 self.put_slice(&chunk);
200 Ok(())
201 }
202
203 pub fn finish(&mut self) -> Result<()> {
207 let written = self.buf.written(Ordering::Acquire);
208 if written != self.buf.capacity() {
209 return Err(Error::WrongSize);
210 }
211 let mut state = self.modify()?;
213 state.fin = true;
214 Ok(())
215 }
216
217 pub fn abort(&mut self, err: Error) -> Result<()> {
219 let mut guard = self.modify()?;
220 guard.abort = Some(err);
221 guard.close();
222 Ok(())
223 }
224
225 pub fn consume(&self) -> FrameConsumer {
227 FrameConsumer {
228 info: self.info.clone(),
229 state: self.state.consume(),
230 buf: self.buf.clone(),
231 read_idx: 0,
232 }
233 }
234
235 pub async fn unused(&self) -> Result<()> {
237 self.state
238 .unused()
239 .await
240 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
241 }
242
243 fn modify(&mut self) -> Result<kio::Mut<'_, FrameState>> {
244 self.state
245 .write()
246 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
247 }
248
249 fn bail_if_aborted(&self) -> Result<()> {
250 let state = self.state.read();
251 if let Some(err) = &state.abort {
252 return Err(err.clone());
253 }
254 Ok(())
255 }
256}
257
258unsafe impl BufMut for FrameProducer {
265 fn remaining_mut(&self) -> usize {
266 self.buf.capacity() - self.buf.written(Ordering::Acquire)
267 }
268
269 fn chunk_mut(&mut self) -> &mut UninitSlice {
270 let written = self.buf.written(Ordering::Acquire);
271 let cap = self.buf.capacity();
272 unsafe {
276 let ptr = self.buf.data_ptr().add(written);
277 UninitSlice::from_raw_parts_mut(ptr, cap - written)
278 }
279 }
280
281 unsafe fn advance_mut(&mut self, cnt: usize) {
282 let cap = self.buf.capacity();
283 let prev = self.buf.written(Ordering::Relaxed);
284 assert!(
285 prev + cnt <= cap,
286 "advance_mut past frame.size: prev={prev} cnt={cnt} cap={cap}"
287 );
288 unsafe { self.buf.store_written(prev + cnt) };
290
291 if let Ok(mut state) = self.state.write() {
294 if prev + cnt == cap {
295 state.fin = true;
296 }
297 }
298 }
299}
300
301impl Clone for FrameProducer {
302 fn clone(&self) -> Self {
303 Self {
304 info: self.info.clone(),
305 state: self.state.clone(),
306 buf: self.buf.clone(),
307 }
308 }
309}
310
311impl From<Frame> for FrameProducer {
312 fn from(info: Frame) -> Self {
313 FrameProducer::new(info)
314 }
315}
316
317#[derive(Clone)]
319pub struct FrameConsumer {
320 info: Frame,
321 state: kio::Consumer<FrameState>,
322 buf: FrameBuf,
323 read_idx: usize,
326}
327
328impl std::ops::Deref for FrameConsumer {
329 type Target = Frame;
330
331 fn deref(&self) -> &Self::Target {
332 &self.info
333 }
334}
335
336impl FrameConsumer {
337 fn poll<F, R>(&self, waiter: &kio::Waiter, f: F) -> Poll<Result<R>>
339 where
340 F: Fn(&kio::Ref<'_, FrameState>) -> Poll<Result<R>>,
341 {
342 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
343 Ok(res) => res,
344 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
345 })
346 }
347
348 fn snapshot(&self, read_idx: usize) -> Option<Bytes> {
349 let written = self.buf.written(Ordering::Acquire);
352 if written > read_idx {
353 Some(Bytes::from_owner(self.buf.clone()).slice(read_idx..written))
354 } else {
355 None
356 }
357 }
358
359 pub fn poll_read_all(&mut self, waiter: &kio::Waiter) -> Poll<Result<Bytes>> {
364 let read_idx = self.read_idx;
365 let res = ready!(self.poll(waiter, |state| {
366 if state.fin {
367 return Poll::Ready(Ok(()));
368 }
369 if let Some(err) = &state.abort {
370 return Poll::Ready(Err(err.clone()));
371 }
372 Poll::Pending
373 }));
374 match res {
375 Ok(()) => {
376 let bytes = self
378 .snapshot(read_idx)
379 .unwrap_or_else(|| Bytes::from_owner(self.buf.clone()).slice(read_idx..read_idx));
380 self.read_idx = self.buf.capacity();
381 Poll::Ready(Ok(bytes))
382 }
383 Err(e) => Poll::Ready(Err(e)),
384 }
385 }
386
387 pub async fn read_all(&mut self) -> Result<Bytes> {
389 kio::wait(|waiter| self.poll_read_all(waiter)).await
390 }
391
392 pub fn poll_read_all_chunks(&mut self, waiter: &kio::Waiter) -> Poll<Result<Vec<Bytes>>> {
395 let bytes = ready!(self.poll_read_all(waiter)?);
396 Poll::Ready(Ok(if bytes.is_empty() { Vec::new() } else { vec![bytes] }))
397 }
398
399 pub fn poll_read_chunk(&mut self, waiter: &kio::Waiter) -> Poll<Result<Option<Bytes>>> {
405 let read_idx = self.read_idx;
406 let res = ready!(self.poll(waiter, |state| {
407 let written = self.buf.written(Ordering::Acquire);
408 if written > read_idx {
409 return Poll::Ready(Ok(Some(written)));
410 }
411 if state.fin {
412 return Poll::Ready(Ok(None));
413 }
414 if let Some(err) = &state.abort {
415 return Poll::Ready(Err(err.clone()));
416 }
417 Poll::Pending
418 }));
419 match res {
420 Ok(Some(written)) => {
421 let bytes = Bytes::from_owner(self.buf.clone()).slice(read_idx..written);
422 self.read_idx = written;
423 Poll::Ready(Ok(Some(bytes)))
424 }
425 Ok(None) => Poll::Ready(Ok(None)),
426 Err(e) => Poll::Ready(Err(e)),
427 }
428 }
429
430 pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
432 kio::wait(|waiter| self.poll_read_chunk(waiter)).await
433 }
434
435 pub fn poll_read_chunks(&mut self, waiter: &kio::Waiter) -> Poll<Result<Vec<Bytes>>> {
438 match ready!(self.poll_read_chunk(waiter)?) {
439 Some(b) => Poll::Ready(Ok(vec![b])),
440 None => Poll::Ready(Ok(Vec::new())),
441 }
442 }
443
444 pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
446 kio::wait(|waiter| self.poll_read_chunks(waiter)).await
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use super::*;
453 use futures::FutureExt;
454
455 #[test]
456 fn single_chunk_roundtrip() {
457 let mut producer = Frame { size: 5 }.produce();
458 producer.write(Bytes::from_static(b"hello")).unwrap();
459 producer.finish().unwrap();
460
461 let mut consumer = producer.consume();
462 let data = consumer.read_all().now_or_never().unwrap().unwrap();
463 assert_eq!(data, Bytes::from_static(b"hello"));
464 }
465
466 #[test]
467 fn multi_chunk_read_all() {
468 let mut producer = Frame { size: 10 }.produce();
469 producer.write(Bytes::from_static(b"hello")).unwrap();
470 producer.write(Bytes::from_static(b"world")).unwrap();
471 producer.finish().unwrap();
472
473 let mut consumer = producer.consume();
474 let data = consumer.read_all().now_or_never().unwrap().unwrap();
475 assert_eq!(data, Bytes::from_static(b"helloworld"));
476 }
477
478 #[test]
479 fn read_chunk_sequential() {
480 let mut producer = Frame { size: 10 }.produce();
481 producer.write(Bytes::from_static(b"hello")).unwrap();
482 let mut consumer = producer.consume();
485 let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
486 assert_eq!(c1, Some(Bytes::from_static(b"hello")));
487
488 producer.write(Bytes::from_static(b"world")).unwrap();
489 producer.finish().unwrap();
490
491 let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
492 assert_eq!(c2, Some(Bytes::from_static(b"world")));
493 let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
494 assert_eq!(c3, None);
495 }
496
497 #[test]
498 fn read_all_chunks() {
499 let mut producer = Frame { size: 10 }.produce();
500 producer.write(Bytes::from_static(b"hello")).unwrap();
501 producer.write(Bytes::from_static(b"world")).unwrap();
502 producer.finish().unwrap();
503
504 let mut consumer = producer.consume();
505 let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
506 assert_eq!(chunks.len(), 1);
507 assert_eq!(chunks[0], Bytes::from_static(b"helloworld"));
508 }
509
510 #[test]
511 fn finish_checks_remaining() {
512 let mut producer = Frame { size: 5 }.produce();
513 producer.write(Bytes::from_static(b"hi")).unwrap();
514 let err = producer.finish().unwrap_err();
515 assert!(matches!(err, Error::WrongSize));
516 }
517
518 #[test]
519 fn write_too_many_bytes() {
520 let mut producer = Frame { size: 3 }.produce();
521 let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
522 assert!(matches!(err, Error::WrongSize));
523 }
524
525 #[test]
526 fn abort_propagates() {
527 let mut producer = Frame { size: 5 }.produce();
528 let mut consumer = producer.consume();
529 producer.abort(Error::Cancel).unwrap();
530
531 let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
532 assert!(matches!(err, Error::Cancel));
533 }
534
535 #[test]
536 fn empty_frame() {
537 let mut producer = Frame { size: 0 }.produce();
538 producer.finish().unwrap();
539
540 let mut consumer = producer.consume();
541 let data = consumer.read_all().now_or_never().unwrap().unwrap();
542 assert_eq!(data, Bytes::new());
543 }
544
545 #[tokio::test]
546 async fn pending_then_ready() {
547 let mut producer = Frame { size: 5 }.produce();
548 let mut consumer = producer.consume();
549
550 assert!(consumer.read_all().now_or_never().is_none());
552
553 producer.write(Bytes::from_static(b"hello")).unwrap();
554 producer.finish().unwrap();
555
556 let data = consumer.read_all().now_or_never().unwrap().unwrap();
557 assert_eq!(data, Bytes::from_static(b"hello"));
558 }
559
560 #[test]
561 fn buf_mut_roundtrip() {
562 let mut producer = Frame { size: 12 }.produce();
564 assert_eq!(producer.remaining_mut(), 12);
565 producer.put_slice(b"hello");
566 assert_eq!(producer.remaining_mut(), 7);
567 producer.put_slice(b" world!");
568 assert_eq!(producer.remaining_mut(), 0);
569 producer.finish().unwrap();
570
571 let mut consumer = producer.consume();
572 let data = consumer.read_all().now_or_never().unwrap().unwrap();
573 assert_eq!(data, Bytes::from_static(b"hello world!"));
574 }
575
576 #[test]
577 #[should_panic(expected = "advance_mut past frame.size")]
578 fn buf_mut_advance_past_capacity_panics() {
579 let mut producer = Frame { size: 4 }.produce();
580 unsafe { producer.advance_mut(5) };
582 }
583
584 #[test]
585 fn read_chunk_streams_partial_writes() {
586 let mut producer = Frame { size: 6 }.produce();
587 let mut consumer = producer.consume();
588
589 producer.write(Bytes::from_static(b"foo")).unwrap();
590 let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
591 assert_eq!(c1, Some(Bytes::from_static(b"foo")));
592
593 assert!(consumer.read_chunk().now_or_never().is_none());
595
596 producer.write(Bytes::from_static(b"bar")).unwrap();
597 producer.finish().unwrap();
598 let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
599 assert_eq!(c2, Some(Bytes::from_static(b"bar")));
600 let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
601 assert_eq!(c3, None);
602 }
603
604 #[test]
605 fn cloned_consumer_independent_cursor() {
606 let mut producer = Frame { size: 10 }.produce();
607 let mut c1 = producer.consume();
608 producer.write(Bytes::from_static(b"hello")).unwrap();
609
610 let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
612 assert_eq!(chunk, Some(Bytes::from_static(b"hello")));
613 let mut c2 = c1.clone();
614
615 producer.write(Bytes::from_static(b"world")).unwrap();
616 producer.finish().unwrap();
617
618 let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
620 assert_eq!(chunk, Some(Bytes::from_static(b"world")));
621 let chunk = c2.read_chunk().now_or_never().unwrap().unwrap();
622 assert_eq!(chunk, Some(Bytes::from_static(b"world")));
623 }
624}