async_taskex/
task_executor.rs

1use async_std::sync::{Arc, Mutex};
2use async_trait::async_trait;
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7pub trait SharedData: Send + Sync + 'static {}
8
9#[derive(Serialize, Deserialize)]
10pub struct TaskMessage {
11    pub task_name: String,
12    pub payload: String,
13}
14
15#[derive(Serialize, Deserialize)]
16pub struct Response {
17    pub success: bool,
18    pub result: Option<String>,
19    pub error: Option<String>,
20}
21
22// definition of the interface for handling tasks.
23#[async_trait]
24pub trait TaskHandler<S: SharedData>: Send + Sync {
25    async fn handle(
26        &self,
27        task_message: TaskMessage,
28        shared_data: Arc<S>,
29    ) -> Result<Response, String>;
30
31    async fn authorize(&self, _shared_data: &Arc<S>) -> bool {
32        true
33    }
34}
35
36// The TaskExecutor struct manages a collection of TaskHandler implementations.
37pub struct TaskExecutor<S: SharedData> {
38    task_handlers: Arc<Mutex<HashMap<String, Arc<dyn TaskHandler<S>>>>>,
39    shared_data: Arc<S>,
40}
41
42impl<S: SharedData> TaskExecutor<S> {
43    pub fn new(shared_data: Arc<S>) -> Self {
44        TaskExecutor {
45            task_handlers: Arc::new(Mutex::new(HashMap::new())),
46            shared_data,
47        }
48    }
49
50    pub async fn register(&self, task_name: String, handler: Arc<dyn TaskHandler<S>>) {
51        self.task_handlers.lock().await.insert(task_name, handler);
52    }
53
54    pub async fn execute(&self, task_message: TaskMessage) -> Result<Response, String> {
55        let task_name = &task_message.task_name;
56        let handlers = self.task_handlers.lock().await;
57
58        let handler = match handlers.get(task_name) {
59            Some(handler) => {
60                if handler.authorize(&self.shared_data).await {
61                    handler.clone()
62                } else {
63                    return Err("Unauthorized".to_string());
64                }
65            }
66            None => return Err("task not found".to_string()),
67        };
68
69        // explicitly drop the MutexGuard.
70        drop(handlers);
71
72        let shared_data = self.shared_data.clone();
73        let handle_futere = async move { handler.handle(task_message, shared_data).await };
74        async_std::task::spawn(handle_futere).await
75    }
76}