1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use nuro_core::{NuroError, Result};
5
6use crate::{GraphNode, GraphStateTrait, NodeContext};
7
8pub trait Checkpointer<S>: Send + Sync
13where
14 S: GraphStateTrait,
15{
16 fn save_state(&self, node_id: &str, state: &S) -> Result<()>;
18
19 fn load_state(&self, _node_id: &str) -> Result<Option<S>> {
23 Ok(None)
24 }
25}
26
27pub struct InMemoryCheckpointer<S>
33where
34 S: GraphStateTrait,
35{
36 inner: Mutex<HashMap<String, S>>,
37}
38
39impl<S> InMemoryCheckpointer<S>
40where
41 S: GraphStateTrait,
42{
43 pub fn new() -> Self {
44 Self {
45 inner: Mutex::new(HashMap::new()),
46 }
47 }
48
49 pub fn get(&self, node_id: &str) -> Option<S> {
50 self.inner
51 .lock()
52 .ok()
53 .and_then(|m| m.get(node_id).cloned())
54 }
55}
56
57impl<S> Checkpointer<S> for InMemoryCheckpointer<S>
58where
59 S: GraphStateTrait,
60{
61 fn save_state(&self, node_id: &str, state: &S) -> Result<()> {
62 let mut guard = self
63 .inner
64 .lock()
65 .map_err(|_| NuroError::InvalidInput("failed to lock InMemoryCheckpointer".into()))?;
66 guard.insert(node_id.to_string(), state.clone());
67 Ok(())
68 }
69
70 fn load_state(&self, node_id: &str) -> Result<Option<S>> {
71 let guard = self
72 .inner
73 .lock()
74 .map_err(|_| NuroError::InvalidInput("failed to lock InMemoryCheckpointer".into()))?;
75 Ok(guard.get(node_id).cloned())
76 }
77}
78
79pub struct StateGraph<S>
81where
82 S: GraphStateTrait,
83{
84 nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
85 edges: HashMap<String, Vec<String>>, conditional_edges: HashMap<String, ConditionalEdge<S>>, entry: Option<String>,
88 finish: Option<String>,
89}
90
91struct ConditionalEdge<S>
92where
93 S: GraphStateTrait,
94{
95 router: Arc<dyn Fn(&S) -> String + Send + Sync>,
96 routes: HashMap<String, String>,
97}
98
99impl<S> StateGraph<S>
100where
101 S: GraphStateTrait,
102{
103 pub fn new() -> Self {
104 Self {
105 nodes: HashMap::new(),
106 edges: HashMap::new(),
107 conditional_edges: HashMap::new(),
108 entry: None,
109 finish: None,
110 }
111 }
112
113 pub fn add_node<N>(mut self, id: impl Into<String>, node: N) -> Self
115 where
116 N: GraphNode<S> + 'static,
117 {
118 let id = id.into();
119 self.nodes.insert(id, Arc::new(node));
120 self
121 }
122
123 pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
125 let from = from.into();
126 let to = to.into();
127 self.edges.entry(from).or_default().push(to);
128 self
129 }
130
131 pub fn add_conditional_edge(
136 mut self,
137 from: impl Into<String>,
138 router: impl Fn(&S) -> String + Send + Sync + 'static,
139 routes: HashMap<String, String>,
140 ) -> Self {
141 let from = from.into();
142 let edge = ConditionalEdge {
143 router: Arc::new(router),
144 routes,
145 };
146 self.conditional_edges.insert(from, edge);
147 self
148 }
149
150 pub fn set_entry_point(mut self, id: impl Into<String>) -> Self {
152 self.entry = Some(id.into());
153 self
154 }
155
156 pub fn set_finish_point(mut self, id: impl Into<String>) -> Self {
158 self.finish = Some(id.into());
159 self
160 }
161
162 pub fn compile(self) -> Result<CompiledGraph<S>> {
169 let entry = self
170 .entry
171 .ok_or_else(|| NuroError::InvalidInput("entry point is not set".into()))?;
172
173 if !self.nodes.contains_key(&entry) {
174 return Err(NuroError::InvalidInput(format!(
175 "entry node '{}' not found in graph",
176 entry
177 )));
178 }
179
180 if let Some(ref finish) = self.finish {
181 if !self.nodes.contains_key(finish) {
182 return Err(NuroError::InvalidInput(format!(
183 "finish node '{}' not found in graph",
184 finish
185 )));
186 }
187 }
188
189 for (from, tos) in &self.edges {
191 if !self.nodes.contains_key(from) {
192 return Err(NuroError::InvalidInput(format!(
193 "edge references unknown source node '{}'",
194 from
195 )));
196 }
197 for to in tos {
198 if !self.nodes.contains_key(to) {
199 return Err(NuroError::InvalidInput(format!(
200 "edge from '{}' references unknown target node '{}'",
201 from, to
202 )));
203 }
204 }
205 }
206
207 for (from, cond) in &self.conditional_edges {
209 if !self.nodes.contains_key(from) {
210 return Err(NuroError::InvalidInput(format!(
211 "conditional edge references unknown source node '{}'",
212 from
213 )));
214 }
215 for (key, to) in &cond.routes {
216 if !self.nodes.contains_key(to) {
217 return Err(NuroError::InvalidInput(format!(
218 "conditional edge from '{}' with route key '{}' references unknown target node '{}'",
219 from, key, to
220 )));
221 }
222 }
223 }
224
225 Ok(CompiledGraph {
226 entry,
227 finish: self.finish,
228 nodes: self.nodes,
229 edges: self.edges,
230 conditional_edges: self.conditional_edges,
231 checkpointer: None,
232 })
233 }
234}
235
236pub struct CompiledGraph<S>
238where
239 S: GraphStateTrait,
240{
241 entry: String,
242 finish: Option<String>,
243 nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
244 edges: HashMap<String, Vec<String>>,
245 conditional_edges: HashMap<String, ConditionalEdge<S>>,
246 checkpointer: Option<Arc<dyn Checkpointer<S>>>,
247}
248
249impl<S> CompiledGraph<S>
250where
251 S: GraphStateTrait,
252{
253 pub fn with_checkpointer<C>(mut self, checkpointer: C) -> Self
255 where
256 C: Checkpointer<S> + 'static,
257 {
258 self.checkpointer = Some(Arc::new(checkpointer));
259 self
260 }
261
262 pub async fn invoke(&self, mut state: S) -> Result<S> {
269 let mut ctx = NodeContext::new();
270 let mut current = self.entry.clone();
271
272 loop {
273 let node = self.nodes.get(¤t).ok_or_else(|| {
274 NuroError::InvalidInput(format!("node '{}' not found in compiled graph", current))
275 })?;
276
277 let update = node.run(&state, &mut ctx).await?;
278 state.apply_update(update);
279
280 if let Some(cp) = &self.checkpointer {
281 cp.save_state(¤t, &state)?;
282 }
283
284 if let Some(ref finish) = self.finish {
285 if ¤t == finish {
286 break;
287 }
288 }
289
290 if let Some(cond) = self.conditional_edges.get(¤t) {
292 let key = (cond.router)(&state);
293 if let Some(next) = cond.routes.get(&key) {
294 current = next.clone();
295 continue;
296 }
297 }
298
299 if let Some(nexts) = self.edges.get(¤t) {
301 if let Some(next) = nexts.first() {
302 current = next.clone();
303 continue;
304 }
305 }
306
307 break;
309 }
310
311 Ok(state)
312 }
313
314 pub async fn resume(&self, node_id: &str) -> Result<S> {
321 let cp = self
322 .checkpointer
323 .as_ref()
324 .ok_or_else(|| NuroError::InvalidInput("cannot resume without a checkpointer".into()))?;
325
326 let state = cp
327 .load_state(node_id)?
328 .ok_or_else(|| NuroError::InvalidInput(format!(
329 "no checkpoint found for node '{}'",
330 node_id
331 )))?;
332
333 self.invoke(state).await
334 }
335}