celers_core/
executor.rs

1#![allow(clippy::missing_errors_doc)]
2use crate::{Result, SerializedTask, Task};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7/// Type-erased task handler
8type TaskHandler = Arc<
9    dyn Fn(Vec<u8>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>
10        + Send
11        + Sync,
12>;
13
14/// Registry for mapping task names to their implementations
15pub struct TaskRegistry {
16    handlers: Arc<RwLock<HashMap<String, TaskHandler>>>,
17}
18
19impl TaskRegistry {
20    #[inline]
21    #[must_use]
22    pub fn new() -> Self {
23        Self {
24            handlers: Arc::new(RwLock::new(HashMap::new())),
25        }
26    }
27
28    /// Register a task type with the registry
29    pub async fn register<T>(&self, task: T)
30    where
31        T: Task + 'static,
32    {
33        let task_name = task.name().to_string();
34        let task = Arc::new(task);
35
36        let handler: TaskHandler = Arc::new(move |payload: Vec<u8>| {
37            let task = Arc::clone(&task);
38            Box::pin(async move {
39                // Deserialize input
40                let input: T::Input = serde_json::from_slice(&payload)
41                    .map_err(|e| crate::CelersError::Deserialization(e.to_string()))?;
42
43                // Execute task
44                let output = task.execute(input).await?;
45
46                // Serialize output
47                let output_bytes = serde_json::to_vec(&output)
48                    .map_err(|e| crate::CelersError::Serialization(e.to_string()))?;
49
50                Ok(output_bytes)
51            })
52        });
53
54        self.handlers.write().await.insert(task_name, handler);
55    }
56
57    /// Execute a task by name
58    pub async fn execute(&self, task: &SerializedTask) -> Result<Vec<u8>> {
59        let handlers = self.handlers.read().await;
60
61        let handler = handlers.get(&task.metadata.name).ok_or_else(|| {
62            crate::CelersError::TaskExecution(format!(
63                "Task not found in registry: {}",
64                task.metadata.name
65            ))
66        })?;
67
68        handler(task.payload.clone()).await
69    }
70
71    /// Check if a task is registered
72    pub async fn has_task(&self, name: &str) -> bool {
73        self.handlers.read().await.contains_key(name)
74    }
75
76    /// List all registered task names
77    pub async fn list_tasks(&self) -> Vec<String> {
78        self.handlers.read().await.keys().cloned().collect()
79    }
80}
81
82impl Default for TaskRegistry {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::Task;
92    use serde::{Deserialize, Serialize};
93
94    #[derive(Serialize, Deserialize)]
95    struct AddInput {
96        a: i32,
97        b: i32,
98    }
99
100    #[derive(Serialize, Deserialize, PartialEq, Debug)]
101    struct AddOutput {
102        result: i32,
103    }
104
105    struct AddTask;
106
107    #[async_trait::async_trait]
108    impl Task for AddTask {
109        type Input = AddInput;
110        type Output = AddOutput;
111
112        async fn execute(&self, input: Self::Input) -> Result<Self::Output> {
113            Ok(AddOutput {
114                result: input.a + input.b,
115            })
116        }
117
118        fn name(&self) -> &'static str {
119            "add"
120        }
121    }
122
123    #[tokio::test]
124    async fn test_registry() {
125        let registry = TaskRegistry::new();
126
127        // Register task
128        registry.register(AddTask).await;
129
130        // Check registration
131        assert!(registry.has_task("add").await);
132
133        // Create task
134        let task = SerializedTask {
135            metadata: crate::TaskMetadata::new("add".to_string()),
136            payload: serde_json::to_vec(&AddInput { a: 2, b: 3 }).unwrap(),
137        };
138
139        // Execute
140        let result = registry.execute(&task).await.unwrap();
141        let output: AddOutput = serde_json::from_slice(&result).unwrap();
142
143        assert_eq!(output.result, 5);
144    }
145}