1use std::sync::Arc;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Poll, ready};
4
5use bytes::buf::UninitSlice;
6use bytes::{BufMut, Bytes};
7
8use crate::{Error, Result};
9
10#[derive(Clone, Debug)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct Frame {
17 pub size: u64,
19}
20
21impl Frame {
22 pub fn produce(self) -> FrameProducer {
24 FrameProducer::new(self)
25 }
26}
27
28impl From<usize> for Frame {
29 fn from(size: usize) -> Self {
30 Self { size: size as u64 }
31 }
32}
33
34impl From<u64> for Frame {
35 fn from(size: u64) -> Self {
36 Self { size }
37 }
38}
39
40impl From<u32> for Frame {
41 fn from(size: u32) -> Self {
42 Self { size: size as u64 }
43 }
44}
45
46impl From<u16> for Frame {
47 fn from(size: u16) -> Self {
48 Self { size: size as u64 }
49 }
50}
51
52#[derive(Clone)]
62struct FrameBuf(Arc<FrameBufInner>);
63
64struct FrameBufInner {
65 data: *mut u8,
67 capacity: usize,
68 written: AtomicUsize,
69}
70
71unsafe impl Send for FrameBufInner {}
75unsafe impl Sync for FrameBufInner {}
76
77impl Drop for FrameBufInner {
78 fn drop(&mut self) {
79 unsafe {
82 let slice = std::ptr::slice_from_raw_parts_mut(self.data, self.capacity);
83 drop(Box::from_raw(slice));
84 }
85 }
86}
87
88impl FrameBuf {
89 fn new(size: usize) -> Self {
90 let boxed: Box<[u8]> = vec![0u8; size].into_boxed_slice();
91 let capacity = boxed.len();
92 let data = Box::into_raw(boxed) as *mut u8;
93 Self(Arc::new(FrameBufInner {
94 data,
95 capacity,
96 written: AtomicUsize::new(0),
97 }))
98 }
99
100 fn capacity(&self) -> usize {
101 self.0.capacity
102 }
103
104 fn written(&self, ord: Ordering) -> usize {
105 self.0.written.load(ord)
106 }
107
108 unsafe fn data_ptr(&self) -> *mut u8 {
110 self.0.data
111 }
112
113 unsafe fn store_written(&self, new_written: usize) {
115 self.0.written.store(new_written, Ordering::Release);
117 }
118}
119
120impl AsRef<[u8]> for FrameBuf {
121 fn as_ref(&self) -> &[u8] {
122 let written = self.0.written.load(Ordering::Acquire);
125 unsafe { std::slice::from_raw_parts(self.0.data, written) }
129 }
130}
131
132#[derive(Default, Debug)]
133struct FrameState {
134 fin: bool,
136 abort: Option<Error>,
138}
139
140pub struct FrameProducer {
148 info: Frame,
149 state: conducer::Producer<FrameState>,
150 buf: FrameBuf,
151}
152
153impl std::ops::Deref for FrameProducer {
154 type Target = Frame;
155
156 fn deref(&self) -> &Self::Target {
157 &self.info
158 }
159}
160
161impl FrameProducer {
162 pub fn new(info: Frame) -> Self {
164 let buf = FrameBuf::new(info.size as usize);
165 Self {
166 info,
167 state: conducer::Producer::new(FrameState::default()),
168 buf,
169 }
170 }
171
172 pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
176 let chunk = chunk.into();
177 if chunk.len() > self.remaining_mut() {
178 return Err(Error::WrongSize);
179 }
180 self.bail_if_aborted()?;
182 self.put_slice(&chunk);
183 Ok(())
184 }
185
186 pub fn finish(&mut self) -> Result<()> {
190 let written = self.buf.written(Ordering::Acquire);
191 if written != self.buf.capacity() {
192 return Err(Error::WrongSize);
193 }
194 let mut state = self.modify()?;
196 state.fin = true;
197 Ok(())
198 }
199
200 pub fn abort(&mut self, err: Error) -> Result<()> {
202 let mut guard = self.modify()?;
203 guard.abort = Some(err);
204 guard.close();
205 Ok(())
206 }
207
208 pub fn consume(&self) -> FrameConsumer {
210 FrameConsumer {
211 info: self.info.clone(),
212 state: self.state.consume(),
213 buf: self.buf.clone(),
214 read_idx: 0,
215 }
216 }
217
218 pub async fn unused(&self) -> Result<()> {
220 self.state
221 .unused()
222 .await
223 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
224 }
225
226 fn modify(&mut self) -> Result<conducer::Mut<'_, FrameState>> {
227 self.state
228 .write()
229 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
230 }
231
232 fn bail_if_aborted(&self) -> Result<()> {
233 let state = self.state.read();
234 if let Some(err) = &state.abort {
235 return Err(err.clone());
236 }
237 Ok(())
238 }
239}
240
241unsafe impl BufMut for FrameProducer {
248 fn remaining_mut(&self) -> usize {
249 self.buf.capacity() - self.buf.written(Ordering::Acquire)
250 }
251
252 fn chunk_mut(&mut self) -> &mut UninitSlice {
253 let written = self.buf.written(Ordering::Acquire);
254 let cap = self.buf.capacity();
255 unsafe {
259 let ptr = self.buf.data_ptr().add(written);
260 UninitSlice::from_raw_parts_mut(ptr, cap - written)
261 }
262 }
263
264 unsafe fn advance_mut(&mut self, cnt: usize) {
265 let cap = self.buf.capacity();
266 let prev = self.buf.written(Ordering::Relaxed);
267 assert!(
268 prev + cnt <= cap,
269 "advance_mut past frame.size: prev={prev} cnt={cnt} cap={cap}"
270 );
271 unsafe { self.buf.store_written(prev + cnt) };
273
274 if let Ok(mut state) = self.state.write() {
277 if prev + cnt == cap {
278 state.fin = true;
279 }
280 }
281 }
282}
283
284impl Clone for FrameProducer {
285 fn clone(&self) -> Self {
286 Self {
287 info: self.info.clone(),
288 state: self.state.clone(),
289 buf: self.buf.clone(),
290 }
291 }
292}
293
294impl From<Frame> for FrameProducer {
295 fn from(info: Frame) -> Self {
296 FrameProducer::new(info)
297 }
298}
299
300#[derive(Clone)]
302pub struct FrameConsumer {
303 info: Frame,
304 state: conducer::Consumer<FrameState>,
305 buf: FrameBuf,
306 read_idx: usize,
309}
310
311impl std::ops::Deref for FrameConsumer {
312 type Target = Frame;
313
314 fn deref(&self) -> &Self::Target {
315 &self.info
316 }
317}
318
319impl FrameConsumer {
320 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
322 where
323 F: Fn(&conducer::Ref<'_, FrameState>) -> Poll<Result<R>>,
324 {
325 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
326 Ok(res) => res,
327 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
328 })
329 }
330
331 fn snapshot(&self, read_idx: usize) -> Option<Bytes> {
332 let written = self.buf.written(Ordering::Acquire);
335 if written > read_idx {
336 Some(Bytes::from_owner(self.buf.clone()).slice(read_idx..written))
337 } else {
338 None
339 }
340 }
341
342 pub fn poll_read_all(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Bytes>> {
347 let read_idx = self.read_idx;
348 let res = ready!(self.poll(waiter, |state| {
349 if state.fin {
350 return Poll::Ready(Ok(()));
351 }
352 if let Some(err) = &state.abort {
353 return Poll::Ready(Err(err.clone()));
354 }
355 Poll::Pending
356 }));
357 match res {
358 Ok(()) => {
359 let bytes = self
361 .snapshot(read_idx)
362 .unwrap_or_else(|| Bytes::from_owner(self.buf.clone()).slice(read_idx..read_idx));
363 self.read_idx = self.buf.capacity();
364 Poll::Ready(Ok(bytes))
365 }
366 Err(e) => Poll::Ready(Err(e)),
367 }
368 }
369
370 pub async fn read_all(&mut self) -> Result<Bytes> {
372 conducer::wait(|waiter| self.poll_read_all(waiter)).await
373 }
374
375 pub fn poll_read_all_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
378 let bytes = ready!(self.poll_read_all(waiter)?);
379 Poll::Ready(Ok(if bytes.is_empty() { Vec::new() } else { vec![bytes] }))
380 }
381
382 pub fn poll_read_chunk(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
388 let read_idx = self.read_idx;
389 let res = ready!(self.poll(waiter, |state| {
390 let written = self.buf.written(Ordering::Acquire);
391 if written > read_idx {
392 return Poll::Ready(Ok(Some(written)));
393 }
394 if state.fin {
395 return Poll::Ready(Ok(None));
396 }
397 if let Some(err) = &state.abort {
398 return Poll::Ready(Err(err.clone()));
399 }
400 Poll::Pending
401 }));
402 match res {
403 Ok(Some(written)) => {
404 let bytes = Bytes::from_owner(self.buf.clone()).slice(read_idx..written);
405 self.read_idx = written;
406 Poll::Ready(Ok(Some(bytes)))
407 }
408 Ok(None) => Poll::Ready(Ok(None)),
409 Err(e) => Poll::Ready(Err(e)),
410 }
411 }
412
413 pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
415 conducer::wait(|waiter| self.poll_read_chunk(waiter)).await
416 }
417
418 pub fn poll_read_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
421 match ready!(self.poll_read_chunk(waiter)?) {
422 Some(b) => Poll::Ready(Ok(vec![b])),
423 None => Poll::Ready(Ok(Vec::new())),
424 }
425 }
426
427 pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
429 conducer::wait(|waiter| self.poll_read_chunks(waiter)).await
430 }
431}
432
433#[cfg(test)]
434mod test {
435 use super::*;
436 use futures::FutureExt;
437
438 #[test]
439 fn single_chunk_roundtrip() {
440 let mut producer = Frame { size: 5 }.produce();
441 producer.write(Bytes::from_static(b"hello")).unwrap();
442 producer.finish().unwrap();
443
444 let mut consumer = producer.consume();
445 let data = consumer.read_all().now_or_never().unwrap().unwrap();
446 assert_eq!(data, Bytes::from_static(b"hello"));
447 }
448
449 #[test]
450 fn multi_chunk_read_all() {
451 let mut producer = Frame { size: 10 }.produce();
452 producer.write(Bytes::from_static(b"hello")).unwrap();
453 producer.write(Bytes::from_static(b"world")).unwrap();
454 producer.finish().unwrap();
455
456 let mut consumer = producer.consume();
457 let data = consumer.read_all().now_or_never().unwrap().unwrap();
458 assert_eq!(data, Bytes::from_static(b"helloworld"));
459 }
460
461 #[test]
462 fn read_chunk_sequential() {
463 let mut producer = Frame { size: 10 }.produce();
464 producer.write(Bytes::from_static(b"hello")).unwrap();
465 let mut consumer = producer.consume();
468 let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
469 assert_eq!(c1, Some(Bytes::from_static(b"hello")));
470
471 producer.write(Bytes::from_static(b"world")).unwrap();
472 producer.finish().unwrap();
473
474 let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
475 assert_eq!(c2, Some(Bytes::from_static(b"world")));
476 let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
477 assert_eq!(c3, None);
478 }
479
480 #[test]
481 fn read_all_chunks() {
482 let mut producer = Frame { size: 10 }.produce();
483 producer.write(Bytes::from_static(b"hello")).unwrap();
484 producer.write(Bytes::from_static(b"world")).unwrap();
485 producer.finish().unwrap();
486
487 let mut consumer = producer.consume();
488 let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
489 assert_eq!(chunks.len(), 1);
490 assert_eq!(chunks[0], Bytes::from_static(b"helloworld"));
491 }
492
493 #[test]
494 fn finish_checks_remaining() {
495 let mut producer = Frame { size: 5 }.produce();
496 producer.write(Bytes::from_static(b"hi")).unwrap();
497 let err = producer.finish().unwrap_err();
498 assert!(matches!(err, Error::WrongSize));
499 }
500
501 #[test]
502 fn write_too_many_bytes() {
503 let mut producer = Frame { size: 3 }.produce();
504 let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
505 assert!(matches!(err, Error::WrongSize));
506 }
507
508 #[test]
509 fn abort_propagates() {
510 let mut producer = Frame { size: 5 }.produce();
511 let mut consumer = producer.consume();
512 producer.abort(Error::Cancel).unwrap();
513
514 let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
515 assert!(matches!(err, Error::Cancel));
516 }
517
518 #[test]
519 fn empty_frame() {
520 let mut producer = Frame { size: 0 }.produce();
521 producer.finish().unwrap();
522
523 let mut consumer = producer.consume();
524 let data = consumer.read_all().now_or_never().unwrap().unwrap();
525 assert_eq!(data, Bytes::new());
526 }
527
528 #[tokio::test]
529 async fn pending_then_ready() {
530 let mut producer = Frame { size: 5 }.produce();
531 let mut consumer = producer.consume();
532
533 assert!(consumer.read_all().now_or_never().is_none());
535
536 producer.write(Bytes::from_static(b"hello")).unwrap();
537 producer.finish().unwrap();
538
539 let data = consumer.read_all().now_or_never().unwrap().unwrap();
540 assert_eq!(data, Bytes::from_static(b"hello"));
541 }
542
543 #[test]
544 fn buf_mut_roundtrip() {
545 let mut producer = Frame { size: 12 }.produce();
547 assert_eq!(producer.remaining_mut(), 12);
548 producer.put_slice(b"hello");
549 assert_eq!(producer.remaining_mut(), 7);
550 producer.put_slice(b" world!");
551 assert_eq!(producer.remaining_mut(), 0);
552 producer.finish().unwrap();
553
554 let mut consumer = producer.consume();
555 let data = consumer.read_all().now_or_never().unwrap().unwrap();
556 assert_eq!(data, Bytes::from_static(b"hello world!"));
557 }
558
559 #[test]
560 #[should_panic(expected = "advance_mut past frame.size")]
561 fn buf_mut_advance_past_capacity_panics() {
562 let mut producer = Frame { size: 4 }.produce();
563 unsafe { producer.advance_mut(5) };
565 }
566
567 #[test]
568 fn read_chunk_streams_partial_writes() {
569 let mut producer = Frame { size: 6 }.produce();
570 let mut consumer = producer.consume();
571
572 producer.write(Bytes::from_static(b"foo")).unwrap();
573 let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
574 assert_eq!(c1, Some(Bytes::from_static(b"foo")));
575
576 assert!(consumer.read_chunk().now_or_never().is_none());
578
579 producer.write(Bytes::from_static(b"bar")).unwrap();
580 producer.finish().unwrap();
581 let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
582 assert_eq!(c2, Some(Bytes::from_static(b"bar")));
583 let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
584 assert_eq!(c3, None);
585 }
586
587 #[test]
588 fn cloned_consumer_independent_cursor() {
589 let mut producer = Frame { size: 10 }.produce();
590 let mut c1 = producer.consume();
591 producer.write(Bytes::from_static(b"hello")).unwrap();
592
593 let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
595 assert_eq!(chunk, Some(Bytes::from_static(b"hello")));
596 let mut c2 = c1.clone();
597
598 producer.write(Bytes::from_static(b"world")).unwrap();
599 producer.finish().unwrap();
600
601 let chunk = c1.read_chunk().now_or_never().unwrap().unwrap();
603 assert_eq!(chunk, Some(Bytes::from_static(b"world")));
604 let chunk = c2.read_chunk().now_or_never().unwrap().unwrap();
605 assert_eq!(chunk, Some(Bytes::from_static(b"world")));
606 }
607}