use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use crate::graph::persistence::error::PersistenceError;
#[async_trait]
pub trait Task: Send + Sync {
async fn execute(&self, input: Value) -> Result<Value, TaskError>;
fn cache_key(&self, input: &Value) -> String;
fn task_id(&self) -> &str;
}
#[derive(thiserror::Error, Debug)]
pub enum TaskError {
#[error("Task execution error: {0}")]
ExecutionError(String),
#[error("Task cache error: {0}")]
CacheError(String),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
#[error("Persistence error: {0}")]
PersistenceError(#[from] PersistenceError),
}
pub type TaskResult<T> = Result<T, TaskError>;
pub struct FunctionTask<F> {
task_id: String,
func: Arc<F>,
}
impl<F> FunctionTask<F>
where
F: Fn(
Value,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, TaskError>> + Send>>
+ Send
+ Sync
+ 'static,
{
pub fn new(task_id: impl Into<String>, func: F) -> Self {
Self {
task_id: task_id.into(),
func: Arc::new(func),
}
}
}
#[async_trait]
impl<F> Task for FunctionTask<F>
where
F: Fn(
Value,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, TaskError>> + Send>>
+ Send
+ Sync
+ 'static,
{
async fn execute(&self, input: Value) -> Result<Value, TaskError> {
(self.func)(input).await
}
fn cache_key(&self, input: &Value) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.task_id.hash(&mut hasher);
serde_json::to_string(input)
.unwrap_or_default()
.hash(&mut hasher);
format!("task:{}:{}", self.task_id, hasher.finish())
}
fn task_id(&self) -> &str {
&self.task_id
}
}
pub type TaskBox = Arc<dyn Task>;