Skip to main content

mpmc_queue/
queue.rs

1use std::collections::VecDeque;
2use std::sync::{Condvar, Mutex};
3
4pub trait BoundedQueue<T: Send>: Send + Sync {
5    fn new(capacity: usize) -> Self where Self: Sized;
6
7    fn push(&self, item: T);
8    fn pop(&self) -> T;
9
10    fn try_push(&self, item: T) -> Result<(), T>;
11    fn try_pop(&self) -> Option<T>;
12}
13
14pub struct MpmcQueue<T: Send> {
15    inner: Mutex<Inner<T>>,
16    not_empty: Condvar,
17    not_full: Condvar,
18    capacity: usize,
19}
20
21struct Inner<T> {
22    buffer: VecDeque<T>,
23}
24
25impl<T: Send> BoundedQueue<T> for MpmcQueue<T> {
26    fn new(capacity: usize) -> Self {
27        Self {
28            inner: Mutex::new(Inner {
29                buffer: VecDeque::with_capacity(capacity),
30            }),
31            not_empty: Condvar::new(),
32            not_full: Condvar::new(),
33            capacity,
34        }
35    }
36
37    fn push(&self, item: T) {
38        let mut guard = self.inner.lock().unwrap();
39
40        while guard.buffer.len() == self.capacity {
41            guard = self.not_full.wait(guard).unwrap();
42        }
43
44        guard.buffer.push_back(item);
45        self.not_empty.notify_one();
46    }
47
48    fn pop(&self) -> T {
49        let mut guard = self.inner.lock().unwrap();
50
51        while guard.buffer.is_empty() {
52            guard = self.not_empty.wait(guard).unwrap();
53        }
54
55        let item = guard.buffer.pop_front().unwrap();
56        self.not_full.notify_one();
57
58        item
59    }
60
61    fn try_push(&self, item: T) -> Result<(), T> {
62        let mut guard = self.inner.lock().unwrap();
63
64        if guard.buffer.len() == self.capacity {
65            return Err(item);
66        }
67
68        guard.buffer.push_back(item);
69        self.not_empty.notify_one();
70        Ok(())
71    }
72
73    fn try_pop(&self) -> Option<T> {
74        let mut guard = self.inner.lock().unwrap();
75
76        let item = guard.buffer.pop_front();
77
78        if item.is_some() {
79            self.not_full.notify_one();
80        }
81
82        item
83    }
84}
85
86
87
88
89
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use std::sync::{Arc, Mutex};
95    use std::thread;
96
97    #[test]
98    fn stress_count_correctness() {
99        let q = Arc::new(MpmcQueue::new(64));
100
101        let producers = 4;
102        let consumers = 4;
103        let items_per_producer = 10_000;
104
105        let total_items = producers * items_per_producer;
106        let items_per_consumer = total_items / consumers;
107
108        let mut handles = vec![];
109
110        // Producers
111        for p in 0..producers {
112            let q = Arc::clone(&q);
113            handles.push(thread::spawn(move || {
114                for i in 0..items_per_producer {
115                    q.push(p * items_per_producer + i);
116                }
117            }));
118        }
119
120        // Consumers (FIXED WORK — no race condition)
121        for _ in 0..consumers {
122            let q = Arc::clone(&q);
123            handles.push(thread::spawn(move || {
124                for _ in 0..items_per_consumer {
125                    q.pop();
126                }
127            }));
128        }
129
130        for h in handles {
131            h.join().unwrap();
132        }
133    }
134
135    #[test]
136    fn no_duplicates_no_loss() {
137        let q = Arc::new(MpmcQueue::new(64));
138
139        let producers = 4;
140        let items_per_producer = 5000;
141        let total_items = producers * items_per_producer;
142
143        let results = Arc::new(Mutex::new(Vec::new()));
144
145        let mut handles = vec![];
146
147        // Producers
148        for p in 0..producers {
149            let q = Arc::clone(&q);
150            handles.push(thread::spawn(move || {
151                for i in 0..items_per_producer {
152                    q.push(p * items_per_producer + i);
153                }
154            }));
155        }
156
157        // Consumers
158        for _ in 0..producers {
159            let q = Arc::clone(&q);
160            let results = Arc::clone(&results);
161
162            handles.push(thread::spawn(move || {
163                for _ in 0..items_per_producer {
164                    let val = q.pop();
165                    results.lock().unwrap().push(val);
166                }
167            }));
168        }
169
170        for h in handles {
171            h.join().unwrap();
172        }
173
174        let mut data = results.lock().unwrap();
175        data.sort();
176
177        for i in 0..total_items {
178            assert_eq!(data[i], i);
179        }
180    }
181}