use std::time::Duration;
use thiserror::Error;
#[cfg(feature = "sync")]
pub use sync::BatchedQueue;
#[derive(Error, Debug, Clone)]
pub enum BatchedQueueError {
#[error("Channel is full (backpressure limit reached)")]
ChannelFull,
#[error("Channel is disconnected (all receivers dropped)")]
Disconnected,
#[error("Operation timed out after {0:?}")]
Timeout(Duration),
#[error("Queue capacity exceeded: tried to add more than {max_capacity} items")]
CapacityExceeded { max_capacity: usize },
#[error("Invalid batch size: {0}")]
InvalidBatchSize(usize),
#[error("Failed to send batch: {0}")]
SendError(String),
#[error("Failed to receive batch: {0}")]
ReceiveError(String),
}
#[derive(Debug, Clone)]
pub struct ErrorContext {
pub operation: String,
pub queue_info: String,
}
impl BatchedQueueError {
pub fn timeout(duration: Duration) -> Self {
BatchedQueueError::Timeout(duration)
}
pub fn capacity_exceeded(max_capacity: usize) -> Self {
BatchedQueueError::CapacityExceeded { max_capacity }
}
}
pub trait BatchedQueueTrait<T> {
fn len(&self) -> usize;
fn capacity(&self) -> usize;
fn is_empty(&self) -> bool;
fn push(&self, item: T) -> Result<(), BatchedQueueError>;
fn try_next_batch(&self) -> Result<Option<Vec<T>>, BatchedQueueError>;
fn next_batch(&self) -> Result<Vec<T>, BatchedQueueError>;
fn next_batch_timeout(&self, timeout: std::time::Duration)
-> Result<Vec<T>, BatchedQueueError>;
fn flush(&self) -> Result<(), BatchedQueueError>;
}
#[cfg(feature = "sync")]
pub mod sync {
use super::*;
use crossbeam_channel as channel;
use parking_lot::Mutex;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct BatchedQueue<T> {
batch_size: usize,
current_batch: Arc<Mutex<Vec<T>>>,
batch_receiver: channel::Receiver<Vec<T>>,
batch_sender: channel::Sender<Vec<T>>,
item_count: Arc<AtomicUsize>,
}
impl<T: Send + 'static> BatchedQueue<T> {
pub fn new(batch_size: usize) -> Result<Self, BatchedQueueError> {
if batch_size == 0 {
return Err(BatchedQueueError::InvalidBatchSize(batch_size));
}
let (batch_sender, batch_receiver) = channel::unbounded();
Ok(Self {
batch_size,
current_batch: Arc::new(Mutex::new(Vec::with_capacity(batch_size))),
batch_receiver,
batch_sender,
item_count: Arc::new(AtomicUsize::new(0)),
})
}
pub fn new_bounded(
batch_size: usize,
max_batches: usize,
) -> Result<Self, BatchedQueueError> {
if batch_size == 0 {
return Err(BatchedQueueError::InvalidBatchSize(batch_size));
}
if max_batches == 0 {
return Err(BatchedQueueError::InvalidBatchSize(max_batches));
}
let (batch_sender, batch_receiver) = channel::bounded(max_batches);
Ok(Self {
batch_size,
current_batch: Arc::new(Mutex::new(Vec::with_capacity(batch_size))),
batch_receiver,
batch_sender,
item_count: Arc::new(AtomicUsize::new(0)),
})
}
pub fn create_sender(&self) -> BatchedQueueSender<T> {
BatchedQueueSender {
batch_size: self.batch_size,
current_batch: self.current_batch.clone(),
batch_sender: self.batch_sender.clone(),
item_count: self.item_count.clone(),
}
}
pub fn close_queue(&self) -> Vec<T> {
let mut batch = self.current_batch.lock();
std::mem::take(&mut *batch)
}
}
impl<T: Send + 'static> BatchedQueueTrait<T> for BatchedQueue<T> {
fn push(&self, item: T) -> Result<(), BatchedQueueError> {
let mut batch = self.current_batch.lock();
batch.push(item);
let count = self.item_count.fetch_add(1, Ordering::SeqCst);
if count % self.batch_size == self.batch_size - 1 {
let full_batch =
std::mem::replace(&mut *batch, Vec::with_capacity(self.batch_size));
self.batch_sender
.send(full_batch)
.map_err(|_| BatchedQueueError::Disconnected)?;
}
Ok(())
}
fn try_next_batch(&self) -> Result<Option<Vec<T>>, BatchedQueueError> {
match self.batch_receiver.try_recv() {
Ok(batch) => Ok(Some(batch)),
Err(channel::TryRecvError::Empty) => Ok(None),
Err(channel::TryRecvError::Disconnected) => Err(BatchedQueueError::Disconnected),
}
}
fn next_batch(&self) -> Result<Vec<T>, BatchedQueueError> {
self.batch_receiver
.recv()
.map_err(|_| BatchedQueueError::Disconnected)
}
fn next_batch_timeout(
&self,
timeout: std::time::Duration,
) -> Result<Vec<T>, BatchedQueueError> {
match self.batch_receiver.recv_timeout(timeout) {
Ok(batch) => Ok(batch),
Err(channel::RecvTimeoutError::Timeout) => Err(BatchedQueueError::Timeout(timeout)),
Err(channel::RecvTimeoutError::Disconnected) => {
Err(BatchedQueueError::Disconnected)
}
}
}
fn len(&self) -> usize {
self.item_count.load(Ordering::SeqCst)
}
fn capacity(&self) -> usize {
self.batch_size
}
fn flush(&self) -> Result<(), BatchedQueueError> {
let mut batch = self.current_batch.lock();
if !batch.is_empty() {
let partial_batch =
std::mem::replace(&mut *batch, Vec::with_capacity(self.batch_size));
self.batch_sender
.send(partial_batch)
.map_err(|_| BatchedQueueError::Disconnected)?;
}
Ok(())
}
fn is_empty(&self) -> bool {
self.batch_receiver.is_empty() && self.current_batch.lock().is_empty()
}
}
pub struct BatchedQueueSender<T> {
batch_size: usize,
current_batch: Arc<Mutex<Vec<T>>>,
batch_sender: channel::Sender<Vec<T>>,
item_count: Arc<AtomicUsize>,
}
impl<T: Send + 'static> Clone for BatchedQueueSender<T> {
fn clone(&self) -> Self {
Self {
batch_size: self.batch_size,
current_batch: self.current_batch.clone(),
batch_sender: self.batch_sender.clone(),
item_count: self.item_count.clone(),
}
}
}
impl<T: Send + Clone + 'static> BatchedQueueSender<T> {
pub fn push(&self, item: T) -> Result<(), BatchedQueueError> {
let should_send_batch;
let mut full_batch = None;
{
let mut batch = self.current_batch.lock();
batch.push(item);
let count = self.item_count.fetch_add(1, Ordering::SeqCst);
should_send_batch = count % self.batch_size == self.batch_size - 1;
if should_send_batch {
full_batch = Some(std::mem::replace(
&mut *batch,
Vec::with_capacity(self.batch_size),
));
}
}
if let Some(batch) = full_batch {
self.batch_sender
.send(batch)
.map_err(|_| BatchedQueueError::Disconnected)?;
}
Ok(())
}
pub fn try_push(&self, item: T) -> Result<(), BatchedQueueError> {
let should_send_batch;
let mut full_batch = None;
{
let mut batch = self.current_batch.lock();
batch.push(item);
let count = self.item_count.fetch_add(1, Ordering::SeqCst);
should_send_batch = count % self.batch_size == self.batch_size - 1;
if should_send_batch {
full_batch = Some(std::mem::replace(
&mut *batch,
Vec::with_capacity(self.batch_size),
));
}
}
if let Some(batch_to_send) = full_batch {
match self.batch_sender.try_send(batch_to_send.clone()) {
Ok(_) => {}
Err(channel::TrySendError::Full(_)) => {
{
let mut batch = self.current_batch.lock();
*batch = batch_to_send;
}
self.item_count.fetch_sub(1, Ordering::SeqCst);
return Err(BatchedQueueError::ChannelFull);
}
Err(channel::TrySendError::Disconnected(_)) => {
return Err(BatchedQueueError::Disconnected);
}
}
}
Ok(())
}
pub fn flush(&self) -> Result<(), BatchedQueueError> {
let mut batch = self.current_batch.lock();
if !batch.is_empty() {
let partial_batch =
std::mem::replace(&mut *batch, Vec::with_capacity(self.batch_size));
self.batch_sender
.send(partial_batch)
.map_err(|_| BatchedQueueError::Disconnected)?;
}
Ok(())
}
pub fn try_flush(&self) -> Result<(), BatchedQueueError> {
let mut batch = self.current_batch.lock();
if !batch.is_empty() {
let partial_batch =
std::mem::replace(&mut *batch, Vec::with_capacity(self.batch_size));
match self.batch_sender.try_send(partial_batch) {
Ok(_) => Ok(()),
Err(channel::TrySendError::Full(_)) => Err(BatchedQueueError::ChannelFull),
Err(channel::TrySendError::Disconnected(_)) => {
Err(BatchedQueueError::Disconnected)
}
}
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn multithreaded() {
let queue = BatchedQueue::<i32>::new(10).expect("Failed to create queue");
let sender1 = queue.create_sender();
let sender2 = queue.create_sender();
let t1 = thread::spawn(move || {
for i in 0..50 {
sender1.push(i).expect("Failed to push item");
thread::sleep(Duration::from_millis(1));
}
sender1.flush().expect("Failed to flush");
});
let t2 = thread::spawn(move || {
for i in 100..150 {
sender2.push(i).expect("Failed to push item");
thread::sleep(Duration::from_millis(1));
}
sender2.flush().expect("Failed to flush");
});
let t3 = thread::spawn(move || {
let mut all_items = Vec::new();
for _ in 0..15 {
if let Some(batch) = queue.try_next_batch().expect("Failed to get batch") {
all_items.extend(batch);
}
thread::sleep(Duration::from_millis(10));
}
while let Some(batch) = queue.try_next_batch().expect("Failed to get batch") {
all_items.extend(batch);
}
all_items
});
t1.join().unwrap();
t2.join().unwrap();
let result = t3.join().unwrap();
assert_eq!(result.len(), 100);
let mut result_sorted = result.clone();
result_sorted.sort();
for i in 0..50 {
assert!(result_sorted.contains(&i));
assert!(result_sorted.contains(&(i + 100)));
}
}
#[test]
fn timeout() {
let queue = BatchedQueue::<i32>::new(5).expect("Failed to create queue");
let sender = queue.create_sender();
for i in 1..4 {
sender.push(i).unwrap();
}
let result = queue.next_batch_timeout(Duration::from_millis(10));
assert!(result.is_err());
sender.flush().unwrap();
let batch = queue.next_batch_timeout(Duration::from_millis(10)).unwrap();
assert_eq!(batch, vec![1, 2, 3]);
}
#[test]
fn bounded_channel() {
let queue = BatchedQueue::new_bounded(5, 2).expect("Failed to create queue");
let sender = queue.create_sender();
let handle = thread::spawn(move || {
let mut successful_pushes = 0;
for item_idx in 0..20 {
sender.push(item_idx).expect("Failed to push item");
successful_pushes += 1;
if item_idx % 5 == 4 {
thread::sleep(Duration::from_millis(5));
}
}
sender.flush().expect("Failed to flush");
successful_pushes
});
let mut received_batches = 0;
let mut all_items = Vec::new();
while received_batches < 4 {
if let Some(batch) = queue.try_next_batch().expect("Failed to get batch") {
received_batches += 1;
all_items.extend(batch);
}
thread::sleep(Duration::from_millis(5));
}
let successful_pushes = handle.join().unwrap();
while let Some(batch) = queue.try_next_batch().expect("Failed to get batch") {
all_items.extend(batch);
}
assert_eq!(all_items.len(), 20);
assert_eq!(successful_pushes, 20);
let mut sorted_items = all_items.clone();
sorted_items.sort();
for i in 0..20 {
assert!(sorted_items.contains(&i));
}
}
#[test]
fn backpressure() {
let queue = BatchedQueue::new_bounded(5, 1).expect("Failed to create queue"); let sender = queue.create_sender();
for i in 0..5 {
sender.push(i).expect("Failed to push item");
}
for i in 5..8 {
sender.push(i).expect("Failed to push item");
}
let batch = queue.next_batch().expect("Failed to get batch");
assert_eq!(batch, vec![0, 1, 2, 3, 4]);
assert!(sender.try_flush().is_ok());
let batch = queue
.next_batch_timeout(Duration::from_millis(50))
.expect("Failed to get batch");
assert_eq!(batch, vec![5, 6, 7]);
}
#[test]
fn error_handling() {
let invalid_queue = BatchedQueue::<i32>::new(0);
assert!(matches!(
invalid_queue,
Err(BatchedQueueError::InvalidBatchSize(0))
));
let limited_queue = BatchedQueue::new_bounded(5, 1).expect("Failed to create queue");
let limited_sender = limited_queue.create_sender();
for i in 0..5 {
limited_sender
.push(i)
.expect("Should succeed for first batch");
}
for i in 5..9 {
limited_sender
.push(i)
.expect("Should succeed as we're building a partial batch");
}
let result = limited_sender.try_push(9);
assert!(matches!(result, Err(BatchedQueueError::ChannelFull)));
limited_queue
.next_batch()
.expect("Should get the first batch");
let result = limited_queue.next_batch_timeout(Duration::from_millis(1));
assert!(matches!(result, Err(BatchedQueueError::Timeout(_))));
}
#[cfg(test)]
mod stress_tests {
use super::*;
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::{Duration, Instant};
#[test]
fn batched_queue() {
const BATCH_SIZE: usize = 100;
const CHANNEL_CAPACITY: usize = 10;
const PRODUCER_COUNT: usize = 64;
const ITEMS_PER_PRODUCER: usize = 10_000;
const CONSUMER_COUNT: usize = 4;
let queue = Arc::new(
BatchedQueue::new_bounded(BATCH_SIZE, CHANNEL_CAPACITY)
.expect("Failed to create queue"),
);
let start_barrier = Arc::new(Barrier::new(PRODUCER_COUNT + CONSUMER_COUNT + 1));
let total_expected_items = PRODUCER_COUNT * ITEMS_PER_PRODUCER;
let processed_items = Arc::new(AtomicUsize::new(0));
let all_produced_items = Arc::new(parking_lot::Mutex::new(HashSet::new()));
let all_consumed_items = Arc::new(parking_lot::Mutex::new(HashSet::new()));
let producer_times = Arc::new(parking_lot::Mutex::new(Vec::new()));
let consumer_times = Arc::new(parking_lot::Mutex::new(Vec::new()));
let producer_handles: Vec<_> = (0..PRODUCER_COUNT)
.map(|producer_id| {
let queue_sender = queue.create_sender();
let start = start_barrier.clone();
let produced = all_produced_items.clone();
let producer_timing = producer_times.clone();
thread::spawn(move || {
start.wait();
let start_time = Instant::now();
let offset = producer_id * ITEMS_PER_PRODUCER;
let mut local_produced = HashSet::new();
for i in 0..ITEMS_PER_PRODUCER {
let item = offset + i;
queue_sender.push(item).expect("Failed to push item");
local_produced.insert(item);
if i % 1000 == 0 {
thread::sleep(Duration::from_micros(10));
}
}
queue_sender.flush().expect("Failed to flush");
let mut global_produced = produced.lock();
for item in local_produced {
global_produced.insert(item);
}
let elapsed = start_time.elapsed();
producer_timing.lock().push(elapsed);
println!("Producer {}: Finished in {:?}", producer_id, elapsed);
})
})
.collect();
let consumer_handles: Vec<_> = (0..CONSUMER_COUNT)
.map(|consumer_id| {
let queue = queue.clone(); let start = start_barrier.clone();
let processed = processed_items.clone();
let consumed = all_consumed_items.clone();
let consumer_timing = consumer_times.clone();
thread::spawn(move || {
start.wait();
let start_time = Instant::now();
let mut local_consumed = HashSet::new();
let mut batches_processed = 0;
loop {
if let Ok(batch) =
queue.next_batch_timeout(Duration::from_millis(100))
{
batches_processed += 1;
let batch_size = batch.len();
for item in batch {
local_consumed.insert(item);
}
let current = processed.fetch_add(batch_size, Ordering::SeqCst);
if current + batch_size >= total_expected_items {
break;
}
} else if processed.load(Ordering::SeqCst) >= total_expected_items {
break;
}
if processed.load(Ordering::SeqCst) >= total_expected_items {
break;
}
}
let mut global_consumed = consumed.lock();
for item in local_consumed {
global_consumed.insert(item);
}
let elapsed = start_time.elapsed();
consumer_timing.lock().push(elapsed);
println!(
"Consumer {}: Processed {} batches in {:?}",
consumer_id, batches_processed, elapsed
);
})
})
.collect();
println!(
"Starting stress test with {} producers and {} consumers",
PRODUCER_COUNT, CONSUMER_COUNT
);
println!(
"Each producer will generate {} items with batch size {}",
ITEMS_PER_PRODUCER, BATCH_SIZE
);
let overall_start = Instant::now();
start_barrier.wait();
for handle in producer_handles {
handle.join().unwrap();
}
println!("All producers finished");
for handle in consumer_handles {
handle.join().unwrap();
}
let overall_elapsed = overall_start.elapsed();
println!("All consumers finished");
println!("Overall test time: {:?}", overall_elapsed);
let produced = all_produced_items.lock();
let consumed = all_consumed_items.lock();
println!("Items produced: {}", produced.len());
println!("Items consumed: {}", consumed.len());
assert_eq!(
produced.len(),
total_expected_items,
"Number of produced items doesn't match expected"
);
assert_eq!(
consumed.len(),
total_expected_items,
"Number of consumed items doesn't match expected"
);
for item in produced.iter() {
assert!(
consumed.contains(item),
"Item {} was produced but not consumed",
item
);
}
let producer_times = producer_times.lock();
let consumer_times = consumer_times.lock();
let avg_producer_time = producer_times.iter().map(|d| d.as_millis()).sum::<u128>()
/ producer_times.len() as u128;
let avg_consumer_time = consumer_times.iter().map(|d| d.as_millis()).sum::<u128>()
/ consumer_times.len() as u128;
let throughput =
total_expected_items as f64 / (overall_elapsed.as_millis() as f64 / 1000.0);
println!("Average producer time: {}ms", avg_producer_time);
println!("Average consumer time: {}ms", avg_consumer_time);
println!("Throughput: {:.2} items/second", throughput);
}
}
}
}