1use std::collections::HashMap;
4use std::sync::Arc;
5
6use cognis_core::{CognisError, Result};
7
8use crate::node::Node;
9use crate::state::GraphState;
10
11#[derive(Clone)]
14pub struct Graph<S: GraphState> {
15 pub(crate) nodes: HashMap<String, Arc<dyn Node<S>>>,
16 pub(crate) edges: HashMap<String, String>,
17 pub(crate) start: Option<String>,
18 pub(crate) version: Option<String>,
20 pub(crate) annotations: HashMap<String, HashMap<String, serde_json::Value>>,
23}
24
25impl<S: GraphState> Default for Graph<S> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<S: GraphState> Graph<S> {
32 pub fn new() -> Self {
34 Self {
35 nodes: HashMap::new(),
36 edges: HashMap::new(),
37 start: None,
38 version: None,
39 annotations: HashMap::new(),
40 }
41 }
42
43 pub fn annotate(
51 mut self,
52 node_name: impl Into<String>,
53 key: impl Into<String>,
54 value: impl Into<serde_json::Value>,
55 ) -> Self {
56 let node = node_name.into();
57 if !self.nodes.contains_key(&node) {
58 return self;
59 }
60 self.annotations
61 .entry(node)
62 .or_default()
63 .insert(key.into(), value.into());
64 self
65 }
66
67 pub fn with_version(mut self, v: impl Into<String>) -> Self {
71 self.version = Some(v.into());
72 self
73 }
74
75 pub fn node(mut self, name: impl Into<String>, node: impl Node<S> + 'static) -> Self {
78 self.nodes.insert(name.into(), Arc::new(node));
79 self
80 }
81
82 pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
87 self.edges.insert(from.into(), to.into());
88 self
89 }
90
91 pub fn start_at(mut self, name: impl Into<String>) -> Self {
93 self.start = Some(name.into());
94 self
95 }
96
97 pub fn compile(self) -> Result<crate::compiled::CompiledGraph<S>> {
99 crate::validate::validate(&self)?;
100 Ok(crate::compiled::CompiledGraph::new(self))
101 }
102}
103
104pub struct LinearBuilder {
111 stages: Vec<(String, Arc<dyn Node<()>>)>,
112}
113
114impl Default for LinearBuilder {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl LinearBuilder {
121 pub fn new() -> Self {
123 Self { stages: Vec::new() }
124 }
125
126 pub fn then(mut self, node: impl Node<()> + 'static) -> Self {
128 let idx = self.stages.len().to_string();
129 self.stages.push((idx, Arc::new(node)));
130 self
131 }
132
133 pub fn compile(self) -> Result<crate::compiled::CompiledGraph<()>> {
135 if self.stages.is_empty() {
136 return Err(CognisError::Configuration(
137 "Graph::linear() requires at least one stage".into(),
138 ));
139 }
140 let mut g = Graph::<()>::new();
141 let last_idx = self.stages.len() - 1;
142 for (i, (name, node)) in self.stages.into_iter().enumerate() {
143 g.nodes.insert(name.clone(), node);
144 if i < last_idx {
145 let next = (i + 1).to_string();
146 g.edges.insert(name, next);
147 }
148 }
149 g.start = Some("0".to_string());
150 g.compile()
151 }
152}
153
154impl<S: GraphState> Graph<S> {
155 pub fn linear() -> LinearBuilder {
159 LinearBuilder::new()
160 }
161}
162
163impl GraphState for () {
166 type Update = ();
167 fn apply(&mut self, _: Self::Update) {}
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::goto::Goto;
174 use crate::node::{node_fn, NodeOut};
175
176 #[derive(Default, Clone, Debug)]
177 struct S {
178 msg: String,
179 }
180 #[derive(Default)]
181 struct SU {
182 msg: String,
183 }
184 impl GraphState for S {
185 type Update = SU;
186 fn apply(&mut self, u: Self::Update) {
187 self.msg.push_str(&u.msg);
188 }
189 }
190
191 #[test]
192 fn build_with_nodes_and_start() {
193 let g = Graph::<S>::new()
194 .node(
195 "a",
196 node_fn::<S, _, _>("a", |_s, _c| async move {
197 Ok(NodeOut {
198 update: SU { msg: "a".into() },
199 goto: Goto::node("b"),
200 })
201 }),
202 )
203 .node(
204 "b",
205 node_fn::<S, _, _>("b", |_s, _c| async move {
206 Ok(NodeOut::end_with(SU { msg: "b".into() }))
207 }),
208 )
209 .start_at("a");
210 assert_eq!(g.nodes.len(), 2);
211 assert_eq!(g.start.as_deref(), Some("a"));
212 }
213
214 #[tokio::test]
215 async fn linear_builder_chains_three_stages() {
216 let n = node_fn::<(), _, _>("noop", |_s, _c| async move {
217 Ok(NodeOut::goto_only(Goto::end()))
218 });
219 let n2 = node_fn::<(), _, _>("noop", |_s, _c| async move {
220 Ok(NodeOut::goto_only(Goto::end()))
221 });
222 let n3 = node_fn::<(), _, _>("noop", |_s, _c| async move {
223 Ok(NodeOut::goto_only(Goto::end()))
224 });
225 let cg = Graph::<()>::linear().then(n).then(n2).then(n3).compile();
226 assert!(cg.is_ok());
227 }
228
229 #[test]
230 fn linear_builder_rejects_empty() {
231 let cg = LinearBuilder::new().compile();
232 assert!(cg.is_err());
233 }
234}