use crate::error::{GraphError, GraphResult};
use crate::graph::Graph;
use crate::persistence::StatePersistence;
use crate::state::{generate_run_id, GraphRunResult, GraphState};
use std::sync::Arc;
use tracing::{info, span, Level};
pub struct GraphExecutor<State, Deps, End, P = NoPersistence>
where
State: GraphState,
{
_persistence_type: std::marker::PhantomData<P>,
graph: Arc<Graph<State, Deps, End>>,
persistence: Option<Arc<P>>,
auto_save: bool,
max_steps: u32,
instrumentation: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct NoPersistence;
impl<State, Deps, End> GraphExecutor<State, Deps, End, NoPersistence>
where
State: GraphState,
Deps: Clone + Send + Sync + 'static,
End: Clone + Send + Sync + 'static,
{
pub fn new(graph: Graph<State, Deps, End>) -> Self {
Self {
_persistence_type: std::marker::PhantomData,
graph: Arc::new(graph),
persistence: None,
auto_save: false,
max_steps: 100,
instrumentation: true,
}
}
}
impl<State, Deps, End, P> GraphExecutor<State, Deps, End, P>
where
State: GraphState,
Deps: Clone + Send + Sync + 'static,
End: Clone + Send + Sync + 'static,
P: StatePersistence<State, End> + 'static,
{
pub fn with_persistence(graph: Graph<State, Deps, End>, persistence: P) -> Self {
Self {
_persistence_type: std::marker::PhantomData,
graph: Arc::new(graph),
persistence: Some(Arc::new(persistence)),
auto_save: true,
max_steps: 100,
instrumentation: true,
}
}
pub fn auto_save(mut self, enabled: bool) -> Self {
self.auto_save = enabled;
self
}
pub fn max_steps(mut self, max: u32) -> Self {
self.max_steps = max;
self
}
pub fn without_instrumentation(mut self) -> Self {
self.instrumentation = false;
self
}
pub fn graph(&self) -> &Graph<State, Deps, End> {
&self.graph
}
pub async fn run(&self, state: State, deps: Deps) -> GraphResult<GraphRunResult<State, End>> {
let options = ExecutionOptions::new()
.max_steps(self.max_steps)
.tracing(self.instrumentation);
self.run_with_options(state, deps, options).await
}
pub async fn run_with_options(
&self,
state: State,
deps: Deps,
mut options: ExecutionOptions,
) -> GraphResult<GraphRunResult<State, End>> {
let run_id = options.run_id.clone().unwrap_or_else(generate_run_id);
options.run_id = Some(run_id.clone());
if options.tracing {
let _span = span!(Level::INFO, "graph_run", run_id = %run_id).entered();
info!("Starting graph execution");
}
self.graph.run_with_options(state, deps, options).await
}
pub async fn resume(
&self,
run_id: &str,
deps: Deps,
) -> GraphResult<Option<GraphRunResult<State, End>>> {
let Some(ref persistence) = self.persistence else {
return Err(GraphError::persistence("No persistence configured"));
};
let Some((state, _step)) = persistence.load_state(run_id).await? else {
return Ok(None);
};
let options = ExecutionOptions::new()
.max_steps(self.max_steps)
.tracing(self.instrumentation)
.run_id(run_id.to_string());
let result = self.graph.run_with_options(state, deps, options).await?;
if self.auto_save {
persistence.save_result(run_id, &result.result).await?;
}
Ok(Some(result))
}
pub async fn get_result(&self, run_id: &str) -> GraphResult<Option<End>> {
let Some(ref persistence) = self.persistence else {
return Err(GraphError::persistence("No persistence configured"));
};
Ok(persistence.load_result(run_id).await?)
}
pub async fn list_runs(&self) -> GraphResult<Vec<String>> {
let Some(ref persistence) = self.persistence else {
return Err(GraphError::persistence("No persistence configured"));
};
Ok(persistence.list_runs().await?)
}
}
#[derive(Debug, Clone)]
pub struct ExecutionOptions {
pub max_steps: u32,
pub tracing: bool,
pub checkpoint_interval: Option<u32>,
pub run_id: Option<String>,
}
impl Default for ExecutionOptions {
fn default() -> Self {
Self {
max_steps: 100,
tracing: true,
checkpoint_interval: None,
run_id: None,
}
}
}
impl ExecutionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn max_steps(mut self, max: u32) -> Self {
self.max_steps = max;
self
}
pub fn tracing(mut self, enabled: bool) -> Self {
self.tracing = enabled;
self
}
pub fn checkpoint_every(mut self, steps: u32) -> Self {
self.checkpoint_interval = Some(steps);
self
}
pub fn run_id(mut self, id: impl Into<String>) -> Self {
self.run_id = Some(id.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_execution_options() {
let opts = ExecutionOptions::new()
.max_steps(50)
.tracing(false)
.checkpoint_every(10)
.run_id("custom-run");
assert_eq!(opts.max_steps, 50);
assert!(!opts.tracing);
assert_eq!(opts.checkpoint_interval, Some(10));
assert_eq!(opts.run_id, Some("custom-run".to_string()));
}
}