use super::types::{BatchOperation, BatchRequest};
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct DependencyGraph {
operations: HashMap<String, BatchOperation>,
operation_order: Vec<String>,
dependencies: HashMap<String, HashSet<String>>,
dependents: HashMap<String, HashSet<String>>,
}
impl DependencyGraph {
pub fn new(operations: Vec<BatchOperation>) -> Result<Self, String> {
let mut graph = Self {
operations: HashMap::new(),
operation_order: Vec::new(),
dependencies: HashMap::new(),
dependents: HashMap::new(),
};
for op in operations {
if graph.operations.contains_key(&op.id) {
return Err(format!("Duplicate operation ID: {}", op.id));
}
graph.operation_order.push(op.id.clone());
graph.operations.insert(op.id.clone(), op);
}
for (id, op) in &graph.operations {
for dep in &op.depends_on {
if !graph.operations.contains_key(dep) {
return Err(format!(
"Operation '{}' depends on unknown operation '{}'",
id, dep
));
}
graph
.dependencies
.entry(id.clone())
.or_default()
.insert(dep.clone());
graph
.dependents
.entry(dep.clone())
.or_default()
.insert(id.clone());
}
}
graph.validate_acyclic()?;
Ok(graph)
}
fn validate_acyclic(&self) -> Result<(), String> {
let mut visited = HashSet::new();
let mut stack = HashSet::new();
for id in self.operations.keys() {
if !visited.contains(id) {
self.detect_cycle(id, &mut visited, &mut stack)?;
}
}
Ok(())
}
fn detect_cycle(
&self,
node: &str,
visited: &mut HashSet<String>,
stack: &mut HashSet<String>,
) -> Result<(), String> {
visited.insert(node.to_string());
stack.insert(node.to_string());
if let Some(deps) = self.dependencies.get(node) {
for dep in deps {
if !visited.contains(dep) {
self.detect_cycle(dep, visited, stack)?;
} else if stack.contains(dep) {
return Err(format!("Circular dependency detected: {} -> {}", node, dep));
}
}
}
stack.remove(node);
Ok(())
}
pub fn get_ready_operations(&self, completed: &HashSet<String>) -> Vec<BatchOperation> {
self.operations
.values()
.filter(|op| {
self.dependencies
.get(&op.id)
.map(|deps| deps.iter().all(|dep| completed.contains(dep)))
.unwrap_or(true) && !completed.contains(&op.id) })
.cloned()
.collect()
}
pub fn len(&self) -> usize {
self.operations.len()
}
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
pub fn operations_in_order(&self) -> Vec<BatchOperation> {
self.operation_order
.iter()
.filter_map(|id| self.operations.get(id).cloned())
.collect()
}
}
impl From<BatchRequest> for DependencyGraph {
fn from(request: BatchRequest) -> Self {
DependencyGraph::new(request.operations).unwrap_or_else(|_e| {
DependencyGraph {
operations: HashMap::new(),
operation_order: Vec::new(),
dependencies: HashMap::new(),
dependents: HashMap::new(),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
#[test]
fn test_dependency_graph_simple() {
let ops = vec![
BatchOperation {
id: "op1".to_string(),
tool: "tool1".to_string(),
arguments: Value::Null,
depends_on: vec![],
},
BatchOperation {
id: "op2".to_string(),
tool: "tool2".to_string(),
arguments: Value::Null,
depends_on: vec![],
},
];
let graph = DependencyGraph::new(ops).unwrap();
assert_eq!(graph.len(), 2);
assert!(!graph.is_empty());
}
#[test]
fn test_dependency_graph_with_dependencies() {
let ops = vec![
BatchOperation {
id: "op1".to_string(),
tool: "tool1".to_string(),
arguments: Value::Null,
depends_on: vec![],
},
BatchOperation {
id: "op2".to_string(),
tool: "tool2".to_string(),
arguments: Value::Null,
depends_on: vec!["op1".to_string()],
},
];
let graph = DependencyGraph::new(ops).unwrap();
assert_eq!(graph.len(), 2);
}
#[test]
fn test_dependency_graph_cycle() {
let ops = vec![
BatchOperation {
id: "op1".to_string(),
tool: "tool1".to_string(),
arguments: Value::Null,
depends_on: vec!["op2".to_string()],
},
BatchOperation {
id: "op2".to_string(),
tool: "tool2".to_string(),
arguments: Value::Null,
depends_on: vec!["op1".to_string()],
},
];
let result = DependencyGraph::new(ops);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Circular dependency"));
}
#[test]
fn test_dependency_graph_unknown_dependency() {
let ops = vec![BatchOperation {
id: "op1".to_string(),
tool: "tool1".to_string(),
arguments: Value::Null,
depends_on: vec!["unknown".to_string()],
}];
let result = DependencyGraph::new(ops);
assert!(result.is_err());
assert!(result.unwrap_err().contains("unknown operation"));
}
#[test]
fn test_dependency_graph_duplicate_id() {
let ops = vec![
BatchOperation {
id: "op1".to_string(),
tool: "tool1".to_string(),
arguments: Value::Null,
depends_on: vec![],
},
BatchOperation {
id: "op1".to_string(),
tool: "tool2".to_string(),
arguments: Value::Null,
depends_on: vec![],
},
];
let result = DependencyGraph::new(ops);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Duplicate operation ID"));
}
#[test]
fn test_get_ready_operations() {
let ops = vec![
BatchOperation {
id: "op1".to_string(),
tool: "tool1".to_string(),
arguments: Value::Null,
depends_on: vec![],
},
BatchOperation {
id: "op2".to_string(),
tool: "tool2".to_string(),
arguments: Value::Null,
depends_on: vec!["op1".to_string()],
},
];
let graph = DependencyGraph::new(ops).unwrap();
let completed = HashSet::new();
let ready = graph.get_ready_operations(&completed);
assert_eq!(ready.len(), 1);
assert_eq!(ready[0].id, "op1");
let mut completed = HashSet::new();
completed.insert("op1".to_string());
let ready = graph.get_ready_operations(&completed);
assert_eq!(ready.len(), 1);
assert_eq!(ready[0].id, "op2");
}
}