use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use crate::error::{GoblinError, Result};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum StepInput {
Literal(String),
StepReference { step: String },
Template { template: String },
}
impl StepInput {
pub fn literal(value: impl Into<String>) -> Self {
Self::Literal(value.into())
}
pub fn step_ref(step: impl Into<String>) -> Self {
Self::StepReference { step: step.into() }
}
pub fn template(template: impl Into<String>) -> Self {
Self::Template { template: template.into() }
}
pub fn get_dependencies(&self) -> Vec<String> {
match self {
Self::Literal(_) => Vec::new(),
Self::StepReference { step } => vec![step.clone()],
Self::Template { template } => {
let mut deps = Vec::new();
let mut chars = template.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '{' {
let mut dep = String::new();
while let Some(&next_ch) = chars.peek() {
if next_ch == '}' {
chars.next(); break;
}
dep.push(chars.next().unwrap());
}
if !dep.is_empty() {
deps.push(dep);
}
}
}
deps
}
}
}
pub fn resolve(&self, context: &HashMap<String, String>) -> Result<String> {
match self {
Self::Literal(value) => Ok(value.clone()),
Self::StepReference { step } => {
context.get(step)
.cloned()
.ok_or_else(|| GoblinError::missing_dependency("unknown", step))
}
Self::Template { template } => {
let mut result = template.clone();
for (key, value) in context {
let placeholder = format!("{{{}}}", key);
result = result.replace(&placeholder, value);
}
Ok(result)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepConfig {
pub name: String,
#[serde(default)]
pub function: Option<String>, #[serde(default)]
pub inputs: Vec<String>,
#[serde(default)]
pub timeout: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct Step {
pub name: String,
pub function: String, pub inputs: Vec<StepInput>,
pub timeout: Option<std::time::Duration>,
}
impl Step {
pub fn new(
name: impl Into<String>,
function: impl Into<String>,
inputs: Vec<StepInput>
) -> Self {
Self {
name: name.into(),
function: function.into(),
inputs,
timeout: None,
}
}
pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn get_dependencies(&self) -> Vec<String> {
let mut deps = Vec::new();
for input in &self.inputs {
deps.extend(input.get_dependencies());
}
deps.sort();
deps.dedup();
deps
}
pub fn resolve_inputs(&self, context: &HashMap<String, String>) -> Result<Vec<String>> {
self.inputs
.iter()
.map(|input| input.resolve(context))
.collect()
}
}
impl From<StepConfig> for Step {
fn from(config: StepConfig) -> Self {
let function = config.function.unwrap_or_else(|| config.name.clone());
let inputs = config.inputs
.into_iter()
.map(|input| {
if input.contains('{') && input.contains('}') {
StepInput::Template { template: input }
} else if input == "default_input" || input.chars().all(|c| c.is_alphanumeric() || c == '_') {
if input.starts_with('"') && input.ends_with('"') {
StepInput::Literal(input[1..input.len()-1].to_string())
} else {
StepInput::StepReference { step: input }
}
} else {
StepInput::Literal(input)
}
})
.collect();
let mut step = Self::new(config.name, function, inputs);
if let Some(timeout_secs) = config.timeout {
step = step.with_timeout(std::time::Duration::from_secs(timeout_secs));
}
step
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanConfig {
pub name: String,
#[serde(default)]
pub steps: Vec<StepConfig>,
}
#[derive(Debug, Clone)]
pub struct Plan {
pub name: String,
pub steps: Vec<Step>,
}
impl Plan {
pub fn new(name: impl Into<String>, steps: Vec<Step>) -> Self {
Self {
name: name.into(),
steps,
}
}
pub fn from_toml_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
Self::from_toml_str(&content)
}
pub fn from_toml_str(toml_str: &str) -> Result<Self> {
let config: PlanConfig = toml::from_str(toml_str)?;
Ok(Self::from(config))
}
pub fn get_required_scripts(&self) -> Vec<String> {
let mut scripts = HashSet::new();
for step in &self.steps {
scripts.insert(step.function.clone());
}
let mut result: Vec<String> = scripts.into_iter().collect();
result.sort();
result
}
pub fn validate(&self) -> Result<()> {
self.check_circular_dependencies()?;
let mut step_names = HashSet::new();
for step in &self.steps {
if !step_names.insert(step.name.clone()) {
return Err(GoblinError::invalid_step_config(format!(
"Duplicate step name: {}", step.name
)));
}
}
for step in &self.steps {
let deps = step.get_dependencies();
for dep in deps {
if dep != "default_input" && !step_names.contains(&dep) {
return Err(GoblinError::missing_dependency(&step.name, &dep));
}
}
}
Ok(())
}
fn check_circular_dependencies(&self) -> Result<()> {
let mut graph: HashMap<String, Vec<String>> = HashMap::new();
for step in &self.steps {
let deps = step.get_dependencies();
graph.insert(step.name.clone(), deps);
}
let mut visiting = HashSet::new();
let mut visited = HashSet::new();
for step in &self.steps {
if !visited.contains(&step.name) {
if self.has_cycle(&graph, &step.name, &mut visiting, &mut visited)? {
return Err(GoblinError::circular_dependency(&self.name));
}
}
}
Ok(())
}
fn has_cycle(
&self,
graph: &HashMap<String, Vec<String>>,
node: &str,
visiting: &mut HashSet<String>,
visited: &mut HashSet<String>,
) -> Result<bool> {
if visiting.contains(node) {
return Ok(true); }
if visited.contains(node) {
return Ok(false); }
visiting.insert(node.to_string());
if let Some(deps) = graph.get(node) {
for dep in deps {
if dep != "default_input" {
if self.has_cycle(graph, dep, visiting, visited)? {
return Ok(true);
}
}
}
}
visiting.remove(node);
visited.insert(node.to_string());
Ok(false)
}
pub fn get_execution_order(&self) -> Result<Vec<String>> {
self.validate()?;
let mut graph: HashMap<String, Vec<String>> = HashMap::new();
let mut in_degree: HashMap<String, usize> = HashMap::new();
for step in &self.steps {
in_degree.insert(step.name.clone(), 0);
graph.insert(step.name.clone(), Vec::new());
}
for step in &self.steps {
let deps = step.get_dependencies();
for dep in deps {
if dep != "default_input" {
graph.entry(dep.clone()).or_default().push(step.name.clone());
*in_degree.entry(step.name.clone()).or_insert(0) += 1;
}
}
}
let mut queue: VecDeque<String> = VecDeque::new();
let mut result = Vec::new();
for (node, °ree) in &in_degree {
if degree == 0 {
queue.push_back(node.clone());
}
}
while let Some(node) = queue.pop_front() {
result.push(node.clone());
if let Some(neighbors) = graph.get(&node) {
for neighbor in neighbors {
let degree = in_degree.get_mut(neighbor).unwrap();
*degree -= 1;
if *degree == 0 {
queue.push_back(neighbor.clone());
}
}
}
}
if result.len() != self.steps.len() {
return Err(GoblinError::circular_dependency(&self.name));
}
Ok(result)
}
}
impl From<PlanConfig> for Plan {
fn from(config: PlanConfig) -> Self {
let steps = config.steps.into_iter().map(Step::from).collect();
Self::new(config.name, steps)
}
}
impl From<String> for StepInput {
fn from(s: String) -> Self {
if s.contains('{') && s.contains('}') {
Self::Template { template: s }
} else {
Self::Literal(s)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_step_input_literal() {
let input = StepInput::literal("hello world");
let context = HashMap::new();
assert_eq!(input.resolve(&context).unwrap(), "hello world");
assert!(input.get_dependencies().is_empty());
}
#[test]
fn test_step_input_template() {
let input = StepInput::template("Result: {step1} and {step2}");
let mut context = HashMap::new();
context.insert("step1".to_string(), "foo".to_string());
context.insert("step2".to_string(), "bar".to_string());
assert_eq!(input.resolve(&context).unwrap(), "Result: foo and bar");
let deps = input.get_dependencies();
assert_eq!(deps, vec!["step1", "step2"]);
}
#[test]
fn test_plan_from_toml() {
let toml_content = r#"
name = "test_plan"
[[steps]]
name = "step1"
function = "script1"
inputs = ["default_input"]
[[steps]]
name = "step2"
function = "script2"
inputs = ["step1"]
"#;
let plan = Plan::from_toml_str(toml_content).unwrap();
assert_eq!(plan.name, "test_plan");
assert_eq!(plan.steps.len(), 2);
assert_eq!(plan.steps[0].name, "step1");
assert_eq!(plan.steps[1].name, "step2");
}
#[test]
fn test_execution_order() {
let toml_content = r#"
name = "test_plan"
[[steps]]
name = "step3"
function = "script3"
inputs = ["step1", "step2"]
[[steps]]
name = "step1"
function = "script1"
inputs = ["default_input"]
[[steps]]
name = "step2"
function = "script2"
inputs = ["step1"]
"#;
let plan = Plan::from_toml_str(toml_content).unwrap();
let order = plan.get_execution_order().unwrap();
let step1_pos = order.iter().position(|x| x == "step1").unwrap();
let step2_pos = order.iter().position(|x| x == "step2").unwrap();
let step3_pos = order.iter().position(|x| x == "step3").unwrap();
assert!(step1_pos < step2_pos);
assert!(step2_pos < step3_pos);
assert!(step1_pos < step3_pos);
}
#[test]
fn test_circular_dependency_detection() {
let toml_content = r#"
name = "circular_plan"
[[steps]]
name = "step1"
function = "script1"
inputs = ["step2"]
[[steps]]
name = "step2"
function = "script2"
inputs = ["step1"]
"#;
let plan = Plan::from_toml_str(toml_content).unwrap();
let result = plan.validate();
assert!(result.is_err(), "Expected circular dependency error, but validation passed");
let execution_result = plan.get_execution_order();
assert!(execution_result.is_err(), "Expected circular dependency error in execution order");
}
}