use crate::error::{WorkflowError, WorkflowResult};
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct HandlerContext {
context: Value,
secret: Option<Value>,
workflow: Value,
authorization: Option<Value>,
}
impl HandlerContext {
pub(crate) fn from_vars(vars: &std::collections::HashMap<String, Value>) -> Self {
Self {
context: vars
.get(crate::context::vars::CONTEXT)
.cloned()
.unwrap_or(Value::Null),
secret: vars.get(crate::context::vars::SECRET).cloned(),
workflow: vars
.get(crate::context::vars::WORKFLOW)
.cloned()
.unwrap_or(Value::Null),
authorization: vars.get(crate::context::vars::AUTHORIZATION).cloned(),
}
}
pub fn context(&self) -> &Value {
&self.context
}
pub fn secret(&self) -> Option<&Value> {
self.secret.as_ref()
}
pub fn workflow(&self) -> &Value {
&self.workflow
}
pub fn authorization(&self) -> Option<&Value> {
self.authorization.as_ref()
}
}
#[async_trait::async_trait]
pub trait CallHandler: Send + Sync {
fn call_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
call_config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value>;
}
#[async_trait::async_trait]
pub trait RunHandler: Send + Sync {
fn run_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
run_config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value>;
}
#[async_trait::async_trait]
pub trait CustomTaskHandler: Send + Sync {
fn task_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
task_type: &str,
task_config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value>;
}
#[async_trait::async_trait]
pub trait TypedCustomTaskHandler: Send + Sync + 'static {
type Config: DeserializeOwned + Send + Sync + 'static;
fn task_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
config: &Self::Config,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value>;
fn into_boxed(self) -> Box<dyn CustomTaskHandler>
where
Self: Sized,
{
Box::new(TypedHandlerWrapper(self))
}
}
struct TypedHandlerWrapper<H: TypedCustomTaskHandler>(H);
#[async_trait::async_trait]
impl<H: TypedCustomTaskHandler> CustomTaskHandler for TypedHandlerWrapper<H> {
fn task_type(&self) -> &str {
self.0.task_type()
}
async fn handle(
&self,
task_name: &str,
_task_type: &str,
task_config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value> {
let config: H::Config = serde_json::from_value(task_config.clone()).map_err(|e| {
WorkflowError::validation(
format!("invalid config for '{}' task: {}", self.0.task_type(), e),
task_name,
)
})?;
self.0.handle(task_name, &config, input, context).await
}
}
#[async_trait::async_trait]
pub trait TaskHandler: Send + Sync {
fn handler_type(&self) -> &str;
async fn handle(
&self,
task_name: &str,
config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value>;
}
struct CallHandlerAdapter(std::sync::Arc<dyn CallHandler>);
#[async_trait::async_trait]
impl TaskHandler for CallHandlerAdapter {
fn handler_type(&self) -> &str {
self.0.call_type()
}
async fn handle(
&self,
task_name: &str,
config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value> {
self.0.handle(task_name, config, input, context).await
}
}
struct RunHandlerAdapter(std::sync::Arc<dyn RunHandler>);
#[async_trait::async_trait]
impl TaskHandler for RunHandlerAdapter {
fn handler_type(&self) -> &str {
self.0.run_type()
}
async fn handle(
&self,
task_name: &str,
config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value> {
self.0.handle(task_name, config, input, context).await
}
}
struct CustomHandlerAdapter(std::sync::Arc<dyn CustomTaskHandler>);
#[async_trait::async_trait]
impl TaskHandler for CustomHandlerAdapter {
fn handler_type(&self) -> &str {
self.0.task_type()
}
async fn handle(
&self,
task_name: &str,
config: &Value,
input: &Value,
context: &HandlerContext,
) -> WorkflowResult<Value> {
self.0
.handle(task_name, self.0.task_type(), config, input, context)
.await
}
}
#[derive(Default, Clone)]
pub struct HandlerRegistry {
handlers: std::sync::Arc<std::collections::HashMap<String, std::sync::Arc<dyn TaskHandler>>>,
call_handlers:
std::sync::Arc<std::collections::HashMap<String, std::sync::Arc<dyn CallHandler>>>,
run_handlers: std::sync::Arc<std::collections::HashMap<String, std::sync::Arc<dyn RunHandler>>>,
custom_task_handlers:
std::sync::Arc<std::collections::HashMap<String, std::sync::Arc<dyn CustomTaskHandler>>>,
}
impl HandlerRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_handler(&mut self, handler: std::sync::Arc<dyn TaskHandler>) {
let key = handler.handler_type().to_string();
Arc::make_mut(&mut self.handlers).insert(key, handler);
}
pub fn register_call_handler(&mut self, handler: Box<dyn CallHandler>) {
let key = handler.call_type().to_string();
let arc: std::sync::Arc<dyn CallHandler> = std::sync::Arc::from(handler);
Arc::make_mut(&mut self.handlers).insert(
key.clone(),
std::sync::Arc::new(CallHandlerAdapter(arc.clone())),
);
Arc::make_mut(&mut self.call_handlers).insert(key, arc);
}
pub fn register_run_handler(&mut self, handler: Box<dyn RunHandler>) {
let key = handler.run_type().to_string();
let arc: std::sync::Arc<dyn RunHandler> = std::sync::Arc::from(handler);
Arc::make_mut(&mut self.handlers).insert(
key.clone(),
std::sync::Arc::new(RunHandlerAdapter(arc.clone())),
);
Arc::make_mut(&mut self.run_handlers).insert(key, arc);
}
pub fn register_custom_task_handler(&mut self, handler: Box<dyn CustomTaskHandler>) {
let key = handler.task_type().to_string();
let arc: std::sync::Arc<dyn CustomTaskHandler> = std::sync::Arc::from(handler);
Arc::make_mut(&mut self.handlers).insert(
key.clone(),
std::sync::Arc::new(CustomHandlerAdapter(arc.clone())),
);
Arc::make_mut(&mut self.custom_task_handlers).insert(key, arc);
}
pub fn get_handler(&self, handler_type: &str) -> Option<std::sync::Arc<dyn TaskHandler>> {
self.handlers.get(handler_type).cloned()
}
pub fn get_call_handler(&self, call_type: &str) -> Option<std::sync::Arc<dyn CallHandler>> {
self.call_handlers.get(call_type).cloned()
}
pub fn get_run_handler(&self, run_type: &str) -> Option<std::sync::Arc<dyn RunHandler>> {
self.run_handlers.get(run_type).cloned()
}
pub fn get_custom_task_handler(
&self,
task_type: &str,
) -> Option<std::sync::Arc<dyn CustomTaskHandler>> {
self.custom_task_handlers.get(task_type).cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Deserialize)]
struct TestConfig {
name: String,
#[serde(default)]
count: u32,
}
struct TestTypedHandler;
#[async_trait::async_trait]
impl TypedCustomTaskHandler for TestTypedHandler {
type Config = TestConfig;
fn task_type(&self) -> &str {
"test_typed"
}
async fn handle(
&self,
_task_name: &str,
config: &TestConfig,
_input: &Value,
_context: &HandlerContext,
) -> WorkflowResult<Value> {
Ok(serde_json::json!({
"name": config.name,
"count": config.count,
}))
}
}
#[tokio::test]
async fn test_typed_handler_wrapper() {
let handler = TestTypedHandler.into_boxed();
let ctx = HandlerContext::from_vars(&std::collections::HashMap::new());
let config = serde_json::json!({ "name": "hello", "count": 42 });
let result = handler
.handle("task1", "test_typed", &config, &serde_json::json!({}), &ctx)
.await
.unwrap();
assert_eq!(result["name"], "hello");
assert_eq!(result["count"], 42);
}
#[tokio::test]
async fn test_typed_handler_invalid_config_returns_validation_error() {
let handler = TestTypedHandler.into_boxed();
let ctx = HandlerContext::from_vars(&std::collections::HashMap::new());
let bad_config = serde_json::json!({ "count": 5 });
let result = handler
.handle(
"task1",
"test_typed",
&bad_config,
&serde_json::json!({}),
&ctx,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("invalid config"));
assert!(err.to_string().contains("test_typed"));
}
#[tokio::test]
async fn test_typed_handler_register_in_registry() {
let mut registry = HandlerRegistry::new();
registry.register_custom_task_handler(TestTypedHandler.into_boxed());
let handler = registry.get_custom_task_handler("test_typed");
assert!(handler.is_some());
let ctx = HandlerContext::from_vars(&std::collections::HashMap::new());
let config = serde_json::json!({ "name": "world" });
let result = handler
.unwrap()
.handle("task1", "test_typed", &config, &serde_json::json!({}), &ctx)
.await
.unwrap();
assert_eq!(result["name"], "world");
assert_eq!(result["count"], 0);
}
}