use std::collections::HashMap;
use crate::node::PipelineNode;
#[derive(Debug, Clone)]
pub struct PipelineDag {
pub nodes: Vec<(String, PipelineNode)>,
pub edges: Vec<(usize, usize)>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum DagError {
NodeNotFound(String),
CycleDetected,
DuplicateNode(String),
}
impl std::fmt::Display for DagError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DagError::NodeNotFound(id) => write!(f, "node not found: {id}"),
DagError::CycleDetected => write!(f, "cycle detected in DAG"),
DagError::DuplicateNode(id) => write!(f, "duplicate node id: {id}"),
}
}
}
impl std::error::Error for DagError {}
impl PipelineDag {
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn node_index(&self, id: &str) -> Option<usize> {
self.nodes.iter().position(|(nid, _)| nid == id)
}
pub fn predecessors(&self, idx: usize) -> Vec<usize> {
self.edges
.iter()
.filter_map(|&(from, to)| if to == idx { Some(from) } else { None })
.collect()
}
pub fn successors(&self, idx: usize) -> Vec<usize> {
self.edges
.iter()
.filter_map(|&(from, to)| if from == idx { Some(to) } else { None })
.collect()
}
}
#[derive(Debug, Default)]
pub struct DagBuilder {
nodes: Vec<(String, PipelineNode)>,
edges: Vec<(String, String)>,
}
impl DagBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn node(mut self, id: impl Into<String>, node: PipelineNode) -> Self {
self.nodes.push((id.into(), node));
self
}
pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.edges.push((from.into(), to.into()));
self
}
pub fn build(self) -> Result<PipelineDag, DagError> {
let mut seen = HashMap::new();
for (i, (id, _)) in self.nodes.iter().enumerate() {
if seen.insert(id.clone(), i).is_some() {
return Err(DagError::DuplicateNode(id.clone()));
}
}
let mut edges = Vec::with_capacity(self.edges.len());
for (from_id, to_id) in &self.edges {
let from = *seen
.get(from_id)
.ok_or_else(|| DagError::NodeNotFound(from_id.clone()))?;
let to = *seen
.get(to_id)
.ok_or_else(|| DagError::NodeNotFound(to_id.clone()))?;
edges.push((from, to));
}
let dag = PipelineDag {
nodes: self.nodes,
edges,
};
if has_cycle(&dag) {
return Err(DagError::CycleDetected);
}
Ok(dag)
}
}
fn has_cycle(dag: &PipelineDag) -> bool {
let n = dag.nodes.len();
let mut in_degree = vec![0usize; n];
for &(_, to) in &dag.edges {
in_degree[to] += 1;
}
let mut queue: Vec<usize> = in_degree
.iter()
.enumerate()
.filter_map(|(i, &d)| if d == 0 { Some(i) } else { None })
.collect();
let mut visited = 0usize;
while let Some(node) = queue.pop() {
visited += 1;
for &(from, to) in &dag.edges {
if from == node {
in_degree[to] -= 1;
if in_degree[to] == 0 {
queue.push(to);
}
}
}
}
visited != n
}
#[cfg(test)]
mod tests {
use super::*;
use crate::node::PipelineNode;
fn embed_node() -> PipelineNode {
PipelineNode::Embed {
model: "m".into(),
dimensions: 128,
}
}
fn search_node() -> PipelineNode {
PipelineNode::Search {
index: "idx".into(),
top_k: 10,
}
}
#[test]
fn test_build_simple() {
let dag = DagBuilder::new()
.node("a", embed_node())
.node("b", search_node())
.edge("a", "b")
.build()
.unwrap();
assert_eq!(dag.node_count(), 2);
assert_eq!(dag.edge_count(), 1);
}
#[test]
fn test_duplicate_node() {
let res = DagBuilder::new()
.node("a", embed_node())
.node("a", search_node())
.build();
assert_eq!(res.unwrap_err(), DagError::DuplicateNode("a".into()));
}
#[test]
fn test_node_not_found() {
let res = DagBuilder::new()
.node("a", embed_node())
.edge("a", "missing")
.build();
assert_eq!(res.unwrap_err(), DagError::NodeNotFound("missing".into()));
}
#[test]
fn test_cycle_detected() {
let res = DagBuilder::new()
.node("a", embed_node())
.node("b", search_node())
.edge("a", "b")
.edge("b", "a")
.build();
assert_eq!(res.unwrap_err(), DagError::CycleDetected);
}
#[test]
fn test_predecessors_successors() {
let dag = DagBuilder::new()
.node("a", embed_node())
.node("b", search_node())
.edge("a", "b")
.build()
.unwrap();
assert_eq!(dag.predecessors(1), vec![0]);
assert_eq!(dag.successors(0), vec![1]);
}
}