use crate::error::WorkflowResult;
use serde_json::Value;
use std::sync::Arc;
#[async_trait::async_trait]
pub trait CallHandler: Send + Sync {
fn call_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
call_config: &Value,
input: &Value,
) -> WorkflowResult<Value>;
}
#[async_trait::async_trait]
pub trait RunHandler: Send + Sync {
fn run_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
run_config: &Value,
input: &Value,
) -> WorkflowResult<Value>;
}
#[async_trait::async_trait]
pub trait CustomTaskHandler: Send + Sync {
fn task_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
task_type: &str,
task_config: &Value,
input: &Value,
) -> WorkflowResult<Value>;
}
#[derive(Default, Clone)]
pub struct HandlerRegistry {
call_handlers:
std::sync::Arc<std::collections::HashMap<String, std::sync::Arc<dyn CallHandler>>>,
run_handlers: std::sync::Arc<std::collections::HashMap<String, std::sync::Arc<dyn RunHandler>>>,
custom_task_handlers:
std::sync::Arc<std::collections::HashMap<String, std::sync::Arc<dyn CustomTaskHandler>>>,
}
impl HandlerRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_call_handler(&mut self, handler: Box<dyn CallHandler>) {
let key = handler.call_type().to_string();
Arc::make_mut(&mut self.call_handlers).insert(key, std::sync::Arc::from(handler));
}
pub fn register_run_handler(&mut self, handler: Box<dyn RunHandler>) {
let key = handler.run_type().to_string();
Arc::make_mut(&mut self.run_handlers).insert(key, std::sync::Arc::from(handler));
}
pub fn register_custom_task_handler(&mut self, handler: Box<dyn CustomTaskHandler>) {
let key = handler.task_type().to_string();
Arc::make_mut(&mut self.custom_task_handlers).insert(key, std::sync::Arc::from(handler));
}
pub fn get_call_handler(&self, call_type: &str) -> Option<std::sync::Arc<dyn CallHandler>> {
self.call_handlers.get(call_type).cloned()
}
pub fn get_run_handler(&self, run_type: &str) -> Option<std::sync::Arc<dyn RunHandler>> {
self.run_handlers.get(run_type).cloned()
}
pub fn get_custom_task_handler(
&self,
task_type: &str,
) -> Option<std::sync::Arc<dyn CustomTaskHandler>> {
self.custom_task_handlers.get(task_type).cloned()
}
}