rivulet/
circular_buffer.rs

1#![cfg(all(feature = "std"))]
2#![cfg_attr(docsrs, doc(cfg(all(feature = "std"))))]
3//! An asynchronous copy-free circular buffer.
4//!
5//! This buffer is optimized for contiguous memory segments and never copies data to other regions
6//! of the buffer.
7use crate::{
8    error::GrantOverflow,
9    splittable::{SplittableViewImpl, SplittableViewImplMut},
10    View, ViewMut,
11};
12use futures::task::AtomicWaker;
13use num_integer::{div_ceil, lcm};
14use std::{
15    convert::TryInto,
16    mem::{size_of, MaybeUninit},
17    pin::Pin,
18    sync::{
19        atomic::{AtomicBool, AtomicU64, Ordering},
20        Arc, Mutex,
21    },
22    task::{Context, Poll, Waker},
23};
24
25struct UnsafeCircularBuffer<T> {
26    ptr: *mut T,
27    size: usize,
28}
29
30unsafe impl<T> Send for UnsafeCircularBuffer<T> where T: Send {}
31unsafe impl<T> Sync for UnsafeCircularBuffer<T> where T: Send {}
32
33impl<T> Drop for UnsafeCircularBuffer<T> {
34    fn drop(&mut self) {
35        // Safety: the underlying storage is always initialized upon construction, and is safe to
36        // drop and unmap.
37        unsafe {
38            for i in 0..self.size {
39                std::ptr::drop_in_place(self.ptr.add(i));
40            }
41            vmap::os::unmap_ring(self.ptr as *mut u8, self.size * size_of::<T>()).unwrap();
42        }
43    }
44}
45
46impl<T: Default> UnsafeCircularBuffer<T> {
47    pub fn new(minimum_size: usize) -> Self {
48        // Determine the smallest buffer larger than minimum_size that is both a multiple of the
49        // allocation size and the type size.
50        let size_bytes = {
51            let granularity = lcm(vmap::allocation_size(), size_of::<T>());
52            div_ceil(minimum_size * size_of::<T>(), granularity)
53                .checked_mul(granularity)
54                .unwrap()
55        };
56        let size = size_bytes / size_of::<T>();
57
58        // Initialize the buffer memory
59        // Safety: `map_ring` returns an uninitialized slice.
60        let ptr = unsafe {
61            let ptr = vmap::os::map_ring(size_bytes).unwrap() as *mut T;
62            for v in std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<T>, size) {
63                v.as_mut_ptr().write(T::default());
64            }
65            ptr
66        };
67
68        Self { ptr, size }
69    }
70}
71
72impl<T> UnsafeCircularBuffer<T> {
73    pub fn len(&self) -> usize {
74        self.size
75    }
76
77    // Only safe if you can guarantee no mutable references to this range
78    pub unsafe fn range(&self, index: u64, len: usize) -> &[T] {
79        debug_assert!(len <= self.len());
80        let buf_len: u64 = self.len().try_into().unwrap();
81        let offset = index % buf_len;
82        std::slice::from_raw_parts(self.ptr.add(offset.try_into().unwrap()), len)
83    }
84
85    // Only safe if you can guarantee no other references to the same range
86    #[allow(clippy::mut_from_ref)]
87    pub unsafe fn range_mut(&self, index: u64, len: usize) -> &mut [T] {
88        debug_assert!(len <= self.len());
89        let buf_len: u64 = self.len().try_into().unwrap();
90        let offset = index % buf_len;
91        std::slice::from_raw_parts_mut(self.ptr.add(offset.try_into().unwrap()), len)
92    }
93}
94
95/// Shared state
96struct State<T> {
97    buffer: UnsafeCircularBuffer<T>,
98    closed: AtomicBool,       // true if the stream is closed
99    head: AtomicU64,          // start index of written data
100    tail: AtomicU64,          // start index of unwritten data
101    write_waker: AtomicWaker, // waker waited on by the writer
102    read_waker: Mutex<Option<Box<dyn Fn() + Send + Sync>>>, // wake readers when new data is available
103}
104
105impl<T: Default> State<T> {
106    fn new(minimum_size: usize) -> Self {
107        // The +1 ensures there's room for a marker element (to indicate the difference between
108        // empty and full
109        Self {
110            buffer: UnsafeCircularBuffer::new(minimum_size + 1),
111            closed: AtomicBool::new(false),
112            head: AtomicU64::new(0),
113            tail: AtomicU64::new(0),
114            write_waker: AtomicWaker::new(),
115            read_waker: Mutex::new(None),
116        }
117    }
118}
119
120impl<T> State<T> {
121    fn readable_len(&self, start: u64) -> usize {
122        (self.tail.load(Ordering::Relaxed) - start)
123            .try_into()
124            .unwrap()
125    }
126
127    fn writeable_len(&self) -> usize {
128        self.buffer.len() - self.readable_len(self.head.load(Ordering::Relaxed))
129    }
130}
131
132/// The writer of a circular buffer.
133///
134/// Writes made to this become available at the associated [`Source`].
135pub struct Sink<T> {
136    state: Arc<State<T>>,
137    tail: u64,
138    available: usize,
139    read_waker: Option<Box<dyn Fn() + Send + Sync>>,
140}
141
142impl<T> Sink<T> {
143    fn new(state: Arc<State<T>>) -> Self {
144        Self {
145            state,
146            tail: 0,
147            available: 0,
148            read_waker: None,
149        }
150    }
151
152    fn wake_readers(&mut self) {
153        if self.read_waker.is_none() {
154            let mut lock = self
155                .state
156                .read_waker
157                .lock()
158                .expect("another thread panicked");
159            std::mem::swap(&mut *lock, &mut self.read_waker);
160        }
161        if let Some(read_waker) = self.read_waker.as_ref() {
162            read_waker()
163        }
164    }
165}
166
167impl<T> Drop for Sink<T> {
168    fn drop(&mut self) {
169        self.state.closed.store(true, Ordering::Relaxed);
170        self.wake_readers(); // waiting readers can exit without sufficient data
171    }
172}
173
174impl<T> View for Sink<T> {
175    type Item = T;
176    type Error = GrantOverflow;
177
178    fn view(&self) -> &[T] {
179        // Safety: this region is owned exclusively by the writer.
180        unsafe { self.state.buffer.range(self.tail, self.available) }
181    }
182
183    fn poll_grant(
184        mut self: Pin<&mut Self>,
185        cx: &mut Context,
186        count: usize,
187    ) -> Poll<Result<(), GrantOverflow>> {
188        if count > self.state.buffer.len() {
189            return Poll::Ready(Err(GrantOverflow(self.state.buffer.len())));
190        }
191
192        if self.available >= count {
193            return Poll::Ready(Ok(()));
194        }
195
196        // Perform double-checking on the amount of available data
197        // The first check is efficient, but may spuriously fail.
198        // The second check occurs after the `acquire` produced by registering the waker.
199        let available = self.state.writeable_len();
200        if available >= count {
201            self.available = available;
202            Poll::Ready(Ok(()))
203        } else {
204            self.state.write_waker.register(cx.waker());
205            let available = self.state.writeable_len();
206            if available >= count || self.state.closed.load(Ordering::Relaxed) {
207                self.available = available;
208                Poll::Ready(Ok(()))
209            } else {
210                Poll::Pending
211            }
212        }
213    }
214
215    fn release(&mut self, count: usize) {
216        if count == 0 {
217            return;
218        }
219
220        assert!(
221            count <= self.available,
222            "attempted to release more than current grant"
223        );
224
225        // Advance the buffer
226        self.available -= count;
227        let count: u64 = count.try_into().unwrap();
228        self.tail += count;
229        self.state.tail.store(self.tail, Ordering::Relaxed);
230        self.wake_readers();
231    }
232}
233
234impl<T> ViewMut for Sink<T> {
235    fn view_mut(&mut self) -> &mut [T] {
236        // Safety: this region is owned exclusively by the writer.
237        unsafe { self.state.buffer.range_mut(self.tail, self.available) }
238    }
239}
240
241/// The reader of a circular buffer.
242///
243/// Writes made to the associated [`Sink`] are made available to this.
244pub struct Source<T> {
245    state: Arc<State<T>>,
246}
247
248impl<T> Source<T> {
249    fn new(state: Arc<State<T>>) -> Self {
250        Self { state }
251    }
252}
253
254impl<T> Drop for Source<T> {
255    fn drop(&mut self) {
256        self.state.closed.store(true, Ordering::Relaxed);
257        self.state.write_waker.wake();
258    }
259}
260
261unsafe impl<T> SplittableViewImpl for Source<T> {
262    type Item = T;
263    type Error = GrantOverflow;
264
265    unsafe fn set_reader_waker(&self, waker: impl Fn() + Send + Sync + 'static) {
266        let mut lock = self
267            .state
268            .read_waker
269            .lock()
270            .expect("another thread panicked");
271        assert!(lock.is_none(), "reader waker already set!");
272        *lock = Some(Box::new(waker));
273    }
274
275    unsafe fn set_head(&self, index: u64) {
276        self.state.head.store(index, Ordering::Relaxed);
277        self.state.write_waker.wake();
278    }
279
280    unsafe fn compare_set_head(&self, index: u64) {
281        // only set the head if it's greater than the current head
282        let mut current = self.state.head.load(Ordering::Relaxed);
283        if index > current {
284            while let Err(previous) = self.state.head.compare_exchange_weak(
285                current,
286                index,
287                Ordering::Relaxed,
288                Ordering::Relaxed,
289            ) {
290                if index > previous {
291                    current = previous
292                } else {
293                    break;
294                }
295            }
296        }
297        self.state.write_waker.wake();
298    }
299
300    fn poll_available(
301        self: Pin<&Self>,
302        cx: &mut Context,
303        register_wakeup: impl Fn(&Waker),
304        index: u64,
305        len: usize,
306    ) -> Poll<Result<usize, Self::Error>> {
307        let max_len = self.state.buffer.len();
308        if len > max_len {
309            return Poll::Ready(Err(GrantOverflow(max_len)));
310        }
311
312        // Perform double-checking on the amount of available data
313        // The first check is efficient, but may spuriously fail.
314        // The second check occurs after the `acquire` produced by registering the waker.
315        let available = self.state.readable_len(index);
316        if available >= len {
317            Poll::Ready(Ok(available))
318        } else {
319            register_wakeup(cx.waker());
320            let available = self.state.readable_len(index);
321            if available >= len || self.state.closed.load(Ordering::Relaxed) {
322                Poll::Ready(Ok(available))
323            } else {
324                Poll::Pending
325            }
326        }
327    }
328
329    unsafe fn view(&self, index: u64, len: usize) -> &[Self::Item] {
330        self.state.buffer.range(index, len)
331    }
332}
333
334unsafe impl<T> SplittableViewImplMut for Source<T> {
335    unsafe fn view_mut(&self, index: u64, len: usize) -> &mut [Self::Item] {
336        self.state.buffer.range_mut(index, len)
337    }
338}
339
340/// Create a circular buffer that can hold at least `min_size` elements.
341///
342/// # Panics
343/// Panics if `min_size` is 0.
344pub fn circular_buffer<T: Send + Sync + Default + 'static>(
345    min_size: usize,
346) -> (Sink<T>, Source<T>) {
347    assert!(min_size > 0, "`min_size` must be greater than 0");
348
349    let state = Arc::new(State::new(min_size));
350
351    (Sink::new(state.clone()), Source::new(state))
352}