use std::collections::VecDeque;
use std::sync::{Condvar, Mutex};
pub trait BoundedQueue<T: Send>: Send + Sync {
fn new(capacity: usize) -> Self where Self: Sized;
fn push(&self, item: T);
fn pop(&self) -> T;
fn try_push(&self, item: T) -> Result<(), T>;
fn try_pop(&self) -> Option<T>;
}
pub struct MpmcQueue<T: Send> {
inner: Mutex<Inner<T>>,
not_empty: Condvar,
not_full: Condvar,
capacity: usize,
}
struct Inner<T> {
buffer: VecDeque<T>,
}
impl<T: Send> BoundedQueue<T> for MpmcQueue<T> {
fn new(capacity: usize) -> Self {
Self {
inner: Mutex::new(Inner {
buffer: VecDeque::with_capacity(capacity),
}),
not_empty: Condvar::new(),
not_full: Condvar::new(),
capacity,
}
}
fn push(&self, item: T) {
let mut guard = self.inner.lock().unwrap();
while guard.buffer.len() == self.capacity {
guard = self.not_full.wait(guard).unwrap();
}
guard.buffer.push_back(item);
self.not_empty.notify_one();
}
fn pop(&self) -> T {
let mut guard = self.inner.lock().unwrap();
while guard.buffer.is_empty() {
guard = self.not_empty.wait(guard).unwrap();
}
let item = guard.buffer.pop_front().unwrap();
self.not_full.notify_one();
item
}
fn try_push(&self, item: T) -> Result<(), T> {
let mut guard = self.inner.lock().unwrap();
if guard.buffer.len() == self.capacity {
return Err(item);
}
guard.buffer.push_back(item);
self.not_empty.notify_one();
Ok(())
}
fn try_pop(&self) -> Option<T> {
let mut guard = self.inner.lock().unwrap();
let item = guard.buffer.pop_front();
if item.is_some() {
self.not_full.notify_one();
}
item
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use std::thread;
#[test]
fn stress_count_correctness() {
let q = Arc::new(MpmcQueue::new(64));
let producers = 4;
let consumers = 4;
let items_per_producer = 10_000;
let total_items = producers * items_per_producer;
let items_per_consumer = total_items / consumers;
let mut handles = vec![];
for p in 0..producers {
let q = Arc::clone(&q);
handles.push(thread::spawn(move || {
for i in 0..items_per_producer {
q.push(p * items_per_producer + i);
}
}));
}
for _ in 0..consumers {
let q = Arc::clone(&q);
handles.push(thread::spawn(move || {
for _ in 0..items_per_consumer {
q.pop();
}
}));
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn no_duplicates_no_loss() {
let q = Arc::new(MpmcQueue::new(64));
let producers = 4;
let items_per_producer = 5000;
let total_items = producers * items_per_producer;
let results = Arc::new(Mutex::new(Vec::new()));
let mut handles = vec![];
for p in 0..producers {
let q = Arc::clone(&q);
handles.push(thread::spawn(move || {
for i in 0..items_per_producer {
q.push(p * items_per_producer + i);
}
}));
}
for _ in 0..producers {
let q = Arc::clone(&q);
let results = Arc::clone(&results);
handles.push(thread::spawn(move || {
for _ in 0..items_per_producer {
let val = q.pop();
results.lock().unwrap().push(val);
}
}));
}
for h in handles {
h.join().unwrap();
}
let mut data = results.lock().unwrap();
data.sort();
for i in 0..total_items {
assert_eq!(data[i], i);
}
}
}