use std::collections::HashMap;
use std::sync::Arc;
use lapin::options::{
BasicAckOptions, BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions,
};
use lapin::types::FieldTable;
use lapin::{BasicProperties, Channel, Connection, ConnectionProperties, Consumer};
use tokio::sync::Mutex;
use crate::error::{DaimonError, Result};
use super::broker::TaskBroker;
use super::types::{AgentTask, TaskResult, TaskStatus};
pub struct AmqpBroker {
channel: Channel,
queue_name: String,
statuses: Arc<Mutex<HashMap<String, TaskStatus>>>,
consumer: Arc<Mutex<Option<Consumer>>>,
}
impl AmqpBroker {
pub async fn connect(url: &str, queue_name: impl Into<String>) -> Result<Self> {
let queue_name = queue_name.into();
let conn = Connection::connect(url, ConnectionProperties::default())
.await
.map_err(|e| DaimonError::Other(format!("amqp connect: {e}")))?;
let channel = conn
.create_channel()
.await
.map_err(|e| DaimonError::Other(format!("amqp channel: {e}")))?;
channel
.queue_declare(
&queue_name,
QueueDeclareOptions {
durable: true,
..Default::default()
},
FieldTable::default(),
)
.await
.map_err(|e| DaimonError::Other(format!("amqp declare queue: {e}")))?;
Ok(Self {
channel,
queue_name,
statuses: Arc::new(Mutex::new(HashMap::new())),
consumer: Arc::new(Mutex::new(None)),
})
}
pub fn channel(&self) -> &Channel {
&self.channel
}
async fn ensure_consumer(&self) -> Result<Consumer> {
let mut guard = self.consumer.lock().await;
if let Some(ref consumer) = *guard {
return Ok(consumer.clone());
}
let consumer = self
.channel
.basic_consume(
&self.queue_name,
"daimon-worker",
BasicConsumeOptions {
no_ack: false,
..Default::default()
},
FieldTable::default(),
)
.await
.map_err(|e| DaimonError::Other(format!("amqp consume: {e}")))?;
*guard = Some(consumer.clone());
Ok(consumer)
}
}
impl TaskBroker for AmqpBroker {
async fn submit(&self, task: AgentTask) -> Result<String> {
let id = task.task_id.clone();
let json = serde_json::to_string(&task)
.map_err(|e| DaimonError::Other(format!("serialize task: {e}")))?;
{
let mut statuses = self.statuses.lock().await;
statuses.insert(id.clone(), TaskStatus::Pending);
}
self.channel
.basic_publish(
"",
&self.queue_name,
BasicPublishOptions::default(),
json.as_bytes(),
BasicProperties::default()
.with_delivery_mode(2)
.with_content_type("application/json".into()),
)
.await
.map_err(|e| DaimonError::Other(format!("amqp publish: {e}")))?
.await
.map_err(|e| DaimonError::Other(format!("amqp publish confirm: {e}")))?;
Ok(id)
}
async fn status(&self, task_id: &str) -> Result<TaskStatus> {
let statuses = self.statuses.lock().await;
Ok(statuses
.get(task_id)
.cloned()
.unwrap_or(TaskStatus::Pending))
}
async fn receive(&self) -> Result<Option<AgentTask>> {
use futures::StreamExt;
let consumer = self.ensure_consumer().await?;
let mut stream = consumer;
match stream.next().await {
Some(Ok(delivery)) => {
let task: AgentTask = serde_json::from_slice(&delivery.data)
.map_err(|e| DaimonError::Other(format!("deserialize task: {e}")))?;
delivery
.ack(BasicAckOptions::default())
.await
.map_err(|e| DaimonError::Other(format!("amqp ack: {e}")))?;
{
let mut statuses = self.statuses.lock().await;
statuses.insert(task.task_id.clone(), TaskStatus::Running);
}
Ok(Some(task))
}
Some(Err(e)) => Err(DaimonError::Other(format!("amqp delivery error: {e}"))),
None => Ok(None),
}
}
async fn complete(&self, task_id: &str, result: TaskResult) -> Result<()> {
let mut statuses = self.statuses.lock().await;
statuses.insert(task_id.to_string(), TaskStatus::Completed(result));
Ok(())
}
async fn fail(&self, task_id: &str, error: String) -> Result<()> {
let mut statuses = self.statuses.lock().await;
statuses.insert(task_id.to_string(), TaskStatus::Failed(error));
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_serialization_for_amqp() {
let task = AgentTask::new("amqp test")
.with_metadata("routing", serde_json::json!("high-priority"));
let json = serde_json::to_string(&task).unwrap();
let bytes = json.as_bytes();
let deser: AgentTask = serde_json::from_slice(bytes).unwrap();
assert_eq!(deser.input, "amqp test");
assert_eq!(deser.metadata["routing"], "high-priority");
}
#[test]
fn test_result_serialization_for_amqp() {
let result = TaskResult {
task_id: "t-amqp".into(),
output: "amqp result".into(),
iterations: 1,
cost: 0.002,
error: None,
};
let json = serde_json::to_string(&result).unwrap();
let deser: TaskResult = serde_json::from_str(&json).unwrap();
assert_eq!(deser.output, "amqp result");
}
#[test]
fn test_status_tracking_in_memory() {
let statuses: HashMap<String, TaskStatus> = HashMap::new();
assert!(statuses.get("unknown").is_none());
}
}