use super::*;
use crate::ids::RunId;
use crate::llm::{FinishReason, LlmClient};
use crate::memory::EpisodicMemory;
use crate::test_utils::{
noop_bus, FakeLlmClient, FakeStreamStep, FakeToolInvoker, InMemoryEpisodic, InMemoryLongTerm,
InMemoryShortTerm,
};
use std::sync::Arc;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
fn ctx_with(llm: Arc<dyn LlmClient>) -> (AgentContext, Arc<InMemoryEpisodic>) {
let (pubsub, request_reply, kv, jobs) = noop_bus();
let episodic = Arc::new(InMemoryEpisodic::default());
(
AgentContext {
llm,
short_term: Arc::new(InMemoryShortTerm::default()),
long_term: Arc::new(InMemoryLongTerm::default()),
episodic: episodic.clone(),
pubsub,
kv,
request_reply,
jobs,
tools: Arc::new(FakeToolInvoker::new()),
run_id: RunId::new(),
cancel: CancellationToken::new(),
agent_name: "stream-test".into(),
},
episodic,
)
}
fn ctx_with_tools(
llm: Arc<dyn LlmClient>,
tools: Arc<FakeToolInvoker>,
) -> (AgentContext, Arc<InMemoryEpisodic>) {
let (pubsub, request_reply, kv, jobs) = noop_bus();
let episodic = Arc::new(InMemoryEpisodic::default());
(
AgentContext {
llm,
short_term: Arc::new(InMemoryShortTerm::default()),
long_term: Arc::new(InMemoryLongTerm::default()),
episodic: episodic.clone(),
pubsub,
kv,
request_reply,
jobs,
tools,
run_id: RunId::new(),
cancel: CancellationToken::new(),
agent_name: "stream-test".into(),
},
episodic,
)
}
fn delta(s: &str) -> ChatChunk {
ChatChunk {
delta: s.into(),
tool_calls: vec![],
finish_reason: None,
}
}
fn final_stop_chunk() -> ChatChunk {
ChatChunk {
delta: String::new(),
tool_calls: vec![],
finish_reason: Some(FinishReason::Stop),
}
}
#[tokio::test(start_paused = true)]
async fn streaming_single_step_forwards_three_chunks() {
let llm = Arc::new(
FakeLlmClient::new("fake").with_stream_steps(vec![FakeStreamStep::Chunks(vec![
Ok(delta("hel")),
Ok(delta("lo ")),
Ok(delta("world")),
Ok(final_stop_chunk()),
])]),
);
let (ctx, ep) = ctx_with(llm);
let mut s = run_steps_streaming(&ctx, "sys", ThreadId::new("t"), RunOptions::default())
.await
.expect("stream opens");
let mut chunks = Vec::new();
while let Some(item) = s.next().await {
chunks.push(item.expect("ok chunk"));
}
assert_eq!(chunks.len(), 4, "3 deltas + terminal");
assert_eq!(chunks[0].delta, "hel");
assert_eq!(chunks[1].delta, "lo ");
assert_eq!(chunks[2].delta, "world");
assert_eq!(chunks[3].finish_reason, Some(FinishReason::Stop));
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let events = ep.replay(ctx.run_id).await.unwrap();
let llm_calls = events
.iter()
.filter(|e| matches!(e, Episode::LlmCall { .. }))
.count();
let completed = events
.iter()
.filter(|e| matches!(e, Episode::Completed))
.count();
assert_eq!(llm_calls, 1);
assert_eq!(completed, 1);
}
#[tokio::test(start_paused = true)]
async fn streaming_tool_call_dispatch_starts_second_cycle() {
let tool_call = crate::llm::ToolCall {
id: "tc1".into(),
name: "echo".into(),
args: serde_json::json!({"x": 1}),
};
let stream1 = vec![Ok(ChatChunk {
delta: String::new(),
tool_calls: vec![tool_call.clone()],
finish_reason: Some(FinishReason::ToolCalls),
})];
let stream2 = vec![Ok(delta("done")), Ok(final_stop_chunk())];
let llm = Arc::new(FakeLlmClient::new("fake").with_stream_steps(vec![
FakeStreamStep::Chunks(stream1),
FakeStreamStep::Chunks(stream2),
]));
let tools = Arc::new(FakeToolInvoker::new().with_tool("echo", "echoes", Ok));
let (ctx, ep) = ctx_with_tools(llm, tools);
let mut s = run_steps_streaming(&ctx, "sys", ThreadId::new("t"), RunOptions::default())
.await
.expect("stream opens");
let mut deltas = Vec::new();
while let Some(item) = s.next().await {
deltas.push(item.expect("ok chunk"));
}
assert!(deltas.len() >= 3);
let joined: String = deltas.iter().map(|c| c.delta.clone()).collect();
assert_eq!(joined, "done");
let events = ep.replay(ctx.run_id).await.unwrap();
let llm_calls = events
.iter()
.filter(|e| matches!(e, Episode::LlmCall { .. }))
.count();
let tool_calls = events
.iter()
.filter(|e| matches!(e, Episode::ToolCall { .. }))
.count();
assert_eq!(llm_calls, 2, "two LLM cycles recorded");
assert_eq!(tool_calls, 1, "one tool dispatched");
}
#[tokio::test(start_paused = true)]
async fn streaming_mid_stream_cancellation_emits_cancelled_error() {
let mut chunks: Vec<Result<ChatChunk, LlmError>> = Vec::new();
for i in 0..50 {
chunks.push(Ok(delta(&format!("c{i} "))));
}
chunks.push(Ok(final_stop_chunk()));
let llm = Arc::new(
FakeLlmClient::new("fake").with_stream_steps(vec![FakeStreamStep::Chunks(chunks)]),
);
let (ctx, ep) = ctx_with(llm);
let mut s = run_steps_streaming(&ctx, "sys", ThreadId::new("t"), RunOptions::default())
.await
.expect("stream opens");
let first = s.next().await.expect("first chunk").expect("ok");
assert!(!first.delta.is_empty());
ctx.cancel.cancel();
let mut got_cancel_err = false;
let mut tail_count = 0usize;
while let Some(item) = s.next().await {
tail_count += 1;
match item {
Err(LlmError::Cancelled) => {
got_cancel_err = true;
break;
}
Ok(_) => continue,
Err(other) => panic!("unexpected error: {other:?}"),
}
}
assert!(
got_cancel_err,
"expected terminal LlmError::Cancelled item, drained {tail_count} items"
);
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let events = ep.replay(ctx.run_id).await.unwrap();
let llm_calls = events
.iter()
.filter(|e| matches!(e, Episode::LlmCall { .. }))
.count();
assert_eq!(
llm_calls, 0,
"partial streaming cycle must not be recorded as a successful LlmCall"
);
let completed = events
.iter()
.filter(|e| matches!(e, Episode::Completed))
.count();
assert_eq!(completed, 0);
}
#[tokio::test(start_paused = true)]
async fn streaming_init_retries_server_then_succeeds() {
let llm = Arc::new(FakeLlmClient::new("fake").with_stream_steps(vec![
FakeStreamStep::init_err(LlmError::Server("503".into())),
FakeStreamStep::Chunks(vec![Ok(delta("ok")), Ok(final_stop_chunk())]),
]));
let (ctx, _ep) = ctx_with(llm.clone());
let mut s = run_steps_streaming(&ctx, "sys", ThreadId::new("t"), RunOptions::default())
.await
.expect("stream opens after retry");
while let Some(item) = s.next().await {
item.expect("ok");
}
assert_eq!(llm.stream_call_count(), 2, "one failure + one success");
}
#[tokio::test(start_paused = true)]
async fn streaming_init_unauthorized_propagates_after_one_attempt() {
let llm = Arc::new(
FakeLlmClient::new("fake")
.with_stream_steps(vec![FakeStreamStep::init_err(LlmError::Unauthorized)]),
);
let (ctx, _ep) = ctx_with(llm.clone());
let res = run_steps_streaming(&ctx, "sys", ThreadId::new("t"), RunOptions::default()).await;
match res {
Ok(_) => panic!("expected Unauthorized, got Ok"),
Err(Error::Llm(LlmError::Unauthorized)) => {}
Err(other) => panic!("expected Llm(Unauthorized), got {other:?}"),
}
assert_eq!(
llm.stream_call_count(),
1,
"non-retryable: exactly one attempt"
);
}
#[tokio::test(start_paused = true)]
async fn streaming_max_steps_exceeded_when_tools_loop() {
let tool_call = crate::llm::ToolCall {
id: "tc1".into(),
name: "echo".into(),
args: serde_json::json!({}),
};
let mk_tool_stream = || {
FakeStreamStep::Chunks(vec![Ok(ChatChunk {
delta: String::new(),
tool_calls: vec![tool_call.clone()],
finish_reason: Some(FinishReason::ToolCalls),
})])
};
let llm = Arc::new(FakeLlmClient::new("fake").with_stream_steps(vec![
mk_tool_stream(),
mk_tool_stream(),
mk_tool_stream(),
mk_tool_stream(),
mk_tool_stream(),
]));
let tools = Arc::new(
FakeToolInvoker::new().with_tool("echo", "echoes", |_| Ok(serde_json::json!("ok"))),
);
let (ctx, ep) = ctx_with_tools(llm, tools);
let opts = RunOptions {
max_steps: 2,
..RunOptions::default()
};
let mut s = run_steps_streaming(&ctx, "sys", ThreadId::new("t"), opts)
.await
.expect("stream opens");
let mut got_max_steps_err = false;
while let Some(item) = s.next().await {
if let Err(LlmError::Server(ref m)) = item {
if m.contains("max steps exceeded") {
got_max_steps_err = true;
}
}
}
assert!(
got_max_steps_err,
"expected terminal Err(Server(\"max steps exceeded: ...\")) before stream-close"
);
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let events = ep.replay(ctx.run_id).await.unwrap();
let failed = events
.iter()
.find(|e| matches!(e, Episode::Failed { .. }))
.expect("must record Failed");
match failed {
Episode::Failed { error } => {
assert!(
error.contains("max steps"),
"expected max-steps message, got {error}"
);
}
_ => unreachable!(),
}
}
#[tokio::test(start_paused = true)]
async fn streaming_response_byte_cap_aborts_oversized_stream() {
let mut chunks: Vec<Result<ChatChunk, LlmError>> = Vec::new();
for i in 0..20 {
chunks.push(Ok(delta(&format!("payload-block-{i:02}"))));
}
chunks.push(Ok(final_stop_chunk()));
let llm = Arc::new(
FakeLlmClient::new("fake").with_stream_steps(vec![FakeStreamStep::Chunks(chunks)]),
);
let (ctx, ep) = ctx_with(llm);
let opts = RunOptions {
max_response_bytes: 64,
..RunOptions::default()
};
let mut s = run_steps_streaming(&ctx, "sys", ThreadId::new("t"), opts)
.await
.expect("stream opens");
let mut saw_cap_err = false;
while let Some(item) = s.next().await {
if let Err(LlmError::Server(ref m)) = item {
if m.contains("max_response_bytes cap") {
saw_cap_err = true;
break;
}
}
}
assert!(
saw_cap_err,
"expected terminal Err carrying the byte-cap message"
);
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let events = ep.replay(ctx.run_id).await.unwrap();
let saw_failed = events.iter().any(
|e| matches!(e, Episode::Failed { error } if error.contains("max_response_bytes cap")),
);
assert!(
saw_failed,
"expected Failed episode mentioning max_response_bytes cap, events: {events:?}"
);
}
mod terminal_chunk_for_tests {
use super::super::terminal_chunk_for;
use crate::error::{Error, LlmError, MemoryError, ToolError};
#[test]
fn cancelled_with_cancel_observed_yields_cancelled_episode() {
let (chunk, episode) = terminal_chunk_for(Error::Cancelled, true);
assert!(matches!(chunk, LlmError::Cancelled));
assert_eq!(episode, "cancelled");
}
#[test]
fn cancelled_without_cancel_observed_yields_consumer_dropped_episode() {
let (chunk, episode) = terminal_chunk_for(Error::Cancelled, false);
assert!(matches!(chunk, LlmError::Cancelled));
assert_eq!(episode, "consumer-dropped");
}
#[test]
fn max_steps_exceeded_renders_step_count() {
let (chunk, episode) = terminal_chunk_for(Error::MaxStepsExceeded { steps: 7 }, false);
let message = match chunk {
LlmError::Server(m) => m,
other => panic!("expected Server, got {other:?}"),
};
assert_eq!(message, "max steps exceeded: 7");
assert_eq!(episode, "max steps exceeded: 7");
}
#[test]
fn tool_error_prefixed_with_tool() {
let (chunk, episode) =
terminal_chunk_for(Error::Tool(ToolError::Permanent("boom".into())), false);
let message = match chunk {
LlmError::Server(m) => m,
other => panic!("expected Server, got {other:?}"),
};
assert!(message.starts_with("tool: "), "got {message}");
assert!(message.contains("boom"));
assert_eq!(message, episode);
}
#[test]
fn llm_variant_preserved_end_to_end_not_rewrapped_as_server() {
let (chunk, episode) = terminal_chunk_for(Error::Llm(LlmError::Unauthorized), false);
assert!(
matches!(chunk, LlmError::Unauthorized),
"Llm-wrapped variants must round-trip unchanged"
);
assert!(episode.to_lowercase().contains("unauthorized"));
}
#[test]
fn llm_rate_limit_preserves_retry_after_seconds() {
let (chunk, _) = terminal_chunk_for(
Error::Llm(LlmError::RateLimit {
retry_after_secs: 11,
}),
false,
);
match chunk {
LlmError::RateLimit { retry_after_secs } => assert_eq!(retry_after_secs, 11),
other => panic!("expected RateLimit, got {other:?}"),
}
}
#[test]
fn refused_prefixed_with_refused() {
let (chunk, episode) = terminal_chunk_for(
Error::Refused {
reason: "policy".into(),
},
false,
);
let message = match chunk {
LlmError::Server(m) => m,
other => panic!("expected Server, got {other:?}"),
};
assert_eq!(message, "refused: policy");
assert_eq!(message, episode);
}
#[test]
fn handoff_includes_target_agent_and_reason() {
let (chunk, episode) = terminal_chunk_for(
Error::Handoff {
agent: "safety".into(),
reason: "needs human review".into(),
},
false,
);
let message = match chunk {
LlmError::Server(m) => m,
other => panic!("expected Server, got {other:?}"),
};
assert_eq!(message, "handoff to safety: needs human review");
assert_eq!(message, episode);
}
#[test]
fn unmatched_variant_falls_through_to_server_with_display() {
let (chunk, episode) =
terminal_chunk_for(Error::Memory(MemoryError::Store("disk full".into())), false);
let message = match chunk {
LlmError::Server(m) => m,
other => panic!("expected Server, got {other:?}"),
};
assert!(message.contains("disk full"), "got {message}");
assert_eq!(message, episode);
}
}