use async_trait::async_trait;
use rustc_hash::FxHashMap;
use serde_json;
use thiserror::Error;
use crate::channels::errors::ErrorEvent;
use crate::control::{FrontierCommand, NodeRoute};
use crate::event_bus::{Event, EventEmitter, LLMStreamingEvent};
use crate::message::Message;
use crate::state::{StateKey, StateSlotError, StateSnapshot};
use crate::types::NodeKind;
use crate::utils::clock::Clock;
use std::sync::Arc;
#[async_trait]
pub trait Node: Send + Sync {
async fn run(
&self,
snapshot: StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError>;
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct NodeContext {
pub node_id: String,
pub step: u64,
pub event_emitter: Arc<dyn EventEmitter>,
pub clock: Option<Arc<dyn Clock>>,
pub invocation_id: Option<String>,
}
impl NodeContext {
pub fn new(
node_id: impl Into<String>,
step: u64,
event_emitter: Arc<dyn EventEmitter>,
) -> Self {
Self {
node_id: node_id.into(),
step,
event_emitter,
clock: None,
invocation_id: None,
}
}
#[must_use]
pub fn now_unix_ms(&self) -> Option<i64> {
self.clock.as_ref().map(|clock| clock.now_unix_ms())
}
#[must_use]
pub fn invocation_id(&self) -> Option<&str> {
self.invocation_id.as_deref()
}
pub fn emit(
&self,
scope: impl Into<String>,
message: impl Into<String>,
) -> Result<(), NodeContextError> {
self.emit_node(scope, message)
}
pub fn emit_node(
&self,
scope: impl Into<String>,
message: impl Into<String>,
) -> Result<(), NodeContextError> {
let mut metadata = FxHashMap::default();
if let Some(invocation_id) = &self.invocation_id {
metadata.insert(
"invocation_id".to_string(),
serde_json::Value::String(invocation_id.clone()),
);
}
if let Some(now_unix_ms) = self.now_unix_ms() {
metadata.insert("now_unix_ms".to_string(), serde_json::json!(now_unix_ms));
}
if metadata.is_empty() {
self.emit_event(Event::node_message_with_meta(
self.node_id.clone(),
self.step,
scope,
message,
))
} else {
self.emit_event(Event::node_message_with_metadata(
self.node_id.clone(),
self.step,
scope,
message,
metadata,
))
}
}
pub fn emit_diagnostic(
&self,
scope: impl Into<String>,
message: impl Into<String>,
) -> Result<(), NodeContextError> {
self.emit_event(Event::diagnostic(scope, message))
}
pub fn emit_llm_chunk(
&self,
session_id: Option<String>,
stream_id: Option<String>,
chunk: impl Into<String>,
metadata: Option<FxHashMap<String, serde_json::Value>>,
) -> Result<(), NodeContextError> {
let event = LLMStreamingEvent::chunk_event(
session_id,
Some(self.node_id.clone()),
stream_id,
chunk,
metadata.unwrap_or_default(),
);
self.emit_event(Event::LLM(event))
}
pub fn emit_llm_final(
&self,
session_id: Option<String>,
stream_id: Option<String>,
chunk: impl Into<String>,
metadata: Option<FxHashMap<String, serde_json::Value>>,
) -> Result<(), NodeContextError> {
let event = LLMStreamingEvent::final_event(
session_id,
Some(self.node_id.clone()),
stream_id,
chunk,
metadata.unwrap_or_default(),
);
self.emit_event(Event::LLM(event))
}
pub fn emit_llm_error(
&self,
session_id: Option<String>,
stream_id: Option<String>,
error_message: impl Into<String>,
) -> Result<(), NodeContextError> {
let event = LLMStreamingEvent::error_event(
session_id,
Some(self.node_id.clone()),
stream_id,
error_message,
);
self.emit_event(Event::LLM(event))
}
fn emit_event(&self, event: Event) -> Result<(), NodeContextError> {
self.event_emitter
.emit(event)
.map_err(|_| NodeContextError::EventBusUnavailable)
}
}
#[derive(Clone, Debug, Default)]
pub struct NodePartial {
pub messages: Option<Vec<Message>>,
pub extra: Option<FxHashMap<String, serde_json::Value>>,
pub errors: Option<Vec<ErrorEvent>>,
pub frontier: Option<FrontierCommand>,
}
impl NodePartial {
pub fn new() -> Self {
Self {
..Default::default()
}
}
#[must_use]
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = Some(messages);
self
}
#[must_use]
pub fn with_extra(mut self, extra: FxHashMap<String, serde_json::Value>) -> Self {
self.extra = Some(extra);
self
}
pub fn with_typed_extra<T: serde::Serialize>(
mut self,
key: StateKey<T>,
value: T,
) -> Result<Self, StateSlotError> {
let storage_key = key.storage_key();
let json_value =
serde_json::to_value(value).map_err(|source| StateSlotError::Serialize {
key: storage_key.clone(),
source,
})?;
self.extra
.get_or_insert_with(FxHashMap::default)
.insert(storage_key, json_value);
Ok(self)
}
#[must_use]
pub fn with_errors(mut self, errors: Vec<ErrorEvent>) -> Self {
self.errors = Some(errors);
self
}
#[must_use]
pub fn with_frontier_replace<I>(mut self, targets: I) -> Self
where
I: IntoIterator<Item = NodeKind>,
{
let routes = targets.into_iter().map(NodeRoute::from).collect();
self.frontier = Some(FrontierCommand::Replace(routes));
self
}
#[must_use]
pub fn with_frontier_append<I>(mut self, targets: I) -> Self
where
I: IntoIterator<Item = NodeKind>,
{
let routes = targets.into_iter().map(NodeRoute::from).collect();
self.frontier = Some(FrontierCommand::Append(routes));
self
}
#[must_use]
pub fn with_frontier_command(mut self, command: FrontierCommand) -> Self {
self.frontier = Some(command);
self
}
#[must_use]
pub fn clear_extra_keys<I, S>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let extra = self.extra.get_or_insert_with(FxHashMap::default);
for key in keys {
extra.insert(key.into(), serde_json::Value::Null);
}
self
}
#[must_use]
pub fn clear_typed_extra_key<T>(self, key: crate::state::StateKey<T>) -> Self {
self.clear_extra_keys([key.storage_key()])
}
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum NodeContextError {
#[error("failed to emit event: event bus unavailable")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::node::event_bus_unavailable),
help("The event bus may be disconnected or at capacity. Check workflow state.")
)
)]
EventBusUnavailable,
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
#[non_exhaustive]
pub enum NodeError {
#[error("missing expected input: {what}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::node::missing_input),
help("Check that the previous node produced the required data: {what}.")
)
)]
MissingInput {
what: &'static str,
},
#[error("provider error ({provider}): {message}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::provider)))]
Provider {
provider: &'static str,
message: String,
},
#[error(transparent)]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::other)))]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::node::serde_json))
)]
Serde(#[from] serde_json::Error),
#[error("validation failed: {0}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::node::validation),
help("Check input data format and required fields.")
)
)]
ValidationFailed(String),
#[error("event bus error: {0}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::event_bus)))]
EventBus(#[from] NodeContextError),
}
impl NodeError {
#[must_use]
pub fn other<E>(error: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Other(Box::new(error))
}
}
pub type NodeResult<T> = std::result::Result<T, NodeError>;
pub trait NodeResultExt<T> {
fn node_err(self) -> NodeResult<T>;
}
impl<T, E> NodeResultExt<T> for std::result::Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn node_err(self) -> NodeResult<T> {
self.map_err(NodeError::other)
}
}