disruptor_rs/
ringbuffer.rs1use std::cell::UnsafeCell;
2
3use crate::{sequence::Sequence, traits::DataProvider};
4
5pub struct RingBuffer<T> {
16 capacity: usize,
17 _mask: usize,
18 _data: Vec<UnsafeCell<T>>,
19}
20
21const fn is_power_of_two(x: usize) -> bool {
22 x != 0 && (x & (x - 1)) == 0
23}
24
25unsafe impl<T: Send> Send for RingBuffer<T> {}
26unsafe impl<T: Sync> Sync for RingBuffer<T> {}
27
28impl<T: Default + Send> RingBuffer<T> {
29 pub fn new(capacity: usize) -> Self {
30 assert!(is_power_of_two(capacity), "Capacity must be a power of 2");
31 Self {
32 capacity,
33 _mask: capacity - 1,
34 _data: (0..capacity)
35 .map(|_| UnsafeCell::new(T::default()))
36 .collect(), }
38 }
39
40 pub fn get_capacity(&self) -> usize {
41 self.capacity
42 }
43}
44
45impl<T: Send + Sync> DataProvider<T> for RingBuffer<T> {
46 fn get_capacity(&self) -> usize {
47 self.capacity
48 }
49
50 unsafe fn get(&self, sequence: Sequence) -> &T {
59 let index = sequence as usize & self._mask;
60 &*self._data[index].get()
61 }
62
63 unsafe fn get_mut(&self, sequence: Sequence) -> &mut T {
72 let index = sequence as usize & self._mask;
73 &mut *self._data[index].get()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79
80 use std::{sync::Arc, thread};
81
82 use super::*;
83
84 const ITERATIONS: i64 = 256;
85 const THREADS: usize = 4;
86
87 #[test]
88 fn test_initialization() {
89 let buffer = RingBuffer::<i64>::new(ITERATIONS as usize);
90
91 assert_eq!(buffer.get_capacity(), 256);
92
93 for i in 0..ITERATIONS {
94 unsafe {
95 assert_eq!(*buffer.get(i), 0);
96 }
97 }
98 }
99
100 #[test]
101 fn test_ring_buffer() {
102 let buffer = RingBuffer::<i64>::new(ITERATIONS as usize);
103 assert_eq!(buffer.get_capacity(), 256);
104
105 for i in 0..ITERATIONS {
106 unsafe {
107 *buffer.get_mut(i) = i;
108 }
109 }
110
111 for i in 0..ITERATIONS {
112 unsafe {
113 *buffer.get_mut(i) *= 2;
114 }
115 }
116
117 for i in 0..ITERATIONS {
118 unsafe {
119 assert_eq!(*buffer.get(i), i * 2);
120 }
121 }
122 }
123
124 #[test]
125 fn test_ring_buffer_multithreaded() {
126 let buffer = Arc::new(RingBuffer::<i64>::new(ITERATIONS as usize));
127 let mut handles = vec![];
128
129 for _ in 0..THREADS {
130 let buffer = buffer.clone();
131 let handle = thread::spawn(move || {
132 for i in 0..ITERATIONS {
133 unsafe {
134 *buffer.get_mut(i) += i;
135 }
136 }
137 });
138
139 handles.push(handle);
140 }
141
142 for handle in handles {
143 handle.join().unwrap();
144 }
145
146 for i in 0..ITERATIONS {
147 unsafe {
148 assert_eq!(*buffer.get(i), i * THREADS as i64);
149 }
150 }
151 }
152}