use crate::{TaskExecutionError, TaskId, TaskStatus};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskOutput {
task_id: TaskId,
result: String,
}
impl TaskOutput {
pub fn new(task_id: TaskId, result: String) -> Self {
Self { task_id, result }
}
pub fn task_id(&self) -> TaskId {
self.task_id
}
pub fn result(&self) -> &str {
&self.result
}
}
pub type TaskResult = Result<TaskOutput, String>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskResultMetadata {
task_id: TaskId,
status: TaskStatus,
result: Option<String>,
error: Option<String>,
created_at: i64,
}
impl TaskResultMetadata {
pub fn new(task_id: TaskId, status: TaskStatus, result: Option<String>) -> Self {
Self {
task_id,
status,
result,
error: None,
created_at: chrono::Utc::now().timestamp(),
}
}
pub fn with_error(task_id: TaskId, error: String) -> Self {
Self {
task_id,
status: TaskStatus::Failure,
result: None,
error: Some(error),
created_at: chrono::Utc::now().timestamp(),
}
}
pub fn set_error(&mut self, error: String) {
self.error = Some(error);
}
pub fn task_id(&self) -> TaskId {
self.task_id
}
pub fn status(&self) -> TaskStatus {
self.status
}
pub fn result(&self) -> Option<&str> {
self.result.as_deref()
}
pub fn error(&self) -> Option<&str> {
self.error.as_deref()
}
pub fn created_at(&self) -> i64 {
self.created_at
}
}
#[async_trait]
pub trait ResultBackend: Send + Sync {
async fn store_result(&self, metadata: TaskResultMetadata) -> Result<(), TaskExecutionError>;
async fn get_result(
&self,
task_id: TaskId,
) -> Result<Option<TaskResultMetadata>, TaskExecutionError>;
async fn delete_result(&self, task_id: TaskId) -> Result<(), TaskExecutionError>;
}
pub struct MemoryResultBackend {
results:
std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<TaskId, TaskResultMetadata>>>,
}
impl MemoryResultBackend {
pub fn new() -> Self {
Self {
results: std::sync::Arc::new(
tokio::sync::RwLock::new(std::collections::HashMap::new()),
),
}
}
}
impl Default for MemoryResultBackend {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ResultBackend for MemoryResultBackend {
async fn store_result(&self, metadata: TaskResultMetadata) -> Result<(), TaskExecutionError> {
let mut results = self.results.write().await;
results.insert(metadata.task_id(), metadata);
Ok(())
}
async fn get_result(
&self,
task_id: TaskId,
) -> Result<Option<TaskResultMetadata>, TaskExecutionError> {
let results = self.results.read().await;
Ok(results.get(&task_id).cloned())
}
async fn delete_result(&self, task_id: TaskId) -> Result<(), TaskExecutionError> {
let mut results = self.results.write().await;
results.remove(&task_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_output() {
let task_id = TaskId::new();
let output = TaskOutput::new(task_id, "test result".to_string());
assert_eq!(output.task_id(), task_id);
assert_eq!(output.result(), "test result");
}
#[test]
fn test_task_result_metadata() {
let task_id = TaskId::new();
let metadata =
TaskResultMetadata::new(task_id, TaskStatus::Success, Some("success".to_string()));
assert_eq!(metadata.task_id(), task_id);
assert_eq!(metadata.status(), TaskStatus::Success);
assert_eq!(metadata.result(), Some("success"));
assert_eq!(metadata.error(), None);
}
#[test]
fn test_task_result_metadata_with_error() {
let task_id = TaskId::new();
let metadata = TaskResultMetadata::with_error(task_id, "error occurred".to_string());
assert_eq!(metadata.status(), TaskStatus::Failure);
assert_eq!(metadata.error(), Some("error occurred"));
assert_eq!(metadata.result(), None);
}
#[test]
fn test_set_error_preserves_result_and_status() {
let task_id = TaskId::new();
let mut metadata =
TaskResultMetadata::new(task_id, TaskStatus::Success, Some("my result".to_string()));
metadata.set_error("something went wrong".to_string());
assert_eq!(metadata.status(), TaskStatus::Success);
assert_eq!(metadata.result(), Some("my result"));
assert_eq!(metadata.error(), Some("something went wrong"));
}
#[tokio::test]
async fn test_memory_result_backend() {
let backend = MemoryResultBackend::new();
let task_id = TaskId::new();
let metadata =
TaskResultMetadata::new(task_id, TaskStatus::Success, Some("result".to_string()));
backend.store_result(metadata.clone()).await.unwrap();
let retrieved = backend.get_result(task_id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().result(), Some("result"));
backend.delete_result(task_id).await.unwrap();
let deleted = backend.get_result(task_id).await.unwrap();
assert!(deleted.is_none());
}
}