1use std::cell::UnsafeCell;
51use std::cmp::max;
52use std::mem::MaybeUninit;
53use std::sync::atomic::{AtomicI8, AtomicUsize, Ordering};
54
55#[warn(missing_docs)]
56
57enum CellState {
59 Empty = 0,
60 Storing = 1,
61 Stored = 2,
62 Loading = 3,
63}
64
65impl From<CellState> for i8 {
66 fn from(value: CellState) -> Self {
67 match value {
68 CellState::Empty => 0,
69 CellState::Storing => 1,
70 CellState::Stored => 2,
71 CellState::Loading => 3,
72 }
73 }
74}
75
76pub struct Queue<T> {
86 head: AtomicUsize,
87 tail: AtomicUsize,
88 elements: Vec<UnsafeCell<MaybeUninit<T>>>,
89 states: Vec<AtomicI8>,
90}
91
92unsafe impl<T: Send> Send for Queue<T> {}
93unsafe impl<T: Send> Sync for Queue<T> {}
94
95pub fn bounded<T>(capacity: usize) -> Queue<T> {
99 Queue::new(capacity)
100}
101
102impl<T> Queue<T> {
103 pub fn new(capacity: usize) -> Self {
105 let mut elements = Vec::with_capacity(capacity);
106 for _ in 0..capacity {
107 elements.push(UnsafeCell::new(MaybeUninit::uninit()));
108 }
109 let mut states = Vec::with_capacity(capacity);
110 for _ in 0..capacity {
111 states.push(AtomicI8::new(CellState::Empty.into()));
112 }
113 let head = AtomicUsize::new(0);
114 let tail = AtomicUsize::new(0);
115 Queue {
116 head,
117 tail,
118 elements,
119 states,
120 }
121 }
122
123 pub fn push(&self, element: T) -> bool {
130 let mut head = self.head.load(Ordering::Relaxed);
131 let elements_len = self.elements.len();
132 loop {
133 let length = head as i64 - self.tail.load(Ordering::Relaxed) as i64;
134 if length >= elements_len as i64 {
135 return false;
136 }
137
138 if self
139 .head
140 .compare_exchange(head, head + 1, Ordering::Acquire, Ordering::Relaxed)
141 .is_ok()
142 {
143 self.do_push(element, head);
144 return true;
145 }
146
147 head = self.head.load(Ordering::Relaxed);
148 }
149 }
150
151 pub fn pop(&self) -> Option<T> {
159 let mut tail = self.tail.load(Ordering::Relaxed);
160 loop {
161 let length = self.head.load(Ordering::Relaxed) as i64 - tail as i64;
162 if length <= 0 {
163 return None;
164 }
165
166 if self
167 .tail
168 .compare_exchange(tail, tail + 1, Ordering::Acquire, Ordering::Relaxed)
169 .is_ok()
170 {
171 break;
172 }
173
174 tail = self.tail.load(Ordering::Relaxed);
175 }
176 Some(self.do_pop(tail))
177 }
178
179 pub unsafe fn force_pop(&self) -> T {
184 let tail = self.tail.fetch_add(1, Ordering::Acquire);
185 self.do_pop(tail)
186 }
187
188 pub unsafe fn force_push(&self, element: T) {
193 let head = self.head.fetch_add(1, Ordering::Acquire);
194 self.do_push(element, head);
195 }
196
197 pub fn is_empty(&self) -> bool {
199 self.len() == 0
200 }
201
202 pub fn len(&self) -> usize {
204 max(
205 self.head.load(Ordering::Relaxed) - self.tail.load(Ordering::Relaxed),
206 0,
207 )
208 }
209}
210
211impl<T> Queue<T> {
212 fn do_pop(&self, tail: usize) -> T {
213 let state = &self.states[tail % self.states.len()];
214 loop {
215 let expected = CellState::Stored;
216 if state
217 .compare_exchange(
218 expected.into(),
219 CellState::Loading.into(),
220 Ordering::Acquire,
221 Ordering::Relaxed,
222 )
223 .is_ok()
224 {
225 let element = unsafe {
226 self.elements[tail % self.elements.len()]
227 .get()
228 .replace(MaybeUninit::uninit())
229 .assume_init()
230 };
231
232 state.store(CellState::Empty.into(), Ordering::Release);
233
234 return element;
235 }
236 }
237 }
238
239 fn do_push(&self, element: T, head: usize) {
240 self.do_push_any(element, head);
241 }
242
243 fn do_push_any(&self, element: T, head: usize) {
244 let state = &self.states[head % self.states.len()];
245 loop {
246 let expected = CellState::Empty;
247 if state
248 .compare_exchange(
249 expected.into(),
250 CellState::Storing.into(),
251 Ordering::Acquire,
252 Ordering::Relaxed,
253 )
254 .is_ok()
255 {
256 unsafe {
257 self.elements[head % self.elements.len()]
260 .get()
261 .write(MaybeUninit::new(element));
262 }
263 state.store(CellState::Stored.into(), Ordering::Release);
264 return;
265 }
266 }
267 }
268}
269
270impl<T> Drop for Queue<T> {
271 fn drop(&mut self) {
272 if std::mem::needs_drop::<T>() {
273 while let Some(element) = self.pop() {
276 drop(element);
277 }
278 }
279 }
280}
281
282#[cfg(test)]
283mod test {
284 use std::ffi::c_void;
285 use std::sync::{Arc, Mutex};
286 use std::thread;
287 use std::thread::JoinHandle;
288 use std::time::Duration;
289
290 use super::*;
291
292 #[derive(Eq, PartialEq, Debug, Copy, Clone)]
293 struct MockPtr(*mut c_void);
294
295 unsafe impl Send for MockPtr {}
296
297 fn mock_ptr(value: i32) -> MockPtr {
298 MockPtr(value as *mut c_void)
299 }
300
301 #[test]
302 fn test_create_bounded_queue() {
303 let _queue = Queue::<MockPtr>::new(10);
304 }
305
306 #[test]
307 fn test_get_empty_queue_len() {
308 let queue = Queue::<MockPtr>::new(10);
309 assert_eq!(queue.len(), 0);
310 }
311
312 #[test]
313 fn test_queue_drops_items() {
314 struct Item {
315 drop_count: Arc<AtomicUsize>,
316 }
317 impl Drop for Item {
318 fn drop(&mut self) {
319 self.drop_count.fetch_add(1, Ordering::Relaxed);
320 }
321 }
322 let drop_count = Arc::new(AtomicUsize::new(0));
323 let queue: Queue<Item> = Queue::new(10);
324 queue.push(Item {
325 drop_count: drop_count.clone(),
326 });
327 queue.push(Item {
328 drop_count: drop_count.clone(),
329 });
330 queue.push(Item {
331 drop_count: drop_count.clone(),
332 });
333 drop(queue);
334
335 assert_eq!(drop_count.load(Ordering::Relaxed), 3);
336 }
337
338 #[test]
339 fn test_push_element_to_queue_increments_length() {
340 let queue = Queue::<MockPtr>::new(10);
341 assert_eq!(queue.len(), 0);
342 let ptr = mock_ptr(1);
343 assert!(queue.push(ptr));
344 assert_eq!(queue.len(), 1);
345 let value = queue.pop();
346 assert_eq!(value.unwrap(), ptr);
347 assert_eq!(queue.len(), 0);
348 }
349
350 #[test]
351 fn test_push_pop_push_pop() {
352 let queue = Queue::<MockPtr>::new(10);
353 assert_eq!(queue.len(), 0);
354 {
355 let ptr = mock_ptr(1);
356 assert!(queue.push(ptr));
357 assert_eq!(queue.len(), 1);
358 let value = queue.pop();
359 assert_eq!(value.unwrap(), ptr);
360 assert_eq!(queue.len(), 0);
361 }
362 {
363 let ptr = mock_ptr(2);
364 assert!(queue.push(ptr));
365 assert_eq!(queue.len(), 1);
366 let value = queue.pop();
367 assert_eq!(value.unwrap(), ptr);
368 assert_eq!(queue.len(), 0);
369 }
370 }
371
372 #[test]
373 fn test_overflow_will_not_break_things() {
374 let queue = Queue::<MockPtr>::new(3);
375 assert_eq!(queue.len(), 0);
376
377 assert!(queue.push(mock_ptr(1)));
379 assert_eq!(queue.len(), 1);
380
381 assert!(queue.push(mock_ptr(2)));
383 assert_eq!(queue.len(), 2);
384
385 assert!(queue.push(mock_ptr(3)));
387 assert_eq!(queue.len(), 3);
388
389 assert_eq!(queue.len(), 3);
391 let result = queue.push(mock_ptr(4));
392 assert!(!result);
393 assert_eq!(queue.len(), 3);
394 }
395
396 #[test]
397 fn test_multithread_push() {
398 wisual_logger::init_from_env();
399
400 let queue = Arc::new(Queue::new(50000));
401
402 let writer_thread_1 = spawn_writer_thread(
403 10,
404 queue.clone(),
405 Duration::from_millis((0.0 * rand::random::<f64>()) as u64),
406 );
407 let writer_thread_2 = spawn_writer_thread(
408 10,
409 queue.clone(),
410 Duration::from_millis((0.0 * rand::random::<f64>()) as u64),
411 );
412 let writer_thread_3 = spawn_writer_thread(
413 10,
414 queue.clone(),
415 Duration::from_millis((0.0 * rand::random::<f64>()) as u64),
416 );
417
418 writer_thread_1.join().unwrap();
419 writer_thread_2.join().unwrap();
420 writer_thread_3.join().unwrap();
421 assert_eq!(queue.len(), 30);
422 }
423
424 #[test]
425 fn test_multithread_push_pop() {
426 wisual_logger::init_from_env();
427
428 let size = 10000;
429 let num_threads = 5;
430
431 let queue: Arc<Queue<MockPtr>> = Arc::new(Queue::new(size * num_threads / 3));
432 let output_queue: Arc<Queue<MockPtr>> = Arc::new(Queue::new(size * num_threads));
433
434 let is_running = Arc::new(Mutex::new(true));
435 let reader_thread = {
436 let is_running = is_running.clone();
437 let queue = queue.clone();
438 let output_queue = output_queue.clone();
439 thread::spawn(move || {
440 while *is_running.lock().unwrap() || queue.len() > 0 {
441 loop {
442 match queue.pop() {
443 None => break,
444 Some(value) => {
445 output_queue.push(value);
446 }
447 }
448 }
449 }
450 log::info!("Reader thread done reading");
451 })
452 };
453
454 let threads: Vec<JoinHandle<()>> = (0..num_threads)
455 .into_iter()
456 .map(|_| {
457 spawn_writer_thread(
458 size,
459 queue.clone(),
460 Duration::from_millis((rand::random::<f64>()) as u64),
461 )
462 })
463 .collect();
464
465 for thread in threads {
466 thread.join().unwrap();
467 }
468
469 {
470 let mut is_running = is_running.lock().unwrap();
471 *is_running = false;
472 }
473 reader_thread.join().unwrap();
474
475 assert_eq!(queue.len(), 0);
476 assert_eq!(output_queue.len(), size * num_threads);
477 }
478
479 fn spawn_writer_thread(
480 size: usize,
481 queue: Arc<Queue<MockPtr>>,
482 duration: Duration,
483 ) -> JoinHandle<()> {
484 thread::spawn(move || {
485 for i in 0..size {
486 loop {
487 let pushed = queue.push(mock_ptr(i as i32));
488 if pushed {
489 break;
490 }
491 }
492 thread::sleep(duration);
493 }
494 log::info!("Thread done writing");
495 })
496 }
497}