Skip to main content

forge_runtime/workflow/
registry.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use forge_core::workflow::{ForgeWorkflow, WorkflowContext, WorkflowInfo};
7use serde_json::Value;
8
9/// Normalize args for deserialization.
10/// - Converts `null` to `{}` so both unit `()` and empty structs deserialize correctly.
11/// - Unwraps `{"args": ...}` or `{"input": ...}` wrapper if present (callers may use either format).
12fn normalize_args(args: Value) -> Value {
13    let unwrapped = match &args {
14        Value::Object(map) if map.len() == 1 => {
15            if map.contains_key("args") {
16                map.get("args").cloned().unwrap_or(Value::Null)
17            } else if map.contains_key("input") {
18                map.get("input").cloned().unwrap_or(Value::Null)
19            } else {
20                args
21            }
22        }
23        _ => args,
24    };
25
26    match &unwrapped {
27        Value::Null => Value::Object(serde_json::Map::new()),
28        _ => unwrapped,
29    }
30}
31
32/// Type alias for boxed workflow handler function.
33pub type BoxedWorkflowHandler = Arc<
34    dyn Fn(
35            &WorkflowContext,
36            serde_json::Value,
37        )
38            -> Pin<Box<dyn Future<Output = forge_core::Result<serde_json::Value>> + Send + '_>>
39        + Send
40        + Sync,
41>;
42
43/// A registered workflow entry.
44pub struct WorkflowEntry {
45    /// Workflow metadata.
46    pub info: WorkflowInfo,
47    /// Execution handler (takes serialized input, returns serialized output).
48    pub handler: BoxedWorkflowHandler,
49}
50
51impl WorkflowEntry {
52    /// Create a new workflow entry from a ForgeWorkflow implementor.
53    pub fn new<W: ForgeWorkflow>() -> Self
54    where
55        W::Input: serde::de::DeserializeOwned,
56        W::Output: serde::Serialize,
57    {
58        Self {
59            info: W::info(),
60            handler: Arc::new(|ctx, input| {
61                Box::pin(async move {
62                    let typed_input: W::Input = serde_json::from_value(normalize_args(input))
63                        .map_err(|e| forge_core::ForgeError::Validation(e.to_string()))?;
64                    let result = W::execute(ctx, typed_input).await?;
65                    serde_json::to_value(result).map_err(forge_core::ForgeError::from)
66                })
67            }),
68        }
69    }
70}
71
72/// Composite key for versioned workflow lookup.
73#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74pub struct WorkflowVersionKey {
75    pub name: String,
76    pub version: String,
77}
78
79/// Registry of all workflows, supporting multiple versions per workflow name.
80#[derive(Default)]
81pub struct WorkflowRegistry {
82    /// All entries keyed by (name, version).
83    entries: HashMap<WorkflowVersionKey, WorkflowEntry>,
84    /// Maps workflow name to its active version string.
85    active_versions: HashMap<String, String>,
86}
87
88impl WorkflowRegistry {
89    /// Create a new empty registry.
90    pub fn new() -> Self {
91        Self {
92            entries: HashMap::new(),
93            active_versions: HashMap::new(),
94        }
95    }
96
97    /// Register a workflow handler.
98    pub fn register<W: ForgeWorkflow>(&mut self)
99    where
100        W::Input: serde::de::DeserializeOwned,
101        W::Output: serde::Serialize,
102    {
103        let entry = WorkflowEntry::new::<W>();
104        let info = &entry.info;
105
106        if info.is_active {
107            self.active_versions
108                .insert(info.name.to_string(), info.version.to_string());
109        }
110
111        let key = WorkflowVersionKey {
112            name: info.name.to_string(),
113            version: info.version.to_string(),
114        };
115        self.entries.insert(key, entry);
116    }
117
118    /// Get the active version entry for a workflow by name.
119    /// Used when starting new runs.
120    pub fn get_active(&self, name: &str) -> Option<&WorkflowEntry> {
121        let version = self.active_versions.get(name)?;
122        let key = WorkflowVersionKey {
123            name: name.to_string(),
124            version: version.clone(),
125        };
126        self.entries.get(&key)
127    }
128
129    /// Get a specific workflow version.
130    /// Used when resuming runs pinned to a specific version.
131    pub fn get_version(&self, name: &str, version: &str) -> Option<&WorkflowEntry> {
132        let key = WorkflowVersionKey {
133            name: name.to_string(),
134            version: version.to_string(),
135        };
136        self.entries.get(&key)
137    }
138
139    /// Get a workflow entry by name (returns the active version).
140    /// Backward-compatible with code that only knows the workflow name.
141    pub fn get(&self, name: &str) -> Option<&WorkflowEntry> {
142        self.get_active(name)
143    }
144
145    /// Check if a specific version+signature combination is available.
146    pub fn has_version_with_signature(&self, name: &str, version: &str, signature: &str) -> bool {
147        self.get_version(name, version)
148            .is_some_and(|entry| entry.info.signature == signature)
149    }
150
151    /// Validate that a run can be safely resumed.
152    /// Returns the matching entry, or a blocking reason.
153    pub fn validate_resume(
154        &self,
155        name: &str,
156        version: &str,
157        signature: &str,
158    ) -> Result<&WorkflowEntry, ResumeBlockReason> {
159        // Check if any version of this workflow is registered
160        let has_any = self.entries.keys().any(|k| k.name == name);
161        if !has_any {
162            return Err(ResumeBlockReason::MissingHandler);
163        }
164
165        let entry = self
166            .get_version(name, version)
167            .ok_or(ResumeBlockReason::MissingVersion)?;
168
169        if entry.info.signature != signature {
170            return Err(ResumeBlockReason::SignatureMismatch {
171                expected: signature.to_string(),
172                actual: entry.info.signature.to_string(),
173            });
174        }
175
176        Ok(entry)
177    }
178
179    /// List all registered workflow entries.
180    pub fn list(&self) -> impl Iterator<Item = &WorkflowEntry> {
181        self.entries.values()
182    }
183
184    /// Get the number of registered workflow entries (all versions).
185    pub fn len(&self) -> usize {
186        self.entries.len()
187    }
188
189    /// Check if the registry is empty.
190    pub fn is_empty(&self) -> bool {
191        self.entries.is_empty()
192    }
193
194    /// Get all workflow names (deduplicated).
195    pub fn names(&self) -> impl Iterator<Item = &str> {
196        self.active_versions.keys().map(|s| s.as_str())
197    }
198
199    /// Get all registered definitions for startup persistence.
200    pub fn definitions(&self) -> Vec<&WorkflowInfo> {
201        self.entries.values().map(|e| &e.info).collect()
202    }
203}
204
205/// Reason a workflow run cannot be resumed.
206#[derive(Debug, Clone, PartialEq, Eq)]
207pub enum ResumeBlockReason {
208    /// No handler registered for this workflow name at all.
209    MissingHandler,
210    /// The specific version is not present in the current binary.
211    MissingVersion,
212    /// The version exists but its signature does not match.
213    SignatureMismatch { expected: String, actual: String },
214}
215
216impl ResumeBlockReason {
217    /// Convert to the corresponding WorkflowStatus.
218    pub fn to_status(&self) -> forge_core::workflow::WorkflowStatus {
219        match self {
220            Self::MissingHandler => forge_core::workflow::WorkflowStatus::BlockedMissingHandler,
221            Self::MissingVersion => forge_core::workflow::WorkflowStatus::BlockedMissingVersion,
222            Self::SignatureMismatch { .. } => {
223                forge_core::workflow::WorkflowStatus::BlockedSignatureMismatch
224            }
225        }
226    }
227
228    /// Human-readable description for the blocking_reason column.
229    pub fn description(&self) -> String {
230        match self {
231            Self::MissingHandler => "No handler registered for this workflow".to_string(),
232            Self::MissingVersion => "Workflow version not present in current binary".to_string(),
233            Self::SignatureMismatch { expected, actual } => {
234                format!("Signature mismatch: run expects {expected}, binary has {actual}")
235            }
236        }
237    }
238}
239
240impl Clone for WorkflowRegistry {
241    fn clone(&self) -> Self {
242        Self {
243            entries: self
244                .entries
245                .iter()
246                .map(|(k, v)| {
247                    (
248                        k.clone(),
249                        WorkflowEntry {
250                            info: v.info.clone(),
251                            handler: v.handler.clone(),
252                        },
253                    )
254                })
255                .collect(),
256            active_versions: self.active_versions.clone(),
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_empty_registry() {
267        let registry = WorkflowRegistry::new();
268        assert!(registry.is_empty());
269        assert_eq!(registry.len(), 0);
270    }
271
272    #[test]
273    fn test_resume_block_reasons() {
274        let reason = ResumeBlockReason::MissingHandler;
275        assert_eq!(
276            reason.to_status(),
277            forge_core::workflow::WorkflowStatus::BlockedMissingHandler
278        );
279
280        let reason = ResumeBlockReason::SignatureMismatch {
281            expected: "abc".to_string(),
282            actual: "def".to_string(),
283        };
284        assert_eq!(
285            reason.to_status(),
286            forge_core::workflow::WorkflowStatus::BlockedSignatureMismatch
287        );
288        assert!(reason.description().contains("abc"));
289        assert!(reason.description().contains("def"));
290    }
291}