1use std::fmt::Debug;
25use std::sync::Arc;
26
27use serde::{Deserialize, Serialize};
28
29use crate::checkpoint::{CheckpointSink, Frame, FrameInfo, FrameStack};
30use crate::graph::Graph;
31use crate::state::workflow_state::{MergeStrategy, WorkflowState};
32use crate::state::{State, StateMerge};
33
34#[derive(Debug, thiserror::Error)]
38pub enum SessionError {
39 #[error("graph hash mismatch: expected {expected:#018x}, got {actual:#018x}")]
41 GraphMismatch { expected: u64, actual: u64 },
42}
43
44impl PartialEq for SessionError {
45 fn eq(&self, other: &Self) -> bool {
46 match (self, other) {
47 (
48 SessionError::GraphMismatch {
49 expected: e1,
50 actual: a1,
51 },
52 SessionError::GraphMismatch {
53 expected: e2,
54 actual: a2,
55 },
56 ) => e1 == e2 && a1 == a2,
57 }
58 }
59}
60
61pub struct SessionCheckpointSink<'a, S: WorkflowState = State>
73where
74 S::Checkpoint: Debug,
75{
76 frame_stack: &'a mut FrameStack<S>,
77 graph_name: String,
78}
79
80impl<'a, S: WorkflowState> SessionCheckpointSink<'a, S>
81where
82 S::Checkpoint: Debug,
83{
84 pub fn new(frame_stack: &'a mut FrameStack<S>, graph_name: impl Into<String>) -> Self {
86 Self {
87 frame_stack,
88 graph_name: graph_name.into(),
89 }
90 }
91}
92
93impl<S: WorkflowState> CheckpointSink<S> for SessionCheckpointSink<'_, S>
94where
95 S::Checkpoint: Debug + Sync,
96{
97 fn on_checkpoint(&mut self, state: &S, frame: &FrameInfo) {
98 self.frame_stack.push(Frame {
99 graph_id: self.graph_name.clone(),
100 node_id: frame.node_id.clone(),
101 state: state.snapshot(),
102 cursor: frame.step,
103 });
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct SessionCheckpoint<S: WorkflowState = State>
124where
125 S::Checkpoint: Debug,
126{
127 pub state: S::Checkpoint,
129 pub frames: FrameStack<S>,
131 pub graph_hash: u64,
133}
134
135pub struct ExecutionSession<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>
159where
160 S::Checkpoint: Debug,
161{
162 state: S,
164 frame_stack: FrameStack<S>,
166 graph: Arc<Graph<S, M>>,
168}
169
170impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
171where
172 S::Checkpoint: Debug,
173{
174 pub fn new(state: S, graph: Arc<Graph<S, M>>) -> Self {
176 Self {
177 state,
178 frame_stack: FrameStack::new(),
179 graph,
180 }
181 }
182
183 pub fn restore(
199 checkpoint: SessionCheckpoint<S>,
200 graph: Arc<Graph<S, M>>,
201 ) -> Result<Self, SessionError> {
202 if checkpoint.graph_hash != graph.canonical_hash() {
204 return Err(SessionError::GraphMismatch {
205 expected: checkpoint.graph_hash,
206 actual: graph.canonical_hash(),
207 });
208 }
209
210 let state = S::restore(checkpoint.state);
211 Ok(Self {
212 state,
213 frame_stack: checkpoint.frames,
214 graph,
215 })
216 }
217
218 pub fn checkpoint(&self) -> SessionCheckpoint<S> {
228 SessionCheckpoint {
229 state: self.state.snapshot(),
230 frames: self.frame_stack.clone(),
231 graph_hash: self.graph.canonical_hash(),
232 }
233 }
234
235 pub fn state(&self) -> &S {
237 &self.state
238 }
239
240 pub fn state_mut(&mut self) -> &mut S {
242 &mut self.state
243 }
244
245 pub fn frame_stack(&self) -> &FrameStack<S> {
247 &self.frame_stack
248 }
249
250 pub fn frame_stack_mut(&mut self) -> &mut FrameStack<S> {
252 &mut self.frame_stack
253 }
254
255 pub fn graph(&self) -> &Graph<S, M> {
257 &self.graph
258 }
259
260 pub fn graph_arc(&self) -> Arc<Graph<S, M>> {
262 self.graph.clone()
263 }
264
265 pub fn into_state(self) -> S {
267 self.state
268 }
269
270 pub fn into_parts(self) -> (S, FrameStack<S>) {
272 (self.state, self.frame_stack)
273 }
274}
275
276impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
277where
278 S::Checkpoint: Debug,
279{
280 pub async fn run_with(
304 &mut self,
305 engine: &mut crate::ExecutionEngine<'_, S>,
306 ) -> Result<(), crate::GraphError> {
307 let mut cb = crate::graph::NoopStepCallback;
308 self.graph.run_inline(engine, usize::MAX, &mut cb).await
309 }
310}
311
312impl<S: WorkflowState, M: MergeStrategy<S>> Default for ExecutionSession<S, M>
315where
316 S: Default,
317 S::Checkpoint: Debug,
318{
319 fn default() -> Self {
320 panic!("ExecutionSession::default() not supported — use new(state, graph)")
323 }
324}
325
326#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::state::StateExt;
332 use crate::{GraphBuilder, NodeKind, TaskNode};
333
334 #[test]
335 fn test_session_checkpoint_roundtrip() {
336 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
338 builder.start("a");
339 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
340 builder.end("a");
341 let graph = Arc::new(builder.build().unwrap());
342
343 let state = State::new();
345 let mut session = ExecutionSession::new(state, graph.clone());
346
347 session
349 .state_mut()
350 .insert("key".to_string(), serde_json::json!("value"));
351
352 let checkpoint = session.checkpoint();
354
355 assert!(checkpoint.state.contains("key"));
357
358 let restored_session =
360 ExecutionSession::restore(checkpoint, graph).expect("restore should succeed");
361
362 assert!(restored_session.state().contains("key"));
364 }
365
366 #[test]
367 fn test_session_restore_graph_mismatch() {
368 let mut builder1 = GraphBuilder::<State, StateMerge>::new("test1");
370 builder1.start("a");
371 builder1.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
372 builder1.end("a");
373 builder1.canonical_hash(0x1111); let graph1 = Arc::new(builder1.build().unwrap());
375
376 let mut builder2 = GraphBuilder::<State, StateMerge>::new("test2");
377 builder2.start("b");
378 builder2.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
379 builder2.end("b");
380 builder2.canonical_hash(0x2222); let graph2 = Arc::new(builder2.build().unwrap());
382
383 let session = ExecutionSession::new(State::new(), graph1);
385 let checkpoint = session.checkpoint();
386
387 let result = ExecutionSession::restore(checkpoint, graph2);
389 assert!(result.is_err());
390 match result {
392 Err(e) => assert!(format!("{}", e).contains("graph hash mismatch")),
393 Ok(_) => panic!("expected error"),
394 }
395 }
396
397 #[test]
398 fn test_session_into_parts() {
399 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
401 builder.start("a");
402 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
403 builder.end("a");
404 let graph = Arc::new(builder.build().unwrap());
405
406 let state = State::new();
408 let session = ExecutionSession::new(state, graph);
409
410 let (_state, frame_stack) = session.into_parts();
412
413 assert!(frame_stack.is_empty());
415 }
416
417 #[test]
418 fn test_session_graph_sharing() {
419 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
421 builder.start("a");
422 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
423 builder.end("a");
424 let graph = Arc::new(builder.build().unwrap());
425
426 let session1 = ExecutionSession::new(State::new(), graph.clone());
427 let session2 = ExecutionSession::new(State::new(), graph.clone());
428
429 assert_eq!(Arc::strong_count(&graph), 3); }
432}