Skip to main content

lellm_graph/exec/
session.rs

1//! ExecutionSession — 执行会话,持有 FrameStack,管理恢复。
2//!
3//! # 设计理念
4//!
5//! ```text
6//! ExecutionEngine — 一次执行,借用 State(生命周期短)
7//! ExecutionSession — 持有 FrameStack,管理恢复(生命周期长)
8//!
9//! 职责分离:
10//! - Engine: 执行逻辑,借用 State
11//! - Session: 状态所有权 + FrameStack + Checkpoint 管理
12//! ```
13//!
14//! # P0-1: Checkpoint Projection
15//!
16//! SessionCheckpoint 使用 `S::Checkpoint`(关联类型),不是 `S`(Runtime State)。
17//! 这保证 Runtime State 可以包含不可序列化字段。
18//!
19//! # P0-2: Graph Hash
20//!
21//! SessionCheckpoint 使用 `canonical_hash`(从 DSL 层计算),
22//! 不依赖 compiled graph 的 HashMap 迭代顺序。
23
24use std::fmt::Debug;
25use std::sync::Arc;
26
27use serde::{Deserialize, Serialize};
28
29use crate::checkpoint::{CheckpointSink, Frame, FrameInfo, FrameStack};
30use crate::graph::Graph;
31use crate::state::workflow_state::{MergeStrategy, WorkflowState};
32use crate::state::{State, StateMerge};
33
34// ─── SessionError ──────────────────────────────────────────────
35
36/// Session 操作错误。
37#[derive(Debug, thiserror::Error)]
38pub enum SessionError {
39    /// Graph Hash 不匹配 — Checkpoint 与当前 Graph 不兼容
40    #[error("graph hash mismatch: expected {expected:#018x}, got {actual:#018x}")]
41    GraphMismatch { expected: u64, actual: u64 },
42}
43
44impl PartialEq for SessionError {
45    fn eq(&self, other: &Self) -> bool {
46        match (self, other) {
47            (
48                SessionError::GraphMismatch {
49                    expected: e1,
50                    actual: a1,
51                },
52                SessionError::GraphMismatch {
53                    expected: e2,
54                    actual: a2,
55                },
56            ) => e1 == e2 && a1 == a2,
57        }
58    }
59}
60
61// ─── SessionCheckpointSink ─────────────────────────────────────
62
63/// Session 的 Checkpoint Sink — 将 checkpoint 事件写入 FrameStack。
64///
65/// 这是 CheckpointSink SPI 的实现之一。Engine 通过借用 `&dyn CheckpointSink<S>`
66/// 通知到达恢复边界,SessionCheckpointSink 负责将 Frame 推入 FrameStack。
67///
68/// # 设计原则
69///
70/// Engine 不知道 FrameStack 的存在,只调用 `sink.on_checkpoint(&state, &frame_info)`。
71/// SessionCheckpointSink 是适配器,将通用的 checkpoint 事件转换为 FrameStack 操作。
72pub struct SessionCheckpointSink<'a, S: WorkflowState = State>
73where
74    S::Checkpoint: Debug,
75{
76    frame_stack: &'a mut FrameStack<S>,
77    graph_name: String,
78}
79
80impl<'a, S: WorkflowState> SessionCheckpointSink<'a, S>
81where
82    S::Checkpoint: Debug,
83{
84    /// 创建 SessionCheckpointSink,绑定到 FrameStack。
85    pub fn new(frame_stack: &'a mut FrameStack<S>, graph_name: impl Into<String>) -> Self {
86        Self {
87            frame_stack,
88            graph_name: graph_name.into(),
89        }
90    }
91}
92
93impl<S: WorkflowState> CheckpointSink<S> for SessionCheckpointSink<'_, S>
94where
95    S::Checkpoint: Debug + Sync,
96{
97    fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
98        self.frame_stack.push(Frame {
99            graph_id: self.graph_name.clone(),
100            node_id: frame.node_id.clone(),
101            state: state.snapshot(),
102            cursor: frame.step,
103        });
104    }
105}
106
107// ─── SessionCheckpoint ─────────────────────────────────────────
108
109/// 会话检查点 — 完整恢复快照。
110///
111/// 包含:
112/// - 状态投影(P0-1: `S::Checkpoint`)
113/// - FrameStack(执行位置历史)
114/// - graph_hash(P0-2: canonical hash)
115///
116/// 可序列化 — 用于持久化到文件/数据库。
117///
118/// # 与 Checkpoint 的区别
119///
120/// - `Checkpoint<S>` — 单个 Graph 的检查点(current_node + state)
121/// - `SessionCheckpoint<S>` — 完整会话的检查点(state + frames + graph_hash)
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct SessionCheckpoint<S: WorkflowState = State>
124where
125    S::Checkpoint: Debug,
126{
127    /// 物化状态快照(P0-1: 使用 Checkpoint 关联类型)
128    pub state: S::Checkpoint,
129    /// 执行帧栈(完整执行位置历史)
130    pub frames: FrameStack<S>,
131    /// 图结构指纹(P0-2: canonical hash)
132    pub graph_hash: u64,
133}
134
135// ─── ExecutionSession ──────────────────────────────────────────
136
137/// 执行会话 — 持有 State 所有权 + FrameStack + Graph 引用。
138///
139/// # 职责
140///
141/// - 持有 State 所有权(Engine 只是借用)
142/// - 管理 FrameStack(Subgraph 执行时 push/pop)
143/// - 创建和恢复 SessionCheckpoint
144///
145/// # 设计原则
146///
147/// Graph 是 Immutable 的,多个 Session 共享同一个 Graph 实例。
148/// Session 不拥有 Graph,只持有 `Arc<Graph>` 引用。
149///
150/// ```text
151/// Runtime
152/// └── Arc<Graph>
153///
154/// Session1 ──┐
155/// Session2 ──┼── Arc<Graph>
156/// Session3 ──┘
157/// ```
158pub struct ExecutionSession<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>
159where
160    S::Checkpoint: Debug,
161{
162    /// 运行时状态(拥有所有权)
163    state: S,
164    /// 执行帧栈(Subgraph 执行时 push/pop)
165    frame_stack: FrameStack<S>,
166    /// 图结构(共享引用)
167    graph: Arc<Graph<S, M>>,
168}
169
170impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
171where
172    S::Checkpoint: Debug,
173{
174    /// 创建新的执行会话。
175    pub fn new(state: S, graph: Arc<Graph<S, M>>) -> Self {
176        Self {
177            state,
178            frame_stack: FrameStack::new(),
179            graph,
180        }
181    }
182
183    /// 从 Checkpoint 恢复。
184    ///
185    /// # P0-1: 使用 restore() 恢复 State
186    ///
187    /// `S::restore(checkpoint.state)` 从 checkpoint snapshot 恢复完整 Runtime State。
188    ///
189    /// # Graph 参数
190    ///
191    /// 调用方负责提供 `Arc<Graph>`(从 Runtime 获取),
192    /// Session 不负责存储或查找 Graph。
193    ///
194    /// # 错误
195    ///
196    /// 如果 `checkpoint.graph_hash` 与 `graph.canonical_hash()` 不匹配,
197    /// 返回 `SessionError::GraphMismatch`,拒绝恢复。
198    pub fn restore(
199        checkpoint: SessionCheckpoint<S>,
200        graph: Arc<Graph<S, M>>,
201    ) -> Result<Self, SessionError> {
202        // P0-2: 校验 graph_hash — 不匹配则拒绝恢复
203        if checkpoint.graph_hash != graph.canonical_hash() {
204            return Err(SessionError::GraphMismatch {
205                expected: checkpoint.graph_hash,
206                actual: graph.canonical_hash(),
207            });
208        }
209
210        let state = S::restore(checkpoint.state);
211        Ok(Self {
212            state,
213            frame_stack: checkpoint.frames,
214            graph,
215        })
216    }
217
218    /// 创建 checkpoint — 保存当前执行位置 + 状态投影。
219    ///
220    /// # P0-1: 使用 snapshot() 进行投影
221    ///
222    /// `state.snapshot()` 返回 `S::Checkpoint`,只序列化必要字段。
223    ///
224    /// # P0-2: 使用 canonical_hash
225    ///
226    /// `graph.canonical_hash()` 从 DSL 层计算,不依赖 HashMap 顺序。
227    pub fn checkpoint(&self) -> SessionCheckpoint<S> {
228        SessionCheckpoint {
229            state: self.state.snapshot(),
230            frames: self.frame_stack.clone(),
231            graph_hash: self.graph.canonical_hash(),
232        }
233    }
234
235    /// 获取状态引用。
236    pub fn state(&self) -> &S {
237        &self.state
238    }
239
240    /// 获取状态可变引用。
241    pub fn state_mut(&mut self) -> &mut S {
242        &mut self.state
243    }
244
245    /// 获取帧栈引用。
246    pub fn frame_stack(&self) -> &FrameStack<S> {
247        &self.frame_stack
248    }
249
250    /// 获取帧栈可变引用(用于 Subgraph 执行时 push/pop)。
251    pub fn frame_stack_mut(&mut self) -> &mut FrameStack<S> {
252        &mut self.frame_stack
253    }
254
255    /// 获取图引用。
256    pub fn graph(&self) -> &Graph<S, M> {
257        &self.graph
258    }
259
260    /// 获取图的 Arc 引用(用于共享)。
261    pub fn graph_arc(&self) -> Arc<Graph<S, M>> {
262        self.graph.clone()
263    }
264
265    /// 消费会话,返回最终状态。
266    pub fn into_state(self) -> S {
267        self.state
268    }
269
270    /// 消费会话,返回 (状态, 帧栈)。
271    pub fn into_parts(self) -> (S, FrameStack<S>) {
272        (self.state, self.frame_stack)
273    }
274}
275
276impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
277where
278    S::Checkpoint: Debug,
279{
280    /// 使用指定的 Engine 执行。
281    ///
282    /// Session 不知道 Stream,Engine 才知道 Stream。
283    /// 职责分离:Session 负责 state + frame_stack,Engine 负责执行 + stream。
284    ///
285    /// # 示例
286    ///
287    /// ```ignore
288    /// // 创建 Checkpoint Sink(可选)
289    /// let mut sink = SessionCheckpointSink::new(
290    ///     session.frame_stack_mut(),
291    ///     session.graph().name(),
292    /// );
293    ///
294    /// let mut engine = ExecutionEngine::new(
295    ///     session.state_mut(),
296    ///     Some(stream),       // Stream 由调用者提供
297    ///     cancel,
298    ///     Some(&mut sink),    // 启用自动 checkpoint
299    ///     None,               // 不需要 Barrier Sink
300    /// );
301    /// session.run_with(&mut engine).await?;
302    /// ```
303    pub async fn run_with(
304        &mut self,
305        engine: &mut crate::ExecutionEngine<'_, S>,
306    ) -> Result<(), crate::GraphError> {
307        let mut cb = crate::graph::NoopStepCallback;
308        self.graph.run_inline(engine, usize::MAX, &mut cb).await
309    }
310}
311
312// ─── Default for ExecutionSession ──────────────────────────────
313
314impl<S: WorkflowState, M: MergeStrategy<S>> Default for ExecutionSession<S, M>
315where
316    S: Default,
317    S::Checkpoint: Debug,
318{
319    fn default() -> Self {
320        // 注意:Default 需要一个 Graph,这里用空图占位
321        // 实际使用时应该用 new(state, graph)
322        panic!("ExecutionSession::default() not supported — use new(state, graph)")
323    }
324}
325
326// ─── Tests ─────────────────────────────────────────────────────
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::state::StateExt;
332    use crate::{GraphBuilder, NodeKind, TaskNode};
333
334    #[test]
335    fn test_session_checkpoint_roundtrip() {
336        // 创建一个简单的 Graph
337        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
338        builder.start("a");
339        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
340        builder.end("a");
341        let graph = Arc::new(builder.build().unwrap());
342
343        // 创建 Session
344        let state = State::new();
345        let mut session = ExecutionSession::new(state, graph.clone());
346
347        // 添加一些数据到 state
348        session
349            .state_mut()
350            .insert("key".to_string(), serde_json::json!("value"));
351
352        // 创建 checkpoint
353        let checkpoint = session.checkpoint();
354
355        // 验证 checkpoint 包含状态
356        assert!(checkpoint.state.contains("key"));
357
358        // 从 checkpoint 恢复
359        let restored_session =
360            ExecutionSession::restore(checkpoint, graph).expect("restore should succeed");
361
362        // 验证恢复后的状态
363        assert!(restored_session.state().contains("key"));
364    }
365
366    #[test]
367    fn test_session_restore_graph_mismatch() {
368        // 验证 graph_hash 不匹配时返回错误
369        let mut builder1 = GraphBuilder::<State, StateMerge>::new("test1");
370        builder1.start("a");
371        builder1.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
372        builder1.end("a");
373        builder1.canonical_hash(0x1111); // 设置不同的 hash
374        let graph1 = Arc::new(builder1.build().unwrap());
375
376        let mut builder2 = GraphBuilder::<State, StateMerge>::new("test2");
377        builder2.start("b");
378        builder2.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
379        builder2.end("b");
380        builder2.canonical_hash(0x2222); // 设置不同的 hash
381        let graph2 = Arc::new(builder2.build().unwrap());
382
383        // 用 graph1 创建 checkpoint
384        let session = ExecutionSession::new(State::new(), graph1);
385        let checkpoint = session.checkpoint();
386
387        // 用 graph2 恢复 — 应该失败
388        let result = ExecutionSession::restore(checkpoint, graph2);
389        assert!(result.is_err());
390        // 验证错误信息包含 "graph hash mismatch"
391        match result {
392            Err(e) => assert!(format!("{}", e).contains("graph hash mismatch")),
393            Ok(_) => panic!("expected error"),
394        }
395    }
396
397    #[test]
398    fn test_session_into_parts() {
399        // 创建一个简单的 Graph
400        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
401        builder.start("a");
402        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
403        builder.end("a");
404        let graph = Arc::new(builder.build().unwrap());
405
406        // 创建 Session
407        let state = State::new();
408        let session = ExecutionSession::new(state, graph);
409
410        // 消费 session,获取 state 和 frame_stack
411        let (_state, frame_stack) = session.into_parts();
412
413        // 验证 frame_stack 为空
414        assert!(frame_stack.is_empty());
415    }
416
417    #[test]
418    fn test_session_graph_sharing() {
419        // 验证多个 Session 共享同一个 Graph
420        let mut builder = GraphBuilder::<State, StateMerge>::new("test");
421        builder.start("a");
422        builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
423        builder.end("a");
424        let graph = Arc::new(builder.build().unwrap());
425
426        let session1 = ExecutionSession::new(State::new(), graph.clone());
427        let session2 = ExecutionSession::new(State::new(), graph.clone());
428
429        // 验证 Arc 强引用计数
430        assert_eq!(Arc::strong_count(&graph), 3); // original + session1 + session2
431    }
432}