disruptor_rs/
ringbuffer.rs

1use std::cell::UnsafeCell;
2
3use crate::{sequence::Sequence, traits::DataProvider};
4
5/// A ring buffer with a fixed capacity.
6/// The capacity must be a power of 2. The buffer is initialized with default values.
7/// It is assumed that anything reading from or writing to this buffer will hold its own
8/// index into the buffer, so the buffer itself does not keep track of the current index.
9/// # Types
10/// - `T`: The type of elements in the buffer. It must implement the `Default` and `Send` traits.
11///   The `Default` trait is used to initialize the buffer with default values. The `Send` trait is
12///   used to allow the data to be sent between threads.  
13/// # Safety
14/// We require intrior mutability to allow for multiple readers and writers to access the buffer concurrently.
15pub struct RingBuffer<T> {
16    capacity: usize,
17    _mask: usize,
18    _data: Vec<UnsafeCell<T>>,
19}
20
21const fn is_power_of_two(x: usize) -> bool {
22    x != 0 && (x & (x - 1)) == 0
23}
24
25unsafe impl<T: Send> Send for RingBuffer<T> {}
26unsafe impl<T: Sync> Sync for RingBuffer<T> {}
27
28impl<T: Default + Send> RingBuffer<T> {
29    pub fn new(capacity: usize) -> Self {
30        assert!(is_power_of_two(capacity), "Capacity must be a power of 2");
31        Self {
32            capacity,
33            _mask: capacity - 1,
34            _data: (0..capacity)
35                .map(|_| UnsafeCell::new(T::default()))
36                .collect(), // Initialize buffer with default values
37        }
38    }
39
40    pub fn get_capacity(&self) -> usize {
41        self.capacity
42    }
43}
44
45impl<T: Send + Sync> DataProvider<T> for RingBuffer<T> {
46    fn get_capacity(&self) -> usize {
47        self.capacity
48    }
49
50    /// Get a reference to the element at the given sequence.
51    /// # Safety
52    /// This method is unsafe because it allows for multiple readers to access the buffer concurrently.
53    /// The caller must ensure that the sequence is within the bounds of the buffer.
54    /// # Arguments
55    /// - `sequence`: The sequence of the element to get.
56    /// # Returns
57    /// A reference to the element at the given sequence.
58    unsafe fn get(&self, sequence: Sequence) -> &T {
59        let index = sequence as usize & self._mask;
60        &*self._data[index].get()
61    }
62
63    /// Get a mutable reference to the element at the given sequence.
64    /// # Safety
65    /// This method is unsafe because it allows for multiple writers to access the buffer concurrently.
66    /// The caller must ensure that the sequence is within the bounds of the buffer.
67    /// # Arguments
68    /// - `sequence`: The sequence of the element to get.
69    /// # Returns
70    /// A mutable reference to the element at the given sequence.
71    unsafe fn get_mut(&self, sequence: Sequence) -> &mut T {
72        let index = sequence as usize & self._mask;
73        &mut *self._data[index].get()
74    }
75}
76
77#[cfg(test)]
78mod tests {
79
80    use std::{sync::Arc, thread};
81
82    use super::*;
83
84    const ITERATIONS: i64 = 256;
85    const THREADS: usize = 4;
86
87    #[test]
88    fn test_initialization() {
89        let buffer = RingBuffer::<i64>::new(ITERATIONS as usize);
90
91        assert_eq!(buffer.get_capacity(), 256);
92
93        for i in 0..ITERATIONS {
94            unsafe {
95                assert_eq!(*buffer.get(i), 0);
96            }
97        }
98    }
99
100    #[test]
101    fn test_ring_buffer() {
102        let buffer = RingBuffer::<i64>::new(ITERATIONS as usize);
103        assert_eq!(buffer.get_capacity(), 256);
104
105        for i in 0..ITERATIONS {
106            unsafe {
107                *buffer.get_mut(i) = i;
108            }
109        }
110
111        for i in 0..ITERATIONS {
112            unsafe {
113                *buffer.get_mut(i) *= 2;
114            }
115        }
116
117        for i in 0..ITERATIONS {
118            unsafe {
119                assert_eq!(*buffer.get(i), i * 2);
120            }
121        }
122    }
123
124    #[test]
125    fn test_ring_buffer_multithreaded() {
126        let buffer = Arc::new(RingBuffer::<i64>::new(ITERATIONS as usize));
127        let mut handles = vec![];
128
129        for _ in 0..THREADS {
130            let buffer = buffer.clone();
131            let handle = thread::spawn(move || {
132                for i in 0..ITERATIONS {
133                    unsafe {
134                        *buffer.get_mut(i) += i;
135                    }
136                }
137            });
138
139            handles.push(handle);
140        }
141
142        for handle in handles {
143            handle.join().unwrap();
144        }
145
146        for i in 0..ITERATIONS {
147            unsafe {
148                assert_eq!(*buffer.get(i), i * THREADS as i64);
149            }
150        }
151    }
152}