Skip to main content

atomr_agents_workflow/
dag.rs

1use std::collections::{BTreeMap, HashMap};
2
3use atomr_agents_core::{AgentError, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
7#[serde(transparent)]
8pub struct StepId(pub String);
9
10impl StepId {
11    pub fn new(s: impl Into<String>) -> Self {
12        Self(s.into())
13    }
14
15    pub fn as_str(&self) -> &str {
16        &self.0
17    }
18}
19
20impl From<&str> for StepId {
21    fn from(s: &str) -> Self {
22        Self(s.into())
23    }
24}
25
26/// Static structure of a workflow. Steps are stored separately from
27/// their adjacency to keep the runtime types simple.
28pub struct Dag<S> {
29    pub steps: BTreeMap<StepId, S>,
30    pub edges: HashMap<StepId, Vec<StepId>>,
31    pub entry: StepId,
32}
33
34impl<S> Dag<S> {
35    pub fn builder(entry: impl Into<StepId>) -> DagBuilder<S> {
36        DagBuilder {
37            steps: BTreeMap::new(),
38            edges: HashMap::new(),
39            entry: entry.into(),
40        }
41    }
42
43    /// Topological order of step ids. Errors on cycles.
44    pub fn topo_sort(&self) -> Result<Vec<StepId>> {
45        let mut indeg: HashMap<StepId, usize> = self.steps.keys().map(|k| (k.clone(), 0)).collect();
46        for tos in self.edges.values() {
47            for to in tos {
48                if let Some(d) = indeg.get_mut(to) {
49                    *d += 1;
50                }
51            }
52        }
53        let mut queue: Vec<StepId> = indeg
54            .iter()
55            .filter(|(_, d)| **d == 0)
56            .map(|(k, _)| k.clone())
57            .collect();
58        queue.sort();
59        let mut out = Vec::with_capacity(self.steps.len());
60        while let Some(n) = queue.pop() {
61            out.push(n.clone());
62            if let Some(succ) = self.edges.get(&n) {
63                for s in succ {
64                    if let Some(d) = indeg.get_mut(s) {
65                        *d -= 1;
66                        if *d == 0 {
67                            queue.push(s.clone());
68                        }
69                    }
70                }
71            }
72            queue.sort();
73        }
74        if out.len() != self.steps.len() {
75            return Err(AgentError::Workflow("dag has a cycle".into()));
76        }
77        Ok(out)
78    }
79}
80
81pub struct DagBuilder<S> {
82    steps: BTreeMap<StepId, S>,
83    edges: HashMap<StepId, Vec<StepId>>,
84    entry: StepId,
85}
86
87impl<S> DagBuilder<S> {
88    pub fn step(mut self, id: impl Into<StepId>, step: S) -> Self {
89        self.steps.insert(id.into(), step);
90        self
91    }
92
93    pub fn edge(mut self, from: impl Into<StepId>, to: impl Into<StepId>) -> Self {
94        self.edges.entry(from.into()).or_default().push(to.into());
95        self
96    }
97
98    pub fn build(self) -> Dag<S> {
99        Dag {
100            steps: self.steps,
101            edges: self.edges,
102            entry: self.entry,
103        }
104    }
105}