atomr_agents_workflow/
dag.rs1use std::collections::{BTreeMap, HashMap};
2
3use atomr_agents_core::{AgentError, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
7#[serde(transparent)]
8pub struct StepId(pub String);
9
10impl StepId {
11 pub fn new(s: impl Into<String>) -> Self {
12 Self(s.into())
13 }
14
15 pub fn as_str(&self) -> &str {
16 &self.0
17 }
18}
19
20impl From<&str> for StepId {
21 fn from(s: &str) -> Self {
22 Self(s.into())
23 }
24}
25
26pub struct Dag<S> {
29 pub steps: BTreeMap<StepId, S>,
30 pub edges: HashMap<StepId, Vec<StepId>>,
31 pub entry: StepId,
32}
33
34impl<S> Dag<S> {
35 pub fn builder(entry: impl Into<StepId>) -> DagBuilder<S> {
36 DagBuilder {
37 steps: BTreeMap::new(),
38 edges: HashMap::new(),
39 entry: entry.into(),
40 }
41 }
42
43 pub fn topo_sort(&self) -> Result<Vec<StepId>> {
45 let mut indeg: HashMap<StepId, usize> = self.steps.keys().map(|k| (k.clone(), 0)).collect();
46 for tos in self.edges.values() {
47 for to in tos {
48 if let Some(d) = indeg.get_mut(to) {
49 *d += 1;
50 }
51 }
52 }
53 let mut queue: Vec<StepId> = indeg
54 .iter()
55 .filter(|(_, d)| **d == 0)
56 .map(|(k, _)| k.clone())
57 .collect();
58 queue.sort();
59 let mut out = Vec::with_capacity(self.steps.len());
60 while let Some(n) = queue.pop() {
61 out.push(n.clone());
62 if let Some(succ) = self.edges.get(&n) {
63 for s in succ {
64 if let Some(d) = indeg.get_mut(s) {
65 *d -= 1;
66 if *d == 0 {
67 queue.push(s.clone());
68 }
69 }
70 }
71 }
72 queue.sort();
73 }
74 if out.len() != self.steps.len() {
75 return Err(AgentError::Workflow("dag has a cycle".into()));
76 }
77 Ok(out)
78 }
79}
80
81pub struct DagBuilder<S> {
82 steps: BTreeMap<StepId, S>,
83 edges: HashMap<StepId, Vec<StepId>>,
84 entry: StepId,
85}
86
87impl<S> DagBuilder<S> {
88 pub fn step(mut self, id: impl Into<StepId>, step: S) -> Self {
89 self.steps.insert(id.into(), step);
90 self
91 }
92
93 pub fn edge(mut self, from: impl Into<StepId>, to: impl Into<StepId>) -> Self {
94 self.edges.entry(from.into()).or_default().push(to.into());
95 self
96 }
97
98 pub fn build(self) -> Dag<S> {
99 Dag {
100 steps: self.steps,
101 edges: self.edges,
102 entry: self.entry,
103 }
104 }
105}