forge_runtime/workflow/
registry.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::workflow::{ForgeWorkflow, WorkflowContext, WorkflowInfo};
7use serde_json::Value;
8
9fn normalize_args(args: Value) -> Value {
13 let unwrapped = match &args {
14 Value::Object(map) if map.len() == 1 => {
15 if map.contains_key("args") {
16 map.get("args").cloned().unwrap_or(Value::Null)
17 } else if map.contains_key("input") {
18 map.get("input").cloned().unwrap_or(Value::Null)
19 } else {
20 args
21 }
22 }
23 _ => args,
24 };
25
26 match &unwrapped {
27 Value::Null => Value::Object(serde_json::Map::new()),
28 _ => unwrapped,
29 }
30}
31
32pub type BoxedWorkflowHandler = Arc<
34 dyn Fn(
35 &WorkflowContext,
36 serde_json::Value,
37 )
38 -> Pin<Box<dyn Future<Output = forge_core::Result<serde_json::Value>> + Send + '_>>
39 + Send
40 + Sync,
41>;
42
43pub struct WorkflowEntry {
45 pub info: WorkflowInfo,
47 pub handler: BoxedWorkflowHandler,
49}
50
51impl WorkflowEntry {
52 pub fn new<W: ForgeWorkflow>() -> Self
54 where
55 W::Input: serde::de::DeserializeOwned,
56 W::Output: serde::Serialize,
57 {
58 Self {
59 info: W::info(),
60 handler: Arc::new(|ctx, input| {
61 Box::pin(async move {
62 let typed_input: W::Input = serde_json::from_value(normalize_args(input))
63 .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
64 let result = W::execute(ctx, typed_input).await?;
65 serde_json::to_value(result).map_err(forge_core::ForgeError::from)
66 })
67 }),
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74pub struct WorkflowVersionKey {
75 pub name: String,
76 pub version: String,
77}
78
79#[derive(Default)]
81pub struct WorkflowRegistry {
82 entries: HashMap<WorkflowVersionKey, WorkflowEntry>,
84 active_versions: HashMap<String, String>,
86}
87
88impl WorkflowRegistry {
89 pub fn new() -> Self {
91 Self {
92 entries: HashMap::new(),
93 active_versions: HashMap::new(),
94 }
95 }
96
97 pub fn register<W: ForgeWorkflow>(&mut self)
99 where
100 W::Input: serde::de::DeserializeOwned,
101 W::Output: serde::Serialize,
102 {
103 let entry = WorkflowEntry::new::<W>();
104 let info = &entry.info;
105
106 if info.is_active {
107 self.active_versions
108 .insert(info.name.to_string(), info.version.to_string());
109 }
110
111 let key = WorkflowVersionKey {
112 name: info.name.to_string(),
113 version: info.version.to_string(),
114 };
115 self.entries.insert(key, entry);
116 }
117
118 pub fn get_active(&self, name: &str) -> Option<&WorkflowEntry> {
121 let version = self.active_versions.get(name)?;
122 let key = WorkflowVersionKey {
123 name: name.to_string(),
124 version: version.clone(),
125 };
126 self.entries.get(&key)
127 }
128
129 pub fn get_version(&self, name: &str, version: &str) -> Option<&WorkflowEntry> {
132 let key = WorkflowVersionKey {
133 name: name.to_string(),
134 version: version.to_string(),
135 };
136 self.entries.get(&key)
137 }
138
139 pub fn get(&self, name: &str) -> Option<&WorkflowEntry> {
142 self.get_active(name)
143 }
144
145 pub fn has_version_with_signature(&self, name: &str, version: &str, signature: &str) -> bool {
147 self.get_version(name, version)
148 .is_some_and(|entry| entry.info.signature == signature)
149 }
150
151 pub fn validate_resume(
154 &self,
155 name: &str,
156 version: &str,
157 signature: &str,
158 ) -> Result<&WorkflowEntry, ResumeBlockReason> {
159 let has_any = self.entries.keys().any(|k| k.name == name);
161 if !has_any {
162 return Err(ResumeBlockReason::MissingHandler);
163 }
164
165 let entry = self
166 .get_version(name, version)
167 .ok_or(ResumeBlockReason::MissingVersion)?;
168
169 if entry.info.signature != signature {
170 return Err(ResumeBlockReason::SignatureMismatch {
171 expected: signature.to_string(),
172 actual: entry.info.signature.to_string(),
173 });
174 }
175
176 Ok(entry)
177 }
178
179 pub fn list(&self) -> impl Iterator<Item = &WorkflowEntry> {
181 self.entries.values()
182 }
183
184 pub fn len(&self) -> usize {
186 self.entries.len()
187 }
188
189 pub fn is_empty(&self) -> bool {
191 self.entries.is_empty()
192 }
193
194 pub fn names(&self) -> impl Iterator<Item = &str> {
196 self.active_versions.keys().map(|s| s.as_str())
197 }
198
199 pub fn definitions(&self) -> Vec<&WorkflowInfo> {
201 self.entries.values().map(|e| &e.info).collect()
202 }
203}
204
205#[derive(Debug, Clone, PartialEq, Eq)]
207pub enum ResumeBlockReason {
208 MissingHandler,
210 MissingVersion,
212 SignatureMismatch { expected: String, actual: String },
214}
215
216impl ResumeBlockReason {
217 pub fn to_status(&self) -> forge_core::workflow::WorkflowStatus {
219 match self {
220 Self::MissingHandler => forge_core::workflow::WorkflowStatus::BlockedMissingHandler,
221 Self::MissingVersion => forge_core::workflow::WorkflowStatus::BlockedMissingVersion,
222 Self::SignatureMismatch { .. } => {
223 forge_core::workflow::WorkflowStatus::BlockedSignatureMismatch
224 }
225 }
226 }
227
228 pub fn description(&self) -> String {
230 match self {
231 Self::MissingHandler => "No handler registered for this workflow".to_string(),
232 Self::MissingVersion => "Workflow version not present in current binary".to_string(),
233 Self::SignatureMismatch { expected, actual } => {
234 format!("Signature mismatch: run expects {expected}, binary has {actual}")
235 }
236 }
237 }
238}
239
240impl Clone for WorkflowRegistry {
241 fn clone(&self) -> Self {
242 Self {
243 entries: self
244 .entries
245 .iter()
246 .map(|(k, v)| {
247 (
248 k.clone(),
249 WorkflowEntry {
250 info: v.info.clone(),
251 handler: v.handler.clone(),
252 },
253 )
254 })
255 .collect(),
256 active_versions: self.active_versions.clone(),
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_empty_registry() {
267 let registry = WorkflowRegistry::new();
268 assert!(registry.is_empty());
269 assert_eq!(registry.len(), 0);
270 }
271
272 #[test]
273 fn test_resume_block_reasons() {
274 let reason = ResumeBlockReason::MissingHandler;
275 assert_eq!(
276 reason.to_status(),
277 forge_core::workflow::WorkflowStatus::BlockedMissingHandler
278 );
279
280 let reason = ResumeBlockReason::SignatureMismatch {
281 expected: "abc".to_string(),
282 actual: "def".to_string(),
283 };
284 assert_eq!(
285 reason.to_status(),
286 forge_core::workflow::WorkflowStatus::BlockedSignatureMismatch
287 );
288 assert!(reason.description().contains("abc"));
289 assert!(reason.description().contains("def"));
290 }
291}