Skip to main content

selium_messaging/
lib.rs

1#![deny(missing_docs)]
2//! In-memory asynchronous channel implementation with configurable backpressure.
3
4use std::{cell::UnsafeCell, collections::VecDeque, io, sync::Arc, task::Waker};
5
6#[cfg(feature = "loom")]
7use loom::sync::{
8    Mutex, RwLock,
9    atomic::{AtomicBool, AtomicU64, Ordering},
10};
11#[cfg(not(feature = "loom"))]
12use std::sync::{
13    Mutex, RwLock,
14    atomic::{AtomicBool, AtomicU64, Ordering},
15};
16
17use stable_vec::StableVec;
18use thiserror::Error;
19use tracing::{Span, debug, field::Empty, instrument};
20
21mod driver;
22mod id_factory;
23mod reader;
24mod writer;
25
26pub use driver::{ChannelDriver, ChannelStrongIoDriver, ChannelWeakIoDriver};
27pub use reader::{Reader, StrongReader, WeakReader};
28pub use writer::{StrongWriter, WeakWriter, Writer};
29
30use crate::id_factory::{Id, IdFactory};
31
32type Result<T> = std::result::Result<T, ChannelError>;
33
34/// Backpressure policy for writers when the buffer is full.
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub enum Backpressure {
37    /// Writers park (Pending) until space is available.
38    Park,
39    /// Writers drop data when no space is available (return Ok(0)).
40    Drop,
41}
42
43/// Intermediate storage backing every [`Channel`].
44///
45/// # Safety
46/// The ring buffer maintains monotonically increasing cursors and is written to
47/// exclusively through the channel’s slice reservation helpers. Those helpers
48/// ensure writers and readers never alias; the raw pointer copies inside
49/// [`RingBuffer::read`] and [`RingBuffer::write`] therefore rely on the caller
50/// upholding those preconditions. All public API entry points funnel through
51/// the safe [`Channel`] methods, so additional unsafe usage must preserve the
52/// same guarantees.
53struct RingBuffer {
54    /// Internal storage for the buffer
55    buf: UnsafeCell<Box<[u8]>>,
56    /// Size of `buf`
57    size: usize,
58    /// Used to convert an incremental position into a valid index of `buf`
59    mask: u64,
60}
61
62#[derive(Clone)]
63struct FrameMeta {
64    start: u64,
65    len: u64,
66    writer_id: u16,
67}
68
69/// In-memory, asynchronous, many-to-many byte channel with explicit backpressure semantics.
70///
71/// Writers and readers operate on a shared ring buffer while maintaining their own
72/// monotonic cursors. Reserving slices (writers) or releasing read positions (readers)
73/// is the only way to advance those cursors, ensuring deterministic progress and
74/// preventing aliasing of the underlying buffer.
75pub struct Channel {
76    /// Underlying data store
77    buf: RingBuffer,
78    /// Queue of Wakers to be woken once the given position is reached
79    queue: Mutex<Vec<(u64, Waker)>>,
80    /// Position of each `StrongReader`
81    heads: RwLock<StableVec<Arc<AtomicU64>>>,
82    /// Position of each `Writer`
83    tails: RwLock<StableVec<Arc<AtomicU64>>>,
84    /// Writer Id generator
85    idf: Arc<IdFactory>,
86    /// Frame metadata keyed by start position.
87    frames: Mutex<VecDeque<FrameMeta>>,
88    /// Maximum position for removed tails.
89    ///
90    /// This is needed when `tails` is empty but we still need to know where the tail is.
91    tail_cache: AtomicU64,
92    /// Next position that a `Writer` can write to
93    next_tail: AtomicU64,
94    /// Whether the channel is terminated
95    terminated: AtomicBool,
96    /// Whether the channel is draining
97    draining: AtomicBool,
98    /// Backpressure behavior when writers encounter a full buffer
99    backpressure: Backpressure,
100    /// Tracing span for instrumentation
101    span: Span,
102}
103
104/// Channel operation errors.
105#[derive(Error, Debug, PartialEq)]
106pub enum ChannelError {
107    /// Reader was too slow and data was dropped; value is number of lost bytes.
108    #[error("Reader was too slow and got left behind")]
109    ReaderBehind(u64),
110    /// User asked channel to terminate whilst draining.
111    #[error("Cannot terminate channel that is draining")]
112    TerminateDraining,
113    /// User asked channel to drain after terminating.
114    #[error("Cannot drain a terminated channel")]
115    DrainTerminated,
116    /// Generic IO error surfaced by channel operations.
117    #[error("io error: {0}")]
118    Io(String),
119}
120
121impl From<io::Error> for ChannelError {
122    fn from(value: io::Error) -> Self {
123        Self::Io(value.to_string())
124    }
125}
126
127impl RingBuffer {
128    fn new(mut size: usize) -> Self {
129        size = size.next_power_of_two();
130        // Allocate a buffer with an initialized length of `size` bytes.
131        // Using `with_capacity` would create a zero-length boxed slice,
132        // leading to out-of-bounds pointer arithmetic and UB during reads/writes.
133        let buf: Vec<u8> = vec![0u8; size];
134
135        Self {
136            buf: UnsafeCell::new(buf.into_boxed_slice()),
137            size,
138            mask: (size - 1) as u64,
139        }
140    }
141
142    /// Reads `self.buf` into `dst` until full.
143    ///
144    /// # Safety
145    /// The caller must ensure no concurrent writer is touching the overlapping region.
146    ///
147    /// # Panics
148    ///
149    /// This function panics if `dst` is larger than `self.buf`.
150    unsafe fn read(&self, dst: &mut [u8], pos: u64) {
151        let dstlen = dst.len().min(self.size);
152
153        // Convert position to array index
154        let idx: usize = (pos & self.mask)
155            .try_into()
156            .expect("pointer size less than 64b");
157
158        let taillen = dstlen.min(self.size - idx);
159        let headlen = idx.min(dstlen - taillen);
160
161        // Get a raw pointer to the start of the underlying byte slice
162        let src = unsafe { (&mut *self.buf.get()).as_mut_ptr() };
163        let dst = dst.as_mut_ptr();
164
165        // Write tail to dst
166        unsafe {
167            src.add(idx).copy_to_nonoverlapping(dst, taillen);
168        }
169
170        if headlen > 0 {
171            // Write head to dst
172            unsafe {
173                src.copy_to_nonoverlapping(dst.add(taillen), headlen);
174            }
175        }
176    }
177
178    /// Writes `src` into `self.buf` until exhausted.
179    ///
180    /// # Safety
181    /// The caller must ensure no concurrent reader consumes the overlapping region.
182    ///
183    /// # Panics
184    ///
185    /// This function panics if `src` is larger than `self.buf`.
186    unsafe fn write(&self, src: &[u8], pos: u64) {
187        let srclen = src.len().min(self.size);
188
189        // Convert position to array index
190        let idx: usize = (pos & self.mask)
191            .try_into()
192            .expect("pointer size less than 64b");
193
194        let taillen = srclen.min(self.size - idx);
195        let headlen = idx.min(srclen - taillen);
196
197        let src = src.as_ptr();
198        // Get a raw pointer to the start of the underlying byte slice
199        let dst = unsafe { (&mut *self.buf.get()).as_mut_ptr() };
200
201        // Write tail to dst
202        unsafe {
203            src.copy_to_nonoverlapping(dst.add(idx), taillen);
204        }
205
206        if headlen > 0 {
207            // Write head to dst
208            unsafe {
209                src.add(taillen).copy_to_nonoverlapping(dst, headlen);
210            }
211        }
212    }
213}
214
215unsafe impl Sync for RingBuffer {}
216
217impl Channel {
218    /// Create a channel with the provided capacity in bytes.
219    pub fn new(size: usize) -> Arc<Self> {
220        Self::with_parameters(size, Backpressure::Park)
221    }
222
223    /// Create a channel with the given parameters.
224    #[instrument(name = "Channel", skip_all, fields(ptr = Empty))]
225    pub fn with_parameters(size: usize, backpressure: Backpressure) -> Arc<Self> {
226        let this = Arc::new(Self {
227            buf: RingBuffer::new(size),
228            queue: Mutex::new(Vec::new()),
229            heads: RwLock::new(StableVec::new()),
230            tails: RwLock::new(StableVec::new()),
231            idf: IdFactory::new(),
232            frames: Mutex::new(VecDeque::new()),
233            tail_cache: AtomicU64::new(0),
234            next_tail: AtomicU64::new(0),
235            terminated: AtomicBool::new(false),
236            draining: AtomicBool::new(false),
237            backpressure,
238            span: Span::current(),
239        });
240
241        this.span
242            .record("ptr", format_args!("{:p}", this.as_ref() as *const _));
243        debug!("create channel");
244
245        this
246    }
247
248    fn get_head(&self) -> Option<u64> {
249        let heads = self
250            .heads
251            .read()
252            .unwrap_or_else(|poison| poison.into_inner());
253        self.get_head_locked(&heads)
254    }
255
256    fn get_head_locked(&self, heads: &StableVec<Arc<AtomicU64>>) -> Option<u64> {
257        heads.iter().map(|(_, t)| t.load(Ordering::Acquire)).min()
258    }
259
260    fn get_tail(&self) -> u64 {
261        let tails = self
262            .tails
263            .read()
264            .unwrap_or_else(|poison| poison.into_inner());
265        self.get_tail_locked(&tails)
266    }
267
268    fn get_tail_locked(&self, tails: &StableVec<Arc<AtomicU64>>) -> u64 {
269        tails
270            .iter()
271            .map(|(_, t)| t.load(Ordering::Acquire))
272            .min()
273            .unwrap_or(self.tail_cache.load(Ordering::Acquire))
274    }
275
276    fn reader_start_pos(&self, head: Option<u64>) -> u64 {
277        if let Some(head) = head {
278            return head;
279        }
280
281        let tail = self.get_tail();
282        let floor = tail.saturating_sub(self.buf.size as u64);
283        let frames = self
284            .frames
285            .lock()
286            .unwrap_or_else(|poison| poison.into_inner());
287
288        frames
289            .iter()
290            .find(|frame| frame.start >= floor)
291            .map(|frame| frame.start)
292            .unwrap_or(tail)
293    }
294
295    fn remove_head(&self, idx: usize) {
296        self.heads
297            .write()
298            .unwrap_or_else(|poison| poison.into_inner())
299            .remove(idx);
300        self.prune_frames();
301    }
302
303    fn remove_tail(&self, idx: usize) {
304        let mut tails = self
305            .tails
306            .write()
307            .unwrap_or_else(|poison| poison.into_inner());
308        if let Some(tail) = tails.remove(idx) {
309            self.tail_cache
310                .fetch_max(tail.load(Ordering::Acquire), Ordering::AcqRel);
311        }
312    }
313
314    fn register_frame(&self, start: u64, len: u64, writer_id: u16) {
315        let mut frames = self
316            .frames
317            .lock()
318            .unwrap_or_else(|poison| poison.into_inner());
319        frames.push_back(FrameMeta {
320            start,
321            len,
322            writer_id,
323        });
324    }
325
326    fn frame_for(&self, pos: u64) -> Option<FrameMeta> {
327        let frames = self
328            .frames
329            .lock()
330            .unwrap_or_else(|poison| poison.into_inner());
331        frames.iter().find(|frame| frame.start == pos).cloned()
332    }
333
334    fn frame_from(&self, pos: u64) -> Option<FrameMeta> {
335        let frames = self
336            .frames
337            .lock()
338            .unwrap_or_else(|poison| poison.into_inner());
339        frames.iter().find(|frame| frame.start >= pos).cloned()
340    }
341
342    fn prune_frames(&self) {
343        let head = match self.get_head() {
344            Some(head) => head,
345            None => self.get_tail().saturating_sub(self.buf.size as u64),
346        };
347
348        let mut frames = self
349            .frames
350            .lock()
351            .unwrap_or_else(|poison| poison.into_inner());
352        while let Some(frame) = frames.front() {
353            if frame.start + frame.len <= head {
354                frames.pop_front();
355            } else {
356                break;
357            }
358        }
359    }
360
361    /// Remaining writable capacity before the tail would overrun the head.
362    fn writable_size(&self, pos: u64) -> u64 {
363        (self.buf.size as u64).saturating_sub(pos - self.get_head().unwrap_or(pos))
364    }
365
366    /// Copy `buf` into the ring buffer at logical position `pos` without advancing cursors.
367    fn write(&self, pos: u64, buf: &[u8]) -> usize {
368        // Calculate max buf length to prevent overrunning the head
369        let len = (buf.len() as u64).min(self.writable_size(pos));
370
371        // No point proceeding if no space
372        if len == 0 {
373            return 0;
374        }
375
376        let ulen: usize = len.try_into().expect("pointer size less than 64b");
377
378        // Safety: we prevent overrunning the head position, however the caller must
379        // ensure that they abide by the position and length provided by `reserve_slice()`.
380        unsafe { self.buf.write(&buf[..ulen], pos) };
381
382        ulen
383    }
384
385    /// Read into `buf`, returning an error if the cursor has been lapped by writers.
386    fn read(&self, pos: u64, buf: &mut [u8]) -> Result<usize> {
387        // Ensure the given head position hasn't been overwritten
388        let tail = self.get_tail();
389        if pos + (self.buf.size as u64) < tail {
390            return Err(ChannelError::ReaderBehind(tail - self.buf.size as u64));
391        }
392
393        Ok(unsafe { self.read_unsafe(pos, buf) })
394    }
395
396    /// Read into `buf` without validating that `pos` has not been overwritten.
397    ///
398    /// # Safety
399    /// The caller must ensure `pos` remains ahead of the oldest writer.
400    unsafe fn read_unsafe(&self, pos: u64, buf: &mut [u8]) -> usize {
401        // Calculate max buf length to prevent overrunning the tail
402        let len = (buf.len() as u64).min(self.get_tail().saturating_sub(pos));
403
404        // No point proceeding if no space
405        if len == 0 {
406            return 0;
407        }
408
409        let ulen: usize = len.try_into().expect("pointer size less than 64b");
410
411        // Safety: we never allow `buf` to overrun the tail position, so races are
412        // impossible unless the head position has been overrun. Head overruns are
413        // impossible for strong readers, though this function is inappropriate for
414        // weak readers.
415        unsafe { self.buf.read(&mut buf[..ulen], pos) };
416
417        ulen
418    }
419
420    /// Queue a `Waker` to be woken once `pos` has been reached by either the head or tail.
421    ///
422    /// Note: If the given position is less than the head position, it will never be woken.
423    #[instrument(parent = &self.span, skip(self, waker))]
424    fn enqueue(&self, pos: u64, waker: Waker) {
425        debug!(pos, "channel enqueue");
426        self.queue
427            .lock()
428            .unwrap_or_else(|poison| poison.into_inner())
429            .push((pos, waker));
430    }
431
432    /// Wake any writers whose reserved spans are now safe to fill.
433    #[instrument(parent = &self.span, skip(self))]
434    fn schedule_writers(&self) {
435        // If nothing to schedule, exit early
436        if self
437            .tails
438            .read()
439            .unwrap_or_else(|poison| poison.into_inner())
440            .is_empty()
441        {
442            return;
443        }
444
445        let tail_pos = self.get_tail();
446        let head_pos = self.get_head().unwrap_or(tail_pos);
447        let mut queue = self
448            .queue
449            .lock()
450            .unwrap_or_else(|poison| poison.into_inner());
451
452        debug!(
453            queued = queue.len(),
454            head_pos, tail_pos, "channel schedule_writers"
455        );
456
457        // Dequeue and wake each Waker in a writable position
458        queue
459            .extract_if(.., |(pos, _)| {
460                let wake = *pos < (head_pos + self.buf.size as u64) && *pos >= tail_pos;
461                if wake {
462                    debug!(pos, "channel wake writer");
463                }
464                wake
465            })
466            .for_each(|(_, waker)| waker.wake());
467    }
468
469    /// Wake any readers that now have data available.
470    #[instrument(parent = &self.span, skip(self))]
471    fn schedule_readers(&self) {
472        let tail_pos = self.get_tail();
473        let mut queue = self
474            .queue
475            .lock()
476            .unwrap_or_else(|poison| poison.into_inner());
477
478        debug!(queued = queue.len(), tail_pos, "channel schedule_readers");
479        // Dequeue and wake each Waker in a readable position
480        queue
481            .extract_if(.., |(pos, _)| {
482                let wake = *pos < tail_pos;
483                if wake {
484                    debug!(pos, "channel wake reader");
485                }
486                wake
487            })
488            .for_each(|(_, waker)| waker.wake());
489    }
490
491    /// Create a new writer to push bytes to the channel.
492    ///
493    /// # Concurrency
494    /// Writers can work somewhat concurrently, given enough buffer space. This works by
495    /// allowing multiple writers to reserve and write contiguous buffer "slices",
496    /// provided that those slices do not overwrite any part of the buffer being consumed
497    /// by a strong reader. Readers are not aware of slices and will read the buffer in
498    /// the order that slices are written. If a writer hangs while writing a slice, no
499    /// subsequent slices will be read.
500    pub fn new_writer(self: &Arc<Channel>) -> Writer {
501        Writer::Strong(self.new_strong_writer())
502    }
503
504    /// Create a new strong writer.
505    pub fn new_strong_writer(self: &Arc<Channel>) -> StrongWriter {
506        self.new_strong_writer_with_id(self.idf.generate())
507    }
508
509    fn new_strong_writer_with_id(self: &Arc<Channel>, id: Id) -> StrongWriter {
510        let mut tails = self
511            .tails
512            .write()
513            .unwrap_or_else(|poison| poison.into_inner());
514        let pos = Arc::new(AtomicU64::new(self.get_tail_locked(&tails)));
515        let pos_id = tails.push(pos.clone());
516        drop(tails);
517
518        StrongWriter::new(id, self.clone(), pos, Some(pos_id))
519    }
520
521    /// Create a new weak writer that acquires tail slots on demand.
522    pub fn new_weak_writer(self: &Arc<Channel>) -> WeakWriter {
523        WeakWriter::new(self.idf.generate(), self.clone())
524    }
525
526    /// Create a new reader that applies backpressure to writers if they are trying to
527    /// overwrite the reader's log position.
528    ///
529    /// Note that the entire channel will be as slow as the *slowest* strong reader.
530    pub fn new_strong_reader(self: &Arc<Channel>) -> StrongReader {
531        let mut heads = self
532            .heads
533            .write()
534            .unwrap_or_else(|poison| poison.into_inner());
535        let pos = Arc::new(AtomicU64::new(
536            self.reader_start_pos(self.get_head_locked(&heads)),
537        ));
538        let id = heads.push(pos.clone());
539        drop(heads);
540
541        StrongReader::new(self.clone(), pos, id)
542    }
543
544    /// Create a new reader that can be overtaken by writers if the reader is too slow.
545    ///
546    /// If you require a reader that cannot lose any data, create a `StrongReader` instead.
547    pub fn new_weak_reader(self: &Arc<Channel>) -> WeakReader {
548        WeakReader::new(self.clone(), self.reader_start_pos(self.get_head()))
549    }
550
551    /// Reserve a tail position of given length to write to.
552    ///
553    /// # Caution!
554    /// Failing to write the entire slice will result in *permanent backpressure.*
555    pub fn reserve_slice(&self, len: u64) -> u64 {
556        self.next_tail.fetch_add(len, Ordering::SeqCst)
557    }
558
559    /// Safely terminate this channel.
560    ///
561    /// This will cause any readers or writers to return `io::ErrorKind::ConnectionAborted`.
562    #[instrument(parent = &self.span, skip(self))]
563    pub fn terminate(&self) -> Result<()> {
564        debug!("terminate channel");
565
566        if self.draining.load(Ordering::Acquire) {
567            return Err(ChannelError::TerminateDraining);
568        }
569        self.terminated.store(true, Ordering::Release);
570
571        // Notify all queued wakers so they don't hang indefinitely
572        self.queue
573            .lock()
574            .unwrap_or_else(|poison| poison.into_inner())
575            .drain(..)
576            .for_each(|(_, waker)| waker.wake());
577
578        Ok(())
579    }
580
581    /// Start draining the channel.
582    ///
583    /// When draining, a channel doesn't accept any new frames from `Writer`s, and rejects reads
584    /// once they catch up to the buffer tail.
585    #[instrument(parent = &self.span, skip(self))]
586    pub fn drain(&self) -> Result<()> {
587        debug!("start draining channel");
588
589        if self.terminated.load(Ordering::Acquire) {
590            Err(ChannelError::DrainTerminated)
591        } else {
592            self.draining.store(true, Ordering::Release);
593            Ok(())
594        }
595    }
596}
597
598impl Drop for Channel {
599    fn drop(&mut self) {
600        let _ = self.terminate();
601    }
602}
603
604#[cfg(all(test, not(feature = "loom")))]
605mod tests {
606    use std::{
607        pin::pin,
608        task::{Context, Poll},
609    };
610
611    use futures::task::{noop_waker, noop_waker_ref};
612    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
613
614    use super::*;
615
616    #[test]
617    fn drop_backpressure_returns_zero_when_full() {
618        use std::task::Poll;
619
620        use futures::task::{Context, noop_waker};
621
622        let channel = Channel::with_parameters(2, Backpressure::Drop);
623        let _reader_guard = channel.new_strong_reader();
624        let waker = noop_waker();
625        let mut cx = Context::from_waker(&waker);
626        let writer = channel.new_writer();
627        let mut pinned = pin!(writer);
628
629        match pinned.as_mut().poll_write(&mut cx, &[1, 2]) {
630            Poll::Ready(Ok(2)) => {}
631            other => panic!("unexpected poll result: {:?}", other),
632        }
633        match pinned.as_mut().poll_write(&mut cx, &[3, 4]) {
634            Poll::Ready(Ok(0)) => {}
635            other => panic!("unexpected poll result: {:?}", other),
636        }
637    }
638
639    #[test]
640    fn drop_backpressure_leaves_existing_bytes_intact() {
641        use std::task::Poll;
642
643        use futures::task::{Context, noop_waker};
644
645        let channel = Channel::with_parameters(2, Backpressure::Drop);
646        let reader = channel.new_strong_reader();
647        let waker = noop_waker();
648        let mut cx = Context::from_waker(&waker);
649        let writer = channel.new_writer();
650        let mut pinned_writer = pin!(writer);
651
652        match pinned_writer.as_mut().poll_write(&mut cx, &[7, 8]) {
653            Poll::Ready(Ok(2)) => {}
654            other => panic!("unexpected poll result: {:?}", other),
655        }
656        match pinned_writer.as_mut().poll_write(&mut cx, &[9, 10, 11]) {
657            Poll::Ready(Ok(0)) => {}
658            other => panic!("unexpected poll result: {:?}", other),
659        }
660
661        let mut pinned_reader = pin!(reader);
662        let mut buf = [0u8; 2];
663        let mut rb = ReadBuf::new(&mut buf);
664        match pinned_reader.as_mut().poll_read(&mut cx, &mut rb) {
665            Poll::Ready(Ok(())) if rb.filled().len() == 2 => {}
666            other => panic!("unexpected poll result: {:?}", other),
667        }
668        assert_eq!(&buf, &[7, 8]);
669    }
670
671    #[test]
672    fn ring_buffer_write_clamps_to_capacity() {
673        let ring = RingBuffer::new(8);
674        let data = vec![42u8; 32];
675        unsafe { ring.write(&data, 0) };
676        let mut buf = [0u8; 8];
677        unsafe { ring.read(&mut buf, 0) };
678        assert!(buf.iter().all(|b| *b == 42));
679    }
680
681    #[test]
682    fn ring_buffer_read_large_destination_stays_within_bounds() {
683        let ring = RingBuffer::new(8);
684        let data: Vec<u8> = (0u8..8).collect();
685        unsafe { ring.write(&data, 0) };
686        let mut dst = [0u8; 16];
687        unsafe { ring.read(&mut dst, 0) };
688        assert_eq!(&dst[..8], &data[..]);
689        assert!(dst[8..].iter().all(|b| *b == 0));
690    }
691
692    fn new_buf(size: u8) -> RingBuffer {
693        let buf = RingBuffer::new(size as usize);
694        unsafe {
695            // Write test data directly into the underlying byte slice
696            let dst = (&mut *buf.buf.get()).as_mut_ptr();
697            dst.copy_from_nonoverlapping((0..size).collect::<Vec<_>>().as_ptr(), size as usize);
698        }
699        buf
700    }
701
702    #[test]
703    fn test_ring_read_small() {
704        let mut dst = [0; 2];
705
706        let buf = new_buf(4);
707        unsafe { buf.read(&mut dst, 1) };
708
709        assert_eq!(dst, [1, 2]);
710    }
711
712    #[test]
713    fn test_ring_read_large() {
714        let mut dst = [0; 4];
715
716        let buf = new_buf(4);
717        unsafe { buf.read(&mut dst, 1) };
718
719        assert_eq!(dst, [1, 2, 3, 0]);
720    }
721
722    #[test]
723    fn test_ring_read_too_large() {
724        let mut dst = [0; 5];
725
726        let buf = new_buf(4);
727        unsafe { buf.read(&mut dst, 1) };
728
729        assert_eq!(&dst[..4], [1, 2, 3, 0].as_ref());
730        assert_eq!(dst[4], 0);
731    }
732
733    #[test]
734    fn test_ring_write_small() {
735        let src = [4; 2];
736
737        let buf = new_buf(4);
738        unsafe { buf.write(&src, 1) };
739
740        let mut dst = [0; 4];
741        unsafe {
742            let src = (&mut *buf.buf.get()).as_mut_ptr();
743            src.copy_to_nonoverlapping(dst.as_mut_ptr(), 4)
744        };
745        assert_eq!(dst, [0, 4, 4, 3]);
746    }
747
748    #[test]
749    fn test_ring_write_large() {
750        let src = [4; 4];
751
752        let buf = new_buf(4);
753        unsafe { buf.write(&src, 1) };
754
755        let mut dst = [0; 4];
756        unsafe {
757            let src = (&mut *buf.buf.get()).as_mut_ptr();
758            src.copy_to_nonoverlapping(dst.as_mut_ptr(), 4)
759        };
760        assert_eq!(dst, [4, 4, 4, 4]);
761    }
762
763    #[test]
764    fn test_ring_write_too_large() {
765        let src = [4; 5];
766
767        let buf = new_buf(4);
768        unsafe { buf.write(&src, 1) };
769
770        let mut dst = [0; 4];
771        unsafe {
772            let src = (&mut *buf.buf.get()).as_mut_ptr();
773            src.copy_to_nonoverlapping(dst.as_mut_ptr(), 4)
774        };
775        assert_eq!(dst, [4, 4, 4, 4]);
776    }
777
778    #[test]
779    fn test_channel_write() {
780        let channel = Channel::new(4);
781        assert_eq!(channel.write(0, &[]), 0);
782        assert_eq!(channel.write(4, &[0; 3]), 3);
783        assert_eq!(channel.write(1, &[0; 4]), 4);
784        assert_eq!(channel.write(1, &[0; 5]), 4);
785
786        channel
787            .heads
788            .write()
789            .unwrap()
790            .push(Arc::new(AtomicU64::new(1)));
791
792        assert_eq!(channel.write(5, &[0; 3]), 0);
793        assert_eq!(channel.write(3, &[0; 3]), 2);
794    }
795
796    #[test]
797    fn test_channel_write_returns_zero_when_full() {
798        let channel = Channel::new(4);
799        channel
800            .heads
801            .write()
802            .unwrap()
803            .push(Arc::new(AtomicU64::new(0)));
804        assert_eq!(channel.write(0, &[1, 2, 3, 4]), 4);
805        assert_eq!(channel.write(4, &[9]), 0);
806    }
807
808    #[test]
809    fn test_channel_read() {
810        let channel = Channel::new(4);
811        let mut buf = [0; 3];
812
813        let tail = Arc::new(AtomicU64::new(2));
814        channel.tails.write().unwrap().push(tail.clone());
815
816        assert_eq!(channel.read(0, &mut buf).unwrap(), 2);
817
818        tail.store(5, Ordering::Release);
819
820        assert_eq!(
821            channel.read(0, &mut buf).unwrap_err(),
822            ChannelError::ReaderBehind(1)
823        );
824    }
825
826    #[test]
827    fn test_channel_read_unsafe() {
828        let channel = Channel::new(4);
829        let mut buf = [0; 3];
830
831        assert_eq!(unsafe { channel.read_unsafe(0, &mut buf) }, 0);
832
833        let tail = Arc::new(AtomicU64::new(2));
834        channel.tails.write().unwrap().push(tail.clone());
835
836        assert_eq!(channel.read(0, &mut buf).unwrap(), 2);
837
838        tail.store(5, Ordering::Release);
839
840        assert_eq!(channel.read(1, &mut buf).unwrap(), 3);
841        assert_eq!(channel.read(3, &mut buf).unwrap(), 2);
842    }
843
844    #[test]
845    fn test_channel_schedule_writers() {
846        let channel = Channel::new(4);
847        channel
848            .tails
849            .write()
850            .unwrap()
851            .push(Arc::new(AtomicU64::new(0)));
852
853        let mut queue = channel.queue.lock().unwrap();
854        queue.push((0, noop_waker()));
855        queue.push((1, noop_waker()));
856        queue.push((4, noop_waker()));
857        drop(queue);
858
859        channel.schedule_writers();
860
861        let queue = channel.queue.lock().unwrap();
862        assert_eq!(queue.len(), 1);
863        assert_eq!(queue.first().unwrap().0, 4);
864    }
865
866    #[test]
867    fn test_channel_schedule_readers() {
868        let channel = Channel::new(4);
869        channel
870            .tails
871            .write()
872            .unwrap()
873            .push(Arc::new(AtomicU64::new(2)));
874        channel
875            .heads
876            .write()
877            .unwrap()
878            .push(Arc::new(AtomicU64::new(1)));
879
880        let mut queue = channel.queue.lock().unwrap();
881        queue.push((1, noop_waker()));
882        queue.push((2, noop_waker()));
883        drop(queue);
884
885        channel.schedule_readers();
886
887        let queue = channel.queue.lock().unwrap();
888        assert_eq!(queue.len(), 1);
889        assert_eq!(queue.first().unwrap().0, 2);
890    }
891
892    #[test]
893    fn test_channel_new_writer() {
894        let channel = Arc::new(Channel::new(4));
895        let writer = channel.new_writer();
896        assert_eq!(channel.tails.read().unwrap().num_elements(), 1);
897        drop(writer);
898        assert!(channel.tails.read().unwrap().is_empty());
899    }
900
901    #[test]
902    fn test_channel_new_strong_reader() {
903        let channel = Arc::new(Channel::new(4));
904        channel
905            .tails
906            .write()
907            .unwrap()
908            .push(Arc::new(AtomicU64::new(5)));
909        let reader = channel.new_strong_reader();
910        assert_eq!(channel.heads.read().unwrap().num_elements(), 1);
911        assert_eq!(reader.pos.load(Ordering::Acquire), 5);
912        drop(reader);
913        assert!(channel.heads.read().unwrap().is_empty());
914    }
915
916    #[test]
917    fn test_writer_poll_write() {
918        let mut cx = Context::from_waker(noop_waker_ref());
919        let channel = Arc::new(Channel::new(4));
920        channel
921            .heads
922            .write()
923            .unwrap()
924            .push(Arc::new(AtomicU64::new(0)));
925        let mut writer = pin!(channel.new_strong_writer());
926
927        assert!(matches!(
928            writer.as_mut().poll_write(&mut cx, &[1, 2, 3]),
929            Poll::Ready(Ok(3))
930        ));
931        assert_eq!(channel.next_tail.load(Ordering::Acquire), 3);
932        assert_eq!(writer.pos.load(Ordering::Acquire), 3);
933        assert_eq!(writer.rem, 0);
934
935        assert!(matches!(
936            writer.as_mut().poll_write(&mut cx, &[1, 2, 3]),
937            Poll::Ready(Ok(1))
938        ));
939        assert_eq!(channel.next_tail.load(Ordering::Acquire), 6);
940        assert_eq!(writer.pos.load(Ordering::Acquire), 4);
941        assert_eq!(writer.rem, 2);
942
943        assert!(writer.as_mut().poll_write(&mut cx, &[1, 2, 3]).is_pending());
944    }
945
946    #[test]
947    fn test_writer_poll_strong_read() {
948        let mut cx = Context::from_waker(noop_waker_ref());
949        let channel = Arc::new(Channel::new(4));
950        channel
951            .tails
952            .write()
953            .unwrap()
954            .push(Arc::new(AtomicU64::new(4)));
955        unsafe {
956            let dst = (&mut *channel.buf.buf.get()).as_mut_ptr();
957            dst.copy_from_nonoverlapping((1..=4).collect::<Vec<_>>().as_ptr(), 4);
958        }
959        let mut reader = pin!(channel.new_strong_reader());
960        reader.pos.store(0, Ordering::Release);
961
962        let mut buf = [0; 3];
963        let mut rb = ReadBuf::new(&mut buf);
964        assert!(matches!(
965            reader.as_mut().poll_read(&mut cx, &mut rb),
966            Poll::Ready(Ok(()))
967        ));
968        assert_eq!(rb.filled().len(), 3);
969        assert_eq!(reader.pos.load(Ordering::Acquire), 3);
970
971        let mut buf = [0; 3];
972        let mut rb = ReadBuf::new(&mut buf);
973        assert!(matches!(
974            reader.as_mut().poll_read(&mut cx, &mut rb),
975            Poll::Ready(Ok(()))
976        ));
977        assert_eq!(rb.filled().len(), 1);
978        assert_eq!(reader.pos.load(Ordering::Acquire), 4);
979
980        let mut buf = [0; 3];
981        let mut rb = ReadBuf::new(&mut buf);
982        assert!(matches!(
983            reader.as_mut().poll_read(&mut cx, &mut rb),
984            Poll::Pending
985        ));
986    }
987
988    #[test]
989    fn test_writer_poll_weak_read() {
990        let mut cx = Context::from_waker(noop_waker_ref());
991        let channel = Arc::new(Channel::new(4));
992        let tail_pos = Arc::new(AtomicU64::new(4));
993        channel.tails.write().unwrap().push(tail_pos.clone());
994        unsafe {
995            let dst = (&mut *channel.buf.buf.get()).as_mut_ptr();
996            dst.copy_from_nonoverlapping((1..=4).collect::<Vec<_>>().as_ptr(), 4);
997        }
998        let mut reader = pin!(channel.new_weak_reader());
999        reader.pos = 0;
1000
1001        let mut buf = [0; 4];
1002        let mut rb = ReadBuf::new(&mut buf);
1003        assert!(matches!(
1004            reader.as_mut().poll_read(&mut cx, &mut rb),
1005            Poll::Ready(Ok(()))
1006        ));
1007        assert_eq!(rb.filled().len(), 4);
1008        assert_eq!(reader.pos, 4);
1009
1010        let mut buf = [0; 1];
1011        let mut rb = ReadBuf::new(&mut buf);
1012        assert!(matches!(
1013            reader.as_mut().poll_read(&mut cx, &mut rb),
1014            Poll::Pending
1015        ));
1016
1017        tail_pos.store(9, Ordering::Release);
1018
1019        let mut buf = [0; 1];
1020        let mut rb = ReadBuf::new(&mut buf);
1021        assert!(matches!(
1022            reader.as_mut().poll_read(&mut cx, &mut rb),
1023            Poll::Ready(Err(_))
1024        ));
1025    }
1026
1027    #[test]
1028    fn test_writer_wraparound_preserves_order() {
1029        let mut cx = Context::from_waker(noop_waker_ref());
1030        let channel = Arc::new(Channel::new(4));
1031        let mut writer = pin!(channel.new_strong_writer());
1032        let mut reader = pin!(channel.new_strong_reader());
1033
1034        let mut buf = [0u8; 3];
1035        {
1036            let mut rb = ReadBuf::new(&mut buf);
1037            assert!(matches!(
1038                writer.as_mut().poll_write(&mut cx, &[1, 2, 3]),
1039                Poll::Ready(Ok(3))
1040            ));
1041            assert!(matches!(
1042                reader.as_mut().poll_read(&mut cx, &mut rb),
1043                Poll::Ready(Ok(()))
1044            ));
1045            assert_eq!(rb.filled().len(), 3);
1046        }
1047        assert_eq!(buf, [1, 2, 3]);
1048
1049        let mut buf2 = [0u8; 2];
1050        let mut rb2 = ReadBuf::new(&mut buf2);
1051        assert!(matches!(
1052            writer.as_mut().poll_write(&mut cx, &[4, 5]),
1053            Poll::Ready(Ok(2))
1054        ));
1055        assert!(matches!(
1056            reader.as_mut().poll_read(&mut cx, &mut rb2),
1057            Poll::Ready(Ok(()))
1058        ));
1059        assert_eq!(rb2.filled().len(), 2);
1060        assert_eq!(buf2, [4, 5]);
1061    }
1062
1063    #[test]
1064    fn test_terminate() {
1065        let mut cx = Context::from_waker(noop_waker_ref());
1066        let channel = Arc::new(Channel::new(4));
1067        let mut writer = channel.new_writer();
1068        let mut strong_reader = pin!(channel.new_strong_reader());
1069        let mut weak_reader = pin!(channel.new_weak_reader());
1070        let mut weak_reader2 = pin!(channel.new_weak_reader());
1071        let mut buf = [0; 4];
1072        let mut rb = ReadBuf::new(&mut buf);
1073
1074        writer.terminate();
1075        let mut writer = pin!(writer);
1076        assert!(check_poll_aborted(
1077            writer.as_mut().poll_write(&mut cx, &[]).map_ok(|_| ())
1078        ));
1079
1080        assert!(matches!(
1081            strong_reader.as_mut().poll_read(&mut cx, &mut rb),
1082            Poll::Pending
1083        ));
1084        assert!(matches!(
1085            weak_reader.as_mut().poll_read(&mut cx, &mut rb),
1086            Poll::Pending
1087        ));
1088        assert!(matches!(
1089            weak_reader2.as_mut().poll_read(&mut cx, &mut rb),
1090            Poll::Pending
1091        ));
1092
1093        strong_reader.terminate();
1094        assert!(check_poll_aborted(
1095            strong_reader.poll_read(&mut cx, &mut rb)
1096        ));
1097
1098        assert!(matches!(
1099            weak_reader.as_mut().poll_read(&mut cx, &mut rb),
1100            Poll::Pending
1101        ));
1102        assert!(matches!(
1103            weak_reader2.as_mut().poll_read(&mut cx, &mut rb),
1104            Poll::Pending
1105        ));
1106
1107        weak_reader.terminate();
1108        assert!(check_poll_aborted(weak_reader.poll_read(&mut cx, &mut rb)));
1109
1110        assert!(matches!(
1111            weak_reader2.as_mut().poll_read(&mut cx, &mut rb),
1112            Poll::Pending
1113        ));
1114
1115        channel.terminate().unwrap();
1116        assert!(check_poll_aborted(weak_reader2.poll_read(&mut cx, &mut rb)));
1117    }
1118
1119    #[test]
1120    fn strong_reader_terminate_aborts_read() {
1121        let mut cx = Context::from_waker(noop_waker_ref());
1122        let channel = Arc::new(Channel::new(16));
1123        let reader = channel.new_strong_reader();
1124        reader.terminate();
1125        let mut reader = pin!(reader);
1126        let mut buf = [0u8; 1];
1127        let mut rb = ReadBuf::new(&mut buf);
1128        assert!(check_poll_aborted(
1129            reader.as_mut().poll_read(&mut cx, &mut rb)
1130        ));
1131    }
1132
1133    #[test]
1134    fn writer_terminate_aborts_write() {
1135        let mut cx = Context::from_waker(noop_waker_ref());
1136        let channel = Arc::new(Channel::new(16));
1137        let mut writer = channel.new_writer();
1138        writer.terminate();
1139        let mut writer = pin!(writer);
1140        assert!(check_poll_aborted(
1141            writer.as_mut().poll_write(&mut cx, &[]).map_ok(|_| ())
1142        ));
1143    }
1144
1145    fn check_poll_aborted(poll: Poll<std::io::Result<()>>) -> bool {
1146        match poll {
1147            Poll::Ready(r) => r.is_err_and(|e| e.kind() == std::io::ErrorKind::ConnectionAborted),
1148            _ => false,
1149        }
1150    }
1151
1152    #[test]
1153    fn test_drop_backpressure_writer_returns_zero_when_full() {
1154        let mut cx = Context::from_waker(noop_waker_ref());
1155        // Small buffer to hit full condition quickly
1156        let channel = Arc::new(Channel::with_parameters(4, Backpressure::Drop));
1157        let mut writer = pin!(channel.new_writer());
1158        // Add a strong reader pinned at start (0) so head stays at 0
1159        let _reader = pin!(channel.new_strong_reader());
1160
1161        // First write fills the buffer
1162        assert!(matches!(
1163            writer.as_mut().poll_write(&mut cx, &[1, 2, 3, 4]),
1164            Poll::Ready(Ok(4))
1165        ));
1166        // Next write finds no space and, in Drop mode, returns Ok(0)
1167        assert!(matches!(
1168            writer.as_mut().poll_write(&mut cx, &[9]),
1169            Poll::Ready(Ok(0))
1170        ));
1171    }
1172}
1173
1174#[cfg(all(test, feature = "loom"))]
1175mod loom_tests {
1176    use futures::future;
1177    use loom::future::block_on;
1178    use tokio::io::{AsyncReadExt, AsyncWriteExt};
1179
1180    use super::*;
1181
1182    #[test]
1183    fn strong_reader_and_writer_progress() {
1184        loom::model(|| {
1185            let channel = Channel::new(4);
1186            let mut reader = channel.new_strong_reader();
1187            let mut writer = channel.new_writer();
1188
1189            block_on(async move {
1190                let mut buf = [0u8; 1];
1191                let read = reader.read_exact(&mut buf);
1192                let write = writer.write_all(&[42]);
1193                let (_w, r) = future::join(write, read).await;
1194                r.unwrap();
1195                assert_eq!(buf[0], 42);
1196            });
1197        });
1198    }
1199}