use crate::{TaskError, TaskExecutor, TaskResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedTask {
name: String,
data: String,
}
impl SerializedTask {
pub fn new(name: String, data: String) -> Self {
Self { name, data }
}
pub fn name(&self) -> &str {
&self.name
}
pub fn data(&self) -> &str {
&self.data
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(self)
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
}
#[async_trait]
pub trait TaskFactory: Send + Sync {
async fn create(&self, data: &str) -> TaskResult<Box<dyn TaskExecutor>>;
}
pub struct TaskRegistry {
factories: Arc<RwLock<HashMap<String, Arc<dyn TaskFactory>>>>,
}
impl TaskRegistry {
pub fn new() -> Self {
Self {
factories: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, name: String, factory: Arc<dyn TaskFactory>) {
let mut factories = self.factories.write().await;
factories.insert(name, factory);
}
pub async fn unregister(&self, name: &str) {
let mut factories = self.factories.write().await;
factories.remove(name);
}
pub async fn has(&self, name: &str) -> bool {
let factories = self.factories.read().await;
factories.contains_key(name)
}
pub async fn create(&self, name: &str, data: &str) -> TaskResult<Box<dyn TaskExecutor>> {
let factories = self.factories.read().await;
let factory = factories
.get(name)
.ok_or_else(|| TaskError::ExecutionFailed(format!("Task not registered: {}", name)))?;
factory.create(data).await
}
pub async fn list(&self) -> Vec<String> {
let factories = self.factories.read().await;
factories.keys().cloned().collect()
}
pub async fn clear(&self) {
let mut factories = self.factories.write().await;
factories.clear();
}
}
impl Default for TaskRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Task, TaskId, TaskPriority};
struct TestTask {
id: TaskId,
}
impl Task for TestTask {
fn id(&self) -> TaskId {
self.id
}
fn name(&self) -> &str {
"test_task"
}
fn priority(&self) -> TaskPriority {
TaskPriority::default()
}
}
#[async_trait]
impl TaskExecutor for TestTask {
async fn execute(&self) -> TaskResult<()> {
Ok(())
}
}
struct TestTaskFactory;
#[async_trait]
impl TaskFactory for TestTaskFactory {
async fn create(&self, _data: &str) -> TaskResult<Box<dyn TaskExecutor>> {
Ok(Box::new(TestTask { id: TaskId::new() }))
}
}
#[test]
fn test_serialized_task() {
let task = SerializedTask::new("test".to_string(), r#"{"key":"value"}"#.to_string());
assert_eq!(task.name(), "test");
assert_eq!(task.data(), r#"{"key":"value"}"#);
}
#[test]
fn test_serialized_task_json() {
let task = SerializedTask::new("test".to_string(), "{}".to_string());
let json = task.to_json().unwrap();
let restored = SerializedTask::from_json(&json).unwrap();
assert_eq!(restored.name(), "test");
}
#[tokio::test]
async fn test_registry_register_and_has() {
let registry = TaskRegistry::new();
let factory = Arc::new(TestTaskFactory);
assert!(!registry.has("test_task").await);
registry.register("test_task".to_string(), factory).await;
assert!(registry.has("test_task").await);
}
#[tokio::test]
async fn test_registry_unregister() {
let registry = TaskRegistry::new();
let factory = Arc::new(TestTaskFactory);
registry.register("test_task".to_string(), factory).await;
assert!(registry.has("test_task").await);
registry.unregister("test_task").await;
assert!(!registry.has("test_task").await);
}
#[tokio::test]
async fn test_registry_create() {
let registry = TaskRegistry::new();
let factory = Arc::new(TestTaskFactory);
registry.register("test_task".to_string(), factory).await;
let executor = registry.create("test_task", "{}").await;
assert!(executor.is_ok());
}
#[tokio::test]
async fn test_registry_create_not_found() {
let registry = TaskRegistry::new();
let result = registry.create("nonexistent", "{}").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_registry_list() {
let registry = TaskRegistry::new();
let factory = Arc::new(TestTaskFactory);
registry
.register("task1".to_string(), factory.clone())
.await;
registry.register("task2".to_string(), factory).await;
let names = registry.list().await;
assert_eq!(names.len(), 2);
assert!(names.contains(&"task1".to_string()));
assert!(names.contains(&"task2".to_string()));
}
#[tokio::test]
async fn test_registry_clear() {
let registry = TaskRegistry::new();
let factory = Arc::new(TestTaskFactory);
registry.register("task1".to_string(), factory).await;
assert!(registry.has("task1").await);
registry.clear().await;
assert!(!registry.has("task1").await);
}
}