use async_trait::async_trait;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::{Mutex, Notify};
use crate::broker::Broker;
use crate::error::TaskResult;
use crate::message::TaskMessage;
use crate::task_id::TaskId;
#[derive(Clone)]
pub struct MemoryBroker {
inner: Arc<MemoryBrokerInner>,
}
struct MemoryBrokerInner {
queues: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
dlq: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
processing: Mutex<HashMap<TaskId, TaskMessage>>,
notify: Notify,
}
impl MemoryBroker {
pub fn new() -> Self {
Self {
inner: Arc::new(MemoryBrokerInner {
queues: Mutex::new(HashMap::new()),
dlq: Mutex::new(HashMap::new()),
processing: Mutex::new(HashMap::new()),
notify: Notify::new(),
}),
}
}
}
impl Default for MemoryBroker {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Broker for MemoryBroker {
async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
let mut queues = self.inner.queues.lock().await;
queues
.entry(message.queue.clone())
.or_default()
.push_back(message);
self.inner.notify.notify_one();
Ok(())
}
async fn dequeue(
&self,
queues: &[String],
timeout: std::time::Duration,
) -> TaskResult<Option<TaskMessage>> {
let deadline = tokio::time::Instant::now() + timeout;
loop {
{
let mut q = self.inner.queues.lock().await;
for queue_name in queues {
if let Some(queue) = q.get_mut(queue_name) {
if let Some(msg) = queue.pop_front() {
self.inner
.processing
.lock()
.await
.insert(msg.id, msg.clone());
return Ok(Some(msg));
}
}
}
}
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Ok(None);
}
tokio::select! {
_ = self.inner.notify.notified() => continue,
_ = tokio::time::sleep(remaining) => return Ok(None),
}
}
}
async fn ack(&self, id: &TaskId) -> TaskResult<()> {
self.inner.processing.lock().await.remove(id);
Ok(())
}
async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
self.inner.processing.lock().await.remove(&message.id);
self.enqueue(message).await
}
async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
self.inner.processing.lock().await.remove(&message.id);
let dlq_name = message.queue.clone();
let mut dlq = self.inner.dlq.lock().await;
dlq.entry(dlq_name).or_default().push_back(message);
Ok(())
}
async fn schedule(
&self,
message: TaskMessage,
_eta: chrono::DateTime<chrono::Utc>,
) -> TaskResult<()> {
self.enqueue(message).await
}
async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
let queues = self.inner.queues.lock().await;
Ok(queues.get(queue).map_or(0, |q| q.len()))
}
async fn dlq_len(&self, queue: &str) -> TaskResult<usize> {
let dlq = self.inner.dlq.lock().await;
Ok(dlq.get(queue).map_or(0, |q| q.len()))
}
async fn list_queues(&self) -> TaskResult<Vec<String>> {
let queues = self.inner.queues.lock().await;
Ok(queues.keys().cloned().collect())
}
async fn dlq_messages(
&self,
queue: &str,
offset: usize,
limit: usize,
) -> TaskResult<Vec<TaskMessage>> {
let dlq = self.inner.dlq.lock().await;
Ok(dlq
.get(queue)
.map(|q| q.iter().skip(offset).take(limit).cloned().collect())
.unwrap_or_default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn enqueue_dequeue_fifo() {
let broker = MemoryBroker::new();
let msg1 = TaskMessage::new("task1", "default", serde_json::json!(1));
let msg2 = TaskMessage::new("task2", "default", serde_json::json!(2));
broker.enqueue(msg1.clone()).await.unwrap();
broker.enqueue(msg2.clone()).await.unwrap();
let queues = vec!["default".to_string()];
let out1 = broker
.dequeue(&queues, Duration::from_secs(1))
.await
.unwrap()
.unwrap();
let out2 = broker
.dequeue(&queues, Duration::from_secs(1))
.await
.unwrap()
.unwrap();
assert_eq!(out1.task_name, "task1");
assert_eq!(out2.task_name, "task2");
}
#[tokio::test]
async fn dequeue_timeout_returns_none() {
let broker = MemoryBroker::new();
let queues = vec!["default".to_string()];
let result = broker
.dequeue(&queues, Duration::from_millis(50))
.await
.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn ack_removes_from_processing() {
let broker = MemoryBroker::new();
let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
let id = msg.id;
broker.enqueue(msg).await.unwrap();
let queues = vec!["default".to_string()];
let _out = broker
.dequeue(&queues, Duration::from_secs(1))
.await
.unwrap()
.unwrap();
assert!(broker.inner.processing.lock().await.contains_key(&id));
broker.ack(&id).await.unwrap();
assert!(!broker.inner.processing.lock().await.contains_key(&id));
}
#[tokio::test]
async fn nack_requeues() {
let broker = MemoryBroker::new();
let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
broker.enqueue(msg).await.unwrap();
let queues = vec!["default".to_string()];
let out = broker
.dequeue(&queues, Duration::from_secs(1))
.await
.unwrap()
.unwrap();
broker.nack(out).await.unwrap();
assert_eq!(broker.queue_len("default").await.unwrap(), 1);
}
#[tokio::test]
async fn dead_letter() {
let broker = MemoryBroker::new();
let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
broker.enqueue(msg).await.unwrap();
let queues = vec!["default".to_string()];
let out = broker
.dequeue(&queues, Duration::from_secs(1))
.await
.unwrap()
.unwrap();
broker.dead_letter(out).await.unwrap();
assert_eq!(broker.queue_len("default").await.unwrap(), 0);
assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
}
#[tokio::test]
async fn queue_len() {
let broker = MemoryBroker::new();
assert_eq!(broker.queue_len("default").await.unwrap(), 0);
broker
.enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
.await
.unwrap();
broker
.enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
.await
.unwrap();
assert_eq!(broker.queue_len("default").await.unwrap(), 2);
}
#[tokio::test]
async fn list_queues_returns_known_queues() {
let broker = MemoryBroker::new();
broker
.enqueue(TaskMessage::new("t", "emails", serde_json::json!(1)))
.await
.unwrap();
broker
.enqueue(TaskMessage::new("t", "notifications", serde_json::json!(2)))
.await
.unwrap();
let mut queues = broker.list_queues().await.unwrap();
queues.sort();
assert_eq!(queues, vec!["emails", "notifications"]);
}
#[tokio::test]
async fn dlq_len_via_trait() {
let broker = MemoryBroker::new();
let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
broker.enqueue(msg).await.unwrap();
let queues = vec!["default".to_string()];
let out = broker
.dequeue(&queues, Duration::from_secs(1))
.await
.unwrap()
.unwrap();
broker.dead_letter(out).await.unwrap();
assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
}
#[tokio::test]
async fn dlq_messages_pagination() {
let broker = MemoryBroker::new();
let queues = vec!["default".to_string()];
for i in 0..5 {
let msg = TaskMessage::new("task", "default", serde_json::json!(i));
broker.enqueue(msg).await.unwrap();
let out = broker
.dequeue(&queues, Duration::from_secs(1))
.await
.unwrap()
.unwrap();
broker.dead_letter(out).await.unwrap();
}
let page1 = broker.dlq_messages("default", 0, 3).await.unwrap();
assert_eq!(page1.len(), 3);
let page2 = broker.dlq_messages("default", 3, 3).await.unwrap();
assert_eq!(page2.len(), 2);
}
#[tokio::test]
async fn dlq_messages_empty() {
let broker = MemoryBroker::new();
let messages = broker.dlq_messages("nonexistent", 0, 10).await.unwrap();
assert!(messages.is_empty());
}
}