Skip to main content

rill_core/queues/
mpsc.rs

1//! # Multiple-Producer Single-Consumer queue
2//!
3//! Allows multiple producers to send data to a single consumer.
4//! Uses atomic operations for producer synchronization.
5#![allow(unsafe_code)]
6
7use super::{QueueError, QueueResult, QueueStats};
8use std::ptr;
9use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
10
11/// Linked list node for MPSC queue
12struct Node<T> {
13    value: Option<T>,
14    next: AtomicPtr<Node<T>>,
15}
16
17impl<T> Node<T> {
18    fn new(value: T) -> *mut Node<T> {
19        Box::into_raw(Box::new(Node {
20            value: Some(value),
21            next: AtomicPtr::new(ptr::null_mut()),
22        }))
23    }
24
25    fn stub() -> *mut Node<T> {
26        Box::into_raw(Box::new(Node {
27            value: None,
28            next: AtomicPtr::new(ptr::null_mut()),
29        }))
30    }
31}
32
33/// Multiple-Producer Single-Consumer queue
34///
35/// Implemented as a Michael-Scott lock-free queue.
36/// Producers never block, the consumer can wait for data.
37pub struct MpscQueue<T> {
38    /// Queue head (first element to read)
39    head: AtomicPtr<Node<T>>,
40    /// Queue tail (last element to write)
41    tail: AtomicPtr<Node<T>>,
42    /// Counter for statistics
43    stats: QueueStats,
44    /// Maximum capacity (0 = unlimited)
45    max_capacity: usize,
46    /// Current size (approximate)
47    size: AtomicUsize,
48}
49
50impl<T> Default for MpscQueue<T> {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl<T> MpscQueue<T> {
57    /// Create a new queue
58    pub fn new() -> Self {
59        let stub = Node::<T>::stub();
60        Self {
61            head: AtomicPtr::new(stub),
62            tail: AtomicPtr::new(stub),
63            stats: QueueStats::new(),
64            max_capacity: 0,
65            size: AtomicUsize::new(0),
66        }
67    }
68
69    /// Create a queue with a limited capacity
70    pub fn with_capacity(capacity: usize) -> Self {
71        let mut queue = Self::new();
72        queue.max_capacity = capacity;
73        queue
74    }
75
76    /// Push an element (can be called from multiple threads)
77    pub fn push(&self, value: T) -> QueueResult<()> {
78        // Check for overflow
79        if self.max_capacity > 0 {
80            let size = self.size.load(Ordering::Relaxed);
81            if size >= self.max_capacity {
82                self.stats.record_overflow();
83                return Err(QueueError::QueueFull);
84            }
85        }
86
87        let node = Node::new(value);
88        let mut tail = self.tail.load(Ordering::Acquire);
89
90        loop {
91            let next = unsafe { (*tail).next.load(Ordering::Acquire) };
92
93            if next.is_null() {
94                // Try to add a new node
95                match unsafe {
96                    (*tail).next.compare_exchange_weak(
97                        ptr::null_mut(),
98                        node,
99                        Ordering::Release,
100                        Ordering::Relaxed,
101                    )
102                } {
103                    Ok(_) => {
104                        // Update tail
105                        let _ = self.tail.compare_exchange(
106                            tail,
107                            node,
108                            Ordering::Release,
109                            Ordering::Relaxed,
110                        );
111                        self.size.fetch_add(1, Ordering::Relaxed);
112                        self.stats.record_push(self.size());
113                        return Ok(());
114                    }
115                    Err(new_next) => {
116                        // Another thread already added a node, update tail
117                        let _ = self.tail.compare_exchange(
118                            tail,
119                            new_next,
120                            Ordering::Release,
121                            Ordering::Relaxed,
122                        );
123                        tail = new_next;
124                    }
125                }
126            } else {
127                // Advance tail
128                let _ =
129                    self.tail
130                        .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
131                tail = next;
132            }
133        }
134    }
135
136    /// Pop an element (consumer only)
137    pub fn pop(&self) -> Option<T> {
138        loop {
139            let head = self.head.load(Ordering::Acquire);
140            let tail = self.tail.load(Ordering::Acquire);
141            let next = unsafe { (*head).next.load(Ordering::Acquire) };
142
143            if head == tail {
144                if next.is_null() {
145                    return None;
146                }
147                let _ =
148                    self.tail
149                        .compare_exchange(tail, next, Ordering::Release, Ordering::Relaxed);
150            } else {
151                if next.is_null() {
152                    continue;
153                }
154
155                if self
156                    .head
157                    .compare_exchange(head, next, Ordering::Release, Ordering::Relaxed)
158                    .is_ok()
159                {
160                    let value = unsafe { (*next).value.take() };
161                    unsafe {
162                        drop(Box::from_raw(head));
163                    }
164                    self.size.fetch_sub(1, Ordering::Relaxed);
165                    self.stats.record_pop();
166                    return value;
167                }
168            }
169        }
170    }
171
172    /// Current size (approximate)
173    pub fn size(&self) -> usize {
174        self.size.load(Ordering::Relaxed)
175    }
176
177    /// Capacity (0 = unlimited)
178    pub fn capacity(&self) -> usize {
179        self.max_capacity
180    }
181
182    /// Check if the queue is empty
183    pub fn is_empty(&self) -> bool {
184        let head = self.head.load(Ordering::Acquire);
185        let tail = self.tail.load(Ordering::Acquire);
186        let next = unsafe { (*head).next.load(Ordering::Acquire) };
187
188        head == tail && next.is_null()
189    }
190}
191
192impl<T> Drop for MpscQueue<T> {
193    fn drop(&mut self) {
194        while self.pop().is_some() {}
195
196        let head = self.head.load(Ordering::Relaxed);
197        if !head.is_null() {
198            unsafe {
199                drop(Box::from_raw(head));
200            }
201        }
202    }
203}
204
205unsafe impl<T: Send> Send for MpscQueue<T> {}
206unsafe impl<T: Send> Sync for MpscQueue<T> {}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use std::thread;
212
213    #[test]
214    fn test_mpsc_basic() {
215        let queue = MpscQueue::new();
216
217        queue.push(1).unwrap();
218        queue.push(2).unwrap();
219        queue.push(3).unwrap();
220
221        assert_eq!(queue.pop(), Some(1));
222        assert_eq!(queue.pop(), Some(2));
223        assert_eq!(queue.pop(), Some(3));
224        assert_eq!(queue.pop(), None);
225    }
226
227    #[test]
228    fn test_mpsc_multiple_producers() {
229        let queue = std::sync::Arc::new(MpscQueue::new());
230        let mut handles = vec![];
231
232        for i in 0..4 {
233            let queue = queue.clone();
234            handles.push(thread::spawn(move || {
235                for j in 0..250 {
236                    queue.push(i * 1000 + j).unwrap();
237                }
238            }));
239        }
240
241        for handle in handles {
242            handle.join().unwrap();
243        }
244
245        let mut count = 0;
246        while queue.pop().is_some() {
247            count += 1;
248        }
249
250        assert_eq!(count, 1000);
251    }
252}