1use serde::{Deserialize, Serialize};
2use std::collections::{HashMap, HashSet, VecDeque};
3use crate::error::{GoblinError, Result};
4
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
7#[serde(untagged)]
8pub enum StepInput {
9 Literal(String),
11 StepReference { step: String },
13 Template { template: String },
15}
16
17impl StepInput {
18 pub fn literal(value: impl Into<String>) -> Self {
20 Self::Literal(value.into())
21 }
22
23 pub fn step_ref(step: impl Into<String>) -> Self {
25 Self::StepReference { step: step.into() }
26 }
27
28 pub fn template(template: impl Into<String>) -> Self {
30 Self::Template { template: template.into() }
31 }
32
33 pub fn get_dependencies(&self) -> Vec<String> {
35 match self {
36 Self::Literal(_) => Vec::new(),
37 Self::StepReference { step } => vec![step.clone()],
38 Self::Template { template } => {
39 let mut deps = Vec::new();
41 let mut chars = template.chars().peekable();
42 while let Some(ch) = chars.next() {
43 if ch == '{' {
44 let mut dep = String::new();
45 while let Some(&next_ch) = chars.peek() {
46 if next_ch == '}' {
47 chars.next(); break;
49 }
50 dep.push(chars.next().unwrap());
51 }
52 if !dep.is_empty() {
53 deps.push(dep);
54 }
55 }
56 }
57 deps
58 }
59 }
60 }
61
62 pub fn resolve(&self, context: &HashMap<String, String>) -> Result<String> {
64 match self {
65 Self::Literal(value) => Ok(value.clone()),
66 Self::StepReference { step } => {
67 context.get(step)
68 .cloned()
69 .ok_or_else(|| GoblinError::missing_dependency("unknown", step))
70 }
71 Self::Template { template } => {
72 let mut result = template.clone();
73 for (key, value) in context {
74 let placeholder = format!("{{{}}}", key);
75 result = result.replace(&placeholder, value);
76 }
77 Ok(result)
78 }
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct StepConfig {
86 pub name: String,
87 #[serde(default)]
88 pub function: Option<String>, #[serde(default)]
90 pub inputs: Vec<String>,
91 #[serde(default)]
92 pub timeout: Option<u64>,
93}
94
95#[derive(Debug, Clone)]
97pub struct Step {
98 pub name: String,
99 pub function: String, pub inputs: Vec<StepInput>,
101 pub timeout: Option<std::time::Duration>,
102}
103
104impl Step {
105 pub fn new(
107 name: impl Into<String>,
108 function: impl Into<String>,
109 inputs: Vec<StepInput>
110 ) -> Self {
111 Self {
112 name: name.into(),
113 function: function.into(),
114 inputs,
115 timeout: None,
116 }
117 }
118
119 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
121 self.timeout = Some(timeout);
122 self
123 }
124
125 pub fn get_dependencies(&self) -> Vec<String> {
127 let mut deps = Vec::new();
128 for input in &self.inputs {
129 deps.extend(input.get_dependencies());
130 }
131 deps.sort();
132 deps.dedup();
133 deps
134 }
135
136 pub fn resolve_inputs(&self, context: &HashMap<String, String>) -> Result<Vec<String>> {
138 self.inputs
139 .iter()
140 .map(|input| input.resolve(context))
141 .collect()
142 }
143}
144
145impl From<StepConfig> for Step {
146 fn from(config: StepConfig) -> Self {
147 let function = config.function.unwrap_or_else(|| config.name.clone());
148 let inputs = config.inputs
149 .into_iter()
150 .map(|input| {
151 if input.contains('{') && input.contains('}') {
153 StepInput::Template { template: input }
154 } else if input == "default_input" || input.chars().all(|c| c.is_alphanumeric() || c == '_') {
155 if input.starts_with('"') && input.ends_with('"') {
157 StepInput::Literal(input[1..input.len()-1].to_string())
159 } else {
160 StepInput::StepReference { step: input }
162 }
163 } else {
164 StepInput::Literal(input)
165 }
166 })
167 .collect();
168
169 let mut step = Self::new(config.name, function, inputs);
170 if let Some(timeout_secs) = config.timeout {
171 step = step.with_timeout(std::time::Duration::from_secs(timeout_secs));
172 }
173 step
174 }
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct PlanConfig {
180 pub name: String,
181 #[serde(default)]
182 pub steps: Vec<StepConfig>,
183}
184
185#[derive(Debug, Clone)]
187pub struct Plan {
188 pub name: String,
189 pub steps: Vec<Step>,
190}
191
192impl Plan {
193 pub fn new(name: impl Into<String>, steps: Vec<Step>) -> Self {
195 Self {
196 name: name.into(),
197 steps,
198 }
199 }
200
201 pub fn from_toml_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
203 let content = std::fs::read_to_string(path)?;
204 Self::from_toml_str(&content)
205 }
206
207 pub fn from_toml_str(toml_str: &str) -> Result<Self> {
209 let config: PlanConfig = toml::from_str(toml_str)?;
210 Ok(Self::from(config))
211 }
212
213 pub fn get_required_scripts(&self) -> Vec<String> {
215 let mut scripts = HashSet::new();
216 for step in &self.steps {
217 scripts.insert(step.function.clone());
218 }
219 let mut result: Vec<String> = scripts.into_iter().collect();
220 result.sort();
221 result
222 }
223
224 pub fn validate(&self) -> Result<()> {
226 self.check_circular_dependencies()?;
228
229 let mut step_names = HashSet::new();
231 for step in &self.steps {
232 if !step_names.insert(step.name.clone()) {
233 return Err(GoblinError::invalid_step_config(format!(
234 "Duplicate step name: {}", step.name
235 )));
236 }
237 }
238
239 for step in &self.steps {
241 let deps = step.get_dependencies();
242 for dep in deps {
243 if dep != "default_input" && !step_names.contains(&dep) {
244 return Err(GoblinError::missing_dependency(&step.name, &dep));
245 }
246 }
247 }
248
249 Ok(())
250 }
251
252 fn check_circular_dependencies(&self) -> Result<()> {
254 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
255
256 for step in &self.steps {
258 let deps = step.get_dependencies();
259 graph.insert(step.name.clone(), deps);
260 }
261
262 let mut visiting = HashSet::new();
264 let mut visited = HashSet::new();
265
266 for step in &self.steps {
267 if !visited.contains(&step.name) {
268 if self.has_cycle(&graph, &step.name, &mut visiting, &mut visited)? {
269 return Err(GoblinError::circular_dependency(&self.name));
270 }
271 }
272 }
273
274 Ok(())
275 }
276
277 fn has_cycle(
279 &self,
280 graph: &HashMap<String, Vec<String>>,
281 node: &str,
282 visiting: &mut HashSet<String>,
283 visited: &mut HashSet<String>,
284 ) -> Result<bool> {
285 if visiting.contains(node) {
286 return Ok(true); }
288
289 if visited.contains(node) {
290 return Ok(false); }
292
293 visiting.insert(node.to_string());
294
295 if let Some(deps) = graph.get(node) {
296 for dep in deps {
297 if dep != "default_input" {
298 if self.has_cycle(graph, dep, visiting, visited)? {
299 return Ok(true);
300 }
301 }
302 }
303 }
304
305 visiting.remove(node);
306 visited.insert(node.to_string());
307 Ok(false)
308 }
309
310 pub fn get_execution_order(&self) -> Result<Vec<String>> {
312 self.validate()?;
313
314 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
315 let mut in_degree: HashMap<String, usize> = HashMap::new();
316
317 for step in &self.steps {
319 in_degree.insert(step.name.clone(), 0);
320 graph.insert(step.name.clone(), Vec::new());
321 }
322
323 for step in &self.steps {
324 let deps = step.get_dependencies();
325 for dep in deps {
326 if dep != "default_input" {
327 graph.entry(dep.clone()).or_default().push(step.name.clone());
328 *in_degree.entry(step.name.clone()).or_insert(0) += 1;
329 }
330 }
331 }
332
333 let mut queue: VecDeque<String> = VecDeque::new();
335 let mut result = Vec::new();
336
337 for (node, °ree) in &in_degree {
339 if degree == 0 {
340 queue.push_back(node.clone());
341 }
342 }
343
344 while let Some(node) = queue.pop_front() {
345 result.push(node.clone());
346
347 if let Some(neighbors) = graph.get(&node) {
348 for neighbor in neighbors {
349 let degree = in_degree.get_mut(neighbor).unwrap();
350 *degree -= 1;
351 if *degree == 0 {
352 queue.push_back(neighbor.clone());
353 }
354 }
355 }
356 }
357
358 if result.len() != self.steps.len() {
359 return Err(GoblinError::circular_dependency(&self.name));
360 }
361
362 Ok(result)
363 }
364}
365
366impl From<PlanConfig> for Plan {
367 fn from(config: PlanConfig) -> Self {
368 let steps = config.steps.into_iter().map(Step::from).collect();
369 Self::new(config.name, steps)
370 }
371}
372
373impl From<String> for StepInput {
375 fn from(s: String) -> Self {
376 if s.contains('{') && s.contains('}') {
377 Self::Template { template: s }
378 } else {
379 Self::Literal(s)
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_step_input_literal() {
390 let input = StepInput::literal("hello world");
391 let context = HashMap::new();
392 assert_eq!(input.resolve(&context).unwrap(), "hello world");
393 assert!(input.get_dependencies().is_empty());
394 }
395
396 #[test]
397 fn test_step_input_template() {
398 let input = StepInput::template("Result: {step1} and {step2}");
399 let mut context = HashMap::new();
400 context.insert("step1".to_string(), "foo".to_string());
401 context.insert("step2".to_string(), "bar".to_string());
402
403 assert_eq!(input.resolve(&context).unwrap(), "Result: foo and bar");
404
405 let deps = input.get_dependencies();
406 assert_eq!(deps, vec!["step1", "step2"]);
407 }
408
409 #[test]
410 fn test_plan_from_toml() {
411 let toml_content = r#"
412 name = "test_plan"
413
414 [[steps]]
415 name = "step1"
416 function = "script1"
417 inputs = ["default_input"]
418
419 [[steps]]
420 name = "step2"
421 function = "script2"
422 inputs = ["step1"]
423 "#;
424
425 let plan = Plan::from_toml_str(toml_content).unwrap();
426 assert_eq!(plan.name, "test_plan");
427 assert_eq!(plan.steps.len(), 2);
428 assert_eq!(plan.steps[0].name, "step1");
429 assert_eq!(plan.steps[1].name, "step2");
430 }
431
432 #[test]
433 fn test_execution_order() {
434 let toml_content = r#"
435 name = "test_plan"
436
437 [[steps]]
438 name = "step3"
439 function = "script3"
440 inputs = ["step1", "step2"]
441
442 [[steps]]
443 name = "step1"
444 function = "script1"
445 inputs = ["default_input"]
446
447 [[steps]]
448 name = "step2"
449 function = "script2"
450 inputs = ["step1"]
451 "#;
452
453 let plan = Plan::from_toml_str(toml_content).unwrap();
454 let order = plan.get_execution_order().unwrap();
455
456 let step1_pos = order.iter().position(|x| x == "step1").unwrap();
458 let step2_pos = order.iter().position(|x| x == "step2").unwrap();
459 let step3_pos = order.iter().position(|x| x == "step3").unwrap();
460
461 assert!(step1_pos < step2_pos);
462 assert!(step2_pos < step3_pos);
463 assert!(step1_pos < step3_pos);
464 }
465
466 #[test]
467 fn test_circular_dependency_detection() {
468 let toml_content = r#"
469 name = "circular_plan"
470
471 [[steps]]
472 name = "step1"
473 function = "script1"
474 inputs = ["step2"]
475
476 [[steps]]
477 name = "step2"
478 function = "script2"
479 inputs = ["step1"]
480 "#;
481
482 let plan = Plan::from_toml_str(toml_content).unwrap();
483 let result = plan.validate();
485 assert!(result.is_err(), "Expected circular dependency error, but validation passed");
486
487 let execution_result = plan.get_execution_order();
489 assert!(execution_result.is_err(), "Expected circular dependency error in execution order");
490 }
491}