Skip to main content

memlink_shm/
buffer.rs

1//! Lock-free single-producer single-consumer (SPSC) ring buffer.
2//! Provides atomic slot state management with cache-line aligned slots.
3
4use std::sync::atomic::{AtomicU32, AtomicU64, AtomicU8, Ordering};
5use std::{fmt, ptr};
6
7pub const MAX_SLOT_SIZE: usize = 4096;
8
9#[repr(u8)]
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum SlotState {
12    Empty = 0,
13    Ready = 1,
14    Reading = 2,
15}
16
17#[repr(C, align(64))]
18pub struct RingHeader {
19    pub head: AtomicU64,
20    pub tail: AtomicU64,
21    pub capacity: u64,
22    pub write_seq: AtomicU64,
23    pub read_seq: AtomicU64,
24    _padding: [u8; 64 - 8 * 6],
25}
26
27#[repr(C, align(64))]
28pub struct Slot {
29    pub state: AtomicU8,
30    pub priority: u8,
31    pub len: AtomicU32,
32    pub data: [u8; MAX_SLOT_SIZE],
33    _padding: [u8; 64 - ((1 + 1 + 4 + MAX_SLOT_SIZE) % 64)],
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum RingBufferError {
38    Full,
39    Empty,
40    InvalidState,
41    DataTooLarge,
42}
43
44impl fmt::Display for RingBufferError {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            RingBufferError::Full => write!(f, "Ring buffer is full"),
48            RingBufferError::Empty => write!(f, "Ring buffer is empty"),
49            RingBufferError::InvalidState => write!(f, "Invalid slot state"),
50            RingBufferError::DataTooLarge => write!(f, "Data exceeds maximum slot size"),
51        }
52    }
53}
54
55impl std::error::Error for RingBufferError {}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct SlotId {
59    pub index: u64,
60    pub seq: u64,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
64#[repr(u8)]
65pub enum Priority {
66    Low = 0,
67    #[default]
68    Normal = 1,
69    High = 2,
70}
71
72pub struct RingBuffer {
73    header: *mut RingHeader,
74    slots: *mut Slot,
75    capacity: usize,
76}
77
78unsafe impl Send for RingBuffer {}
79unsafe impl Sync for RingBuffer {}
80
81impl RingBuffer {
82    unsafe fn with_capacity(capacity: usize) -> Result<Self, RingBufferError> {
83        if capacity == 0 || !capacity.is_power_of_two() {
84            return Err(RingBufferError::Full);
85        }
86
87        let header = Box::into_raw(Box::new(RingHeader {
88            head: AtomicU64::new(0),
89            tail: AtomicU64::new(0),
90            capacity: capacity as u64,
91            write_seq: AtomicU64::new(0),
92            read_seq: AtomicU64::new(0),
93            _padding: [0u8; 64 - 48],
94        }));
95
96        let mut slots = Vec::with_capacity(capacity);
97        for _ in 0..capacity {
98            slots.push(Slot {
99                state: AtomicU8::new(SlotState::Empty as u8),
100                priority: 0,
101                len: AtomicU32::new(0),
102                data: [0u8; MAX_SLOT_SIZE],
103                _padding: [0u8; 64 - ((1 + 1 + 4 + MAX_SLOT_SIZE) % 64)],
104            });
105        }
106        let slots = Box::into_raw(slots.into_boxed_slice()) as *mut Slot;
107
108        Ok(Self {
109            header,
110            slots,
111            capacity,
112        })
113    }
114
115    pub fn new(capacity: usize) -> Result<Self, RingBufferError> {
116        unsafe { Self::with_capacity(capacity) }
117    }
118
119    /// Create a ring buffer from raw pointers (for shared memory)
120    ///
121    /// # Safety
122    ///
123    /// The caller must ensure:
124    /// - `header` points to a valid RingHeader
125    /// - `slots` points to an array of `capacity` Slot structures
126    /// - Memory is properly aligned
127    pub unsafe fn from_ptr(header: *mut RingHeader, slots: *mut Slot, capacity: usize) -> Self {
128        Self {
129            header,
130            slots,
131            capacity,
132        }
133    }
134
135    pub fn capacity(&self) -> usize {
136        self.capacity
137    }
138
139    pub fn is_empty(&self) -> bool {
140        let head = unsafe { (*self.header).head.load(Ordering::Acquire) };
141        let tail = unsafe { (*self.header).tail.load(Ordering::Acquire) };
142        head == tail
143    }
144
145    pub fn is_full(&self) -> bool {
146        let head = unsafe { (*self.header).head.load(Ordering::Acquire) };
147        let tail = unsafe { (*self.header).tail.load(Ordering::Acquire) };
148        let capacity = self.capacity as u64;
149        head.wrapping_sub(tail) >= capacity
150    }
151
152    pub fn len(&self) -> u64 {
153        let head = unsafe { (*self.header).head.load(Ordering::Acquire) };
154        let tail = unsafe { (*self.header).tail.load(Ordering::Acquire) };
155        head.wrapping_sub(tail)
156    }
157
158    #[inline]
159    fn next_slot_index(index: u64, capacity: u64) -> u64 {
160        index & (capacity - 1)
161    }
162
163    #[inline]
164    #[cfg(test)]
165    #[allow(dead_code)]
166    unsafe fn get_slot(&self, index: u64) -> &Slot {
167        let slot_idx = Self::next_slot_index(index, self.capacity as u64);
168        &*self.slots.add(slot_idx as usize)
169    }
170
171    pub fn write_slot(&self, priority: Priority, data: &[u8]) -> Result<SlotId, RingBufferError> {
172        if data.len() > MAX_SLOT_SIZE {
173            return Err(RingBufferError::DataTooLarge);
174        }
175
176        let head = unsafe { (*self.header).head.load(Ordering::Acquire) };
177        let tail = unsafe { (*self.header).tail.load(Ordering::Acquire) };
178
179        if head.wrapping_sub(tail) >= self.capacity as u64 {
180            return Err(RingBufferError::Full);
181        }
182
183        let slot_idx = Self::next_slot_index(head, self.capacity as u64);
184
185        unsafe {
186            let slot = &mut *self.slots.add(slot_idx as usize);
187
188            let current_state = slot.state.load(Ordering::Acquire);
189            if current_state != SlotState::Empty as u8 {
190                return Err(RingBufferError::InvalidState);
191            }
192
193            ptr::copy_nonoverlapping(data.as_ptr(), slot.data.as_mut_ptr(), data.len());
194            slot.len.store(data.len() as u32, Ordering::Relaxed);
195            slot.priority = priority as u8;
196
197            let write_seq = (*self.header).write_seq.fetch_add(1, Ordering::Relaxed);
198
199            std::sync::atomic::fence(Ordering::Release);
200
201            slot.state.store(SlotState::Ready as u8, Ordering::Release);
202
203            (*self.header).head.store(head.wrapping_add(1), Ordering::Release);
204
205            Ok(SlotId {
206                index: head,
207                seq: write_seq,
208            })
209        }
210    }
211
212    pub fn read_slot(&self) -> Option<(Priority, Vec<u8>)> {
213        let tail = unsafe { (*self.header).tail.load(Ordering::Acquire) };
214        let head = unsafe { (*self.header).head.load(Ordering::Acquire) };
215
216        if tail >= head {
217            return None;
218        }
219
220        let slot_idx = Self::next_slot_index(tail, self.capacity as u64);
221
222        unsafe {
223            let slot = &*self.slots.add(slot_idx as usize);
224
225            let current_state = slot.state.load(Ordering::Acquire);
226            if current_state != SlotState::Ready as u8 {
227                return None;
228            }
229
230            let len = slot.len.load(Ordering::Acquire) as usize;
231
232            let mut data = Vec::with_capacity(len);
233            ptr::copy_nonoverlapping(slot.data.as_ptr(), data.as_mut_ptr(), len);
234            data.set_len(len);
235
236            let priority = Priority::from_u8(slot.priority).unwrap_or(Priority::Normal);
237
238            slot.state.store(SlotState::Reading as u8, Ordering::Relaxed);
239
240            std::sync::atomic::fence(Ordering::Acquire);
241
242            (*self.header).tail.store(tail.wrapping_add(1), Ordering::Release);
243
244            slot.state.store(SlotState::Empty as u8, Ordering::Release);
245
246            (*self.header).read_seq.fetch_add(1, Ordering::Relaxed);
247
248            Some((priority, data))
249        }
250    }
251
252    pub fn peek_slot(&self) -> Option<(Priority, Vec<u8>)> {
253        let tail = unsafe { (*self.header).tail.load(Ordering::Acquire) };
254        let head = unsafe { (*self.header).head.load(Ordering::Acquire) };
255
256        if tail >= head {
257            return None;
258        }
259
260        let slot_idx = Self::next_slot_index(tail, self.capacity as u64);
261
262        unsafe {
263            let slot = &*self.slots.add(slot_idx as usize);
264
265            let current_state = slot.state.load(Ordering::Acquire);
266            if current_state != SlotState::Ready as u8 {
267                return None;
268            }
269
270            let len = slot.len.load(Ordering::Acquire) as usize;
271            let mut data = Vec::with_capacity(len);
272            ptr::copy_nonoverlapping(slot.data.as_ptr(), data.as_mut_ptr(), len);
273            data.set_len(len);
274
275            let priority = Priority::from_u8(slot.priority).unwrap_or(Priority::Normal);
276
277            Some((priority, data))
278        }
279    }
280}
281
282impl Priority {
283    fn from_u8(val: u8) -> Option<Self> {
284        match val {
285            0 => Some(Priority::Low),
286            1 => Some(Priority::Normal),
287            2 => Some(Priority::High),
288            _ => None,
289        }
290    }
291}
292
293impl Drop for RingBuffer {
294    fn drop(&mut self) {
295        unsafe {
296            let _ = Box::from_raw(self.header);
297            let _ = Vec::from_raw_parts(self.slots, self.capacity, self.capacity);
298        }
299    }
300}