Skip to main content

lellm_graph/
session.rs

1//! ExecutionSession — 执行会话,持有 FrameStack,管理恢复。
2//!
3//! # 设计理念
4//!
5//! ```text
6//! ExecutionEngine — 一次执行,借用 State(生命周期短)
7//! ExecutionSession — 持有 FrameStack,管理恢复(生命周期长)
8//!
9//! 职责分离:
10//! - Engine: 执行逻辑,借用 State
11//! - Session: 状态所有权 + FrameStack + Checkpoint 管理
12//! ```
13//!
14//! # P0-1: Checkpoint Projection
15//!
16//! SessionCheckpoint 使用 `S::Checkpoint`(关联类型),不是 `S`(Runtime State)。
17//! 这保证 Runtime State 可以包含不可序列化字段。
18//!
19//! # P0-2: Graph Hash
20//!
21//! SessionCheckpoint 使用 `canonical_hash`(从 DSL 层计算),
22//! 不依赖 compiled graph 的 HashMap 迭代顺序。
23
24use std::fmt::Debug;
25use std::sync::Arc;
26
27use serde::{Deserialize, Serialize};
28
29use crate::checkpoint::FrameStack;
30use crate::graph::Graph;
31use crate::state::{State, StateMerge};
32use crate::workflow_state::{MergeStrategy, WorkflowState};
33
34// ─── SessionError ──────────────────────────────────────────────
35
36/// Session 操作错误。
37#[derive(Debug, thiserror::Error)]
38pub enum SessionError {
39    /// Graph Hash 不匹配 — Checkpoint 与当前 Graph 不兼容
40    #[error("graph hash mismatch: expected {expected:#018x}, got {actual:#018x}")]
41    GraphMismatch { expected: u64, actual: u64 },
42}
43
44impl PartialEq for SessionError {
45    fn eq(&self, other: &Self) -> bool {
46        match (self, other) {
47            (
48                SessionError::GraphMismatch {
49                    expected: e1,
50                    actual: a1,
51                },
52                SessionError::GraphMismatch {
53                    expected: e2,
54                    actual: a2,
55                },
56            ) => e1 == e2 && a1 == a2,
57        }
58    }
59}
60
61// ─── SessionCheckpoint ─────────────────────────────────────────
62
63/// 会话检查点 — 完整恢复快照。
64///
65/// 包含:
66/// - 状态投影(P0-1: `S::Checkpoint`)
67/// - FrameStack(执行位置历史)
68/// - graph_hash(P0-2: canonical hash)
69///
70/// 可序列化 — 用于持久化到文件/数据库。
71///
72/// # 与 Checkpoint 的区别
73///
74/// - `Checkpoint<S>` — 单个 Graph 的检查点(current_node + state)
75/// - `SessionCheckpoint<S>` — 完整会话的检查点(state + frames + graph_hash)
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SessionCheckpoint<S: WorkflowState = State>
78where
79    S::Checkpoint: Debug,
80{
81    /// 物化状态快照(P0-1: 使用 Checkpoint 关联类型)
82    pub state: S::Checkpoint,
83    /// 执行帧栈(完整执行位置历史)
84    pub frames: FrameStack<S>,
85    /// 图结构指纹(P0-2: canonical hash)
86    pub graph_hash: u64,
87}
88
89// ─── ExecutionSession ──────────────────────────────────────────
90
91/// 执行会话 — 持有 State 所有权 + FrameStack + Graph 引用。
92///
93/// # 职责
94///
95/// - 持有 State 所有权(Engine 只是借用)
96/// - 管理 FrameStack(Subgraph 执行时 push/pop)
97/// - 创建和恢复 SessionCheckpoint
98///
99/// # 设计原则
100///
101/// Graph 是 Immutable 的,多个 Session 共享同一个 Graph 实例。
102/// Session 不拥有 Graph,只持有 `Arc<Graph>` 引用。
103///
104/// ```text
105/// Runtime
106/// └── Arc<Graph>
107///
108/// Session1 ──┐
109/// Session2 ──┼── Arc<Graph>
110/// Session3 ──┘
111/// ```
112pub struct ExecutionSession<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>
113where
114    S::Checkpoint: Debug,
115{
116    /// 运行时状态(拥有所有权)
117    state: S,
118    /// 执行帧栈(Subgraph 执行时 push/pop)
119    frame_stack: FrameStack<S>,
120    /// 图结构(共享引用)
121    graph: Arc<Graph<S, M>>,
122}
123
124impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
125where
126    S::Checkpoint: Debug,
127{
128    /// 创建新的执行会话。
129    pub fn new(state: S, graph: Arc<Graph<S, M>>) -> Self {
130        Self {
131            state,
132            frame_stack: FrameStack::new(),
133            graph,
134        }
135    }
136
137    /// 从 Checkpoint 恢复。
138    ///
139    /// # P0-1: 使用 restore() 恢复 State
140    ///
141    /// `S::restore(checkpoint.state)` 从 checkpoint snapshot 恢复完整 Runtime State。
142    ///
143    /// # Graph 参数
144    ///
145    /// 调用方负责提供 `Arc<Graph>`(从 Runtime 获取),
146    /// Session 不负责存储或查找 Graph。
147    ///
148    /// # 错误
149    ///
150    /// 如果 `checkpoint.graph_hash` 与 `graph.canonical_hash()` 不匹配,
151    /// 返回 `SessionError::GraphMismatch`,拒绝恢复。
152    pub fn restore(
153        checkpoint: SessionCheckpoint<S>,
154        graph: Arc<Graph<S, M>>,
155    ) -> Result<Self, SessionError> {
156        // P0-2: 校验 graph_hash — 不匹配则拒绝恢复
157        if checkpoint.graph_hash != graph.canonical_hash() {
158            return Err(SessionError::GraphMismatch {
159                expected: checkpoint.graph_hash,
160                actual: graph.canonical_hash(),
161            });
162        }
163
164        let state = S::restore(checkpoint.state);
165        Ok(Self {
166            state,
167            frame_stack: checkpoint.frames,
168            graph,
169        })
170    }
171
172    /// 创建 checkpoint — 保存当前执行位置 + 状态投影。
173    ///
174    /// # P0-1: 使用 snapshot() 进行投影
175    ///
176    /// `state.snapshot()` 返回 `S::Checkpoint`,只序列化必要字段。
177    ///
178    /// # P0-2: 使用 canonical_hash
179    ///
180    /// `graph.canonical_hash()` 从 DSL 层计算,不依赖 HashMap 顺序。
181    pub fn checkpoint(&self) -> SessionCheckpoint<S> {
182        SessionCheckpoint {
183            state: self.state.snapshot(),
184            frames: self.frame_stack.clone(),
185            graph_hash: self.graph.canonical_hash(),
186        }
187    }
188
189    /// 获取状态引用。
190    pub fn state(&self) -> &S {
191        &self.state
192    }
193
194    /// 获取状态可变引用。
195    pub fn state_mut(&mut self) -> &mut S {
196        &mut self.state
197    }
198
199    /// 获取帧栈引用。
200    pub fn frame_stack(&self) -> &FrameStack<S> {
201        &self.frame_stack
202    }
203
204    /// 获取帧栈可变引用(用于 Subgraph 执行时 push/pop)。
205    pub fn frame_stack_mut(&mut self) -> &mut FrameStack<S> {
206        &mut self.frame_stack
207    }
208
209    /// 获取图引用。
210    pub fn graph(&self) -> &Graph<S, M> {
211        &self.graph
212    }
213
214    /// 获取图的 Arc 引用(用于共享)。
215    pub fn graph_arc(&self) -> Arc<Graph<S, M>> {
216        self.graph.clone()
217    }
218
219    /// 消费会话,返回最终状态。
220    pub fn into_state(self) -> S {
221        self.state
222    }
223
224    /// 消费会话,返回 (状态, 帧栈)。
225    pub fn into_parts(self) -> (S, FrameStack<S>) {
226        (self.state, self.frame_stack)
227    }
228}
229
230impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
231where
232    S::Checkpoint: Debug,
233{
234    /// 使用指定的 Engine 执行。
235    ///
236    /// Session 不知道 Stream,Engine 才知道 Stream。
237    /// 职责分离:Session 负责 state + frame_stack,Engine 负责执行 + stream。
238    ///
239    /// # 示例
240    ///
241    /// ```ignore
242    /// let mut engine = ExecutionEngine::new(
243    ///     &mut session.state,
244    ///     Some(stream),  // Stream 由调用者提供
245    ///     cancel,
246    /// );
247    /// session.run_with(&mut engine).await?;
248    /// ```
249    pub async fn run_with(
250        &mut self,
251        engine: &mut crate::ExecutionEngine<'_, S>,
252    ) -> Result<(), crate::GraphError> {
253        self.graph.run_inline(engine, usize::MAX).await
254    }
255}
256
257// ─── Default for ExecutionSession ──────────────────────────────
258
259impl<S: WorkflowState, M: MergeStrategy<S>> Default for ExecutionSession<S, M>
260where
261    S: Default,
262    S::Checkpoint: Debug,
263{
264    fn default() -> Self {
265        // 注意:Default 需要一个 Graph,这里用空图占位
266        // 实际使用时应该用 new(state, graph)
267        panic!("ExecutionSession::default() not supported — use new(state, graph)")
268    }
269}
270
271// ─── Tests ─────────────────────────────────────────────────────
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::state::StateExt;
277    use crate::{GraphBuilder, NodeKind, TaskNode};
278
279    #[test]
280    fn test_session_checkpoint_roundtrip() {
281        // 创建一个简单的 Graph
282        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
283        builder.start("a");
284        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
285        builder.end("a");
286        let graph = Arc::new(builder.build().unwrap());
287
288        // 创建 Session
289        let state = State::new();
290        let mut session = ExecutionSession::new(state, graph.clone());
291
292        // 添加一些数据到 state
293        session
294            .state_mut()
295            .insert("key".to_string(), serde_json::json!("value"));
296
297        // 创建 checkpoint
298        let checkpoint = session.checkpoint();
299
300        // 验证 checkpoint 包含状态
301        assert!(checkpoint.state.contains("key"));
302
303        // 从 checkpoint 恢复
304        let restored_session =
305            ExecutionSession::restore(checkpoint, graph).expect("restore should succeed");
306
307        // 验证恢复后的状态
308        assert!(restored_session.state().contains("key"));
309    }
310
311    #[test]
312    fn test_session_restore_graph_mismatch() {
313        // 验证 graph_hash 不匹配时返回错误
314        let mut builder1 = GraphBuilder::<State, StateMerge>::new("test1");
315        builder1.start("a");
316        builder1.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
317        builder1.end("a");
318        builder1.canonical_hash(0x1111); // 设置不同的 hash
319        let graph1 = Arc::new(builder1.build().unwrap());
320
321        let mut builder2 = GraphBuilder::<State, StateMerge>::new("test2");
322        builder2.start("b");
323        builder2.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
324        builder2.end("b");
325        builder2.canonical_hash(0x2222); // 设置不同的 hash
326        let graph2 = Arc::new(builder2.build().unwrap());
327
328        // 用 graph1 创建 checkpoint
329        let session = ExecutionSession::new(State::new(), graph1);
330        let checkpoint = session.checkpoint();
331
332        // 用 graph2 恢复 — 应该失败
333        let result = ExecutionSession::restore(checkpoint, graph2);
334        assert!(result.is_err());
335        // 验证错误信息包含 "graph hash mismatch"
336        match result {
337            Err(e) => assert!(format!("{}", e).contains("graph hash mismatch")),
338            Ok(_) => panic!("expected error"),
339        }
340    }
341
342    #[test]
343    fn test_session_into_parts() {
344        // 创建一个简单的 Graph
345        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
346        builder.start("a");
347        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
348        builder.end("a");
349        let graph = Arc::new(builder.build().unwrap());
350
351        // 创建 Session
352        let state = State::new();
353        let session = ExecutionSession::new(state, graph);
354
355        // 消费 session,获取 state 和 frame_stack
356        let (_state, frame_stack) = session.into_parts();
357
358        // 验证 frame_stack 为空
359        assert!(frame_stack.is_empty());
360    }
361
362    #[test]
363    fn test_session_graph_sharing() {
364        // 验证多个 Session 共享同一个 Graph
365        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
366        builder.start("a");
367        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
368        builder.end("a");
369        let graph = Arc::new(builder.build().unwrap());
370
371        let session1 = ExecutionSession::new(State::new(), graph.clone());
372        let session2 = ExecutionSession::new(State::new(), graph.clone());
373
374        // 验证 Arc 强引用计数
375        assert_eq!(Arc::strong_count(&graph), 3); // original + session1 + session2
376    }
377}