use crate::app::{App, BarrierOutcome};
use crate::channels::Channel;
use crate::channels::errors::{ErrorEvent, ErrorScope, WeaveError};
use crate::control::{FrontierCommand, NodeRoute};
use crate::event_bus::{EventBus, EventStream};
use crate::node::NodePartial;
use crate::runtimes::CheckpointerType;
use crate::runtimes::execution::{
PausedReason, PausedReport, SchedulerOutcome, StepOptions, StepReport, StepResult,
};
use crate::runtimes::session::{SessionInit, SessionState, StateVersions};
use crate::runtimes::streaming::{StreamEndReason, finalize_event_stream};
use crate::runtimes::{
Checkpoint, Checkpointer, CheckpointerError, InMemoryCheckpointer, restore_session_state,
};
use crate::schedulers::{Scheduler, SchedulerError, SchedulerState};
use crate::state::VersionedState;
use crate::types::NodeKind;
use rustc_hash::FxHashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::task::JoinError;
use tracing::instrument;
pub struct AppRunner {
app: Arc<App>,
sessions: FxHashMap<String, SessionState>,
checkpointer: Option<Arc<dyn Checkpointer>>, autosave: bool,
event_bus: EventBus,
event_stream_taken: bool,
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum RunnerError {
#[error("session not found: {session_id}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::runner::session_not_found))
)]
SessionNotFound { session_id: String },
#[error("no nodes to run from START (empty frontier)")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::runner::no_start_nodes),
help("Add edges from Start or set the entry node correctly.")
)
)]
NoStartNodes,
#[error("unexpected pause during run_until_complete")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::runner::unexpected_pause))
)]
UnexpectedPause,
#[error("join handle already consumed")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::runner::join_handle_consumed),
help("InvocationHandle::join() can only be called once.")
)
)]
JoinHandleConsumed,
#[error("workflow task join error: {0}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::runner::join)))]
Join(#[from] JoinError),
#[error(transparent)]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::runner::checkpointer))
)]
Checkpointer(#[from] CheckpointerError),
#[error("app barrier error: {0}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::runner::barrier)))]
AppBarrier(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::runner::scheduler))
)]
Scheduler(#[from] SchedulerError),
}
pub struct AppRunnerBuilder {
app: Option<Arc<App>>,
checkpointer_type: CheckpointerType,
checkpointer_custom: Option<Arc<dyn Checkpointer>>,
autosave: bool,
event_bus: Option<EventBus>,
start_listener: bool,
}
impl Default for AppRunnerBuilder {
fn default() -> Self {
Self::new()
}
}
impl AppRunnerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
app: None,
checkpointer_type: CheckpointerType::InMemory,
checkpointer_custom: None,
autosave: true,
event_bus: None,
start_listener: true,
}
}
#[must_use]
pub fn app(mut self, app: App) -> Self {
self.app = Some(Arc::new(app));
self
}
#[must_use]
pub fn app_arc(mut self, app: Arc<App>) -> Self {
self.app = Some(app);
self
}
#[must_use]
pub fn checkpointer(mut self, checkpointer_type: CheckpointerType) -> Self {
self.checkpointer_type = checkpointer_type;
self
}
#[must_use]
pub fn checkpointer_custom(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
self.checkpointer_custom = Some(checkpointer);
self
}
#[must_use]
pub fn autosave(mut self, autosave: bool) -> Self {
self.autosave = autosave;
self
}
#[must_use]
pub fn event_bus(mut self, event_bus: EventBus) -> Self {
self.event_bus = Some(event_bus);
self
}
#[must_use]
pub fn start_listener(mut self, start: bool) -> Self {
self.start_listener = start;
self
}
pub async fn build(self) -> AppRunner {
self.try_build()
.await
.expect("AppRunnerBuilder requires an app to be set")
}
pub async fn try_build(self) -> Option<AppRunner> {
let app = self.app?;
let event_bus = self
.event_bus
.unwrap_or_else(|| app.runtime_config().event_bus.build_event_bus());
Some(
AppRunner::with_arc_and_bus(
app,
self.checkpointer_type,
self.checkpointer_custom,
self.autosave,
event_bus,
self.start_listener,
)
.await,
)
}
}
impl AppRunner {
#[must_use]
pub fn builder() -> AppRunnerBuilder {
AppRunnerBuilder::new()
}
#[deprecated(
since = "0.2.0",
note = "Use AppRunner::builder().app(app).checkpointer(type).build().await instead"
)]
#[must_use]
#[allow(deprecated)]
pub async fn new(app: App, checkpointer_type: CheckpointerType) -> Self {
Self::with_options(app, checkpointer_type, true).await
}
#[deprecated(
since = "0.2.0",
note = "Use AppRunner::builder().app_arc(app).checkpointer(type).build().await instead"
)]
#[must_use]
#[allow(deprecated)]
pub async fn from_arc(app: Arc<App>, checkpointer_type: CheckpointerType) -> Self {
Self::with_options_arc(app, checkpointer_type, true).await
}
async fn create_checkpointer(
checkpointer_type: CheckpointerType,
sqlite_db_name: Option<String>,
) -> Option<Arc<dyn Checkpointer>> {
match checkpointer_type {
CheckpointerType::InMemory => {
Some(Arc::new(InMemoryCheckpointer::new()) as Arc<dyn Checkpointer>)
}
#[cfg(feature = "sqlite")]
CheckpointerType::SQLite => {
let db_url = std::env::var("WEAVEGRAPH_SQLITE_URL")
.ok()
.or_else(|| {
sqlite_db_name
.as_ref()
.map(|name| format!("sqlite://{name}"))
})
.unwrap_or_else(|| {
let fallback = std::env::var("SQLITE_DB_NAME")
.unwrap_or_else(|_| "weavegraph.db".to_string());
format!("sqlite://{fallback}")
});
if let Some(path) = db_url.strip_prefix("sqlite://") {
let path = path.trim();
if !path.is_empty() {
let p = std::path::Path::new(path);
if let Some(parent) = p.parent() {
let _ = std::fs::create_dir_all(parent);
}
if !p.exists() {
let _ = std::fs::File::create_new(p);
}
}
}
match crate::runtimes::SQLiteCheckpointer::connect(&db_url).await {
Ok(cp) => Some(Arc::new(cp) as Arc<dyn Checkpointer>),
Err(e) => {
tracing::error!(
url = %db_url,
error = %e,
"SQLiteCheckpointer initialization failed"
);
None
}
}
}
#[cfg(feature = "postgres")]
CheckpointerType::Postgres => {
let db_url = std::env::var("WEAVEGRAPH_POSTGRES_URL")
.ok()
.or_else(|| std::env::var("DATABASE_URL").ok())
.unwrap_or_else(|| "postgresql://localhost/weavegraph".to_string());
match crate::runtimes::PostgresCheckpointer::connect(&db_url).await {
Ok(cp) => Some(Arc::new(cp) as Arc<dyn Checkpointer>),
Err(e) => {
tracing::error!(
url = %db_url,
error = %e,
"PostgresCheckpointer initialization failed"
);
None
}
}
}
}
}
#[deprecated(
since = "0.2.0",
note = "Use AppRunner::builder().app(app).checkpointer(type).autosave(bool).build().await instead"
)]
pub async fn with_options(
app: App,
checkpointer_type: CheckpointerType,
autosave: bool,
) -> Self {
let bus = app.runtime_config().event_bus.build_event_bus();
let app = Arc::new(app);
Self::with_arc_and_bus(app, checkpointer_type, None, autosave, bus, true).await
}
#[deprecated(
since = "0.2.0",
note = "Use AppRunner::builder().app_arc(app).checkpointer(type).autosave(bool).build().await instead"
)]
pub async fn with_options_arc(
app: Arc<App>,
checkpointer_type: CheckpointerType,
autosave: bool,
) -> Self {
let bus = app.runtime_config().event_bus.build_event_bus();
Self::with_arc_and_bus(app, checkpointer_type, None, autosave, bus, true).await
}
#[deprecated(
since = "0.2.0",
note = "Use AppRunner::builder().app(app).checkpointer(type).autosave(bool).event_bus(bus).start_listener(bool).build().await instead"
)]
pub async fn with_options_and_bus(
app: App,
checkpointer_type: CheckpointerType,
autosave: bool,
event_bus: EventBus,
start_listener: bool,
) -> Self {
let app = Arc::new(app);
Self::with_arc_and_bus(
app,
checkpointer_type,
None,
autosave,
event_bus,
start_listener,
)
.await
}
#[deprecated(
since = "0.2.0",
note = "Use AppRunner::builder().app_arc(app).checkpointer(type).autosave(bool).event_bus(bus).start_listener(bool).build().await instead"
)]
pub async fn with_options_arc_and_bus(
app: Arc<App>,
checkpointer_type: CheckpointerType,
autosave: bool,
event_bus: EventBus,
start_listener: bool,
) -> Self {
Self::with_arc_and_bus(
app,
checkpointer_type,
None,
autosave,
event_bus,
start_listener,
)
.await
}
async fn with_arc_and_bus(
app: Arc<App>,
checkpointer_type: CheckpointerType,
checkpointer_custom: Option<Arc<dyn Checkpointer>>,
autosave: bool,
event_bus: EventBus,
start_listener: bool,
) -> Self {
let checkpointer = if let Some(custom) = checkpointer_custom {
Some(custom)
} else {
let sqlite_db_name = app.runtime_config().sqlite_db_name.clone();
Self::create_checkpointer(checkpointer_type, sqlite_db_name).await
};
if start_listener {
event_bus.listen_for_events();
}
Self {
app,
sessions: FxHashMap::default(),
checkpointer,
autosave,
event_bus,
event_stream_taken: false,
}
}
pub fn event_stream(&mut self) -> Option<EventStream> {
if self.event_stream_taken {
return None;
}
self.event_stream_taken = true;
Some(self.event_bus.subscribe())
}
#[instrument(skip(self, initial_state, session_id), err)]
pub async fn create_session(
&mut self,
session_id: String,
initial_state: VersionedState,
) -> Result<SessionInit, RunnerError> {
let restored_checkpoint = if let Some(cp) = &self.checkpointer {
cp.load_latest(&session_id)
.await
.map_err(RunnerError::Checkpointer)?
} else {
None
};
if let Some(stored) = restored_checkpoint {
let restored = restore_session_state(&stored);
self.sessions.insert(session_id, restored);
return Ok(SessionInit::Resumed {
checkpoint_step: stored.step,
});
}
let frontier = self
.app
.edges()
.get(&NodeKind::Start)
.cloned()
.unwrap_or_default();
if frontier.is_empty() {
return Err(RunnerError::NoStartNodes);
}
let default_limit = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let scheduler = Scheduler::new(default_limit);
let session_state = SessionState {
state: initial_state,
step: 0,
frontier,
scheduler,
scheduler_state: SchedulerState::default(),
};
self.sessions
.insert(session_id.clone(), session_state.clone());
if let Some(cp) = &self.checkpointer {
let _ = cp
.save(Checkpoint::from_session(&session_id, &session_state))
.await;
}
Ok(SessionInit::Fresh)
}
#[instrument(skip(self, options), err)]
pub async fn run_step(
&mut self,
session_id: &str,
options: StepOptions,
) -> Result<StepResult, RunnerError> {
let (current_step, current_frontier, current_versions) = {
let current_session_state =
self.sessions
.get(session_id)
.ok_or_else(|| RunnerError::SessionNotFound {
session_id: session_id.to_string(),
})?;
let versions = StateVersions {
messages_version: current_session_state.state.messages.version(),
extra_version: current_session_state.state.extra.version(),
};
(
current_session_state.step,
current_session_state.frontier.clone(),
versions,
)
};
if current_frontier.is_empty() || current_frontier.iter().all(|n| *n == NodeKind::End) {
return Ok(StepResult::Completed(StepReport {
step: current_step,
ran_nodes: vec![],
skipped_nodes: current_frontier.clone(),
barrier_outcome: BarrierOutcome::default(),
next_frontier: vec![],
state_versions: current_versions,
completed: true,
}));
}
for node in ¤t_frontier {
if options.interrupt_before.contains(node) {
let session_state = self
.sessions
.get(session_id)
.ok_or_else(|| RunnerError::SessionNotFound {
session_id: session_id.to_string(),
})?
.clone();
return Ok(StepResult::Paused(PausedReport {
session_state,
reason: PausedReason::BeforeNode(node.clone()),
}));
}
}
let mut session_state =
self.sessions
.remove(session_id)
.ok_or_else(|| RunnerError::SessionNotFound {
session_id: session_id.to_string(),
})?;
let step_report = match self.run_one_superstep(&mut session_state).await {
Ok(rep) => rep,
Err(e) => {
let event = match &e {
RunnerError::Scheduler(source) => match source {
crate::schedulers::SchedulerError::NodeNotFound { kind, step } => {
ErrorEvent {
when: chrono::Utc::now(),
scope: ErrorScope::Scheduler { step: *step },
error: WeaveError::msg(format!(
"node {:?} not found in registry",
kind
)),
tags: vec!["scheduler".into(), "node_not_found".into()],
context: serde_json::json!({
"kind": kind.encode()
}),
}
}
crate::schedulers::SchedulerError::NodeRun { kind, step, source } => {
ErrorEvent {
when: chrono::Utc::now(),
scope: ErrorScope::Node {
kind: kind.encode().to_string(),
step: *step,
},
error: WeaveError::msg(format!("{}", source)),
tags: vec!["node".into()],
context: serde_json::json!({}),
}
}
crate::schedulers::SchedulerError::Join(_) => ErrorEvent {
when: chrono::Utc::now(),
scope: ErrorScope::Scheduler {
step: session_state.step,
},
error: WeaveError::msg(format!("{}", e)),
tags: vec!["scheduler".into()],
context: serde_json::json!({}),
},
},
_ => ErrorEvent {
when: chrono::Utc::now(),
scope: ErrorScope::Runner {
session: session_id.to_string(),
step: session_state.step,
},
error: WeaveError::msg(format!("{}", e)),
tags: vec!["runner".into()],
context: serde_json::json!({
"frontier": session_state.frontier.iter().map(|k| k.encode()).collect::<Vec<_>>()
}),
},
};
let mut update_state = session_state.state.clone();
let partial = NodePartial::new().with_errors(vec![event]);
let _ = self
.app
.apply_barrier(&mut update_state, &[], vec![partial])
.await;
session_state.state = update_state;
self.sessions.insert(session_id.to_string(), session_state);
if self.autosave
&& let Some(cp) = &self.checkpointer
&& let Some(s) = self.sessions.get(session_id)
{
let _ = cp.save(Checkpoint::from_session(session_id, s)).await;
}
return Err(e);
}
};
if let Some(node) = step_report
.ran_nodes
.iter()
.find(|n| options.interrupt_after.contains(n))
{
let persisted = session_state.clone();
self.sessions.insert(session_id.to_string(), persisted);
self.maybe_checkpoint(session_id, step_report.step).await;
return Ok(StepResult::Paused(PausedReport {
session_state,
reason: PausedReason::AfterNode(node.clone()),
}));
}
if options.interrupt_each_step {
let persisted = session_state.clone();
self.sessions.insert(session_id.to_string(), persisted);
self.maybe_checkpoint(session_id, step_report.step).await;
return Ok(StepResult::Paused(PausedReport {
session_state,
reason: PausedReason::AfterStep(step_report.step),
}));
}
self.sessions.insert(session_id.to_string(), session_state);
self.maybe_checkpoint(session_id, step_report.step).await;
Ok(StepResult::Completed(step_report))
}
#[inline]
async fn schedule_step(
&self,
session_state: &mut SessionState,
step: u64,
) -> Result<SchedulerOutcome, RunnerError> {
let snapshot = session_state.state.snapshot();
let result = session_state
.scheduler
.superstep(
&mut session_state.scheduler_state,
self.app.nodes(),
session_state.frontier.clone(),
snapshot.clone(),
step,
self.event_bus.get_emitter(),
)
.await?;
let mut partials_by_kind: FxHashMap<NodeKind, NodePartial> = FxHashMap::default();
for (k, partial) in result.outputs {
partials_by_kind.insert(k, partial);
}
let executed_nodes = result.ran_nodes.clone();
let partials = executed_nodes
.iter()
.cloned()
.filter_map(|k| partials_by_kind.remove(&k))
.collect();
Ok(SchedulerOutcome {
ran_nodes: executed_nodes,
skipped_nodes: result.skipped_nodes,
partials,
})
}
#[tracing::instrument(skip(self, session_state, partials, ran), err)]
async fn apply_barrier_and_update(
&self,
session_state: &mut SessionState,
ran: &[NodeKind],
partials: Vec<NodePartial>,
) -> Result<BarrierOutcome, RunnerError> {
let mut update_state = session_state.state.clone();
let outcome = self
.app
.apply_barrier(&mut update_state, ran, partials)
.await
.map_err(RunnerError::AppBarrier)?;
session_state.state = update_state;
Ok(outcome)
}
#[inline]
fn compute_next_frontier(
&self,
session_state: &SessionState,
ran: &[NodeKind],
barrier: &BarrierOutcome,
step: u64,
) -> Vec<NodeKind> {
let mut next_frontier: Vec<NodeKind> = Vec::new();
let graph_edges = self.app.edges();
let conditional_edges = self.app.conditional_edges();
let state_snapshot = session_state.state.snapshot();
let mut frontier_commands_by_node: FxHashMap<NodeKind, Vec<FrontierCommand>> =
FxHashMap::default();
for (origin, command) in &barrier.frontier_commands {
frontier_commands_by_node
.entry(origin.clone())
.or_default()
.push(command.clone());
}
for id in ran.iter() {
let default_edges = graph_edges.get(id).cloned().unwrap_or_default();
let mut next_targets: Vec<NodeKind> = Vec::new();
let mut frontier_replaced = false;
if let Some(commands) = frontier_commands_by_node.get(id) {
for command in commands {
match command {
FrontierCommand::Replace(entries) => {
if frontier_replaced {
tracing::warn!(
step,
origin = %id.encode(),
target = %entries.iter().fold(String::new(),
|acc, e| format!("{} + {}", acc, e.to_node_kind())
),
"Replace frontier command has been issued once already during this step, skipping."
);
continue;
}
next_targets = entries.iter().map(NodeRoute::to_node_kind).collect();
frontier_replaced = true;
}
FrontierCommand::Append(entries) => {
if next_targets.is_empty() && !frontier_replaced {
next_targets.extend(default_edges.clone());
}
next_targets.extend(entries.iter().map(NodeRoute::to_node_kind));
}
}
}
if next_targets.is_empty() && !frontier_replaced {
next_targets.extend(default_edges.clone());
}
} else {
next_targets.extend(default_edges.clone());
}
if !frontier_replaced {
for conditional_edge in conditional_edges.iter().filter(|ce| ce.from() == id) {
tracing::debug!(from = ?conditional_edge.from(), step, "evaluating conditional edge");
let target_node_names = (conditional_edge.predicate())(state_snapshot.clone());
for target_name in target_node_names {
let target = if target_name == "End" {
NodeKind::End
} else if target_name == "Start" {
NodeKind::Start
} else {
NodeKind::Custom(target_name.clone())
};
tracing::debug!(target = ?target, step, "conditional edge routed");
next_targets.push(target);
}
}
}
for target in next_targets {
let is_valid_target = match &target {
NodeKind::End | NodeKind::Start => true,
NodeKind::Custom(_) => self.app.nodes().contains_key(&target),
};
if is_valid_target {
if !next_frontier.contains(&target) {
next_frontier.push(target);
}
} else {
tracing::warn!(
step,
origin = %id.encode(),
target = %target.encode(),
"frontier target not found; skipping"
);
}
}
}
next_frontier
}
async fn maybe_checkpoint(&self, session_id: &str, step: u64) {
let checkpoint_span = tracing::info_span!("checkpoint", step);
checkpoint_span
.in_scope(|| async {
if self.autosave
&& let Some(checkpointer) = &self.checkpointer
&& let Some(session_state) = self.sessions.get(session_id)
{
let _ = checkpointer
.save(Checkpoint::from_session(session_id, session_state))
.await;
}
})
.await;
}
#[instrument(skip(self, session_state), err)]
async fn run_one_superstep(
&self,
session_state: &mut SessionState,
) -> Result<StepReport, RunnerError> {
session_state.step += 1;
let step = session_state.step;
tracing::debug!(step, "starting superstep");
let schedule_span = tracing::info_span!(
"schedule",
step,
frontier_len = session_state.frontier.len()
);
let scheduler_outcome = schedule_span
.in_scope(|| self.schedule_step(session_state, step))
.await?;
let errors_in_partials = scheduler_outcome
.partials
.iter()
.filter_map(|p| p.errors.as_ref())
.map(|e| e.len())
.sum::<usize>();
let barrier_span = tracing::info_span!(
"barrier",
ran_nodes_len = scheduler_outcome.ran_nodes.len(),
errors_in_partials
);
let barrier_outcome = barrier_span
.in_scope(|| {
self.apply_barrier_and_update(
session_state,
&scheduler_outcome.ran_nodes,
scheduler_outcome.partials,
)
})
.await?;
let commands_count = barrier_outcome.frontier_commands.len();
let conditional_edges_evaluated = self.app.conditional_edges().len();
let frontier_span =
tracing::info_span!("frontier", commands_count, conditional_edges_evaluated);
let next_frontier = frontier_span.in_scope(|| {
self.compute_next_frontier(
session_state,
&scheduler_outcome.ran_nodes,
&barrier_outcome,
step,
)
});
tracing::debug!(
step,
updated_channels = ?barrier_outcome.updated_channels,
error_count = barrier_outcome.errors.len(),
"barrier applied"
);
tracing::debug!(step, next_frontier = ?next_frontier, "computed next frontier");
let completed =
next_frontier.is_empty() || next_frontier.iter().all(|n| *n == NodeKind::End);
session_state.frontier = next_frontier.clone();
let state_versions = StateVersions {
messages_version: session_state.state.messages.version(),
extra_version: session_state.state.extra.version(),
};
Ok(StepReport {
step,
ran_nodes: scheduler_outcome.ran_nodes,
skipped_nodes: scheduler_outcome.skipped_nodes,
barrier_outcome,
next_frontier,
state_versions,
completed,
})
}
#[instrument(skip(self, session_id), err)]
pub async fn run_until_complete(
&mut self,
session_id: &str,
) -> Result<VersionedState, RunnerError> {
tracing::info!(session = %session_id, "workflow run started");
loop {
let session_state =
self.sessions
.get(session_id)
.ok_or_else(|| RunnerError::SessionNotFound {
session_id: session_id.to_string(),
})?;
if self.is_session_complete(session_state) {
tracing::info!(
session = %session_id,
step = session_state.step,
"frontier reached terminal state"
);
break;
}
let step_result = match self.run_step(session_id, StepOptions::default()).await {
Ok(res) => res,
Err(err) => {
let reason = err.to_string();
let step = self.sessions.get(session_id).map(|state| state.step);
self.finalize_event_stream(
session_id,
StreamEndReason::Error {
step,
error: reason,
},
);
return Err(err);
}
};
match step_result {
StepResult::Completed(report) => {
if report.completed {
break;
}
}
StepResult::Paused(_) => {
let step = self.sessions.get(session_id).map(|state| state.step);
self.finalize_event_stream(
session_id,
StreamEndReason::Error {
step,
error: "execution paused unexpectedly".to_string(),
},
);
return Err(RunnerError::UnexpectedPause);
}
}
}
tracing::info!(session = %session_id, "workflow run completed");
let (final_state, versions, final_step) = self.finalize_state_snapshot(session_id)?;
let messages_snapshot = final_state.messages.snapshot();
let extra_snapshot = final_state.extra.snapshot();
let messages_version = versions.messages_version;
let extra_version = versions.extra_version;
for (i, m) in messages_snapshot.iter().enumerate() {
tracing::debug!(
session = %session_id,
message_index = i,
role = %m.role,
content = %m.content,
"final message snapshot entry"
);
}
tracing::debug!(
session = %session_id,
messages_version,
"messages channel version"
);
tracing::debug!(
session = %session_id,
extra_version,
keys = extra_snapshot.len(),
"extra channel summary"
);
for (k, v) in extra_snapshot.iter() {
tracing::debug!(
session = %session_id,
key = %k,
value = %v,
"final extra entry"
);
}
self.finalize_event_stream(session_id, StreamEndReason::Completed { step: final_step });
Ok(final_state)
}
#[must_use]
pub fn get_session(&self, session_id: &str) -> Option<&SessionState> {
self.sessions.get(session_id)
}
#[must_use]
pub fn list_sessions(&self) -> Vec<&String> {
self.sessions.keys().collect()
}
}
impl AppRunner {
#[inline]
fn is_session_complete(&self, session_state: &SessionState) -> bool {
session_state.frontier.is_empty()
|| session_state.frontier.iter().all(|n| *n == NodeKind::End)
}
#[inline]
fn finalize_state_snapshot(
&self,
session_id: &str,
) -> Result<(VersionedState, StateVersions, u64), RunnerError> {
let session_state =
self.sessions
.get(session_id)
.ok_or_else(|| RunnerError::SessionNotFound {
session_id: session_id.to_string(),
})?;
let final_state = session_state.state.clone();
let state_versions = StateVersions {
messages_version: final_state.messages.version(),
extra_version: final_state.extra.version(),
};
let final_step = session_state.step;
Ok((final_state, state_versions, final_step))
}
fn finalize_event_stream(&mut self, session_id: &str, reason: StreamEndReason) {
finalize_event_stream(
&self.event_bus,
session_id,
reason,
&mut self.event_stream_taken,
);
}
}