forge_runtime/workflow/
registry.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::workflow::{ForgeWorkflow, WorkflowContext, WorkflowInfo};
7use serde_json::Value;
8
9fn normalize_args(args: Value) -> Value {
13 let unwrapped = match &args {
14 Value::Object(map) if map.len() == 1 => {
15 if map.contains_key("args") {
16 map.get("args").cloned().unwrap_or(Value::Null)
17 } else if map.contains_key("input") {
18 map.get("input").cloned().unwrap_or(Value::Null)
19 } else {
20 args
21 }
22 }
23 _ => args,
24 };
25
26 match &unwrapped {
27 Value::Object(map) if map.is_empty() => Value::Null,
28 _ => unwrapped,
29 }
30}
31
32pub type BoxedWorkflowHandler = Arc<
34 dyn Fn(
35 &WorkflowContext,
36 serde_json::Value,
37 )
38 -> Pin<Box<dyn Future<Output = forge_core::Result<serde_json::Value>> + Send + '_>>
39 + Send
40 + Sync,
41>;
42
43pub struct WorkflowEntry {
45 pub info: WorkflowInfo,
47 pub handler: BoxedWorkflowHandler,
49}
50
51impl WorkflowEntry {
52 pub fn new<W: ForgeWorkflow>() -> Self
54 where
55 W::Input: serde::de::DeserializeOwned,
56 W::Output: serde::Serialize,
57 {
58 Self {
59 info: W::info(),
60 handler: Arc::new(|ctx, input| {
61 Box::pin(async move {
62 let typed_input: W::Input = serde_json::from_value(normalize_args(input))
63 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
64 let result = W::execute(ctx, typed_input).await?;
65 serde_json::to_value(result).map_err(forge_core::ForgeError::from)
66 })
67 }),
68 }
69 }
70}
71
72#[derive(Default)]
74pub struct WorkflowRegistry {
75 workflows: HashMap<String, WorkflowEntry>,
76}
77
78impl WorkflowRegistry {
79 pub fn new() -> Self {
81 Self {
82 workflows: HashMap::new(),
83 }
84 }
85
86 pub fn register<W: ForgeWorkflow>(&mut self)
88 where
89 W::Input: serde::de::DeserializeOwned,
90 W::Output: serde::Serialize,
91 {
92 let entry = WorkflowEntry::new::<W>();
93 self.workflows.insert(entry.info.name.to_string(), entry);
94 }
95
96 pub fn get(&self, name: &str) -> Option<&WorkflowEntry> {
98 self.workflows.get(name)
99 }
100
101 pub fn get_version(&self, name: &str, version: u32) -> Option<&WorkflowEntry> {
103 self.workflows
104 .get(name)
105 .filter(|e| e.info.version == version)
106 }
107
108 pub fn list(&self) -> Vec<&WorkflowEntry> {
110 self.workflows.values().collect()
111 }
112
113 pub fn len(&self) -> usize {
115 self.workflows.len()
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.workflows.is_empty()
121 }
122
123 pub fn names(&self) -> Vec<&str> {
125 self.workflows.keys().map(|s| s.as_str()).collect()
126 }
127}
128
129impl Clone for WorkflowRegistry {
130 fn clone(&self) -> Self {
131 Self {
132 workflows: self
133 .workflows
134 .iter()
135 .map(|(k, v)| {
136 (
137 k.clone(),
138 WorkflowEntry {
139 info: v.info.clone(),
140 handler: v.handler.clone(),
141 },
142 )
143 })
144 .collect(),
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_empty_registry() {
155 let registry = WorkflowRegistry::new();
156 assert!(registry.is_empty());
157 assert_eq!(registry.len(), 0);
158 }
159}