1use std::fmt::Debug;
25use std::sync::Arc;
26
27use serde::{Deserialize, Serialize};
28
29use crate::checkpoint::FrameStack;
30use crate::graph::Graph;
31use crate::state::{State, StateMerge};
32use crate::workflow_state::{MergeStrategy, WorkflowState};
33
34#[derive(Debug, thiserror::Error)]
38pub enum SessionError {
39 #[error("graph hash mismatch: expected {expected:#018x}, got {actual:#018x}")]
41 GraphMismatch { expected: u64, actual: u64 },
42}
43
44impl PartialEq for SessionError {
45 fn eq(&self, other: &Self) -> bool {
46 match (self, other) {
47 (
48 SessionError::GraphMismatch {
49 expected: e1,
50 actual: a1,
51 },
52 SessionError::GraphMismatch {
53 expected: e2,
54 actual: a2,
55 },
56 ) => e1 == e2 && a1 == a2,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SessionCheckpoint<S: WorkflowState = State>
78where
79 S::Checkpoint: Debug,
80{
81 pub state: S::Checkpoint,
83 pub frames: FrameStack<S>,
85 pub graph_hash: u64,
87}
88
89pub struct ExecutionSession<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>
113where
114 S::Checkpoint: Debug,
115{
116 state: S,
118 frame_stack: FrameStack<S>,
120 graph: Arc<Graph<S, M>>,
122}
123
124impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
125where
126 S::Checkpoint: Debug,
127{
128 pub fn new(state: S, graph: Arc<Graph<S, M>>) -> Self {
130 Self {
131 state,
132 frame_stack: FrameStack::new(),
133 graph,
134 }
135 }
136
137 pub fn restore(
153 checkpoint: SessionCheckpoint<S>,
154 graph: Arc<Graph<S, M>>,
155 ) -> Result<Self, SessionError> {
156 if checkpoint.graph_hash != graph.canonical_hash() {
158 return Err(SessionError::GraphMismatch {
159 expected: checkpoint.graph_hash,
160 actual: graph.canonical_hash(),
161 });
162 }
163
164 let state = S::restore(checkpoint.state);
165 Ok(Self {
166 state,
167 frame_stack: checkpoint.frames,
168 graph,
169 })
170 }
171
172 pub fn checkpoint(&self) -> SessionCheckpoint<S> {
182 SessionCheckpoint {
183 state: self.state.snapshot(),
184 frames: self.frame_stack.clone(),
185 graph_hash: self.graph.canonical_hash(),
186 }
187 }
188
189 pub fn state(&self) -> &S {
191 &self.state
192 }
193
194 pub fn state_mut(&mut self) -> &mut S {
196 &mut self.state
197 }
198
199 pub fn frame_stack(&self) -> &FrameStack<S> {
201 &self.frame_stack
202 }
203
204 pub fn frame_stack_mut(&mut self) -> &mut FrameStack<S> {
206 &mut self.frame_stack
207 }
208
209 pub fn graph(&self) -> &Graph<S, M> {
211 &self.graph
212 }
213
214 pub fn graph_arc(&self) -> Arc<Graph<S, M>> {
216 self.graph.clone()
217 }
218
219 pub fn into_state(self) -> S {
221 self.state
222 }
223
224 pub fn into_parts(self) -> (S, FrameStack<S>) {
226 (self.state, self.frame_stack)
227 }
228}
229
230impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
231where
232 S::Checkpoint: Debug,
233{
234 pub async fn run_with(
250 &mut self,
251 engine: &mut crate::ExecutionEngine<'_, S>,
252 ) -> Result<(), crate::GraphError> {
253 self.graph.run_inline(engine, usize::MAX).await
254 }
255}
256
257impl<S: WorkflowState, M: MergeStrategy<S>> Default for ExecutionSession<S, M>
260where
261 S: Default,
262 S::Checkpoint: Debug,
263{
264 fn default() -> Self {
265 panic!("ExecutionSession::default() not supported — use new(state, graph)")
268 }
269}
270
271#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::state::StateExt;
277 use crate::{GraphBuilder, NodeKind, TaskNode};
278
279 #[test]
280 fn test_session_checkpoint_roundtrip() {
281 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
283 builder.start("a");
284 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
285 builder.end("a");
286 let graph = Arc::new(builder.build().unwrap());
287
288 let state = State::new();
290 let mut session = ExecutionSession::new(state, graph.clone());
291
292 session
294 .state_mut()
295 .insert("key".to_string(), serde_json::json!("value"));
296
297 let checkpoint = session.checkpoint();
299
300 assert!(checkpoint.state.contains("key"));
302
303 let restored_session =
305 ExecutionSession::restore(checkpoint, graph).expect("restore should succeed");
306
307 assert!(restored_session.state().contains("key"));
309 }
310
311 #[test]
312 fn test_session_restore_graph_mismatch() {
313 let mut builder1 = GraphBuilder::<State, StateMerge>::new("test1");
315 builder1.start("a");
316 builder1.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
317 builder1.end("a");
318 builder1.canonical_hash(0x1111); let graph1 = Arc::new(builder1.build().unwrap());
320
321 let mut builder2 = GraphBuilder::<State, StateMerge>::new("test2");
322 builder2.start("b");
323 builder2.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
324 builder2.end("b");
325 builder2.canonical_hash(0x2222); let graph2 = Arc::new(builder2.build().unwrap());
327
328 let session = ExecutionSession::new(State::new(), graph1);
330 let checkpoint = session.checkpoint();
331
332 let result = ExecutionSession::restore(checkpoint, graph2);
334 assert!(result.is_err());
335 match result {
337 Err(e) => assert!(format!("{}", e).contains("graph hash mismatch")),
338 Ok(_) => panic!("expected error"),
339 }
340 }
341
342 #[test]
343 fn test_session_into_parts() {
344 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
346 builder.start("a");
347 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
348 builder.end("a");
349 let graph = Arc::new(builder.build().unwrap());
350
351 let state = State::new();
353 let session = ExecutionSession::new(state, graph);
354
355 let (_state, frame_stack) = session.into_parts();
357
358 assert!(frame_stack.is_empty());
360 }
361
362 #[test]
363 fn test_session_graph_sharing() {
364 let mut builder = GraphBuilder::<State, StateMerge>::new("test");
366 builder.start("a");
367 builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
368 builder.end("a");
369 let graph = Arc::new(builder.build().unwrap());
370
371 let session1 = ExecutionSession::new(State::new(), graph.clone());
372 let session2 = ExecutionSession::new(State::new(), graph.clone());
373
374 assert_eq!(Arc::strong_count(&graph), 3); }
377}