use crate::{
Task, TaskExecutionError, TaskId, TaskStatus, registry::SerializedTask, result::ResultBackend,
result::TaskResultMetadata,
};
use async_trait::async_trait;
use aws_sdk_sqs::{Client, types::MessageAttributeValue};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TaskMetadata {
id: TaskId,
name: String,
status: TaskStatus,
created_at: i64,
updated_at: i64,
}
#[derive(Debug, Clone)]
pub struct SqsConfig {
queue_url: String,
visibility_timeout: i32,
max_messages: i32,
wait_time_seconds: i32,
}
impl SqsConfig {
pub fn new(queue_url: impl Into<String>) -> Self {
Self {
queue_url: queue_url.into(),
visibility_timeout: 30,
max_messages: 1,
wait_time_seconds: 0,
}
}
pub fn with_visibility_timeout(mut self, timeout: i32) -> Self {
self.visibility_timeout = timeout;
self
}
pub fn with_max_messages(mut self, max_messages: i32) -> Self {
self.max_messages = max_messages.min(10); self
}
pub fn with_wait_time_seconds(mut self, wait_time: i32) -> Self {
self.wait_time_seconds = wait_time;
self
}
}
pub struct SqsBackend {
client: Client,
config: SqsConfig,
metadata_store: Arc<RwLock<HashMap<TaskId, TaskMetadata>>>,
receipt_handles: Arc<RwLock<HashMap<TaskId, String>>>,
}
impl SqsBackend {
pub async fn new(config: SqsConfig) -> Result<Self, TaskExecutionError> {
let aws_config = aws_config::load_from_env().await;
let client = Client::new(&aws_config);
Ok(Self {
client,
config,
metadata_store: Arc::new(RwLock::new(HashMap::new())),
receipt_handles: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn with_config(config: SqsConfig, aws_config: &aws_config::SdkConfig) -> Self {
let client = Client::new(aws_config);
Self {
client,
config,
metadata_store: Arc::new(RwLock::new(HashMap::new())),
receipt_handles: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn delete_message(&self, receipt_handle: &str) -> Result<(), TaskExecutionError> {
self.client
.delete_message()
.queue_url(&self.config.queue_url)
.receipt_handle(receipt_handle)
.send()
.await
.map_err(|e| TaskExecutionError::BackendError(format!("SQS delete error: {}", e)))?;
Ok(())
}
}
#[async_trait]
impl crate::backend::TaskBackend for SqsBackend {
async fn enqueue(&self, task: Box<dyn Task>) -> Result<TaskId, TaskExecutionError> {
let task_id = task.id();
let task_name = task.name().to_string();
let metadata = TaskMetadata {
id: task_id,
name: task_name.clone(),
status: TaskStatus::Pending,
created_at: chrono::Utc::now().timestamp(),
updated_at: chrono::Utc::now().timestamp(),
};
{
let mut store = self.metadata_store.write().await;
store.insert(task_id, metadata.clone());
}
let serialized_task = SerializedTask::new(task_name, "{}".to_string());
let message_body = serialized_task
.to_json()
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
let metadata_json = serde_json::to_string(&metadata)
.map_err(|e| TaskExecutionError::BackendError(e.to_string()))?;
self.client
.send_message()
.queue_url(&self.config.queue_url)
.message_body(message_body)
.message_attributes(
"task_id",
MessageAttributeValue::builder()
.data_type("String")
.string_value(task_id.to_string())
.build()
.map_err(|e| {
TaskExecutionError::BackendError(format!("Message attribute error: {}", e))
})?,
)
.message_attributes(
"metadata",
MessageAttributeValue::builder()
.data_type("String")
.string_value(metadata_json)
.build()
.map_err(|e| {
TaskExecutionError::BackendError(format!("Message attribute error: {}", e))
})?,
)
.send()
.await
.map_err(|e| TaskExecutionError::BackendError(format!("SQS send error: {}", e)))?;
Ok(task_id)
}
async fn dequeue(&self) -> Result<Option<TaskId>, TaskExecutionError> {
let result = self
.client
.receive_message()
.queue_url(&self.config.queue_url)
.max_number_of_messages(self.config.max_messages)
.visibility_timeout(self.config.visibility_timeout)
.wait_time_seconds(self.config.wait_time_seconds)
.message_attribute_names("All")
.send()
.await
.map_err(|e| TaskExecutionError::BackendError(format!("SQS receive error: {}", e)))?;
if let Some(messages) = result.messages
&& let Some(message) = messages.into_iter().next()
{
if let Some(attributes) = message.message_attributes
&& let Some(task_id_attr) = attributes.get("task_id")
&& let Some(task_id_str) = task_id_attr.string_value()
{
let task_id = task_id_str
.parse()
.map_err(|e: uuid::Error| TaskExecutionError::BackendError(e.to_string()))?;
if let Some(receipt_handle) = message.receipt_handle {
let mut handles = self.receipt_handles.write().await;
handles.insert(task_id, receipt_handle);
}
self.update_status(task_id, TaskStatus::Running).await?;
return Ok(Some(task_id));
}
}
Ok(None)
}
async fn get_status(&self, task_id: TaskId) -> Result<TaskStatus, TaskExecutionError> {
let store = self.metadata_store.read().await;
let metadata = store
.get(&task_id)
.ok_or(TaskExecutionError::NotFound(task_id))?;
Ok(metadata.status)
}
async fn update_status(
&self,
task_id: TaskId,
status: TaskStatus,
) -> Result<(), TaskExecutionError> {
{
let mut store = self.metadata_store.write().await;
let metadata = store
.get_mut(&task_id)
.ok_or(TaskExecutionError::NotFound(task_id))?;
metadata.status = status;
metadata.updated_at = chrono::Utc::now().timestamp();
}
if matches!(status, TaskStatus::Success | TaskStatus::Failure) {
let receipt_handle = {
let mut handles = self.receipt_handles.write().await;
handles.remove(&task_id)
};
if let Some(receipt_handle) = receipt_handle {
self.delete_message(&receipt_handle).await?;
}
}
Ok(())
}
async fn get_task_data(
&self,
task_id: TaskId,
) -> Result<Option<SerializedTask>, TaskExecutionError> {
let store = self.metadata_store.read().await;
if let Some(metadata) = store.get(&task_id) {
Ok(Some(SerializedTask::new(
metadata.name.clone(),
"{}".to_string(),
)))
} else {
Ok(None)
}
}
fn backend_name(&self) -> &str {
"sqs"
}
}
pub struct SqsResultBackend {
results: Arc<RwLock<HashMap<TaskId, TaskResultMetadata>>>,
}
impl SqsResultBackend {
pub fn new() -> Self {
Self {
results: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for SqsResultBackend {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ResultBackend for SqsResultBackend {
async fn store_result(&self, metadata: TaskResultMetadata) -> Result<(), TaskExecutionError> {
let mut results = self.results.write().await;
results.insert(metadata.task_id(), metadata);
Ok(())
}
async fn get_result(
&self,
task_id: TaskId,
) -> Result<Option<TaskResultMetadata>, TaskExecutionError> {
let results = self.results.read().await;
Ok(results.get(&task_id).cloned())
}
async fn delete_result(&self, task_id: TaskId) -> Result<(), TaskExecutionError> {
let mut results = self.results.write().await;
results.remove(&task_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn test_sqs_config_creation() {
let config = SqsConfig::new("https://sqs.us-east-1.amazonaws.com/123456789012/my-queue");
assert_eq!(
config.queue_url,
"https://sqs.us-east-1.amazonaws.com/123456789012/my-queue"
);
assert_eq!(config.visibility_timeout, 30);
assert_eq!(config.max_messages, 1);
}
#[rstest]
fn test_sqs_config_with_options() {
let config = SqsConfig::new("https://sqs.us-east-1.amazonaws.com/123456789012/my-queue")
.with_visibility_timeout(60)
.with_max_messages(5)
.with_wait_time_seconds(20);
assert_eq!(config.visibility_timeout, 60);
assert_eq!(config.max_messages, 5);
assert_eq!(config.wait_time_seconds, 20);
}
#[rstest]
fn test_sqs_config_max_messages_limit() {
let config = SqsConfig::new("https://sqs.us-east-1.amazonaws.com/123456789012/my-queue")
.with_max_messages(15);
assert_eq!(config.max_messages, 10);
}
#[rstest]
#[tokio::test]
async fn test_sqs_result_backend_store_and_retrieve() {
let backend = SqsResultBackend::new();
let task_id = TaskId::new();
let metadata = TaskResultMetadata::new(
task_id,
TaskStatus::Success,
Some("Test result".to_string()),
);
backend
.store_result(metadata.clone())
.await
.expect("Failed to store result");
let retrieved = backend
.get_result(task_id)
.await
.expect("Failed to get result");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().result(), Some("Test result"));
}
#[rstest]
#[tokio::test]
async fn test_sqs_result_backend_delete() {
let backend = SqsResultBackend::new();
let task_id = TaskId::new();
let metadata = TaskResultMetadata::new(task_id, TaskStatus::Success, None);
backend
.store_result(metadata)
.await
.expect("Failed to store result");
backend
.delete_result(task_id)
.await
.expect("Failed to delete result");
let retrieved = backend
.get_result(task_id)
.await
.expect("Failed to get result");
assert!(retrieved.is_none());
}
#[rstest]
#[tokio::test]
async fn test_sqs_result_backend_concurrent_reads_during_write() {
let backend = Arc::new(SqsResultBackend::new());
let task_ids: Vec<TaskId> = (0..10).map(|_| TaskId::new()).collect();
for &id in &task_ids {
let meta = TaskResultMetadata::new(id, TaskStatus::Success, Some("ok".to_string()));
backend.store_result(meta).await.unwrap();
}
let mut read_handles = vec![];
for &id in &task_ids {
let backend_clone = Arc::clone(&backend);
read_handles.push(tokio::spawn(async move {
backend_clone.get_result(id).await.unwrap()
}));
}
let new_id = TaskId::new();
let writer = {
let backend_clone = Arc::clone(&backend);
tokio::spawn(async move {
let meta = TaskResultMetadata::new(
new_id,
TaskStatus::Success,
Some("concurrent".to_string()),
);
backend_clone.store_result(meta).await.unwrap();
})
};
for handle in read_handles {
assert!(handle.await.unwrap().is_some());
}
writer.await.unwrap();
let result = backend.get_result(new_id).await.unwrap();
assert_eq!(result.unwrap().result(), Some("concurrent"));
}
#[rstest]
#[tokio::test]
async fn test_update_status_releases_write_lock_before_cleanup() {
let backend = Arc::new(SqsResultBackend::new());
let task_id = TaskId::new();
backend
.store_result(TaskResultMetadata::new(
task_id,
TaskStatus::Success,
Some("data".to_string()),
))
.await
.unwrap();
let backend_write = Arc::clone(&backend);
let backend_read = Arc::clone(&backend);
let write =
tokio::spawn(async move { backend_write.delete_result(task_id).await.unwrap() });
let other_id = TaskId::new();
backend
.store_result(TaskResultMetadata::new(
other_id,
TaskStatus::Success,
Some("other".to_string()),
))
.await
.unwrap();
let read = tokio::spawn(async move { backend_read.get_result(other_id).await.unwrap() });
write.await.unwrap();
let read_result = read.await.unwrap();
assert!(read_result.is_some());
}
}