use async_trait::async_trait;
use rustc_hash::FxHashMap;
use serde_json;
use thiserror::Error;
use crate::channels::errors::ErrorEvent;
use crate::control::{FrontierCommand, NodeRoute};
use crate::event_bus::{Event, EventEmitter, LLMStreamingEvent};
use crate::message::Message;
use crate::state::StateSnapshot;
use crate::types::NodeKind;
use std::sync::Arc;
#[async_trait]
pub trait Node: Send + Sync {
async fn run(
&self,
snapshot: StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError>;
}
#[derive(Clone, Debug)]
pub struct NodeContext {
pub node_id: String,
pub step: u64,
pub event_emitter: Arc<dyn EventEmitter>,
}
impl NodeContext {
pub fn emit(
&self,
scope: impl Into<String>,
message: impl Into<String>,
) -> Result<(), NodeContextError> {
self.emit_node(scope, message)
}
pub fn emit_node(
&self,
scope: impl Into<String>,
message: impl Into<String>,
) -> Result<(), NodeContextError> {
self.emit_event(Event::node_message_with_meta(
self.node_id.clone(),
self.step,
scope,
message,
))
}
pub fn emit_diagnostic(
&self,
scope: impl Into<String>,
message: impl Into<String>,
) -> Result<(), NodeContextError> {
self.emit_event(Event::diagnostic(scope, message))
}
pub fn emit_llm_chunk(
&self,
session_id: Option<String>,
stream_id: Option<String>,
chunk: impl Into<String>,
metadata: Option<FxHashMap<String, serde_json::Value>>,
) -> Result<(), NodeContextError> {
let event = LLMStreamingEvent::chunk_event(
session_id,
Some(self.node_id.clone()),
stream_id,
chunk,
metadata.unwrap_or_default(),
);
self.emit_event(Event::LLM(event))
}
pub fn emit_llm_final(
&self,
session_id: Option<String>,
stream_id: Option<String>,
chunk: impl Into<String>,
metadata: Option<FxHashMap<String, serde_json::Value>>,
) -> Result<(), NodeContextError> {
let event = LLMStreamingEvent::final_event(
session_id,
Some(self.node_id.clone()),
stream_id,
chunk,
metadata.unwrap_or_default(),
);
self.emit_event(Event::LLM(event))
}
pub fn emit_llm_error(
&self,
session_id: Option<String>,
stream_id: Option<String>,
error_message: impl Into<String>,
) -> Result<(), NodeContextError> {
let event = LLMStreamingEvent::error_event(
session_id,
Some(self.node_id.clone()),
stream_id,
error_message,
);
self.emit_event(Event::LLM(event))
}
fn emit_event(&self, event: Event) -> Result<(), NodeContextError> {
self.event_emitter
.emit(event)
.map_err(|_| NodeContextError::EventBusUnavailable)
}
}
#[derive(Clone, Debug, Default)]
pub struct NodePartial {
pub messages: Option<Vec<Message>>,
pub extra: Option<FxHashMap<String, serde_json::Value>>,
pub errors: Option<Vec<ErrorEvent>>,
pub frontier: Option<FrontierCommand>,
}
impl NodePartial {
pub fn new() -> Self {
Self {
..Default::default()
}
}
#[must_use]
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = Some(messages);
self
}
#[must_use]
pub fn with_extra(mut self, extra: FxHashMap<String, serde_json::Value>) -> Self {
self.extra = Some(extra);
self
}
#[must_use]
pub fn with_errors(mut self, errors: Vec<ErrorEvent>) -> Self {
self.errors = Some(errors);
self
}
#[must_use]
pub fn with_frontier_replace<I>(mut self, targets: I) -> Self
where
I: IntoIterator<Item = NodeKind>,
{
let routes = targets.into_iter().map(NodeRoute::from).collect();
self.frontier = Some(FrontierCommand::Replace(routes));
self
}
#[must_use]
pub fn with_frontier_append<I>(mut self, targets: I) -> Self
where
I: IntoIterator<Item = NodeKind>,
{
let routes = targets.into_iter().map(NodeRoute::from).collect();
self.frontier = Some(FrontierCommand::Append(routes));
self
}
#[must_use]
pub fn with_frontier_command(mut self, command: FrontierCommand) -> Self {
self.frontier = Some(command);
self
}
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum NodeContextError {
#[error("failed to emit event: event bus unavailable")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::node::event_bus_unavailable),
help("The event bus may be disconnected or at capacity. Check workflow state.")
)
)]
EventBusUnavailable,
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
pub enum NodeError {
#[error("missing expected input: {what}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::node::missing_input),
help("Check that the previous node produced the required data: {what}.")
)
)]
MissingInput { what: &'static str },
#[error("provider error ({provider}): {message}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::provider)))]
Provider {
provider: &'static str,
message: String,
},
#[error(transparent)]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::other)))]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::node::serde_json))
)]
Serde(#[from] serde_json::Error),
#[error("validation failed: {0}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(weavegraph::node::validation),
help("Check input data format and required fields.")
)
)]
ValidationFailed(String),
#[error("event bus error: {0}")]
#[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::event_bus)))]
EventBus(#[from] NodeContextError),
}
impl NodeError {
#[must_use]
pub fn other<E>(error: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::Other(Box::new(error))
}
}
pub type NodeResult<T> = std::result::Result<T, NodeError>;
pub trait NodeResultExt<T> {
fn node_err(self) -> NodeResult<T>;
}
impl<T, E> NodeResultExt<T> for std::result::Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn node_err(self) -> NodeResult<T> {
self.map_err(NodeError::other)
}
}