#![cfg(feature = "testkit")]
mod common;
use std::sync::Arc;
use std::time::Duration;
use common::{
MockStreamFn, MockTool, abort_events, default_convert, default_model, text_only_events,
tool_call_events, user_msg,
};
use futures::stream::StreamExt;
use swink_agent::{
Agent, AgentError, AgentEvent, AgentMessage, AgentOptions, AgentTool, AssistantMessageEvent,
ContentBlock, DefaultRetryStrategy, LlmMessage, ModelSpec, StopReason, StreamFn,
};
fn make_agent(stream_fn: Arc<dyn StreamFn>) -> Agent {
Agent::new(
AgentOptions::new(
"test system prompt",
default_model(),
stream_fn,
default_convert,
)
.with_retry_strategy(Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
)),
)
}
fn make_agent_with_tools(stream_fn: Arc<dyn StreamFn>, tools: Vec<Arc<dyn AgentTool>>) -> Agent {
Agent::new(
AgentOptions::new(
"test system prompt",
default_model(),
stream_fn,
default_convert,
)
.with_tools(tools)
.with_retry_strategy(Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
)),
)
}
#[tokio::test]
async fn prompt_async_returns_correct_result() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("Hello world")]));
let mut agent = make_agent(stream_fn);
let result = agent.prompt_async(vec![user_msg("Hi")]).await.unwrap();
assert_eq!(result.stop_reason, StopReason::Stop);
assert!(result.error.is_none());
assert!(!result.messages.is_empty());
let has_assistant_text = result.messages.iter().any(|m| {
matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
if a.content.iter().any(|b| matches!(b, ContentBlock::Text { text } if text == "Hello world")))
});
assert!(has_assistant_text, "result should contain assistant text");
assert!(!agent.state().is_running);
}
#[test]
fn prompt_sync_returns_result() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("sync result")]));
let mut agent = make_agent(stream_fn);
let result = agent.prompt_sync(vec![user_msg("Hi")]).unwrap();
assert_eq!(result.stop_reason, StopReason::Stop);
assert!(result.error.is_none());
let has_text = result.messages.iter().any(|m| {
matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
if a.content.iter().any(|b| matches!(b, ContentBlock::Text { text } if text == "sync result")))
});
assert!(has_text, "sync result should contain assistant text");
assert!(!agent.state().is_running);
}
#[tokio::test]
async fn prompt_stream_yields_events_in_order() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("streamed")]));
let mut agent = make_agent(stream_fn);
let mut stream = agent.prompt_stream(vec![user_msg("Hi")]).unwrap();
let mut event_names: Vec<String> = Vec::new();
while let Some(event) = stream.next().await {
let name = format!("{event:?}");
let prefix = name.split([' ', '{', '(']).next().unwrap_or("").to_string();
event_names.push(prefix);
}
let find = |name: &str| event_names.iter().position(|n| n == name);
let agent_start = find("AgentStart").expect("should have AgentStart");
let turn_start = find("TurnStart").expect("should have TurnStart");
let msg_start = find("MessageStart").expect("should have MessageStart");
let msg_end = find("MessageEnd").expect("should have MessageEnd");
let turn_end = find("TurnEnd").expect("should have TurnEnd");
let agent_end = find("AgentEnd").expect("should have AgentEnd");
assert!(agent_start < turn_start);
assert!(turn_start < msg_start);
assert!(msg_start < msg_end);
assert!(msg_end < turn_end);
assert!(turn_end < agent_end);
}
#[tokio::test]
async fn already_running_error() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("first")]));
let mut agent = make_agent(stream_fn);
let _stream = agent.prompt_stream(vec![user_msg("first")]).unwrap();
assert!(agent.state().is_running);
let result = agent.prompt_stream(vec![user_msg("second")]);
let err = result.err().expect("should be an error");
assert!(
matches!(err, AgentError::AlreadyRunning),
"expected AlreadyRunning, got {err:?}"
);
}
#[tokio::test]
async fn abort_causes_aborted_stop() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events("tc_1", "slow_tool", "{}"),
text_only_events("should not reach"),
]));
let tool = Arc::new(MockTool::new("slow_tool").with_delay(Duration::from_secs(10)));
let mut agent = make_agent_with_tools(stream_fn, vec![tool]);
let mut stream = agent.prompt_stream(vec![user_msg("go")]).unwrap();
let mut found_abort = false;
let mut saw_tool_start = false;
while let Some(event) = stream.next().await {
if matches!(event, AgentEvent::ToolExecutionStart { .. }) {
saw_tool_start = true;
agent.abort();
}
if let AgentEvent::TurnEnd {
ref assistant_message,
..
} = event
&& assistant_message.stop_reason == StopReason::Aborted
{
found_abort = true;
}
}
assert!(saw_tool_start, "should have seen tool execution start");
let _ = found_abort; }
#[tokio::test]
async fn abort_during_tool_turn_keeps_single_turn_and_tool_payloads() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events("tc_abort", "slow_tool", "{}"),
text_only_events("should not reach"),
]));
let tool = Arc::new(MockTool::new("slow_tool").with_delay(Duration::from_secs(10)));
let mut agent = make_agent_with_tools(stream_fn, vec![tool]);
let mut stream = agent.prompt_stream(vec![user_msg("go")]).unwrap();
let mut turn_start_count = 0;
let mut aborted_turn: Option<(
swink_agent::AssistantMessage,
Vec<swink_agent::ToolResultMessage>,
swink_agent::TurnEndReason,
)> = None;
while let Some(event) = stream.next().await {
match event {
AgentEvent::TurnStart => turn_start_count += 1,
AgentEvent::ToolExecutionStart { .. } => agent.abort(),
AgentEvent::TurnEnd {
assistant_message,
tool_results,
reason,
..
} if assistant_message.stop_reason == StopReason::Aborted => {
aborted_turn = Some((assistant_message, tool_results, reason));
}
_ => {}
}
}
let (assistant_message, tool_results, reason) =
aborted_turn.expect("abort during tool execution should emit an aborted TurnEnd");
assert_eq!(
turn_start_count, 1,
"aborting a tool turn should not synthesize a second TurnStart"
);
assert_eq!(
reason,
swink_agent::TurnEndReason::Cancelled,
"external cancellation should still surface as a cancelled turn"
);
assert!(
assistant_message.content.iter().any(|block| matches!(
block,
ContentBlock::ToolCall { id, name, .. }
if id == "tc_abort" && name == "slow_tool"
)),
"the terminal assistant payload should preserve the original tool call"
);
assert_eq!(
tool_results.len(),
1,
"aborted tool turns should preserve deterministic tool-result parity"
);
assert_eq!(tool_results[0].tool_call_id, "tc_abort");
let tool_text = ContentBlock::extract_text(&tool_results[0].content);
assert!(
tool_text.contains("aborted") || tool_text.contains("cancelled"),
"expected the preserved tool result to explain the abort, got: {tool_text}"
);
}
#[tokio::test]
async fn abort_stop_reason_emits_turn_end_aborted() {
let stream_fn = Arc::new(MockStreamFn::new(vec![abort_events("user cancelled")]));
let mut agent = make_agent(stream_fn);
let mut stream = agent.prompt_stream(vec![user_msg("go")]).unwrap();
let mut found_aborted_reason = false;
let mut found_error_reason = false;
while let Some(event) = stream.next().await {
if let AgentEvent::TurnEnd { reason, .. } = &event {
match reason {
swink_agent::TurnEndReason::Aborted => found_aborted_reason = true,
swink_agent::TurnEndReason::Error => found_error_reason = true,
_ => {}
}
}
}
assert!(
found_aborted_reason,
"abort path should emit TurnEndReason::Aborted"
);
assert!(
!found_error_reason,
"abort path should NOT emit TurnEndReason::Error for StopReason::Aborted"
);
}
#[tokio::test]
async fn reset_clears_state() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("before reset")]));
let mut agent = make_agent(stream_fn);
let _result = agent.prompt_async(vec![user_msg("Hi")]).await.unwrap();
assert!(
!agent.state().messages.is_empty(),
"should have messages after prompt"
);
agent.steer(user_msg("steering"));
agent.follow_up(user_msg("follow up"));
assert!(agent.has_pending_messages());
agent.reset();
assert!(
agent.state().messages.is_empty(),
"messages should be cleared"
);
assert!(!agent.state().is_running, "should not be running");
assert!(agent.state().error.is_none(), "error should be cleared");
assert!(
agent.state().stream_message.is_none(),
"stream_message should be cleared"
);
assert!(
agent.state().pending_tool_calls.is_empty(),
"pending_tool_calls should be cleared"
);
assert!(!agent.has_pending_messages(), "queues should be cleared");
}
#[tokio::test]
async fn wait_for_idle_resolves_immediately_when_idle() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("done")]));
let mut agent = make_agent(stream_fn);
agent.wait_for_idle().await;
let _result = agent.prompt_async(vec![user_msg("Hi")]).await.unwrap();
agent.wait_for_idle().await;
}
#[test]
fn default_state_initialization() {
let stream_fn = Arc::new(MockStreamFn::new(vec![]));
let agent = make_agent(stream_fn);
let s = agent.state();
assert_eq!(s.system_prompt, "test system prompt");
assert!(!s.is_running);
assert!(s.messages.is_empty());
assert!(s.stream_message.is_none());
assert!(s.pending_tool_calls.is_empty());
assert!(s.error.is_none());
}
#[tokio::test]
async fn state_mutators() {
let stream_fn = Arc::new(MockStreamFn::new(vec![]));
let mut agent = make_agent(stream_fn);
agent.set_system_prompt("new prompt");
assert_eq!(agent.state().system_prompt, "new prompt");
let new_model = ModelSpec::new("other", "other-model");
agent.set_model(new_model);
assert_eq!(agent.state().model.provider, "other");
assert_eq!(agent.state().model.model_id, "other-model");
agent.set_thinking_level(swink_agent::ThinkingLevel::High);
assert_eq!(
agent.state().model.thinking_level,
swink_agent::ThinkingLevel::High
);
agent.set_messages(vec![user_msg("hello")]);
assert_eq!(agent.state().messages.len(), 1);
agent.clear_messages();
assert!(agent.state().messages.is_empty());
agent.append_messages(vec![user_msg("a"), user_msg("b")]);
assert_eq!(agent.state().messages.len(), 2);
}
#[tokio::test]
async fn error_sets_state_error() {
let stream_fn = Arc::new(MockStreamFn::new(vec![vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::Error {
stop_reason: StopReason::Error,
error_message: "something went wrong".to_string(),
error_kind: None,
usage: None,
},
]]));
let mut agent = make_agent(stream_fn);
let result = agent.prompt_async(vec![user_msg("hi")]).await.unwrap();
assert!(result.error.is_some());
let state_error = agent.state().error.as_ref();
assert!(state_error.is_some(), "agent state should have error set");
assert_eq!(state_error, result.error.as_ref());
}
#[tokio::test]
async fn wait_for_idle_returns_immediately_when_not_running() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("hi")]));
let agent = make_agent(stream_fn);
assert!(!agent.state().is_running);
agent.wait_for_idle().await;
}
#[tokio::test]
async fn wait_for_idle_resolves_on_completion() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("done")]));
let mut agent = make_agent(stream_fn);
let result = agent.prompt_async(vec![user_msg("hi")]).await.unwrap();
assert_eq!(result.stop_reason, StopReason::Stop);
agent.wait_for_idle().await;
assert!(!agent.state().is_running);
}
#[tokio::test]
async fn wait_for_idle_resolves_after_abort() {
use common::tool_call_events;
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events("call_1", "slow_tool", "{}"),
text_only_events("done"),
]));
let tool = Arc::new(MockTool::new("slow_tool"));
let mut agent = Agent::new(
AgentOptions::new(
"sys",
default_model(),
stream_fn as Arc<dyn StreamFn>,
default_convert,
)
.with_tools(vec![tool as Arc<dyn AgentTool>])
.with_retry_strategy(Box::new(DefaultRetryStrategy::default().with_jitter(false))),
);
let _result = agent.prompt_async(vec![user_msg("do stuff")]).await;
agent.wait_for_idle().await;
assert!(!agent.state().is_running);
}
#[tokio::test]
async fn wait_for_idle_multiple_waiters() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("done")]));
let mut agent = make_agent(stream_fn);
let _result = agent.prompt_async(vec![user_msg("hi")]).await.unwrap();
let ((), ()) = tokio::join!(agent.wait_for_idle(), agent.wait_for_idle(),);
assert!(!agent.state().is_running);
}
#[tokio::test]
async fn reset_cancels_active_loop_and_allows_new_run() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
text_only_events("first"),
text_only_events("second"),
]));
let mut agent = make_agent(stream_fn);
let mut stream = agent.prompt_stream(vec![user_msg("go")]).unwrap();
let _first_event = stream.next().await;
assert!(
agent.state().is_running,
"agent should be running mid-stream"
);
agent.reset();
drop(stream);
assert!(
!agent.state().is_running,
"agent should be idle after reset"
);
let result = agent
.prompt_async(vec![user_msg("go again")])
.await
.unwrap();
assert_eq!(result.stop_reason, StopReason::Stop);
assert!(!agent.state().is_running);
}