use crate::{
Task, TaskId, TaskStatus,
backend::{TaskBackend, TaskExecutionError},
backends::metadata_store::{InMemoryMetadataStore, MetadataStore, TaskMetadata},
registry::SerializedTask,
};
use async_trait::async_trait;
use lapin::{
BasicProperties, Channel, Connection, ConnectionProperties, Error as LapinError, options::*,
types::FieldTable,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct QueueMessage {
id: TaskId,
name: String,
created_at: i64,
}
#[derive(Debug, Clone)]
pub struct RabbitMQConfig {
pub url: String,
pub queue_name: String,
pub exchange_name: String,
pub routing_key: String,
}
impl RabbitMQConfig {
pub fn new(url: &str) -> Self {
Self {
url: url.to_string(),
queue_name: "reinhardt_tasks".to_string(),
exchange_name: String::new(),
routing_key: "reinhardt_tasks".to_string(),
}
}
pub fn with_queue_name(mut self, queue_name: &str) -> Self {
self.queue_name = queue_name.to_string();
self.routing_key = queue_name.to_string();
self
}
pub fn with_exchange(mut self, exchange_name: &str) -> Self {
self.exchange_name = exchange_name.to_string();
self
}
pub fn with_routing_key(mut self, routing_key: &str) -> Self {
self.routing_key = routing_key.to_string();
self
}
}
impl Default for RabbitMQConfig {
fn default() -> Self {
Self::new("amqp://localhost:5672/%2f")
}
}
#[cfg_attr(doc, aquamarine::aquamarine)]
pub struct RabbitMQBackend {
connection: Arc<Connection>,
channel: Arc<RwLock<Channel>>,
config: RabbitMQConfig,
metadata_store: Arc<dyn MetadataStore>,
}
impl RabbitMQBackend {
pub async fn new(config: RabbitMQConfig) -> Result<Self, LapinError> {
let metadata_store = Arc::new(InMemoryMetadataStore::new());
Self::with_metadata_store(config, metadata_store).await
}
pub async fn with_metadata_store(
config: RabbitMQConfig,
metadata_store: Arc<dyn MetadataStore>,
) -> Result<Self, LapinError> {
let connection = Connection::connect(&config.url, ConnectionProperties::default()).await?;
let channel = connection.create_channel().await?;
channel
.queue_declare(
&config.queue_name,
QueueDeclareOptions {
durable: true,
..Default::default()
},
FieldTable::default(),
)
.await?;
Ok(Self {
connection: Arc::new(connection),
channel: Arc::new(RwLock::new(channel)),
config,
metadata_store,
})
}
async fn ensure_connection(&self) -> Result<(), TaskExecutionError> {
if !self.connection.status().connected() {
return Err(TaskExecutionError::BackendError(
"RabbitMQ connection lost".to_string(),
));
}
Ok(())
}
async fn get_channel(&self) -> Result<Channel, TaskExecutionError> {
self.ensure_connection().await?;
let channel = self.channel.read().await;
if channel.status().connected() {
return Ok(channel.clone());
}
drop(channel);
let new_channel = self
.connection
.create_channel()
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
new_channel
.queue_declare(
&self.config.queue_name,
QueueDeclareOptions {
durable: true,
..Default::default()
},
FieldTable::default(),
)
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
let mut channel_write = self.channel.write().await;
*channel_write = new_channel.clone();
Ok(new_channel)
}
}
#[async_trait]
impl TaskBackend for RabbitMQBackend {
async fn enqueue(&self, task: Box<dyn Task>) -> Result<TaskId, TaskExecutionError> {
let task_id = task.id();
let task_name = task.name().to_string();
let metadata = TaskMetadata::new(task_id, task_name.clone());
self.metadata_store
.store(metadata)
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
let queue_message = QueueMessage {
id: task_id,
name: task_name,
created_at: chrono::Utc::now().timestamp(),
};
let message_json = serde_json::to_string(&queue_message)
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
let channel = self.get_channel().await?;
channel
.basic_publish(
&self.config.exchange_name,
&self.config.routing_key,
BasicPublishOptions::default(),
message_json.as_bytes(),
BasicProperties::default().with_delivery_mode(2), )
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
Ok(task_id)
}
async fn dequeue(&self) -> Result<Option<TaskId>, TaskExecutionError> {
let channel = self.get_channel().await?;
let delivery = channel
.basic_get(&self.config.queue_name, BasicGetOptions { no_ack: false })
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
match delivery {
Some(delivery) => {
let queue_message: QueueMessage = serde_json::from_slice(&delivery.data)
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
self.metadata_store
.update_status(queue_message.id, TaskStatus::Running)
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
delivery
.ack(BasicAckOptions::default())
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
Ok(Some(queue_message.id))
}
None => Ok(None),
}
}
async fn get_status(&self, task_id: TaskId) -> Result<TaskStatus, TaskExecutionError> {
let metadata = self
.metadata_store
.get(task_id)
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
match metadata {
Some(m) => Ok(m.status),
None => Err(TaskExecutionError::NotFound(task_id)),
}
}
async fn update_status(
&self,
task_id: TaskId,
status: TaskStatus,
) -> Result<(), TaskExecutionError> {
self.metadata_store
.update_status(task_id, status)
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
Ok(())
}
async fn get_task_data(
&self,
task_id: TaskId,
) -> Result<Option<SerializedTask>, TaskExecutionError> {
let metadata = self
.metadata_store
.get(task_id)
.await
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
match metadata {
Some(m) => {
if let Some(task_data) = m.task_data {
Ok(Some(task_data))
} else {
Ok(Some(SerializedTask::new(m.name, "{}".to_string())))
}
}
None => Ok(None),
}
}
fn backend_name(&self) -> &str {
"rabbitmq"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::metadata_store::MetadataStoreError;
use rstest::rstest;
#[rstest]
fn test_rabbitmq_config_new() {
let config = RabbitMQConfig::new("amqp://localhost:5672/%2f");
assert_eq!(config.url, "amqp://localhost:5672/%2f");
assert_eq!(config.queue_name, "reinhardt_tasks");
assert_eq!(config.exchange_name, "");
assert_eq!(config.routing_key, "reinhardt_tasks");
}
#[test]
fn test_rabbitmq_config_with_queue_name() {
let config =
RabbitMQConfig::new("amqp://localhost:5672/%2f").with_queue_name("custom_queue");
assert_eq!(config.queue_name, "custom_queue");
assert_eq!(config.routing_key, "custom_queue");
}
#[test]
fn test_rabbitmq_config_with_exchange() {
let config = RabbitMQConfig::new("amqp://localhost:5672/%2f").with_exchange("my_exchange");
assert_eq!(config.exchange_name, "my_exchange");
}
#[test]
fn test_rabbitmq_config_with_routing_key() {
let config =
RabbitMQConfig::new("amqp://localhost:5672/%2f").with_routing_key("my_routing_key");
assert_eq!(config.routing_key, "my_routing_key");
}
#[test]
fn test_rabbitmq_config_default() {
let config = RabbitMQConfig::default();
assert_eq!(config.url, "amqp://localhost:5672/%2f");
}
#[test]
fn test_queue_message_serialization() {
let message = QueueMessage {
id: TaskId::new(),
name: "test_task".to_string(),
created_at: 1234567890,
};
let json = serde_json::to_string(&message).unwrap();
let deserialized: QueueMessage = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, message.id);
assert_eq!(deserialized.name, message.name);
assert_eq!(deserialized.created_at, message.created_at);
}
#[rstest]
#[case::not_found_error(
MetadataStoreError::NotFound(TaskId::new()),
"not found in metadata store"
)]
#[case::storage_error(
MetadataStoreError::StorageError("connection refused".to_string()),
"connection refused"
)]
#[case::serialization_error(
MetadataStoreError::SerializationError("invalid JSON".to_string()),
"invalid JSON"
)]
fn test_metadata_store_error_converts_to_backend_error(
#[case] metadata_error: MetadataStoreError,
#[case] expected_substring: &str,
) {
let error_message = metadata_error.to_string();
let backend_error = TaskExecutionError::BackendError(error_message);
let error_string = backend_error.to_string();
assert!(
error_string.contains(expected_substring),
"Expected error string '{}' to contain '{}'",
error_string,
expected_substring,
);
assert!(matches!(backend_error, TaskExecutionError::BackendError(_)));
}
#[rstest]
#[tokio::test]
async fn test_metadata_update_status_error_propagation_path() {
let store = InMemoryMetadataStore::new();
let nonexistent_id = TaskId::new();
let result = store
.update_status(nonexistent_id, TaskStatus::Running)
.await;
assert!(result.is_err());
let metadata_err = result.unwrap_err();
let backend_err = TaskExecutionError::BackendError(metadata_err.to_string());
assert!(
matches!(backend_err, TaskExecutionError::BackendError(msg) if msg.contains("not found"))
);
}
}