use crate::ack::{AckHandle, AckSubscriber};
use crate::memory::{InMemoryBroker, InMemoryError};
use crate::Message;
use async_trait::async_trait;
use std::sync::{Arc, Weak};
use std::time::SystemTime;
use tokio::sync::{mpsc, Mutex};
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct InMemoryAckHandleFixed {
message_id: String,
topic: String,
timestamp: SystemTime,
delivery_count: u32,
broker: Weak<InMemoryBroker>,
handle_id: String,
}
impl InMemoryAckHandleFixed {
pub fn new(
message_id: String,
topic: String,
timestamp: SystemTime,
delivery_count: u32,
broker: Weak<InMemoryBroker>,
) -> Self {
Self {
message_id,
topic,
timestamp,
delivery_count,
broker,
handle_id: Uuid::new_v4().to_string(),
}
}
}
impl AckHandle for InMemoryAckHandleFixed {
fn message_id(&self) -> &str {
&self.message_id
}
fn topic(&self) -> &str {
&self.topic
}
fn timestamp(&self) -> SystemTime {
self.timestamp
}
fn delivery_count(&self) -> u32 {
self.delivery_count
}
}
pub struct InMemoryAckSubscriberFixed {
broker: Arc<InMemoryBroker>,
state: Mutex<SubscriberState>,
}
struct SubscriberState {
subscribed_topic: Option<String>,
receiver: Option<mpsc::UnboundedReceiver<Message>>,
pending_acks: std::collections::HashMap<String, (Message, u32)>,
}
impl InMemoryAckSubscriberFixed {
pub fn new(broker: Arc<InMemoryBroker>) -> Self {
Self {
broker,
state: Mutex::new(SubscriberState {
subscribed_topic: None,
receiver: None,
pending_acks: std::collections::HashMap::new(),
}),
}
}
pub async fn is_subscribed(&self) -> bool {
let state = self.state.lock().await;
state.subscribed_topic.is_some()
}
pub async fn subscribed_topic(&self) -> Option<String> {
let state = self.state.lock().await;
state.subscribed_topic.clone()
}
}
#[async_trait]
impl AckSubscriber for InMemoryAckSubscriberFixed {
type Error = InMemoryError;
type AckHandle = InMemoryAckHandleFixed;
async fn subscribe(&self, topic: &str) -> Result<(), Self::Error> {
if self.broker.is_shutdown() {
return Err(InMemoryError::BrokerShutdown);
}
if topic.is_empty() || topic.contains('\0') {
return Err(InMemoryError::invalid_topic_name(topic));
}
let receiver = self.broker.subscribe(topic)?;
let mut state = self.state.lock().await;
state.receiver = Some(receiver);
state.subscribed_topic = Some(topic.to_string());
Ok(())
}
async fn receive_with_ack(&mut self) -> Result<(Message, Self::AckHandle), Self::Error> {
if self.broker.is_shutdown() {
return Err(InMemoryError::BrokerShutdown);
}
let mut state = self.state.lock().await;
let topic = state.subscribed_topic.clone();
if let Some(ref mut receiver) = state.receiver {
if let Some(topic_name) = topic {
match receiver.recv().await {
Some(message) => {
let handle = InMemoryAckHandleFixed::new(
message.uuid.clone(),
topic_name,
SystemTime::now(),
1, Arc::downgrade(&self.broker),
);
state.pending_acks.insert(handle.handle_id.clone(), (message.clone(), 1));
Ok((message, handle))
}
None => Err(InMemoryError::ChannelReceiveError {
message: "Channel closed".to_string()
}),
}
} else {
Err(InMemoryError::ChannelReceiveError {
message: "Not subscribed to any topic".to_string()
})
}
} else {
Err(InMemoryError::ChannelReceiveError {
message: "No receiver available".to_string()
})
}
}
async fn ack(&self, handle: Self::AckHandle) -> Result<(), Self::Error> {
if self.broker.is_shutdown() {
return Err(InMemoryError::BrokerShutdown);
}
let mut state = self.state.lock().await;
state.pending_acks.remove(&handle.handle_id);
if let Some(stats) = self.broker.stats() {
stats.increment_messages_consumed(1);
}
Ok(())
}
async fn nack(&self, handle: Self::AckHandle, requeue: bool) -> Result<(), Self::Error> {
if self.broker.is_shutdown() {
return Err(InMemoryError::BrokerShutdown);
}
let mut state = self.state.lock().await;
if requeue {
if let Some((message, delivery_count)) = state.pending_acks.get(&handle.handle_id).cloned() {
let new_delivery_count = delivery_count + 1;
state.pending_acks.insert(
handle.handle_id.clone(),
(message, new_delivery_count)
);
}
} else {
state.pending_acks.remove(&handle.handle_id);
}
Ok(())
}
async fn ack_batch(&self, handles: Vec<Self::AckHandle>) -> Result<(), Self::Error> {
for handle in handles {
self.ack(handle).await?;
}
Ok(())
}
async fn nack_batch(&self, handles: Vec<Self::AckHandle>, requeue: bool) -> Result<(), Self::Error> {
for handle in handles {
self.nack(handle, requeue).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{InMemoryBroker, InMemoryPublisher};
use crate::{Publisher};
use std::time::Duration;
use tokio::time::timeout;
fn create_test_message(content: &str) -> Message {
let mut msg = Message::new(content.as_bytes().to_vec());
msg = msg.with_metadata("test", "true");
msg = msg.with_metadata("content", content);
msg
}
#[tokio::test]
async fn test_fixed_basic_acknowledgment() {
let broker = Arc::new(InMemoryBroker::with_default_config());
let publisher = InMemoryPublisher::new(broker.clone());
let mut subscriber = InMemoryAckSubscriberFixed::new(broker.clone());
let topic = "test_topic";
let test_message = create_test_message("Test message");
subscriber.subscribe(topic).await.expect("Failed to subscribe");
publisher.publish(topic, vec![test_message.clone()]).await.expect("Failed to publish");
let (received, handle) = timeout(Duration::from_secs(2), subscriber.receive_with_ack()).await
.expect("Timeout waiting for message")
.expect("Failed to receive message with ack");
assert_eq!(received.payload, test_message.payload);
assert_eq!(handle.topic(), topic);
assert_eq!(handle.delivery_count(), 1);
assert!(!handle.is_retry());
subscriber.ack(handle).await.expect("Failed to acknowledge message");
}
#[tokio::test]
async fn test_fixed_negative_acknowledgment() {
let broker = Arc::new(InMemoryBroker::with_default_config());
let publisher = InMemoryPublisher::new(broker.clone());
let mut subscriber = InMemoryAckSubscriberFixed::new(broker.clone());
let topic = "nack_topic";
let test_message = create_test_message("Nack test message");
subscriber.subscribe(topic).await.expect("Failed to subscribe");
publisher.publish(topic, vec![test_message.clone()]).await.expect("Failed to publish");
let (received, handle) = timeout(Duration::from_secs(2), subscriber.receive_with_ack()).await
.expect("Timeout waiting for message")
.expect("Failed to receive message with ack");
assert_eq!(received.payload, test_message.payload);
subscriber.nack(handle, true).await.expect("Failed to nack message");
}
#[tokio::test]
async fn test_fixed_batch_acknowledgment() {
let broker = Arc::new(InMemoryBroker::with_default_config());
let publisher = InMemoryPublisher::new(broker.clone());
let mut subscriber = InMemoryAckSubscriberFixed::new(broker.clone());
let topic = "batch_topic";
let test_messages = vec![
create_test_message("Batch message 1"),
create_test_message("Batch message 2"),
create_test_message("Batch message 3"),
];
subscriber.subscribe(topic).await.expect("Failed to subscribe");
publisher.publish(topic, test_messages.clone()).await.expect("Failed to publish");
let mut handles = Vec::new();
for i in 0..3 {
let (received, handle) = timeout(Duration::from_secs(2), subscriber.receive_with_ack()).await
.expect(&format!("Timeout waiting for message {}", i + 1))
.expect(&format!("Failed to receive message {}", i + 1));
assert!(String::from_utf8_lossy(&received.payload).contains("Batch message"));
handles.push(handle);
}
subscriber.ack_batch(handles).await.expect("Failed to batch acknowledge");
}
}