use std::fmt::Debug;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::checkpoint::FrameStack;
use crate::graph::Graph;
use crate::state::{State, StateMerge};
use crate::workflow_state::{MergeStrategy, WorkflowState};
#[derive(Debug, thiserror::Error)]
pub enum SessionError {
#[error("graph hash mismatch: expected {expected:#018x}, got {actual:#018x}")]
GraphMismatch { expected: u64, actual: u64 },
}
impl PartialEq for SessionError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
SessionError::GraphMismatch {
expected: e1,
actual: a1,
},
SessionError::GraphMismatch {
expected: e2,
actual: a2,
},
) => e1 == e2 && a1 == a2,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionCheckpoint<S: WorkflowState = State>
where
S::Checkpoint: Debug,
{
pub state: S::Checkpoint,
pub frames: FrameStack<S>,
pub graph_hash: u64,
}
pub struct ExecutionSession<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge>
where
S::Checkpoint: Debug,
{
state: S,
frame_stack: FrameStack<S>,
graph: Arc<Graph<S, M>>,
}
impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
where
S::Checkpoint: Debug,
{
pub fn new(state: S, graph: Arc<Graph<S, M>>) -> Self {
Self {
state,
frame_stack: FrameStack::new(),
graph,
}
}
pub fn restore(
checkpoint: SessionCheckpoint<S>,
graph: Arc<Graph<S, M>>,
) -> Result<Self, SessionError> {
if checkpoint.graph_hash != graph.canonical_hash() {
return Err(SessionError::GraphMismatch {
expected: checkpoint.graph_hash,
actual: graph.canonical_hash(),
});
}
let state = S::restore(checkpoint.state);
Ok(Self {
state,
frame_stack: checkpoint.frames,
graph,
})
}
pub fn checkpoint(&self) -> SessionCheckpoint<S> {
SessionCheckpoint {
state: self.state.snapshot(),
frames: self.frame_stack.clone(),
graph_hash: self.graph.canonical_hash(),
}
}
pub fn state(&self) -> &S {
&self.state
}
pub fn state_mut(&mut self) -> &mut S {
&mut self.state
}
pub fn frame_stack(&self) -> &FrameStack<S> {
&self.frame_stack
}
pub fn frame_stack_mut(&mut self) -> &mut FrameStack<S> {
&mut self.frame_stack
}
pub fn graph(&self) -> &Graph<S, M> {
&self.graph
}
pub fn graph_arc(&self) -> Arc<Graph<S, M>> {
self.graph.clone()
}
pub fn into_state(self) -> S {
self.state
}
pub fn into_parts(self) -> (S, FrameStack<S>) {
(self.state, self.frame_stack)
}
}
impl<S: WorkflowState, M: MergeStrategy<S>> ExecutionSession<S, M>
where
S::Checkpoint: Debug,
{
pub async fn run_with(
&mut self,
engine: &mut crate::ExecutionEngine<'_, S>,
) -> Result<(), crate::GraphError> {
self.graph.run_inline(engine, usize::MAX).await
}
}
impl<S: WorkflowState, M: MergeStrategy<S>> Default for ExecutionSession<S, M>
where
S: Default,
S::Checkpoint: Debug,
{
fn default() -> Self {
panic!("ExecutionSession::default() not supported — use new(state, graph)")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::StateExt;
use crate::{GraphBuilder, NodeKind, TaskNode};
#[test]
fn test_session_checkpoint_roundtrip() {
let mut builder = GraphBuilder::<State, StateMerge>::new("test");
builder.start("a");
builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
builder.end("a");
let graph = Arc::new(builder.build().unwrap());
let state = State::new();
let mut session = ExecutionSession::new(state, graph.clone());
session
.state_mut()
.insert("key".to_string(), serde_json::json!("value"));
let checkpoint = session.checkpoint();
assert!(checkpoint.state.contains("key"));
let restored_session =
ExecutionSession::restore(checkpoint, graph).expect("restore should succeed");
assert!(restored_session.state().contains("key"));
}
#[test]
fn test_session_restore_graph_mismatch() {
let mut builder1 = GraphBuilder::<State, StateMerge>::new("test1");
builder1.start("a");
builder1.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
builder1.end("a");
builder1.canonical_hash(0x1111); let graph1 = Arc::new(builder1.build().unwrap());
let mut builder2 = GraphBuilder::<State, StateMerge>::new("test2");
builder2.start("b");
builder2.node("b", NodeKind::Task(TaskNode::new("b", |_| Ok(()))));
builder2.end("b");
builder2.canonical_hash(0x2222); let graph2 = Arc::new(builder2.build().unwrap());
let session = ExecutionSession::new(State::new(), graph1);
let checkpoint = session.checkpoint();
let result = ExecutionSession::restore(checkpoint, graph2);
assert!(result.is_err());
match result {
Err(e) => assert!(format!("{}", e).contains("graph hash mismatch")),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn test_session_into_parts() {
let mut builder = GraphBuilder::<State, StateMerge>::new("test");
builder.start("a");
builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
builder.end("a");
let graph = Arc::new(builder.build().unwrap());
let state = State::new();
let session = ExecutionSession::new(state, graph);
let (_state, frame_stack) = session.into_parts();
assert!(frame_stack.is_empty());
}
#[test]
fn test_session_graph_sharing() {
let mut builder = GraphBuilder::<State, StateMerge>::new("test");
builder.start("a");
builder.node("a", NodeKind::Task(TaskNode::new("a", |_| Ok(()))));
builder.end("a");
let graph = Arc::new(builder.build().unwrap());
let session1 = ExecutionSession::new(State::new(), graph.clone());
let session2 = ExecutionSession::new(State::new(), graph.clone());
assert_eq!(Arc::strong_count(&graph), 3); }
}