use crate::state::State;
use futures::StreamExt;
use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub enum StreamMode {
#[default]
Values,
Updates,
Messages,
Custom,
Debug,
Tools,
Checkpoints,
Tasks,
Multi(Vec<StreamMode>),
}
#[derive(Clone, Debug)]
pub enum StreamEvent<S: State> {
Values { state: S, step: usize },
FilteredValues {
data: serde_json::Value,
step: usize,
},
Updates {
node: String,
update: S::Update,
step: usize,
},
FilteredUpdates {
node: String,
data: serde_json::Value,
step: usize,
},
Messages {
chunk: MessageChunk,
metadata: MessageStreamMetadata,
},
Custom {
node: String,
data: serde_json::Value,
ns: Vec<String>,
},
TaskStart {
node: String,
task_id: String,
step: usize,
},
TaskEnd {
node: String,
task_id: String,
step: usize,
duration_ms: u64,
},
Interrupt {
node: String,
payload: serde_json::Value,
resumable: bool,
ns: Vec<String>,
},
BudgetExceeded {
reason: crate::pregel::BudgetExceededReason,
usage: BudgetUsage,
},
End { output: S },
Cancelled { step: usize },
Debug(DebugEvent),
Tools(ToolsEvent),
CheckpointSaved {
checkpoint_id: String,
metadata: crate::checkpoint::CheckpointMetadata,
step: usize,
},
TaskDetail {
task_id: String,
node: String,
step: usize,
attempt: usize,
event: TaskEventType,
},
}
impl<S: State> StreamEvent<S> {
#[must_use]
#[allow(
clippy::match_same_arms,
reason = "each arm is explicit for clarity even when some return the same value"
)]
pub fn namespace(&self) -> &[String] {
match self {
Self::Custom { ns, .. } => ns,
Self::Messages { metadata, .. } => &metadata.ns,
Self::Interrupt { ns, .. } => ns,
Self::Values { .. }
| Self::FilteredValues { .. }
| Self::Updates { .. }
| Self::FilteredUpdates { .. }
| Self::TaskStart { .. }
| Self::TaskEnd { .. }
| Self::BudgetExceeded { .. }
| Self::End { .. }
| Self::Cancelled { .. }
| Self::Debug(_)
| Self::Tools(_)
| Self::CheckpointSaved { .. }
| Self::TaskDetail { .. } => &[],
}
}
}
#[derive(Clone)]
pub struct StreamPart<S: State> {
pub ns: Vec<String>,
pub event: &'static str,
pub data: StreamEvent<S>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
impl<S: State> std::fmt::Debug for StreamPart<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamPart")
.field("ns", &self.ns)
.field("event", &self.event)
.field("data", &"<StreamEvent>")
.field("metadata", &self.metadata)
.finish()
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MessageChunk {
pub content: String,
pub tool_call_chunks: Vec<ToolCallChunk>,
pub usage_delta: Option<crate::state::TokenUsage>,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallChunk {
pub id: Option<String>,
pub name: Option<String>,
pub args_delta: String,
pub index: usize,
}
#[derive(Clone, Debug)]
pub struct MessageStreamMetadata {
pub node: String,
pub model: String,
pub tags: Vec<String>,
pub ns: Vec<String>,
}
#[derive(Clone, Debug, serde::Serialize)]
pub enum DebugEvent {
GraphStart {
thread_id: String,
input: serde_json::Value,
},
SuperstepStart {
step: usize,
pending_nodes: Vec<String>,
},
SuperstepEnd { step: usize, duration_ms: u64 },
NodeStart { node: String, step: usize },
NodeEnd {
node: String,
step: usize,
duration_ms: u64,
output_type: String,
},
NodeError {
node: String,
step: usize,
error: String,
},
ChannelWrite {
channel: String,
node: String,
value_summary: String,
},
ChannelUpdate { channel: String, new_version: u64 },
Merge {
step: usize,
channels_updated: Vec<String>,
},
EdgeTraversed {
from: String,
to: String,
edge_type: String,
},
CheckpointSaved {
checkpoint_id: String,
step: usize,
source: String,
},
BudgetCheck {
tokens_used: u64,
cost_usd: f64,
budget_remaining_pct: f32,
},
GraphEnd {
total_steps: usize,
total_duration_ms: u64,
},
}
#[derive(Clone, Debug)]
pub enum ToolsEvent {
ToolStarted {
tool_name: String,
tool_call_id: String,
node: String,
input: serde_json::Value,
timestamp: chrono::DateTime<chrono::Utc>,
},
ToolOutputDelta {
tool_call_id: String,
delta: String,
},
ToolFinished {
tool_call_id: String,
output: serde_json::Value,
duration_ms: u64,
success: bool,
},
ToolError {
tool_call_id: String,
error: String,
},
}
#[derive(Clone, Debug)]
pub struct BudgetUsage {
pub tokens_used: u64,
pub cost_usd: f64,
pub duration_ms: u64,
pub steps_completed: usize,
}
#[derive(Clone, Debug)]
pub enum TaskEventType {
Started,
Completed { duration_ms: u64 },
Failed { error: String },
Retrying { attempt: usize },
}
pub trait StreamTransformer: Send + Sync + 'static {
#[must_use]
fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value>;
}
#[derive(Clone)]
pub struct EventEmitter<S: State> {
pub tx: tokio::sync::mpsc::Sender<StreamEvent<S>>,
pub mode: StreamMode,
ns: Vec<String>,
_phantom: std::marker::PhantomData<S>,
}
impl<S: State> std::fmt::Debug for EventEmitter<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EventEmitter")
.field("tx", &"<mpsc::Sender>")
.field("mode", &self.mode)
.field("ns", &self.ns)
.finish()
}
}
impl<S: State> EventEmitter<S> {
#[must_use]
pub const fn new(tx: tokio::sync::mpsc::Sender<StreamEvent<S>>, mode: StreamMode) -> Self {
Self {
tx,
mode,
ns: Vec::new(),
_phantom: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_subgraph_ns(&self, ns_segment: String) -> Self {
let mut new_ns = self.ns.clone();
new_ns.push(ns_segment);
Self {
tx: self.tx.clone(),
mode: self.mode.clone(),
ns: new_ns,
_phantom: std::marker::PhantomData,
}
}
#[must_use]
pub fn ns(&self) -> &[String] {
&self.ns
}
#[must_use]
pub const fn mode(&self) -> &StreamMode {
&self.mode
}
pub async fn emit(&self, event: StreamEvent<S>) {
let _ = self.tx.send(event).await;
}
#[must_use]
pub fn stream_writer(&self, node: String) -> StreamWriter<S> {
StreamWriter::new(self.tx.clone(), node, self.mode.clone())
}
#[must_use]
#[allow(clippy::match_same_arms, reason = "each arm is explicit for clarity")]
pub fn should_emit(&self, event: &StreamEvent<S>) -> bool {
match (&self.mode, event) {
(
StreamMode::Values,
StreamEvent::Values { .. }
| StreamEvent::FilteredValues { .. }
| StreamEvent::End { .. },
) => true,
(
StreamMode::Updates,
StreamEvent::Updates { .. }
| StreamEvent::FilteredUpdates { .. }
| StreamEvent::End { .. },
) => true,
(StreamMode::Messages, StreamEvent::Messages { .. } | StreamEvent::End { .. }) => {
if let StreamEvent::Messages { metadata, .. } = event {
!Self::has_nostream_tag_in_metadata(metadata)
} else {
true
}
}
(StreamMode::Custom, StreamEvent::Custom { .. } | StreamEvent::End { .. }) => true,
(StreamMode::Debug, _) => true, (StreamMode::Tools, StreamEvent::Tools(_) | StreamEvent::End { .. }) => true,
(
StreamMode::Checkpoints,
StreamEvent::CheckpointSaved { .. } | StreamEvent::End { .. },
) => true,
(StreamMode::Tasks, StreamEvent::TaskDetail { .. } | StreamEvent::End { .. }) => true,
(StreamMode::Multi(modes), _) => {
Self::mode_matches_multi(modes, event)
}
_ => false,
}
}
#[must_use]
fn has_nostream_tag_in_metadata(metadata: &MessageStreamMetadata) -> bool {
metadata.tags.iter().any(|tag| tag == "nostream")
}
#[must_use]
pub fn has_nostream_tag(&self, options: Option<&crate::llm::CallOptions>) -> bool {
options.is_some_and(|opts| opts.tags.iter().any(|tag| tag == "nostream"))
}
#[must_use]
fn mode_matches_multi(modes: &[StreamMode], event: &StreamEvent<S>) -> bool {
modes.iter().any(|m| Self::mode_matches_single(m, event))
}
#[must_use]
#[allow(
clippy::match_same_arms,
clippy::missing_const_for_fn,
reason = "each arm is explicit for clarity; non-const for multi-mode filtering"
)]
fn mode_matches_single(mode: &StreamMode, event: &StreamEvent<S>) -> bool {
match (mode, event) {
(
StreamMode::Values,
StreamEvent::Values { .. }
| StreamEvent::FilteredValues { .. }
| StreamEvent::End { .. },
) => true,
(
StreamMode::Updates,
StreamEvent::Updates { .. }
| StreamEvent::FilteredUpdates { .. }
| StreamEvent::End { .. },
) => true,
(StreamMode::Messages, StreamEvent::Messages { .. } | StreamEvent::End { .. }) => true,
(StreamMode::Custom, StreamEvent::Custom { .. } | StreamEvent::End { .. }) => true,
(StreamMode::Debug, _) => true,
(StreamMode::Tools, StreamEvent::Tools(_) | StreamEvent::End { .. }) => true,
(
StreamMode::Checkpoints,
StreamEvent::CheckpointSaved { .. } | StreamEvent::End { .. },
) => true,
(StreamMode::Tasks, StreamEvent::TaskDetail { .. } | StreamEvent::End { .. }) => true,
(StreamMode::Multi(_), _) => false,
_ => false,
}
}
}
#[derive(Clone)]
pub struct StreamWriter<S: State> {
tx: Option<tokio::sync::mpsc::Sender<StreamEvent<S>>>,
node: String,
mode: StreamMode,
ns: Vec<String>,
}
impl<S: State> StreamWriter<S> {
#[must_use]
pub const fn new(
tx: tokio::sync::mpsc::Sender<StreamEvent<S>>,
node: String,
mode: StreamMode,
) -> Self {
Self {
tx: Some(tx),
node,
mode,
ns: Vec::new(),
}
}
#[must_use]
pub const fn disconnected(node: String, mode: StreamMode) -> Self {
Self {
tx: None,
node,
mode,
ns: Vec::new(),
}
}
#[must_use]
pub fn with_ns(&self, ns_segment: String) -> Self {
let mut new_ns = self.ns.clone();
new_ns.push(ns_segment);
Self {
tx: self.tx.clone(),
node: self.node.clone(),
mode: self.mode.clone(),
ns: new_ns,
}
}
pub async fn send(&self, data: serde_json::Value) {
let Some(ref tx) = self.tx else {
return;
};
let event = StreamEvent::Custom {
node: self.node.clone(),
data,
ns: self.ns.clone(),
};
let emitter = EventEmitter::new(tx.clone(), self.mode.clone());
if emitter.should_emit(&event) {
let _ = tx.send(event).await;
}
}
}
impl<S: State> std::fmt::Debug for StreamWriter<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamWriter")
.field("tx", &self.tx.is_some())
.field("node", &self.node)
.field("mode", &self.mode)
.field("ns", &self.ns)
.finish()
}
}
pub async fn call_llm_streaming<S: State, M: crate::llm::ChatModel>(
model: &M,
messages: &[crate::state::Message],
options: Option<&crate::llm::CallOptions>,
emitter: &EventEmitter<S>,
node_name: &str,
) -> Result<crate::state::Message, crate::llm::LlmError> {
let mut stream = model.stream(messages, options).await?;
let mut full_content = String::new();
let mut tool_calls: Vec<crate::state::ToolCall> = Vec::new();
let mut total_usage = crate::state::TokenUsage::default();
#[allow(clippy::option_if_let_else, reason = "explicit match is clearer")]
let tags: Vec<String> = match options {
Some(opts) => opts.tags.clone(),
None => Vec::new(),
};
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
full_content.push_str(&chunk.content);
for tc_chunk in &chunk.tool_call_chunks {
while tool_calls.len() <= tc_chunk.index {
tool_calls.push(crate::state::ToolCall {
id: String::new(),
name: String::new(),
arguments: serde_json::Value::Null,
});
}
let tc = &mut tool_calls[tc_chunk.index];
if let Some(ref id) = tc_chunk.id {
id.clone_into(&mut tc.id);
}
if let Some(ref name) = tc_chunk.name {
name.clone_into(&mut tc.name);
}
if !tc_chunk.args_delta.is_empty() {
match &mut tc.arguments {
serde_json::Value::String(s) => s.push_str(&tc_chunk.args_delta),
serde_json::Value::Null => {
tc.arguments = serde_json::Value::String(tc_chunk.args_delta.clone());
}
other => {
let mut s = match std::mem::replace(other, serde_json::Value::Null) {
serde_json::Value::String(existing) => existing,
_ => String::new(),
};
s.push_str(&tc_chunk.args_delta);
*other = serde_json::Value::String(s);
}
}
}
}
if let Some(ref usage) = chunk.usage {
total_usage.input_tokens += usage.input_tokens;
total_usage.output_tokens += usage.output_tokens;
total_usage.total_tokens += usage.total_tokens;
}
let stream_chunk = MessageChunk {
content: chunk.content,
tool_call_chunks: chunk.tool_call_chunks,
usage_delta: chunk.usage,
};
let event = StreamEvent::Messages {
chunk: stream_chunk,
metadata: MessageStreamMetadata {
node: node_name.to_string(),
model: model.model_name().to_string(),
tags: tags.clone(),
ns: emitter.ns().to_vec(),
},
};
if emitter.should_emit(&event) {
emitter.emit(event).await;
}
}
for tc in &mut tool_calls {
if let serde_json::Value::String(s) = &tc.arguments {
tc.arguments = serde_json::from_str(s).unwrap_or_else(|_| {
serde_json::Value::String(std::mem::take(&mut tc.arguments).to_string())
});
}
}
total_usage.total_tokens = total_usage.input_tokens + total_usage.output_tokens;
Ok(crate::state::Message {
id: uuid::Uuid::new_v4().to_string(),
role: crate::state::Role::Ai,
content: crate::state::Content::Text(full_content),
tool_calls,
tool_call_id: None,
name: None,
usage: Some(total_usage),
})
}
#[derive(Clone, Debug)]
pub struct MessageBatchConfig {
pub max_chunks: usize,
pub flush_interval_ms: Option<u64>,
}
impl Default for MessageBatchConfig {
fn default() -> Self {
Self {
max_chunks: 10,
flush_interval_ms: Some(100),
}
}
}
impl MessageBatchConfig {
#[must_use]
pub const fn new(max_chunks: usize, flush_interval_ms: Option<u64>) -> Self {
Self {
max_chunks,
flush_interval_ms,
}
}
#[must_use]
pub const fn no_batching() -> Self {
Self {
max_chunks: 1,
flush_interval_ms: None,
}
}
}
pub(crate) fn filter_json_by_keys(value: serde_json::Value, keys: &[String]) -> serde_json::Value {
if keys.is_empty() {
return value;
}
match value {
serde_json::Value::Object(mut map) => {
let keep: std::collections::HashSet<&String> = keys.iter().collect();
map.retain(|k, _| keep.contains(k));
serde_json::Value::Object(map)
}
other => other,
}
}
#[derive(Clone, Debug, Default)]
pub struct StreamConfig {
pub mode: StreamMode,
pub include_subgraphs: bool,
pub subgraph_filter: Option<Vec<String>>,
pub output_keys: Option<Vec<String>>,
pub message_batch_config: MessageBatchConfig,
pub resumption: Option<StreamResumption>,
}
impl StreamConfig {
#[must_use]
pub const fn new(mode: StreamMode) -> Self {
Self {
mode,
include_subgraphs: false,
subgraph_filter: None,
output_keys: None,
message_batch_config: MessageBatchConfig {
max_chunks: 10,
flush_interval_ms: Some(100),
},
resumption: None,
}
}
#[must_use]
pub const fn with_subgraphs(mut self, include: bool) -> Self {
self.include_subgraphs = include;
self
}
#[must_use]
pub fn with_subgraph_filter(mut self, filter: Vec<String>) -> Self {
self.subgraph_filter = Some(filter);
self
}
#[must_use]
pub fn with_output_keys(mut self, keys: Vec<String>) -> Self {
self.output_keys = Some(keys);
self
}
#[must_use]
pub const fn with_message_batch_config(mut self, config: MessageBatchConfig) -> Self {
self.message_batch_config = config;
self
}
#[must_use]
pub fn with_resumption(mut self, resumption: StreamResumption) -> Self {
self.resumption = Some(resumption);
self
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct StreamResumption {
pub run_id: String,
pub last_checkpoint_id: Option<String>,
pub last_step: Option<usize>,
}
impl StreamResumption {
#[must_use]
pub const fn new(
run_id: String,
last_checkpoint_id: Option<String>,
last_step: Option<usize>,
) -> Self {
Self {
run_id,
last_checkpoint_id,
last_step,
}
}
#[must_use]
pub const fn should_skip(&self, current_step: usize) -> bool {
match self.last_step {
Some(last_step) => current_step <= last_step,
None => false,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct JsonParseTransformer;
impl JsonParseTransformer {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl StreamTransformer for JsonParseTransformer {
#[allow(
clippy::option_if_let_else,
reason = "project rules prohibit map_or with unwrap; match is explicit and readable"
)]
fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value> {
match data {
serde_json::Value::String(s) => match serde_json::from_str(&s) {
Ok(v) => Some(v),
Err(_) => Some(serde_json::Value::Null),
},
_ => Some(data),
}
}
}
#[derive(Clone, Debug)]
pub struct FilterFieldsTransformer {
pub fields: Vec<String>,
}
impl FilterFieldsTransformer {
#[must_use]
pub const fn new(fields: Vec<String>) -> Self {
Self { fields }
}
}
impl StreamTransformer for FilterFieldsTransformer {
fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value> {
match data {
serde_json::Value::Object(mut map) => {
let keys_to_keep: std::collections::HashSet<_> = self.fields.iter().collect();
map.retain(|k, _| keys_to_keep.contains(k));
Some(serde_json::Value::Object(map))
}
_ => Some(data),
}
}
}
#[derive(Debug)]
pub struct BatchTransformer {
pub size: usize,
buffer: Mutex<Vec<serde_json::Value>>,
}
impl BatchTransformer {
#[must_use]
pub fn new(size: usize) -> Self {
Self {
size: size.max(1),
buffer: Mutex::new(Vec::new()),
}
}
#[must_use]
pub fn flush(&self) -> Option<serde_json::Value> {
let mut buffer = self.buffer.lock().expect("BatchTransformer buffer lock");
if buffer.is_empty() {
return None;
}
let items = std::mem::take(&mut *buffer);
drop(buffer);
Some(serde_json::Value::Array(items))
}
}
impl Clone for BatchTransformer {
fn clone(&self) -> Self {
Self {
size: self.size,
buffer: Mutex::new(Vec::new()),
}
}
}
impl StreamTransformer for BatchTransformer {
fn transform(&self, data: serde_json::Value) -> Option<serde_json::Value> {
let mut buffer = self.buffer.lock().expect("BatchTransformer buffer lock");
buffer.push(data);
let items = (buffer.len() >= self.size).then(|| std::mem::take(&mut *buffer));
drop(buffer);
items.map(serde_json::Value::Array)
}
}
#[cfg(test)]
mod tests {
use super::{
BatchTransformer, EventEmitter, MessageBatchConfig, MessageChunk, MessageStreamMetadata,
StreamConfig, StreamEvent, StreamMode, StreamResumption, StreamTransformer, ToolsEvent,
};
use crate::state::{FieldVersions, FieldsChanged, State};
#[derive(Clone, Debug, Default)]
struct TestState;
impl State for TestState {
type Update = TestStateUpdate;
type FieldVersions = FieldVersions;
fn apply(&mut self, _update: Self::Update) -> FieldsChanged {
FieldsChanged(0)
}
fn reset_ephemeral(&mut self) {}
}
#[derive(Clone, Debug, Default)]
struct TestStateUpdate;
#[test]
fn message_batch_config_default() {
let config = MessageBatchConfig::default();
assert_eq!(config.max_chunks, 10);
assert_eq!(config.flush_interval_ms, Some(100));
}
#[test]
fn message_batch_config_no_batching() {
let config = MessageBatchConfig::no_batching();
assert_eq!(config.max_chunks, 1);
assert_eq!(config.flush_interval_ms, None);
}
#[test]
fn message_batch_config_new_custom() {
let config = MessageBatchConfig::new(50, Some(200));
assert_eq!(config.max_chunks, 50);
assert_eq!(config.flush_interval_ms, Some(200));
}
#[test]
fn resumption_should_skip_returns_true_when_step_at_last_step() {
let r = StreamResumption::new("run1".to_string(), None, Some(3));
assert!(r.should_skip(3));
}
#[test]
fn resumption_should_skip_returns_true_when_step_before_last_step() {
let r = StreamResumption::new("run1".to_string(), None, Some(3));
assert!(r.should_skip(2));
assert!(r.should_skip(0));
}
#[test]
fn resumption_should_skip_returns_false_when_step_after_last_step() {
let r = StreamResumption::new("run1".to_string(), None, Some(3));
assert!(!r.should_skip(4));
assert!(!r.should_skip(100));
}
#[test]
fn resumption_should_skip_returns_false_when_last_step_is_none() {
let r = StreamResumption::new("run1".to_string(), None, None);
assert!(!r.should_skip(0));
assert!(!r.should_skip(100));
}
#[test]
fn stream_config_default_has_no_resumption() {
let config = StreamConfig::default();
assert!(config.resumption.is_none());
}
#[test]
fn stream_config_new_has_no_resumption() {
let config = StreamConfig::new(StreamMode::Values);
assert!(config.resumption.is_none());
}
#[test]
fn stream_config_with_resumption_sets_field() {
let r = StreamResumption::new("run1".to_string(), Some("cp-5".to_string()), Some(5));
let config = StreamConfig::new(StreamMode::Values).with_resumption(r);
assert!(config.resumption.is_some());
let resumption = config.resumption.expect("resumption should be set");
assert_eq!(resumption.run_id, "run1");
assert_eq!(resumption.last_checkpoint_id, Some("cp-5".to_string()));
assert_eq!(resumption.last_step, Some(5));
}
#[test]
fn should_emit_messages_event_without_nostream() {
let (tx, _rx) = tokio::sync::mpsc::channel(16);
let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
let event = StreamEvent::Messages {
chunk: MessageChunk {
content: "hello".to_string(),
tool_call_chunks: Vec::new(),
usage_delta: None,
},
metadata: MessageStreamMetadata {
node: "agent".to_string(),
model: "test".to_string(),
tags: vec![],
ns: Vec::new(),
},
};
assert!(emitter.should_emit(&event));
}
#[test]
fn should_emit_messages_event_with_nostream_suppressed() {
let (tx, _rx) = tokio::sync::mpsc::channel(16);
let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
let event = StreamEvent::Messages {
chunk: MessageChunk {
content: "hello".to_string(),
tool_call_chunks: Vec::new(),
usage_delta: None,
},
metadata: MessageStreamMetadata {
node: "agent".to_string(),
model: "test".to_string(),
tags: vec!["nostream".to_string()],
ns: Vec::new(),
},
};
assert!(!emitter.should_emit(&event));
}
#[test]
fn should_emit_messages_event_with_other_tags_not_suppressed() {
let (tx, _rx) = tokio::sync::mpsc::channel(16);
let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
let event = StreamEvent::Messages {
chunk: MessageChunk {
content: "hello".to_string(),
tool_call_chunks: Vec::new(),
usage_delta: None,
},
metadata: MessageStreamMetadata {
node: "agent".to_string(),
model: "test".to_string(),
tags: vec!["fast".to_string(), "stream".to_string()],
ns: Vec::new(),
},
};
assert!(emitter.should_emit(&event));
}
#[test]
fn should_emit_end_event_always_in_messages_mode() {
let (tx, _rx) = tokio::sync::mpsc::channel(16);
let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Messages);
let event = StreamEvent::End { output: TestState };
assert!(emitter.should_emit(&event));
}
#[test]
fn should_emit_tools_event_in_tools_mode() {
let (tx, _rx) = tokio::sync::mpsc::channel(16);
let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Tools);
let event = StreamEvent::Tools(ToolsEvent::ToolStarted {
tool_name: "search".to_string(),
tool_call_id: "call_1".to_string(),
node: "tools".to_string(),
input: serde_json::json!({}),
timestamp: chrono::Utc::now(),
});
assert!(emitter.should_emit(&event));
}
#[test]
fn should_emit_tool_output_delta_in_tools_mode() {
let (tx, _rx) = tokio::sync::mpsc::channel(16);
let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Tools);
let event = StreamEvent::Tools(ToolsEvent::ToolOutputDelta {
tool_call_id: "call_1".to_string(),
delta: "partial".to_string(),
});
assert!(emitter.should_emit(&event));
}
#[test]
fn should_emit_tool_finished_in_tools_mode() {
let (tx, _rx) = tokio::sync::mpsc::channel(16);
let emitter = EventEmitter::<TestState>::new(tx, StreamMode::Tools);
let event = StreamEvent::Tools(ToolsEvent::ToolFinished {
tool_call_id: "call_1".to_string(),
output: serde_json::json!({"result": "ok"}),
duration_ms: 100,
success: true,
});
assert!(emitter.should_emit(&event));
}
#[test]
fn batch_transformer_emits_batch_when_max_size_reached() {
let transformer = BatchTransformer::new(3);
let item = serde_json::json!({"token": "hello"});
assert!(transformer.transform(item.clone()).is_none());
assert!(transformer.transform(item.clone()).is_none());
let result = transformer.transform(item);
assert!(result.is_some());
let batch = result.expect("batch should be emitted");
assert!(batch.is_array());
assert_eq!(batch.as_array().expect("batch should be an array").len(), 3);
}
#[test]
fn batch_transformer_returns_none_below_threshold() {
let transformer = BatchTransformer::new(5);
let item = serde_json::json!("test");
assert!(transformer.transform(item).is_none());
}
#[test]
fn batch_transformer_flush_returns_remaining() {
let transformer = BatchTransformer::new(10);
let item = serde_json::json!("data");
let _ = transformer.transform(item.clone());
let _ = transformer.transform(item.clone());
let _ = transformer.transform(item);
let flushed = transformer.flush();
assert!(flushed.is_some());
let batch = flushed.expect("flush should return items");
assert_eq!(
batch.as_array().expect("flush should return array").len(),
3
);
}
#[test]
fn batch_transformer_flush_empty_returns_none() {
let transformer = BatchTransformer::new(10);
assert!(transformer.flush().is_none());
}
#[test]
fn batch_transformer_size_one_emits_immediately() {
let transformer = BatchTransformer::new(1);
let result = transformer.transform(serde_json::json!("single"));
assert!(result.is_some());
let batch = result.expect("batch should be emitted");
assert_eq!(batch.as_array().expect("batch should be array").len(), 1);
}
#[test]
fn batch_transformer_size_zero_clamped_to_one() {
let transformer = BatchTransformer::new(0);
let result = transformer.transform(serde_json::json!("clamped"));
assert!(result.is_some());
let batch = result.expect("batch should be emitted immediately");
assert_eq!(batch.as_array().expect("batch should be array").len(), 1);
}
#[test]
fn batch_transformer_multiple_batches() {
let transformer = BatchTransformer::new(2);
let item = serde_json::json!("x");
assert!(transformer.transform(item.clone()).is_none());
let batch1 = transformer.transform(item.clone());
assert!(batch1.is_some());
assert_eq!(
batch1
.expect("batch1")
.as_array()
.expect("batch1 array")
.len(),
2
);
assert!(transformer.transform(item.clone()).is_none());
let batch2 = transformer.transform(item);
assert!(batch2.is_some());
assert_eq!(
batch2
.expect("batch2")
.as_array()
.expect("batch2 array")
.len(),
2
);
}
#[test]
fn batch_transformer_clone_maintains_independent_buffer() {
let transformer = BatchTransformer::new(3);
let item = serde_json::json!("x");
let _ = transformer.transform(item);
let cloned = transformer.clone();
let flushed_original = transformer.flush();
assert!(flushed_original.is_some());
assert_eq!(
flushed_original
.expect("original flush")
.as_array()
.expect("original flush array")
.len(),
1
);
assert!(cloned.flush().is_none());
}
}