use crate::error::{IgtlError, Result};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tracing::{debug, info, trace, warn};
#[derive(Debug, Clone)]
pub struct QueueConfig {
pub capacity: Option<usize>,
pub drop_on_full: bool,
}
impl Default for QueueConfig {
fn default() -> Self {
Self {
capacity: Some(1000), drop_on_full: false, }
}
}
impl QueueConfig {
pub fn unbounded() -> Self {
Self {
capacity: None,
drop_on_full: false,
}
}
pub fn bounded(capacity: usize) -> Self {
Self {
capacity: Some(capacity),
drop_on_full: false,
}
}
pub fn bounded_drop_old(capacity: usize) -> Self {
Self {
capacity: Some(capacity),
drop_on_full: true,
}
}
}
pub struct MessageQueue {
tx: mpsc::UnboundedSender<Vec<u8>>,
rx: Arc<Mutex<mpsc::UnboundedReceiver<Vec<u8>>>>,
config: QueueConfig,
stats: Arc<Mutex<QueueStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct QueueStats {
pub enqueued: u64,
pub dequeued: u64,
pub dropped: u64,
pub current_size: usize,
pub peak_size: usize,
}
impl MessageQueue {
pub fn new() -> Self {
Self::with_config(QueueConfig::default())
}
pub fn with_config(config: QueueConfig) -> Self {
info!(
capacity = ?config.capacity,
drop_on_full = config.drop_on_full,
"Creating message queue"
);
let (tx, rx) = mpsc::unbounded_channel();
Self {
tx,
rx: Arc::new(Mutex::new(rx)),
config,
stats: Arc::new(Mutex::new(QueueStats::default())),
}
}
pub async fn enqueue(&self, data: Vec<u8>) -> Result<()> {
let mut stats = self.stats.lock().await;
if let Some(capacity) = self.config.capacity {
if stats.current_size >= capacity {
if self.config.drop_on_full {
warn!(
capacity = capacity,
current_size = stats.current_size,
"Queue full, dropping oldest message"
);
drop(stats); #[allow(clippy::redundant_pattern_matching)]
if let Ok(_) = self.try_dequeue().await {
stats = self.stats.lock().await;
stats.dropped += 1;
} else {
return Err(IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"Queue full and cannot drop oldest",
)));
}
} else {
debug!(
capacity = capacity,
current_size = stats.current_size,
"Queue full, rejecting enqueue"
);
return Err(IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"Queue full",
)));
}
}
}
let size = data.len();
self.tx.send(data).map_err(|_| {
warn!("Failed to enqueue: queue closed");
IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Queue closed",
))
})?;
stats.enqueued += 1;
stats.current_size += 1;
if stats.current_size > stats.peak_size {
stats.peak_size = stats.current_size;
}
trace!(
size = size,
queue_size = stats.current_size,
"Message enqueued"
);
Ok(())
}
pub async fn dequeue(&self) -> Result<Vec<u8>> {
let mut rx = self.rx.lock().await;
match rx.recv().await {
Some(data) => {
let size = data.len();
drop(rx); let mut stats = self.stats.lock().await;
stats.dequeued += 1;
stats.current_size = stats.current_size.saturating_sub(1);
trace!(
size = size,
queue_size = stats.current_size,
"Message dequeued"
);
Ok(data)
}
None => {
warn!("Dequeue failed: queue closed");
Err(IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Queue closed",
)))
}
}
}
pub async fn try_dequeue(&self) -> Result<Vec<u8>> {
let mut rx = self.rx.lock().await;
match rx.try_recv() {
Ok(data) => {
drop(rx);
let mut stats = self.stats.lock().await;
stats.dequeued += 1;
stats.current_size = stats.current_size.saturating_sub(1);
Ok(data)
}
Err(mpsc::error::TryRecvError::Empty) => Err(IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"Queue empty",
))),
Err(mpsc::error::TryRecvError::Disconnected) => Err(IgtlError::Io(
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Queue closed"),
)),
}
}
pub async fn size(&self) -> usize {
self.stats.lock().await.current_size
}
pub async fn stats(&self) -> QueueStats {
self.stats.lock().await.clone()
}
pub async fn is_empty(&self) -> bool {
self.stats.lock().await.current_size == 0
}
pub fn config(&self) -> &QueueConfig {
&self.config
}
}
impl Default for MessageQueue {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_unbounded_queue() {
let queue = MessageQueue::with_config(QueueConfig::unbounded());
for i in 0..100 {
let data = vec![i as u8];
queue.enqueue(data).await.unwrap();
}
assert_eq!(queue.size().await, 100);
for i in 0..100 {
let data = queue.dequeue().await.unwrap();
assert_eq!(data, vec![i as u8]);
}
assert!(queue.is_empty().await);
}
#[tokio::test]
async fn test_bounded_queue() {
let queue = MessageQueue::with_config(QueueConfig::bounded(10));
for i in 0..10 {
let data = vec![i as u8];
queue.enqueue(data).await.unwrap();
}
let result = queue.enqueue(vec![100]).await;
assert!(result.is_err());
let _ = queue.dequeue().await.unwrap();
let result = queue.enqueue(vec![100]).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_bounded_drop_old() {
let queue = MessageQueue::with_config(QueueConfig::bounded_drop_old(5));
for i in 0..5 {
let data = vec![i as u8];
queue.enqueue(data).await.unwrap();
}
for i in 5..10 {
let data = vec![i as u8];
queue.enqueue(data).await.unwrap();
}
assert_eq!(queue.size().await, 5);
let data = queue.dequeue().await.unwrap();
assert_eq!(data, vec![5]);
let stats = queue.stats().await;
assert_eq!(stats.enqueued, 10);
assert_eq!(stats.dropped, 5);
}
#[tokio::test]
async fn test_try_dequeue_empty() {
let queue = MessageQueue::new();
let result = queue.try_dequeue().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_queue_stats() {
let queue = MessageQueue::new();
for i in 0..10 {
queue.enqueue(vec![i]).await.unwrap();
}
for _ in 0..5 {
let _ = queue.dequeue().await.unwrap();
}
let stats = queue.stats().await;
assert_eq!(stats.enqueued, 10);
assert_eq!(stats.dequeued, 5);
assert_eq!(stats.current_size, 5);
assert_eq!(stats.peak_size, 10);
}
#[tokio::test]
async fn test_concurrent_access() {
let queue = Arc::new(MessageQueue::with_config(QueueConfig::bounded(100)));
let queue_clone = queue.clone();
let producer = tokio::spawn(async move {
for i in 0..50 {
queue_clone.enqueue(vec![i as u8]).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
});
let queue_clone = queue.clone();
let consumer = tokio::spawn(async move {
for _ in 0..50 {
let _ = queue_clone.dequeue().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
});
producer.await.unwrap();
consumer.await.unwrap();
assert!(queue.is_empty().await);
}
}