1use std::collections::VecDeque;
2use std::sync::{Condvar, Mutex};
3
4pub trait BoundedQueue<T: Send>: Send + Sync {
5 fn new(capacity: usize) -> Self where Self: Sized;
6
7 fn push(&self, item: T);
8 fn pop(&self) -> T;
9
10 fn try_push(&self, item: T) -> Result<(), T>;
11 fn try_pop(&self) -> Option<T>;
12}
13
14pub struct MpmcQueue<T: Send> {
15 inner: Mutex<Inner<T>>,
16 not_empty: Condvar,
17 not_full: Condvar,
18 capacity: usize,
19}
20
21struct Inner<T> {
22 buffer: VecDeque<T>,
23}
24
25impl<T: Send> BoundedQueue<T> for MpmcQueue<T> {
26 fn new(capacity: usize) -> Self {
27 Self {
28 inner: Mutex::new(Inner {
29 buffer: VecDeque::with_capacity(capacity),
30 }),
31 not_empty: Condvar::new(),
32 not_full: Condvar::new(),
33 capacity,
34 }
35 }
36
37 fn push(&self, item: T) {
38 let mut guard = self.inner.lock().unwrap();
39
40 while guard.buffer.len() == self.capacity {
41 guard = self.not_full.wait(guard).unwrap();
42 }
43
44 guard.buffer.push_back(item);
45 self.not_empty.notify_one();
46 }
47
48 fn pop(&self) -> T {
49 let mut guard = self.inner.lock().unwrap();
50
51 while guard.buffer.is_empty() {
52 guard = self.not_empty.wait(guard).unwrap();
53 }
54
55 let item = guard.buffer.pop_front().unwrap();
56 self.not_full.notify_one();
57
58 item
59 }
60
61 fn try_push(&self, item: T) -> Result<(), T> {
62 let mut guard = self.inner.lock().unwrap();
63
64 if guard.buffer.len() == self.capacity {
65 return Err(item);
66 }
67
68 guard.buffer.push_back(item);
69 self.not_empty.notify_one();
70 Ok(())
71 }
72
73 fn try_pop(&self) -> Option<T> {
74 let mut guard = self.inner.lock().unwrap();
75
76 let item = guard.buffer.pop_front();
77
78 if item.is_some() {
79 self.not_full.notify_one();
80 }
81
82 item
83 }
84}
85
86
87
88
89
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use std::sync::{Arc, Mutex};
95 use std::thread;
96
97 #[test]
98 fn stress_count_correctness() {
99 let q = Arc::new(MpmcQueue::new(64));
100
101 let producers = 4;
102 let consumers = 4;
103 let items_per_producer = 10_000;
104
105 let total_items = producers * items_per_producer;
106 let items_per_consumer = total_items / consumers;
107
108 let mut handles = vec![];
109
110 for p in 0..producers {
112 let q = Arc::clone(&q);
113 handles.push(thread::spawn(move || {
114 for i in 0..items_per_producer {
115 q.push(p * items_per_producer + i);
116 }
117 }));
118 }
119
120 for _ in 0..consumers {
122 let q = Arc::clone(&q);
123 handles.push(thread::spawn(move || {
124 for _ in 0..items_per_consumer {
125 q.pop();
126 }
127 }));
128 }
129
130 for h in handles {
131 h.join().unwrap();
132 }
133 }
134
135 #[test]
136 fn no_duplicates_no_loss() {
137 let q = Arc::new(MpmcQueue::new(64));
138
139 let producers = 4;
140 let items_per_producer = 5000;
141 let total_items = producers * items_per_producer;
142
143 let results = Arc::new(Mutex::new(Vec::new()));
144
145 let mut handles = vec![];
146
147 for p in 0..producers {
149 let q = Arc::clone(&q);
150 handles.push(thread::spawn(move || {
151 for i in 0..items_per_producer {
152 q.push(p * items_per_producer + i);
153 }
154 }));
155 }
156
157 for _ in 0..producers {
159 let q = Arc::clone(&q);
160 let results = Arc::clone(&results);
161
162 handles.push(thread::spawn(move || {
163 for _ in 0..items_per_producer {
164 let val = q.pop();
165 results.lock().unwrap().push(val);
166 }
167 }));
168 }
169
170 for h in handles {
171 h.join().unwrap();
172 }
173
174 let mut data = results.lock().unwrap();
175 data.sort();
176
177 for i in 0..total_items {
178 assert_eq!(data[i], i);
179 }
180 }
181}