use rustc_hash::FxHashMap;
use std::sync::Arc;
use crate::channels::Channel;
use crate::channels::errors::{ErrorEvent, ErrorScope};
use crate::control::FrontierCommand;
use crate::event_bus::{ChannelSink, EventBus, EventStream};
use crate::message::*;
use crate::node::*;
use crate::reducers::ReducerRegistry;
use crate::runtimes::runner::RunnerError;
use crate::runtimes::{AppRunner, Checkpointer, CheckpointerType, RuntimeConfig, SessionInit};
use crate::state::*;
use crate::types::*;
use crate::utils::collections::new_extra_map;
use crate::utils::id_generator::IdGenerator;
use futures_util::stream::BoxStream;
use thiserror::Error;
use tokio::task::JoinHandle;
use tracing::instrument;
#[derive(Clone)]
pub struct App {
nodes: FxHashMap<NodeKind, Arc<dyn Node>>,
edges: FxHashMap<NodeKind, Vec<NodeKind>>,
conditional_edges: Vec<crate::graphs::ConditionalEdge>,
reducer_registry: ReducerRegistry,
runtime_config: RuntimeConfig,
}
pub struct AppEventStream {
event_bus: EventBus,
event_stream: Option<EventStream>,
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum AppEventStreamError {
#[error("event stream has already been taken")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::app::event_stream),
help("Verify stream subscription and event channel capacity.")
)
)]
AlreadyTaken,
}
type AppEventStreamResult<T> = Result<T, AppEventStreamError>;
pub struct InvocationHandle {
join_handle: Option<JoinHandle<Result<VersionedState, RunnerError>>>,
}
#[derive(Debug, Clone, Default)]
pub struct BarrierOutcome {
pub updated_channels: Vec<&'static str>,
pub errors: Vec<ErrorEvent>,
pub frontier_commands: Vec<(NodeKind, FrontierCommand)>,
}
impl AppEventStream {
fn new(event_bus: EventBus, event_stream: EventStream) -> Self {
Self {
event_bus,
event_stream: Some(event_stream),
}
}
pub fn event_bus(&self) -> &EventBus {
&self.event_bus
}
pub fn event_stream(&mut self) -> AppEventStreamResult<&mut EventStream> {
self.event_stream
.as_mut()
.ok_or(AppEventStreamError::AlreadyTaken)
}
pub fn into_stream(mut self) -> AppEventStreamResult<EventStream> {
self.event_stream
.take()
.ok_or(AppEventStreamError::AlreadyTaken)
}
pub fn into_event_bus(self) -> EventBus {
self.event_bus
}
pub fn split(mut self) -> AppEventStreamResult<(EventBus, EventStream)> {
let stream = self
.event_stream
.take()
.ok_or(AppEventStreamError::AlreadyTaken)?;
Ok((self.event_bus, stream))
}
pub fn into_blocking_iter(self) -> AppEventStreamResult<crate::event_bus::BlockingEventIter> {
Ok(self.into_stream()?.into_blocking_iter())
}
pub fn into_async_stream(
self,
) -> AppEventStreamResult<BoxStream<'static, crate::event_bus::Event>> {
Ok(self.into_stream()?.into_async_stream())
}
pub async fn next_timeout(
&mut self,
duration: std::time::Duration,
) -> AppEventStreamResult<Option<crate::event_bus::Event>> {
Ok(self.event_stream()?.next_timeout(duration).await)
}
}
impl InvocationHandle {
pub fn abort(&self) {
if let Some(handle) = &self.join_handle {
handle.abort();
}
}
#[must_use]
pub fn is_finished(&self) -> bool {
self.join_handle
.as_ref()
.map(|h| h.is_finished())
.unwrap_or(true)
}
pub async fn join(mut self) -> Result<VersionedState, RunnerError> {
let handle = self
.join_handle
.take()
.ok_or(RunnerError::JoinHandleConsumed)?;
match handle.await {
Ok(result) => result,
Err(err) => Err(RunnerError::Join(err)),
}
}
}
impl App {
pub(crate) fn from_parts(
nodes: FxHashMap<NodeKind, Arc<dyn Node>>,
edges: FxHashMap<NodeKind, Vec<NodeKind>>,
conditional_edges: Vec<crate::graphs::ConditionalEdge>,
runtime_config: RuntimeConfig,
reducer_registry: ReducerRegistry,
) -> Self {
App {
nodes,
edges,
conditional_edges,
reducer_registry,
runtime_config,
}
}
#[must_use]
pub fn conditional_edges(&self) -> &Vec<crate::graphs::ConditionalEdge> {
&self.conditional_edges
}
#[must_use]
pub fn nodes(&self) -> &FxHashMap<NodeKind, Arc<dyn Node>> {
&self.nodes
}
#[must_use]
pub fn edges(&self) -> &FxHashMap<NodeKind, Vec<NodeKind>> {
&self.edges
}
#[must_use]
pub fn runtime_config(&self) -> &RuntimeConfig {
&self.runtime_config
}
#[must_use]
pub fn event_stream(&self) -> AppEventStream {
let event_bus = self.runtime_config.event_bus.build_event_bus();
let event_stream = event_bus.subscribe();
AppEventStream::new(event_bus, event_stream)
}
fn resolve_checkpointer(
&self,
override_config: Option<CheckpointerType>,
) -> (CheckpointerType, Option<Arc<dyn Checkpointer>>) {
let checkpointer_type = override_config
.or_else(|| self.runtime_config.checkpointer_type())
.unwrap_or(CheckpointerType::InMemory);
let custom_checkpointer = self.runtime_config.custom_checkpointer();
(checkpointer_type, custom_checkpointer)
}
async fn invoke_with_bus_builder<R, F>(
&self,
initial_state: VersionedState,
autosave: bool,
checkpointer_override: Option<CheckpointerType>,
build_event_bus: F,
) -> (Result<VersionedState, RunnerError>, R)
where
F: FnOnce() -> (EventBus, R),
{
let (event_bus, output) = build_event_bus();
let (checkpointer_type, custom_checkpointer) =
self.resolve_checkpointer(checkpointer_override);
let mut runner_builder = AppRunner::builder()
.app(self.clone())
.autosave(autosave)
.event_bus(event_bus)
.start_listener(true);
runner_builder = if let Some(custom) = custom_checkpointer {
runner_builder.checkpointer_custom(custom)
} else {
runner_builder.checkpointer(checkpointer_type)
};
let runner = runner_builder.build().await;
let session_id = self.next_session_id();
let result = Self::run_session(runner, session_id, initial_state).await;
(result, output)
}
pub async fn invoke_streaming(
&self,
initial_state: VersionedState,
) -> (InvocationHandle, EventStream) {
let (checkpointer_type, custom_checkpointer) = self.resolve_checkpointer(None);
let event_handle = self.event_stream();
let (event_bus, event_stream) = event_handle.split().unwrap_or_else(|_| {
unreachable!("fresh App::event_stream() always yields unused stream")
});
let mut runner_builder = AppRunner::builder()
.app(self.clone())
.autosave(true)
.event_bus(event_bus)
.start_listener(true);
runner_builder = if let Some(custom) = custom_checkpointer {
runner_builder.checkpointer_custom(custom)
} else {
runner_builder.checkpointer(checkpointer_type)
};
let runner = runner_builder.build().await;
let session_id = self.next_session_id();
let join = tokio::spawn(Self::run_session(runner, session_id, initial_state));
(
InvocationHandle {
join_handle: Some(join),
},
event_stream,
)
}
#[instrument(skip(self, initial_state), err)]
pub async fn invoke(
&self,
initial_state: VersionedState,
) -> Result<VersionedState, RunnerError> {
self.invoke_with_bus_builder(initial_state, true, None, || {
(self.runtime_config.event_bus.build_event_bus(), ())
})
.await
.0
}
#[instrument(skip(self, initial_state))]
pub async fn invoke_with_channel(
&self,
initial_state: VersionedState,
) -> (
Result<VersionedState, RunnerError>,
flume::Receiver<crate::event_bus::Event>,
) {
self.invoke_with_bus_builder(initial_state, false, None, || {
let (tx, rx) = flume::unbounded();
let event_bus = self.runtime_config.event_bus.build_event_bus();
event_bus.add_sink(ChannelSink::new(tx));
(event_bus, rx)
})
.await
}
#[instrument(skip(self, initial_state, sinks), err)]
pub async fn invoke_with_sinks(
&self,
initial_state: VersionedState,
sinks: Vec<Box<dyn crate::event_bus::EventSink>>,
) -> Result<VersionedState, RunnerError> {
self.invoke_with_bus_builder(initial_state, false, None, move || {
let event_bus = self.runtime_config.event_bus.build_event_bus();
for sink in sinks {
event_bus.add_boxed_sink(sink);
}
(event_bus, ())
})
.await
.0
}
fn next_session_id(&self) -> String {
self.runtime_config
.session_id
.clone()
.unwrap_or_else(|| IdGenerator::new().generate_run_id())
}
async fn run_session(
mut runner: AppRunner,
session_id: String,
initial_state: VersionedState,
) -> Result<VersionedState, RunnerError> {
let init_state = runner
.create_session(session_id.clone(), initial_state)
.await?;
if let SessionInit::Resumed { checkpoint_step } = init_state {
tracing::info!(
session = %session_id,
checkpoint_step,
"Resuming session from checkpoint"
);
}
runner.run_until_complete(&session_id).await
}
#[instrument(skip(self, state, run_ids, node_partials), err)]
pub async fn apply_barrier(
&self,
state: &mut VersionedState,
run_ids: &[NodeKind],
node_partials: Vec<NodePartial>,
) -> Result<BarrierOutcome, Box<dyn std::error::Error + Send + Sync>> {
let mut msgs_all: Vec<Message> = Vec::new();
let mut extra_all = new_extra_map();
let mut errors_all: Vec<ErrorEvent> = Vec::new();
let mut frontier_commands: Vec<(NodeKind, FrontierCommand)> = Vec::new();
for (i, p) in node_partials.iter().enumerate() {
let fallback = NodeKind::Custom("?".to_string());
let nid = run_ids.get(i).unwrap_or(&fallback);
if let Some(ms) = &p.messages
&& !ms.is_empty()
{
tracing::debug!(node = ?nid, count = ms.len(), "Node produced messages");
msgs_all.extend(ms.clone());
}
if let Some(ex) = &p.extra
&& !ex.is_empty()
{
tracing::debug!(node = ?nid, keys = ex.len(), "Node produced extra data");
let mut sorted_pairs: Vec<_> = ex.iter().collect();
sorted_pairs.sort_by(|(left, _), (right, _)| left.cmp(right));
for (k, v) in sorted_pairs {
extra_all.insert(k.clone(), v.clone());
}
}
if let Some(errs) = &p.errors
&& !errs.is_empty()
{
tracing::debug!(node = ?nid, count = errs.len(), "Node produced errors");
errors_all.extend(errs.clone());
}
if let Some(command) = &p.frontier {
frontier_commands.push((nid.clone(), command.clone()));
}
}
fn scope_sort_key(scope: &ErrorScope) -> (u8, &str, u64) {
match scope {
ErrorScope::Node { kind, step } => (0, kind.as_str(), *step),
ErrorScope::Scheduler { step } => (1, "", *step),
ErrorScope::Runner { session, step } => (2, session.as_str(), *step),
ErrorScope::App => (3, "", 0),
}
}
errors_all.sort_by(|a, b| {
let key_a = scope_sort_key(&a.scope);
let key_b = scope_sort_key(&b.scope);
key_a
.cmp(&key_b)
.then_with(|| a.when.cmp(&b.when))
.then_with(|| a.error.message.cmp(&b.error.message))
});
let errors_for_state = if errors_all.is_empty() {
None
} else {
Some(errors_all.clone())
};
let merged_updates = NodePartial {
messages: if msgs_all.is_empty() {
None
} else {
Some(msgs_all)
},
extra: if extra_all.is_empty() {
None
} else {
Some(extra_all)
},
errors: errors_for_state,
frontier: None,
};
let msgs_before_len = state.messages.len();
let msgs_before_ver = state.messages.version();
let extra_before = state.extra.snapshot();
let extra_before_ver = state.extra.version();
self.reducer_registry
.apply_all(&mut *state, &merged_updates)?;
let mut updated: Vec<&'static str> = Vec::new();
let msgs_changed = state.messages.len() != msgs_before_len;
if msgs_changed {
state
.messages
.set_version(msgs_before_ver.saturating_add(1));
tracing::info!(
target: "weavegraph::app",
channel = "messages",
before_count = msgs_before_len,
after_count = state.messages.len(),
before_version = msgs_before_ver,
after_version = state.messages.version(),
"channel updated"
);
updated.push("messages");
}
let extra_after = state.extra.snapshot();
let extra_changed = extra_after != extra_before;
if extra_changed {
state.extra.set_version(extra_before_ver.saturating_add(1));
tracing::info!(
target: "weavegraph::app",
channel = "extra",
before_count = extra_before.len(),
after_count = extra_after.len(),
before_version = extra_before_ver,
after_version = state.extra.version(),
"channel updated"
);
updated.push("extra");
}
Ok(BarrierOutcome {
updated_channels: updated,
errors: errors_all,
frontier_commands,
})
}
}