use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use tracing::info;
use tracing_error::ErrorLayer;
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
use weavegraph::channels::Channel;
use weavegraph::event_bus::EventBus;
use weavegraph::message::{Message, Role};
use weavegraph::node::{Node, NodeContext, NodeError, NodePartial};
use weavegraph::state::{StateSnapshot, VersionedState};
use weavegraph::utils::collections::new_extra_map;
type ExampleResult<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
pub struct MessageCounterNode {
pub node_name: String,
}
#[async_trait]
impl Node for MessageCounterNode {
async fn run(
&self,
snapshot: StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError> {
let message_count = snapshot.messages.len();
ctx.emit(
"processing",
format!(
"{} starting (found {} existing messages)",
self.node_name, message_count
),
)?;
let summary = format!(
"{} processed {} messages at step {}",
self.node_name, message_count, ctx.step
);
let mut extra = new_extra_map();
extra.insert("processor".to_string(), json!(self.node_name));
extra.insert("message_count".to_string(), json!(message_count));
extra.insert("step".to_string(), json!(ctx.step));
extra.insert(
"timestamp".to_string(),
json!(chrono::Utc::now().to_rfc3339()),
);
ctx.emit(
"completed",
format!("{} finished processing", self.node_name),
)?;
Ok(NodePartial::new()
.with_messages(vec![Message::with_role(Role::Assistant, &summary)])
.with_extra(extra))
}
}
pub struct ValidationNode {
pub required_fields: Vec<String>,
pub min_message_count: usize,
}
#[async_trait]
impl Node for ValidationNode {
async fn run(
&self,
snapshot: StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError> {
ctx.emit("validation", "Starting input validation")?;
if snapshot.messages.len() < self.min_message_count {
return Err(NodeError::ValidationFailed(format!(
"Expected at least {} messages, found {}",
self.min_message_count,
snapshot.messages.len()
)));
}
let mut missing_fields = Vec::new();
for field in &self.required_fields {
if !snapshot.extra.contains_key(field) {
missing_fields.push(field.clone());
}
}
if !missing_fields.is_empty() {
return Err(NodeError::ValidationFailed(format!(
"Missing required fields: {}",
missing_fields.join(", ")
)));
}
ctx.emit("validation", "All validations passed")?;
let mut extra = new_extra_map();
extra.insert("validation_status".to_string(), json!("passed"));
extra.insert("validated_fields".to_string(), json!(self.required_fields));
extra.insert("message_count_ok".to_string(), json!(true));
Ok(NodePartial::new().with_extra(extra))
}
}
pub struct AggregatorNode;
#[async_trait]
impl Node for AggregatorNode {
async fn run(
&self,
snapshot: StateSnapshot,
ctx: NodeContext,
) -> Result<NodePartial, NodeError> {
ctx.emit("aggregation", "Starting data aggregation")?;
let mut processors = Vec::new();
let mut total_steps = 0u64;
for (key, value) in &snapshot.extra {
if key == "processor"
&& let Some(processor_name) = value.as_str()
{
processors.push(processor_name.to_string());
}
if key == "step"
&& let Some(step) = value.as_u64()
{
total_steps += step;
}
}
if total_steps > 100 {
ctx.emit(
"warning",
format!(
"Total processing steps ({}) exceeds recommended threshold",
total_steps
),
)?;
}
let summary = format!(
"Aggregated data from {} processors across {} total steps",
processors.len(),
total_steps
);
let mut extra = new_extra_map();
extra.insert(
"aggregation_summary".to_string(),
json!({
"processors": processors,
"total_steps": total_steps,
"message_count": snapshot.messages.len(),
"aggregated_at": chrono::Utc::now().to_rfc3339()
}),
);
ctx.emit("completed", "Data aggregation completed")?;
Ok(NodePartial::new()
.with_messages(vec![Message::with_role(Role::Assistant, &summary)])
.with_extra(extra))
}
}
fn init_tracing() {
let fmt_layer = fmt::layer()
.with_target(false)
.with_file(false)
.with_line_number(false)
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE);
let filter = EnvFilter::try_from_default_env()
.or_else(|_| EnvFilter::try_new("error,weavegraph=error"))
.unwrap();
tracing_subscriber::registry()
.with(filter)
.with(fmt_layer)
.with(ErrorLayer::default())
.init();
}
#[tokio::main]
async fn main() -> ExampleResult<()> {
init_tracing();
info!("🔧 Basic Node Examples");
info!("======================");
let event_bus = EventBus::default();
event_bus.listen_for_events();
let mut state = VersionedState::builder()
.with_user_message("Initial user message")
.with_extra("processor", json!("initial"))
.with_extra("step", json!(1))
.build();
info!("\n📊 Initial State:");
info!(" Messages: {}", state.messages.snapshot().len());
info!(
" Extra keys: {:?}",
state.extra.snapshot().keys().collect::<Vec<_>>()
);
info!("\n🔄 Running MessageCounterNode...");
let counter_node = MessageCounterNode {
node_name: "CounterExample".to_string(),
};
let emitter = event_bus.get_emitter();
let ctx1 = NodeContext {
node_id: "counter-1".to_string(),
step: 2,
event_emitter: Arc::clone(&emitter),
};
let result1 = counter_node.run(state.snapshot(), ctx1).await?;
if let Some(messages) = result1.messages {
state.messages.get_mut().extend(messages);
}
if let Some(extra) = result1.extra {
state.extra.get_mut().extend(extra);
}
info!(" ✅ Messages now: {}", state.messages.snapshot().len());
info!(
" ✅ Extra keys: {:?}",
state.extra.snapshot().keys().collect::<Vec<_>>()
);
info!("\n🔍 Running ValidationNode...");
let validation_node = ValidationNode {
required_fields: vec!["processor".to_string(), "step".to_string()],
min_message_count: 1,
};
let ctx2 = NodeContext {
node_id: "validator-1".to_string(),
step: 3,
event_emitter: Arc::clone(&emitter),
};
let result2 = validation_node.run(state.snapshot(), ctx2).await?;
if let Some(extra) = result2.extra {
state.extra.get_mut().extend(extra);
}
info!(" ✅ Validation passed");
info!("\n📈 Running AggregatorNode...");
let aggregator_node = AggregatorNode;
let ctx3 = NodeContext {
node_id: "aggregator-1".to_string(),
step: 4,
event_emitter: Arc::clone(&emitter),
};
let result3 = aggregator_node.run(state.snapshot(), ctx3).await?;
if let Some(messages) = result3.messages {
state.messages.get_mut().extend(messages);
}
if let Some(extra) = result3.extra {
state.extra.get_mut().extend(extra);
}
info!(" ✅ Aggregation completed");
info!("\n📋 Final State:");
let final_snapshot = state.snapshot();
info!(" Messages: {}", final_snapshot.messages.len());
for (i, msg) in final_snapshot.messages.iter().enumerate() {
info!(" {}: [{}] {}", i + 1, msg.role, msg.content);
}
info!(
" Extra data keys: {:?}",
final_snapshot.extra.keys().collect::<Vec<_>>()
);
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
event_bus.stop_listener().await;
info!("\n✅ Example completed successfully!");
Ok(())
}