Skip to main content

nuro_graph/
graph.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use nuro_core::{NuroError, Result};
5
6use crate::{GraphNode, GraphStateTrait, NodeContext};
7
8/// 状态检查点存储抽象。
9///
10/// 当前仅在 `InMemoryCheckpointer` 中用于开发/调试场景,未来可以替换为
11/// 基于数据库或对象存储的持久化实现。
12pub trait Checkpointer<S>: Send + Sync
13where
14    S: GraphStateTrait,
15{
16    /// 保存指定节点执行后的完整状态快照。
17    fn save_state(&self, node_id: &str, state: &S) -> Result<()>;
18
19    /// 加载某个节点最近一次保存的状态快照。
20    ///
21    /// 默认实现返回 `Ok(None)`,表示未找到对应检查点。
22    fn load_state(&self, _node_id: &str) -> Result<Option<S>> {
23        Ok(None)
24    }
25}
26
27/// 纯内存版检查点存储,用于开发与测试。
28///
29/// - 按节点 id 存储最新一次执行后的状态快照;
30/// - 使用互斥锁保证线程安全;
31/// - 仅适用于单进程、低并发场景。
32pub struct InMemoryCheckpointer<S>
33where
34    S: GraphStateTrait,
35{
36    inner: Mutex<HashMap<String, S>>,
37}
38
39impl<S> InMemoryCheckpointer<S>
40where
41    S: GraphStateTrait,
42{
43    pub fn new() -> Self {
44        Self {
45            inner: Mutex::new(HashMap::new()),
46        }
47    }
48
49    pub fn get(&self, node_id: &str) -> Option<S> {
50        self.inner
51            .lock()
52            .ok()
53            .and_then(|m| m.get(node_id).cloned())
54    }
55}
56
57impl<S> Checkpointer<S> for InMemoryCheckpointer<S>
58where
59    S: GraphStateTrait,
60{
61    fn save_state(&self, node_id: &str, state: &S) -> Result<()> {
62        let mut guard = self
63            .inner
64            .lock()
65            .map_err(|_| NuroError::InvalidInput("failed to lock InMemoryCheckpointer".into()))?;
66        guard.insert(node_id.to_string(), state.clone());
67        Ok(())
68    }
69
70    fn load_state(&self, node_id: &str) -> Result<Option<S>> {
71        let guard = self
72            .inner
73            .lock()
74            .map_err(|_| NuroError::InvalidInput("failed to lock InMemoryCheckpointer".into()))?;
75        Ok(guard.get(node_id).cloned())
76    }
77}
78
79/// 构建中的状态图。
80pub struct StateGraph<S>
81where
82    S: GraphStateTrait,
83{
84    nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
85    edges: HashMap<String, Vec<String>>,             // 普通有向边
86    conditional_edges: HashMap<String, ConditionalEdge<S>>, // 条件边
87    entry: Option<String>,
88    finish: Option<String>,
89}
90
91struct ConditionalEdge<S>
92where
93    S: GraphStateTrait,
94{
95    router: Arc<dyn Fn(&S) -> String + Send + Sync>,
96    routes: HashMap<String, String>,
97}
98
99impl<S> StateGraph<S>
100where
101    S: GraphStateTrait,
102{
103    pub fn new() -> Self {
104        Self {
105            nodes: HashMap::new(),
106            edges: HashMap::new(),
107            conditional_edges: HashMap::new(),
108            entry: None,
109            finish: None,
110        }
111    }
112
113    /// 添加一个节点。
114    pub fn add_node<N>(mut self, id: impl Into<String>, node: N) -> Self
115    where
116        N: GraphNode<S> + 'static,
117    {
118        let id = id.into();
119        self.nodes.insert(id, Arc::new(node));
120        self
121    }
122
123    /// 添加一条普通有向边 `from -> to`。
124    pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
125        let from = from.into();
126        let to = to.into();
127        self.edges.entry(from).or_default().push(to);
128        self
129    }
130
131    /// 添加一条条件边:
132    ///
133    /// - `router` 根据当前状态返回一个路由 key;
134    /// - `routes` 将 key 映射到下一跳节点 id。
135    pub fn add_conditional_edge(
136        mut self,
137        from: impl Into<String>,
138        router: impl Fn(&S) -> String + Send + Sync + 'static,
139        routes: HashMap<String, String>,
140    ) -> Self {
141        let from = from.into();
142        let edge = ConditionalEdge {
143            router: Arc::new(router),
144            routes,
145        };
146        self.conditional_edges.insert(from, edge);
147        self
148    }
149
150    /// 设置入口节点 id。
151    pub fn set_entry_point(mut self, id: impl Into<String>) -> Self {
152        self.entry = Some(id.into());
153        self
154    }
155
156    /// 设置结束节点 id。
157    pub fn set_finish_point(mut self, id: impl Into<String>) -> Self {
158        self.finish = Some(id.into());
159        self
160    }
161
162    /// 编译为只读的 `CompiledGraph`,用于实际运行。
163    ///
164    /// 在编译阶段会做若干完整性检查:
165    /// - 入口节点必须存在;
166    /// - 如果设置了结束节点,则结束节点也必须存在;
167    /// - 所有边的起点/终点都必须在节点集合中出现。
168    pub fn compile(self) -> Result<CompiledGraph<S>> {
169        let entry = self
170            .entry
171            .ok_or_else(|| NuroError::InvalidInput("entry point is not set".into()))?;
172
173        if !self.nodes.contains_key(&entry) {
174            return Err(NuroError::InvalidInput(format!(
175                "entry node '{}' not found in graph",
176                entry
177            )));
178        }
179
180        if let Some(ref finish) = self.finish {
181            if !self.nodes.contains_key(finish) {
182                return Err(NuroError::InvalidInput(format!(
183                    "finish node '{}' not found in graph",
184                    finish
185                )));
186            }
187        }
188
189        // 校验普通边引用的节点是否存在。
190        for (from, tos) in &self.edges {
191            if !self.nodes.contains_key(from) {
192                return Err(NuroError::InvalidInput(format!(
193                    "edge references unknown source node '{}'",
194                    from
195                )));
196            }
197            for to in tos {
198                if !self.nodes.contains_key(to) {
199                    return Err(NuroError::InvalidInput(format!(
200                        "edge from '{}' references unknown target node '{}'",
201                        from, to
202                    )));
203                }
204            }
205        }
206
207        // 校验条件边引用的节点是否存在。
208        for (from, cond) in &self.conditional_edges {
209            if !self.nodes.contains_key(from) {
210                return Err(NuroError::InvalidInput(format!(
211                    "conditional edge references unknown source node '{}'",
212                    from
213                )));
214            }
215            for (key, to) in &cond.routes {
216                if !self.nodes.contains_key(to) {
217                    return Err(NuroError::InvalidInput(format!(
218                        "conditional edge from '{}' with route key '{}' references unknown target node '{}'",
219                        from, key, to
220                    )));
221                }
222            }
223        }
224
225        Ok(CompiledGraph {
226            entry,
227            finish: self.finish,
228            nodes: self.nodes,
229            edges: self.edges,
230            conditional_edges: self.conditional_edges,
231            checkpointer: None,
232        })
233    }
234}
235
236/// 已编译完成、可执行的状态图。
237pub struct CompiledGraph<S>
238where
239    S: GraphStateTrait,
240{
241    entry: String,
242    finish: Option<String>,
243    nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
244    edges: HashMap<String, Vec<String>>,
245    conditional_edges: HashMap<String, ConditionalEdge<S>>,
246    checkpointer: Option<Arc<dyn Checkpointer<S>>>,
247}
248
249impl<S> CompiledGraph<S>
250where
251    S: GraphStateTrait,
252{
253    /// 挂载一个检查点存储实现。图的执行并不强依赖检查点存在。
254    pub fn with_checkpointer<C>(mut self, checkpointer: C) -> Self
255    where
256        C: Checkpointer<S> + 'static,
257    {
258        self.checkpointer = Some(Arc::new(checkpointer));
259        self
260    }
261
262    /// 按有向图顺序依次执行节点:
263    /// - 从 entry 节点开始;
264    /// - 每个节点运行后通过 `apply_update` 合并状态;
265    /// - 若存在条件边,则优先根据 router 结果选下一跳;
266    /// - 否则选择第一条普通出边;
267    /// - 若到达 finish 节点或无出边,则结束执行。
268    pub async fn invoke(&self, mut state: S) -> Result<S> {
269        let mut ctx = NodeContext::new();
270        let mut current = self.entry.clone();
271
272        loop {
273            let node = self.nodes.get(&current).ok_or_else(|| {
274                NuroError::InvalidInput(format!("node '{}' not found in compiled graph", current))
275            })?;
276
277            let update = node.run(&state, &mut ctx).await?;
278            state.apply_update(update);
279
280            if let Some(cp) = &self.checkpointer {
281                cp.save_state(&current, &state)?;
282            }
283
284            if let Some(ref finish) = self.finish {
285                if &current == finish {
286                    break;
287                }
288            }
289
290            // 条件路由优先。
291            if let Some(cond) = self.conditional_edges.get(&current) {
292                let key = (cond.router)(&state);
293                if let Some(next) = cond.routes.get(&key) {
294                    current = next.clone();
295                    continue;
296                }
297            }
298
299            // 普通有向边(按插入顺序取第一条)。
300            if let Some(nexts) = self.edges.get(&current) {
301                if let Some(next) = nexts.first() {
302                    current = next.clone();
303                    continue;
304                }
305            }
306
307            // 没有出边时终止。
308            break;
309        }
310
311        Ok(state)
312    }
313
314    /// 从某个节点的检查点恢复执行的占位接口:
315    ///
316    /// - 若未挂载 `Checkpointer`,返回错误;
317    /// - 若找不到对应节点的检查点,同样返回错误;
318    /// - 目前的实现会从加载出的状态重新执行整张图(仍从 entry 开始),
319    ///   未来可以扩展为从任意节点继续执行。
320    pub async fn resume(&self, node_id: &str) -> Result<S> {
321        let cp = self
322            .checkpointer
323            .as_ref()
324            .ok_or_else(|| NuroError::InvalidInput("cannot resume without a checkpointer".into()))?;
325
326        let state = cp
327            .load_state(node_id)?
328            .ok_or_else(|| NuroError::InvalidInput(format!(
329                "no checkpoint found for node '{}'",
330                node_id
331            )))?;
332
333        self.invoke(state).await
334    }
335}