use crabllm_core::{FinishReason, FunctionCall, Role, ToolCall};
use crabtalk_core::{
AgentBuilder, AgentConfig, AgentEvent, AgentStopReason, ToolDispatcher, ToolFuture,
model::{HistoryEntry, Model},
testing::provider::{
TestProvider, finish_chunk, mixed_chunk, text_chunk, text_chunks, text_response,
thinking_chunk, tool_chunks, tool_response,
},
};
use futures_util::StreamExt;
use std::{future::Future, pin::Pin, sync::Arc};
use tokio::sync::mpsc;
type BoxFut = Pin<Box<dyn Future<Output = Result<String, String>> + Send>>;
struct FnDispatcher<F>(F);
impl<F> ToolDispatcher for FnDispatcher<F>
where
F: Fn(String) -> BoxFut + Send + Sync + 'static,
{
fn dispatch<'a>(
&'a self,
name: &'a str,
_args: &'a str,
_agent: &'a str,
_sender: &'a str,
_conversation_id: Option<u64>,
) -> ToolFuture<'a> {
(self.0)(name.to_owned())
}
}
fn dispatcher<F>(f: F) -> Arc<dyn ToolDispatcher>
where
F: Fn(String) -> BoxFut + Send + Sync + 'static,
{
Arc::new(FnDispatcher(f))
}
fn make_tool_call(name: &str, args: &str) -> ToolCall {
ToolCall {
index: Some(0),
id: format!("call_{name}"),
function: FunctionCall {
name: name.into(),
arguments: args.into(),
},
..Default::default()
}
}
fn build_agent_no_tools(model: TestProvider) -> crabtalk_core::Agent<TestProvider> {
AgentBuilder::new(Model::new(model))
.config(AgentConfig::new("test-agent"))
.build()
}
#[tokio::test]
async fn step_text_response_appends_to_history() {
let model = TestProvider::new(vec![text_response("hello world")]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let step = agent.step(&mut history, None).await.unwrap();
assert_eq!(history.len(), 2); assert_eq!(*history[1].role(), Role::Assistant);
assert_eq!(history[1].text(), "hello world");
assert!(step.tool_calls.is_empty());
assert!(step.tool_results.is_empty());
}
#[tokio::test]
async fn step_tool_calls_dispatched_and_appended() {
let calls = vec![make_tool_call("bash", r#"{"command":"ls"}"#)];
let model = TestProvider::new(vec![tool_response(calls)]);
let agent = AgentBuilder::new(Model::new(model))
.config(AgentConfig::new("test-agent"))
.dispatcher(dispatcher(|name| {
Box::pin(async move { Ok(format!("result for {name}")) })
}))
.build();
let mut history = vec![HistoryEntry::user("run ls")];
let step = agent.step(&mut history, None).await.unwrap();
assert_eq!(step.tool_calls.len(), 1);
assert_eq!(step.tool_calls[0].function.name, "bash");
assert_eq!(step.tool_results.len(), 1);
assert_eq!(step.tool_results[0].text(), "result for bash");
assert_eq!(history.len(), 3);
}
#[tokio::test]
async fn step_no_dispatcher_returns_error_message() {
let calls = vec![make_tool_call("bash", "{}")];
let model = TestProvider::new(vec![tool_response(calls)]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let step = agent.step(&mut history, None).await.unwrap();
assert_eq!(step.tool_results.len(), 1);
assert!(
step.tool_results[0]
.text()
.contains("no tool dispatcher configured")
);
}
#[tokio::test]
async fn step_send_error_propagates() {
let model = TestProvider::new(vec![]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let result = agent.step(&mut history, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn run_stream_surfaces_handler_error_end_to_end() {
let calls = vec![make_tool_call("bash", "{}")];
let model = TestProvider::with_chunks(vec![tool_chunks(calls), text_chunks("ack")]);
let agent = AgentBuilder::new(Model::new(model))
.config(AgentConfig::new("test-agent"))
.dispatcher(dispatcher(|name| {
Box::pin(async move { Err(format!("{name} blew up")) })
}))
.build();
let mut history = vec![HistoryEntry::user("go")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let tool_results: Vec<&AgentEvent> = events
.iter()
.filter(|e| matches!(e, AgentEvent::ToolResult { .. }))
.collect();
assert_eq!(tool_results.len(), 1);
match tool_results[0] {
AgentEvent::ToolResult { output, .. } => {
assert_eq!(
output.as_ref().err().map(String::as_str),
Some("bash blew up")
);
}
_ => unreachable!(),
}
if let AgentEvent::Done(resp) = events.last().unwrap() {
let tool_entry = resp.steps[0]
.tool_results
.last()
.expect("step has a tool_result entry");
assert_eq!(tool_entry.text(), "bash blew up");
} else {
panic!("last event should be Done");
}
}
#[tokio::test]
async fn run_stream_text_response() {
let model = TestProvider::with_chunks(vec![text_chunks("hello")]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let done = events.last().unwrap();
if let AgentEvent::Done(resp) = done {
assert_eq!(resp.stop_reason, AgentStopReason::TextResponse);
assert_eq!(resp.final_response.as_deref(), Some("hello"));
assert_eq!(resp.iterations, 1);
} else {
panic!("last event should be Done");
}
let text_deltas: Vec<&str> = events
.iter()
.filter_map(|e| match e {
AgentEvent::TextDelta(s) => Some(s.as_str()),
_ => None,
})
.collect();
assert_eq!(text_deltas.join(""), "hello");
}
#[tokio::test]
async fn run_stream_tool_call_then_text() {
let calls = vec![make_tool_call("recall", r#"{"query":"test"}"#)];
let model = TestProvider::with_chunks(vec![tool_chunks(calls), text_chunks("answer")]);
let agent = AgentBuilder::new(Model::new(model))
.config(AgentConfig::new("test-agent"))
.dispatcher(dispatcher(|_name| {
Box::pin(async { Ok("tool output".to_owned()) })
}))
.build();
let mut history = vec![HistoryEntry::user("question")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let done = events.last().unwrap();
if let AgentEvent::Done(resp) = done {
assert_eq!(resp.stop_reason, AgentStopReason::TextResponse);
assert_eq!(resp.iterations, 2);
assert_eq!(resp.steps.len(), 2);
} else {
panic!("last event should be Done");
}
let mut seen_begin = false;
let mut seen_start = false;
let mut seen_result = false;
let mut seen_complete = false;
for event in &events {
match event {
AgentEvent::ToolCallsBegin(_) => {
assert!(!seen_start, "Begin should come before Start");
seen_begin = true;
}
AgentEvent::ToolCallsStart(_) => {
assert!(seen_begin, "Start should come after Begin");
assert!(!seen_result, "Start should come before Result");
seen_start = true;
}
AgentEvent::ToolResult { .. } => {
assert!(seen_start, "Result should come after Start");
seen_result = true;
}
AgentEvent::ToolCallsComplete => {
assert!(seen_result, "Complete should come after Result");
seen_complete = true;
}
_ => {}
}
}
assert!(seen_begin);
assert!(seen_start);
assert!(seen_result);
assert!(seen_complete);
}
#[tokio::test]
async fn run_stream_multiple_tool_calls_in_one_step() {
let calls = vec![
ToolCall {
index: Some(0),
id: "call_1".into(),
function: FunctionCall {
name: "bash".into(),
arguments: r#"{"cmd":"ls"}"#.into(),
},
..Default::default()
},
ToolCall {
index: Some(1),
id: "call_2".into(),
function: FunctionCall {
name: "recall".into(),
arguments: r#"{"query":"x"}"#.into(),
},
..Default::default()
},
];
let model = TestProvider::with_chunks(vec![tool_chunks(calls), text_chunks("done")]);
let agent = AgentBuilder::new(Model::new(model))
.config(AgentConfig::new("test-agent"))
.dispatcher(dispatcher(|name| {
Box::pin(async move { Ok(format!("result:{name}")) })
}))
.build();
let mut history = vec![HistoryEntry::user("multi")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let tool_results: Vec<&AgentEvent> = events
.iter()
.filter(|e| matches!(e, AgentEvent::ToolResult { .. }))
.collect();
assert_eq!(tool_results.len(), 2);
if let AgentEvent::Done(resp) = events.last().unwrap() {
assert_eq!(resp.stop_reason, AgentStopReason::TextResponse);
assert_eq!(resp.steps[0].tool_calls.len(), 2);
assert_eq!(resp.steps[0].tool_results.len(), 2);
} else {
panic!("last event should be Done");
}
}
#[tokio::test]
async fn run_stream_dispatches_tool_calls_concurrently() {
let calls = vec![
ToolCall {
index: Some(0),
id: "call_slow".into(),
function: FunctionCall {
name: "slow".into(),
arguments: "{}".into(),
},
..Default::default()
},
ToolCall {
index: Some(1),
id: "call_med".into(),
function: FunctionCall {
name: "med".into(),
arguments: "{}".into(),
},
..Default::default()
},
ToolCall {
index: Some(2),
id: "call_fast".into(),
function: FunctionCall {
name: "fast".into(),
arguments: "{}".into(),
},
..Default::default()
},
];
let model = TestProvider::with_chunks(vec![tool_chunks(calls), text_chunks("done")]);
let agent = AgentBuilder::new(Model::new(model))
.config(AgentConfig::new("test-agent"))
.dispatcher(dispatcher(|name| {
Box::pin(async move {
let sleep_ms = match name.as_str() {
"slow" => 300,
"med" => 200,
"fast" => 100,
_ => 0,
};
tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
Ok(format!("result:{name}"))
})
}))
.build();
let mut history = vec![HistoryEntry::user("multi")];
let start = std::time::Instant::now();
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let elapsed = start.elapsed();
assert!(
elapsed < std::time::Duration::from_millis(500),
"expected concurrent dispatch (< 500ms), got {elapsed:?}"
);
if let AgentEvent::Done(resp) = events.last().unwrap() {
let results = &resp.steps[0].tool_results;
assert_eq!(results.len(), 3);
assert_eq!(results[0].text(), "result:slow");
assert_eq!(results[1].text(), "result:med");
assert_eq!(results[2].text(), "result:fast");
} else {
panic!("last event should be Done");
}
let result_count = events
.iter()
.filter(|e| matches!(e, AgentEvent::ToolResult { .. }))
.count();
let start_count = events
.iter()
.filter(|e| matches!(e, AgentEvent::ToolCallsStart(_)))
.count();
let complete_count = events
.iter()
.filter(|e| matches!(e, AgentEvent::ToolCallsComplete))
.count();
assert_eq!(result_count, 3);
assert_eq!(start_count, 1);
assert_eq!(complete_count, 1);
let result_ids: Vec<String> = events
.iter()
.filter_map(|e| match e {
AgentEvent::ToolResult { call_id, .. } => Some(call_id.clone()),
_ => None,
})
.collect();
assert_eq!(
result_ids,
vec!["call_fast", "call_med", "call_slow"],
"ToolResult events should fire in completion order"
);
}
#[tokio::test]
async fn run_stream_max_iterations() {
let calls = vec![make_tool_call("bash", "{}")];
let model = TestProvider::with_chunks(vec![
tool_chunks(calls.clone()),
tool_chunks(calls.clone()),
tool_chunks(calls),
]);
let mut config = AgentConfig::new("test-agent");
config.max_iterations = 2;
let agent = AgentBuilder::new(Model::new(model))
.config(config)
.dispatcher(dispatcher(|_name| Box::pin(async { Ok("ok".to_owned()) })))
.build();
let mut history = vec![HistoryEntry::user("loop")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
if let AgentEvent::Done(resp) = events.last().unwrap() {
assert_eq!(resp.stop_reason, AgentStopReason::MaxIterations);
assert_eq!(resp.iterations, 2);
} else {
panic!("last event should be Done");
}
}
#[tokio::test]
async fn run_stream_no_content_no_tools_stops_with_no_action() {
let model = TestProvider::with_chunks(vec![vec![finish_chunk(FinishReason::Stop)]]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
if let AgentEvent::Done(resp) = events.last().unwrap() {
assert_eq!(resp.stop_reason, AgentStopReason::NoAction);
} else {
panic!("last event should be Done");
}
}
#[tokio::test]
async fn run_stream_error_in_stream() {
let model = TestProvider::with_chunks(vec![]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
if let AgentEvent::Done(resp) = events.last().unwrap() {
assert!(matches!(resp.stop_reason, AgentStopReason::Error(_)));
} else {
panic!("last event should be Done");
}
}
#[tokio::test]
async fn run_stream_thinking_delta() {
let chunks = vec![
thinking_chunk("thinking..."),
text_chunk("answer"),
finish_chunk(FinishReason::Stop),
];
let model = TestProvider::with_chunks(vec![chunks]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("think")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let thinking: Vec<&str> = events
.iter()
.filter_map(|e| match e {
AgentEvent::ThinkingDelta(s) => Some(s.as_str()),
_ => None,
})
.collect();
assert_eq!(thinking, vec!["thinking..."]);
}
fn boundary_shape(events: &[AgentEvent]) -> Vec<&'static str> {
events
.iter()
.filter_map(|e| match e {
AgentEvent::TextStart => Some("TextStart"),
AgentEvent::TextDelta(_) => Some("TextDelta"),
AgentEvent::TextEnd => Some("TextEnd"),
AgentEvent::ThinkingStart => Some("ThinkingStart"),
AgentEvent::ThinkingDelta(_) => Some("ThinkingDelta"),
AgentEvent::ThinkingEnd => Some("ThinkingEnd"),
AgentEvent::ToolCallsBegin(_) => Some("ToolCallsBegin"),
AgentEvent::ToolCallsStart(_) => Some("ToolCallsStart"),
AgentEvent::ToolCallsComplete => Some("ToolCallsComplete"),
_ => None,
})
.collect()
}
#[tokio::test]
async fn run_stream_text_segment_is_bracketed() {
let model = TestProvider::with_chunks(vec![text_chunks("hi")]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("ping")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let shape = boundary_shape(&events);
let starts = shape.iter().filter(|s| **s == "TextStart").count();
let ends = shape.iter().filter(|s| **s == "TextEnd").count();
assert_eq!(starts, 1, "expected exactly one TextStart, got {shape:?}");
assert_eq!(ends, 1, "expected exactly one TextEnd, got {shape:?}");
assert_eq!(shape.first(), Some(&"TextStart"));
assert_eq!(shape.last(), Some(&"TextEnd"));
}
#[tokio::test]
async fn run_stream_thinking_then_text_brackets_atomically() {
let chunks = vec![
thinking_chunk("plan"),
text_chunk("answer"),
finish_chunk(FinishReason::Stop),
];
let model = TestProvider::with_chunks(vec![chunks]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("think")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let shape = boundary_shape(&events);
assert_eq!(
shape,
vec![
"ThinkingStart",
"ThinkingDelta",
"ThinkingEnd",
"TextStart",
"TextDelta",
"TextEnd",
],
"expected clean atomic flip from thinking to text",
);
}
#[tokio::test]
async fn run_stream_chunk_with_both_text_and_reasoning_does_not_overlap() {
let chunks = vec![mixed_chunk("a", "b"), finish_chunk(FinishReason::Stop)];
let model = TestProvider::with_chunks(vec![chunks]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let mut open_text = false;
let mut open_thinking = false;
for e in &events {
match e {
AgentEvent::TextStart => {
assert!(!open_text, "TextStart while text already open");
assert!(!open_thinking, "TextStart while thinking still open");
open_text = true;
}
AgentEvent::TextEnd => {
assert!(open_text, "TextEnd without TextStart");
open_text = false;
}
AgentEvent::ThinkingStart => {
assert!(!open_thinking, "ThinkingStart while thinking already open");
assert!(!open_text, "ThinkingStart while text still open");
open_thinking = true;
}
AgentEvent::ThinkingEnd => {
assert!(open_thinking, "ThinkingEnd without ThinkingStart");
open_thinking = false;
}
_ => {}
}
}
assert!(!open_text, "text segment left open at end");
assert!(!open_thinking, "thinking segment left open at end");
}
#[tokio::test]
async fn run_stream_text_then_tools_closes_text_before_tools() {
let calls = vec![make_tool_call("recall", r#"{"q":"x"}"#)];
let model = TestProvider::with_chunks(vec![tool_chunks(calls), text_chunks("ok")]);
let agent = AgentBuilder::new(Model::new(model))
.config(AgentConfig::new("test-agent"))
.dispatcher(dispatcher(|_name| Box::pin(async { Ok("out".to_owned()) })))
.build();
let mut history = vec![HistoryEntry::user("q")];
let mut events: Vec<AgentEvent> = Vec::new();
let mut stream = std::pin::pin!(agent.run_stream(&mut history, None, None, None));
while let Some(event) = stream.next().await {
events.push(event);
}
let shape = boundary_shape(&events);
let mut text_open = false;
for marker in &shape {
match *marker {
"TextStart" => {
assert!(!text_open, "nested TextStart in {shape:?}");
text_open = true;
}
"TextEnd" => {
assert!(text_open, "TextEnd without TextStart in {shape:?}");
text_open = false;
}
"ToolCallsStart" => {
assert!(
!text_open,
"ToolCallsStart while text segment open in {shape:?}"
);
}
_ => {}
}
}
assert!(!text_open, "text segment left open at end");
}
#[tokio::test]
async fn run_forwards_events_through_channel() {
let model = TestProvider::with_chunks(vec![text_chunks("done")]);
let agent = build_agent_no_tools(model);
let mut history = vec![HistoryEntry::user("hi")];
let (tx, mut rx) = mpsc::unbounded_channel();
let response = agent.run(&mut history, tx, None, None).await;
assert_eq!(response.stop_reason, AgentStopReason::TextResponse);
assert_eq!(response.final_response.as_deref(), Some("done"));
let mut event_count = 0;
let mut has_done = false;
while let Ok(event) = rx.try_recv() {
event_count += 1;
if matches!(event, AgentEvent::Done(_)) {
has_done = true;
}
}
assert!(event_count > 0, "events should have been sent");
assert!(has_done, "Done event should have been sent");
}