goblin_engine/
plan.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use crate::error::{GoblinError, Result};
4
5/// Represents an input to a step, which can be a literal value or reference to another step's output
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7#[serde(untagged)]
8pub enum StepInput {
9    /// A literal string value
10    Literal(String),
11    /// A reference to another step's output
12    StepReference { step: String },
13    /// A formatted string with placeholders
14    Template { template: String },
15}
16
17impl StepInput {
18    /// Create a new literal input
19    pub fn literal(value: impl Into<String>) -> Self {
20        Self::Literal(value.into())
21    }
22
23    /// Create a new step reference input
24    pub fn step_ref(step: impl Into<String>) -> Self {
25        Self::StepReference { step: step.into() }
26    }
27
28    /// Create a new template input
29    pub fn template(template: impl Into<String>) -> Self {
30        Self::Template { template: template.into() }
31    }
32
33    /// Get all step dependencies from this input
34    pub fn get_dependencies(&self) -> Vec<String> {
35        match self {
36            Self::Literal(_) => Vec::new(),
37            Self::StepReference { step } => vec![step.clone()],
38            Self::Template { template } => {
39                // Extract step references from template format strings like {step_name}
40                let mut deps = Vec::new();
41                let mut chars = template.chars().peekable();
42                while let Some(ch) = chars.next() {
43                    if ch == '{' {
44                        let mut dep = String::new();
45                        while let Some(&next_ch) = chars.peek() {
46                            if next_ch == '}' {
47                                chars.next(); // consume the '}'
48                                break;
49                            }
50                            dep.push(chars.next().unwrap());
51                        }
52                        if !dep.is_empty() {
53                            deps.push(dep);
54                        }
55                    }
56                }
57                deps
58            }
59        }
60    }
61
62    /// Resolve this input given a context of step results
63    pub fn resolve(&self, context: &HashMap<String, String>) -> Result<String> {
64        match self {
65            Self::Literal(value) => Ok(value.clone()),
66            Self::StepReference { step } => {
67                context.get(step)
68                    .cloned()
69                    .ok_or_else(|| GoblinError::missing_dependency("unknown", step))
70            }
71            Self::Template { template } => {
72                let mut result = template.clone();
73                for (key, value) in context {
74                    let placeholder = format!("{{{}}}", key);
75                    result = result.replace(&placeholder, value);
76                }
77                Ok(result)
78            }
79        }
80    }
81}
82
83/// Configuration for a step in a plan, as loaded from TOML
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct StepConfig {
86    pub name: String,
87    #[serde(default)]
88    pub function: Option<String>, // Optional for backward compatibility
89    #[serde(default)]
90    pub inputs: Vec<String>,
91    #[serde(default)]
92    pub timeout: Option<u64>,
93}
94
95/// Runtime representation of a step in an execution plan
96#[derive(Debug, Clone)]
97pub struct Step {
98    pub name: String,
99    pub function: String, // The script/function to execute
100    pub inputs: Vec<StepInput>,
101    pub timeout: Option<std::time::Duration>,
102}
103
104impl Step {
105    /// Create a new step
106    pub fn new(
107        name: impl Into<String>, 
108        function: impl Into<String>,
109        inputs: Vec<StepInput>
110    ) -> Self {
111        Self {
112            name: name.into(),
113            function: function.into(),
114            inputs,
115            timeout: None,
116        }
117    }
118
119    /// Create a step with a custom timeout
120    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
121        self.timeout = Some(timeout);
122        self
123    }
124
125    /// Get all step dependencies for this step
126    pub fn get_dependencies(&self) -> Vec<String> {
127        let mut deps = Vec::new();
128        for input in &self.inputs {
129            deps.extend(input.get_dependencies());
130        }
131        deps.sort();
132        deps.dedup();
133        deps
134    }
135
136    /// Resolve all inputs for this step given a context
137    pub fn resolve_inputs(&self, context: &HashMap<String, String>) -> Result<Vec<String>> {
138        self.inputs
139            .iter()
140            .map(|input| input.resolve(context))
141            .collect()
142    }
143}
144
145impl From<StepConfig> for Step {
146    fn from(config: StepConfig) -> Self {
147        let function = config.function.unwrap_or_else(|| config.name.clone());
148        let inputs = config.inputs
149            .into_iter()
150            .map(|input| {
151                // Parse input strings into appropriate StepInput types
152                if input.contains('{') && input.contains('}') {
153                    StepInput::Template { template: input }
154                } else if input == "default_input" || input.chars().all(|c| c.is_alphanumeric() || c == '_') {
155                    // Treat simple identifiers as step references (except special literals)
156                    if input.starts_with('"') && input.ends_with('"') {
157                        // If it's quoted, treat as literal
158                        StepInput::Literal(input[1..input.len()-1].to_string())
159                    } else {
160                        // Treat as step reference
161                        StepInput::StepReference { step: input }
162                    }
163                } else {
164                    StepInput::Literal(input)
165                }
166            })
167            .collect();
168        
169        let mut step = Self::new(config.name, function, inputs);
170        if let Some(timeout_secs) = config.timeout {
171            step = step.with_timeout(std::time::Duration::from_secs(timeout_secs));
172        }
173        step
174    }
175}
176
177/// Configuration for a plan, as loaded from TOML
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct PlanConfig {
180    pub name: String,
181    #[serde(default)]
182    pub steps: Vec<StepConfig>,
183}
184
185/// Represents an execution plan with multiple steps
186#[derive(Debug, Clone)]
187pub struct Plan {
188    pub name: String,
189    pub steps: Vec<Step>,
190}
191
192impl Plan {
193    /// Create a new plan
194    pub fn new(name: impl Into<String>, steps: Vec<Step>) -> Self {
195        Self {
196            name: name.into(),
197            steps,
198        }
199    }
200
201    /// Load a plan from a TOML file
202    pub fn from_toml_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
203        let content = std::fs::read_to_string(path)?;
204        Self::from_toml_str(&content)
205    }
206
207    /// Load a plan from a TOML string
208    pub fn from_toml_str(toml_str: &str) -> Result<Self> {
209        let config: PlanConfig = toml::from_str(toml_str)?;
210        Ok(Self::from(config))
211    }
212
213    /// Get all unique script names referenced in this plan
214    pub fn get_required_scripts(&self) -> Vec<String> {
215        let mut scripts = HashSet::new();
216        for step in &self.steps {
217            scripts.insert(step.function.clone());
218        }
219        let mut result: Vec<String> = scripts.into_iter().collect();
220        result.sort();
221        result
222    }
223
224    /// Validate the plan for circular dependencies and other issues
225    pub fn validate(&self) -> Result<()> {
226        // Check for circular dependencies
227        self.check_circular_dependencies()?;
228        
229        // Check that all step names are unique
230        let mut step_names = HashSet::new();
231        for step in &self.steps {
232            if !step_names.insert(step.name.clone()) {
233                return Err(GoblinError::invalid_step_config(format!(
234                    "Duplicate step name: {}", step.name
235                )));
236            }
237        }
238
239        // Check that all dependencies exist
240        for step in &self.steps {
241            let deps = step.get_dependencies();
242            for dep in deps {
243                if dep != "default_input" && !step_names.contains(&dep) {
244                    return Err(GoblinError::missing_dependency(&step.name, &dep));
245                }
246            }
247        }
248
249        Ok(())
250    }
251
252    /// Check for circular dependencies in the plan
253    fn check_circular_dependencies(&self) -> Result<()> {
254        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
255        
256        // Build dependency graph
257        for step in &self.steps {
258            let deps = step.get_dependencies();
259            graph.insert(step.name.clone(), deps);
260        }
261
262        // Detect cycles using DFS
263        let mut visiting = HashSet::new();
264        let mut visited = HashSet::new();
265        
266        for step in &self.steps {
267            if !visited.contains(&step.name) {
268                if self.has_cycle(&graph, &step.name, &mut visiting, &mut visited)? {
269                    return Err(GoblinError::circular_dependency(&self.name));
270                }
271            }
272        }
273
274        Ok(())
275    }
276
277    /// DFS helper for cycle detection
278    fn has_cycle(
279        &self,
280        graph: &HashMap<String, Vec<String>>,
281        node: &str,
282        visiting: &mut HashSet<String>,
283        visited: &mut HashSet<String>,
284    ) -> Result<bool> {
285        if visiting.contains(node) {
286            return Ok(true); // Found a cycle
287        }
288        
289        if visited.contains(node) {
290            return Ok(false); // Already processed
291        }
292
293        visiting.insert(node.to_string());
294
295        if let Some(deps) = graph.get(node) {
296            for dep in deps {
297                if dep != "default_input" {
298                    if self.has_cycle(graph, dep, visiting, visited)? {
299                        return Ok(true);
300                    }
301                }
302            }
303        }
304
305        visiting.remove(node);
306        visited.insert(node.to_string());
307        Ok(false)
308    }
309
310    /// Get the execution order for steps based on dependencies
311    pub fn get_execution_order(&self) -> Result<Vec<String>> {
312        self.validate()?;
313        
314        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
315        let mut in_degree: HashMap<String, usize> = HashMap::new();
316        
317        // Build graph and calculate in-degrees
318        for step in &self.steps {
319            in_degree.insert(step.name.clone(), 0);
320            graph.insert(step.name.clone(), Vec::new());
321        }
322        
323        for step in &self.steps {
324            let deps = step.get_dependencies();
325            for dep in deps {
326                if dep != "default_input" {
327                    graph.entry(dep.clone()).or_default().push(step.name.clone());
328                    *in_degree.entry(step.name.clone()).or_insert(0) += 1;
329                }
330            }
331        }
332        
333        // Topological sort using Kahn's algorithm
334        let mut queue: VecDeque<String> = VecDeque::new();
335        let mut result = Vec::new();
336        
337        // Find all nodes with in-degree 0
338        for (node, &degree) in &in_degree {
339            if degree == 0 {
340                queue.push_back(node.clone());
341            }
342        }
343        
344        while let Some(node) = queue.pop_front() {
345            result.push(node.clone());
346            
347            if let Some(neighbors) = graph.get(&node) {
348                for neighbor in neighbors {
349                    let degree = in_degree.get_mut(neighbor).unwrap();
350                    *degree -= 1;
351                    if *degree == 0 {
352                        queue.push_back(neighbor.clone());
353                    }
354                }
355            }
356        }
357        
358        if result.len() != self.steps.len() {
359            return Err(GoblinError::circular_dependency(&self.name));
360        }
361        
362        Ok(result)
363    }
364}
365
366impl From<PlanConfig> for Plan {
367    fn from(config: PlanConfig) -> Self {
368        let steps = config.steps.into_iter().map(Step::from).collect();
369        Self::new(config.name, steps)
370    }
371}
372
373// Implement string parsing for backwards compatibility with original format
374impl From<String> for StepInput {
375    fn from(s: String) -> Self {
376        if s.contains('{') && s.contains('}') {
377            Self::Template { template: s }
378        } else {
379            Self::Literal(s)
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_step_input_literal() {
390        let input = StepInput::literal("hello world");
391        let context = HashMap::new();
392        assert_eq!(input.resolve(&context).unwrap(), "hello world");
393        assert!(input.get_dependencies().is_empty());
394    }
395
396    #[test]
397    fn test_step_input_template() {
398        let input = StepInput::template("Result: {step1} and {step2}");
399        let mut context = HashMap::new();
400        context.insert("step1".to_string(), "foo".to_string());
401        context.insert("step2".to_string(), "bar".to_string());
402        
403        assert_eq!(input.resolve(&context).unwrap(), "Result: foo and bar");
404        
405        let deps = input.get_dependencies();
406        assert_eq!(deps, vec!["step1", "step2"]);
407    }
408
409    #[test]
410    fn test_plan_from_toml() {
411        let toml_content = r#"
412            name = "test_plan"
413            
414            [[steps]]
415            name = "step1"
416            function = "script1"
417            inputs = ["default_input"]
418            
419            [[steps]]
420            name = "step2"
421            function = "script2"
422            inputs = ["step1"]
423        "#;
424
425        let plan = Plan::from_toml_str(toml_content).unwrap();
426        assert_eq!(plan.name, "test_plan");
427        assert_eq!(plan.steps.len(), 2);
428        assert_eq!(plan.steps[0].name, "step1");
429        assert_eq!(plan.steps[1].name, "step2");
430    }
431
432    #[test]
433    fn test_execution_order() {
434        let toml_content = r#"
435            name = "test_plan"
436            
437            [[steps]]
438            name = "step3"
439            function = "script3"
440            inputs = ["step1", "step2"]
441            
442            [[steps]]
443            name = "step1"
444            function = "script1"
445            inputs = ["default_input"]
446            
447            [[steps]]
448            name = "step2"
449            function = "script2"
450            inputs = ["step1"]
451        "#;
452
453        let plan = Plan::from_toml_str(toml_content).unwrap();
454        let order = plan.get_execution_order().unwrap();
455        
456        // step1 should come first, step2 should come before step3
457        let step1_pos = order.iter().position(|x| x == "step1").unwrap();
458        let step2_pos = order.iter().position(|x| x == "step2").unwrap();
459        let step3_pos = order.iter().position(|x| x == "step3").unwrap();
460        
461        assert!(step1_pos < step2_pos);
462        assert!(step2_pos < step3_pos);
463        assert!(step1_pos < step3_pos);
464    }
465
466    #[test]
467    fn test_circular_dependency_detection() {
468        let toml_content = r#"
469            name = "circular_plan"
470            
471            [[steps]]
472            name = "step1"
473            function = "script1"
474            inputs = ["step2"]
475            
476            [[steps]]
477            name = "step2"
478            function = "script2"
479            inputs = ["step1"]
480        "#;
481
482        let plan = Plan::from_toml_str(toml_content).unwrap();
483        // The plan should detect circular dependencies during validation
484        let result = plan.validate();
485        assert!(result.is_err(), "Expected circular dependency error, but validation passed");
486        
487        // Also test get_execution_order which should also catch this
488        let execution_result = plan.get_execution_order();
489        assert!(execution_result.is_err(), "Expected circular dependency error in execution order");
490    }
491}