#![cfg(feature = "testkit")]
mod common;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use serde_json::json;
use tokio_util::sync::CancellationToken;
use swink_agent::testing::SimpleMockStreamFn;
use swink_agent::{
Agent, AgentMessage, AgentOptions, AgentRegistry, AgentTool, AgentToolResult, ContentBlock,
DefaultRetryStrategy, LlmMessage, ModelSpec, StopReason, ToolExecutionPolicy, TransferChain,
TransferToAgentTool,
};
use common::{
EventCollector, MockStreamFn, MockTool, default_convert, default_model, text_only_events,
tool_call_events, tool_call_events_multi, user_msg,
};
fn dummy_agent() -> Agent {
Agent::new(AgentOptions::new(
"dummy",
ModelSpec::new("test", "test-model"),
Arc::new(SimpleMockStreamFn::from_text("hi")),
default_convert,
))
}
fn fast_retry() -> Box<DefaultRetryStrategy> {
Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
)
}
fn transfer_args(agent_name: &str, reason: &str) -> String {
json!({
"agent_name": agent_name,
"reason": reason
})
.to_string()
}
fn transfer_args_with_summary(agent_name: &str, reason: &str, summary: &str) -> String {
json!({
"agent_name": agent_name,
"reason": reason,
"context_summary": summary
})
.to_string()
}
fn registry_with_billing() -> Arc<AgentRegistry> {
let registry = Arc::new(AgentRegistry::new());
registry.register("billing", dummy_agent());
registry
}
fn make_transfer_agent(stream_fn: Arc<MockStreamFn>, registry: Arc<AgentRegistry>) -> Agent {
let transfer_tool = Arc::new(TransferToAgentTool::new(registry));
Agent::new(
AgentOptions::new(
"test system prompt",
default_model(),
stream_fn,
default_convert,
)
.with_tools(vec![transfer_tool as Arc<dyn AgentTool>])
.with_retry_strategy(fast_retry()),
)
}
struct MockCancellationIgnoringTool {
executed: Arc<AtomicBool>,
}
impl MockCancellationIgnoringTool {
fn new() -> Self {
Self {
executed: Arc::new(AtomicBool::new(false)),
}
}
fn was_executed(&self) -> bool {
self.executed.load(Ordering::SeqCst)
}
}
impl AgentTool for MockCancellationIgnoringTool {
fn name(&self) -> &str {
"blocking_tool"
}
fn label(&self) -> &str {
"blocking_tool"
}
fn description(&self) -> &'static str {
"A tool that ignores cancellation and never completes unless aborted"
}
fn parameters_schema(&self) -> &serde_json::Value {
static SCHEMA: std::sync::OnceLock<serde_json::Value> = std::sync::OnceLock::new();
SCHEMA.get_or_init(|| {
json!({
"type": "object",
"properties": {},
"additionalProperties": true
})
})
}
fn execute(
&self,
_tool_call_id: &str,
_params: serde_json::Value,
_cancellation_token: CancellationToken,
_on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
_state: std::sync::Arc<std::sync::RwLock<swink_agent::SessionState>>,
_credential: Option<swink_agent::ResolvedCredential>,
) -> Pin<Box<dyn Future<Output = AgentToolResult> + Send + '_>> {
self.executed.store(true, Ordering::SeqCst);
Box::pin(async move { std::future::pending::<AgentToolResult>().await })
}
}
#[tokio::test]
async fn agent_loop_detects_transfer_and_terminates_with_transfer_stop_reason() {
let registry = registry_with_billing();
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
text_only_events("should not reach this"),
]));
let mut agent = make_transfer_agent(stream_fn, registry);
let result = agent
.prompt_async(vec![user_msg("transfer me to billing")])
.await
.unwrap();
assert_eq!(
result.stop_reason,
StopReason::Transfer,
"stop_reason should be Transfer"
);
assert!(
result.transfer_signal.is_some(),
"transfer_signal should be present on AgentResult"
);
let signal = result.transfer_signal.as_ref().unwrap();
assert_eq!(signal.target_agent(), "billing");
assert_eq!(signal.reason(), "billing question");
}
#[tokio::test]
async fn transfer_initiated_event_emitted_on_transfer() {
let registry = registry_with_billing();
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
text_only_events("fallback"),
]));
let mut agent = make_transfer_agent(stream_fn, registry);
let collector = EventCollector::new();
agent.subscribe(collector.subscriber());
let result = agent
.prompt_async(vec![user_msg("transfer me")])
.await
.unwrap();
assert_eq!(result.stop_reason, StopReason::Transfer);
let events = collector.events();
assert!(
events.contains(&"TransferInitiated".to_string()),
"should emit TransferInitiated event, got: {events:?}"
);
let transfer_pos = events
.iter()
.position(|e| e == "TransferInitiated")
.unwrap();
let agent_end_pos = events.iter().position(|e| e == "AgentEnd").unwrap();
assert!(
transfer_pos < agent_end_pos,
"TransferInitiated ({transfer_pos}) should precede AgentEnd ({agent_end_pos})"
);
}
#[tokio::test]
async fn transfer_signal_contains_conversation_history() {
let registry = registry_with_billing();
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args_with_summary("billing", "billing dispute", "User disputes $50 charge"),
),
text_only_events("fallback"),
]));
let mut agent = make_transfer_agent(stream_fn, registry);
let result = agent
.prompt_async(vec![user_msg("I need help with my bill")])
.await
.unwrap();
assert_eq!(result.stop_reason, StopReason::Transfer);
let signal = result.transfer_signal.as_ref().unwrap();
assert_eq!(signal.context_summary(), Some("User disputes $50 charge"));
let history = signal.conversation_history();
assert!(
!history.is_empty(),
"conversation_history should not be empty"
);
let has_user = history.iter().any(|m| matches!(m, LlmMessage::User(_)));
assert!(
has_user,
"conversation_history should contain a User message"
);
}
#[tokio::test]
async fn only_first_transfer_signal_honored_when_multiple_transfers_in_same_turn() {
let registry = Arc::new(AgentRegistry::new());
registry.register("billing", dummy_agent());
registry.register("tech", dummy_agent());
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events_multi(&[
(
"tc_1",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
(
"tc_2",
"transfer_to_agent",
&transfer_args("tech", "tech question"),
),
]),
text_only_events("fallback"),
]));
let transfer_tool = Arc::new(TransferToAgentTool::new(registry));
let mut agent = Agent::new(
AgentOptions::new("test", default_model(), stream_fn, default_convert)
.with_tools(vec![transfer_tool as Arc<dyn AgentTool>])
.with_retry_strategy(fast_retry()),
);
let result = agent
.prompt_async(vec![user_msg("transfer me")])
.await
.unwrap();
assert_eq!(result.stop_reason, StopReason::Transfer);
let signal = result.transfer_signal.as_ref().unwrap();
assert!(
signal.target_agent() == "billing" || signal.target_agent() == "tech",
"transfer should target one of the two agents, got: {}",
signal.target_agent()
);
}
#[tokio::test]
async fn cancellation_takes_precedence_over_transfer() {
let registry = registry_with_billing();
let slow_tool = Arc::new(MockTool::new("slow_tool").with_delay(Duration::from_secs(10)));
let transfer_tool: Arc<dyn AgentTool> =
Arc::new(TransferToAgentTool::new(Arc::clone(®istry)));
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events_multi(&[
("tc_slow", "slow_tool", "{}"),
(
"tc_transfer",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
]),
text_only_events("fallback"),
]));
let mut agent = Agent::new(
AgentOptions::new("test", default_model(), stream_fn, default_convert)
.with_tools(vec![slow_tool, transfer_tool])
.with_retry_strategy(fast_retry()),
);
agent.abort();
let result = agent
.prompt_async(vec![user_msg("do stuff")])
.await
.unwrap();
assert!(
result.stop_reason == StopReason::Aborted || result.stop_reason == StopReason::Transfer,
"expected Aborted or Transfer, got {:?}",
result.stop_reason
);
}
#[tokio::test]
async fn transfer_cancels_same_group_siblings() {
let registry = registry_with_billing();
let blocking_tool = Arc::new(MockCancellationIgnoringTool::new());
let transfer_tool: Arc<dyn AgentTool> =
Arc::new(TransferToAgentTool::new(Arc::clone(®istry)));
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events_multi(&[
("tc_blocking", "blocking_tool", "{}"),
(
"tc_transfer",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
]),
text_only_events("fallback"),
]));
let mut agent = Agent::new(
AgentOptions::new("test", default_model(), stream_fn, default_convert)
.with_tools(vec![
blocking_tool.clone() as Arc<dyn AgentTool>,
transfer_tool,
])
.with_retry_strategy(fast_retry()),
);
let result = agent.prompt_async(vec![user_msg("do both")]).await.unwrap();
assert_eq!(
result.stop_reason,
StopReason::Transfer,
"stop_reason should be Transfer even when other tools also ran"
);
assert!(result.transfer_signal.is_some());
assert!(
blocking_tool.was_executed(),
"same-group sibling should have started before transfer cancellation"
);
let blocking_result = result
.messages
.iter()
.filter_map(|msg| match msg {
AgentMessage::Llm(LlmMessage::ToolResult(tool_result))
if tool_result.tool_call_id == "tc_blocking" =>
{
Some(ContentBlock::extract_text(&tool_result.content))
}
_ => None,
})
.next()
.expect("blocking tool result should be present");
assert!(
blocking_result.contains("transfer initiated"),
"blocking tool should be cancelled once transfer wins, got: {blocking_result}"
);
}
#[tokio::test]
async fn transfer_skips_later_priority_groups() {
let registry = registry_with_billing();
let later_group_tool = Arc::new(
MockTool::new("low_priority_tool").with_result(AgentToolResult::text("should not run")),
);
let transfer_tool: Arc<dyn AgentTool> =
Arc::new(TransferToAgentTool::new(Arc::clone(®istry)));
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events_multi(&[
(
"tc_transfer",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
("tc_low", "low_priority_tool", "{}"),
]),
text_only_events("fallback"),
]));
let priority_fn = Arc::new(|summary: &swink_agent::ToolCallSummary<'_>| {
if summary.name == "transfer_to_agent" {
10
} else {
0
}
});
let mut agent = Agent::new(
AgentOptions::new("test", default_model(), stream_fn, default_convert)
.with_tools(vec![
transfer_tool,
later_group_tool.clone() as Arc<dyn AgentTool>,
])
.with_tool_execution_policy(ToolExecutionPolicy::Priority(priority_fn))
.with_retry_strategy(fast_retry()),
);
let result = agent.prompt_async(vec![user_msg("do both")]).await.unwrap();
assert!(
result.stop_reason == StopReason::Transfer,
"priority-group transfer should terminate the turn"
);
assert!(
!later_group_tool.was_executed(),
"later priority groups must not run after a transfer signal"
);
}
#[tokio::test]
async fn transfer_to_nonexistent_agent_produces_error_and_loop_continues() {
let registry = Arc::new(AgentRegistry::new());
let transfer_tool: Arc<dyn AgentTool> =
Arc::new(TransferToAgentTool::new(Arc::clone(®istry)));
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args("nonexistent", "test"),
),
text_only_events("I could not transfer, let me help directly"),
]));
let mut agent = Agent::new(
AgentOptions::new("test", default_model(), stream_fn, default_convert)
.with_tools(vec![transfer_tool])
.with_retry_strategy(fast_retry()),
);
let result = agent
.prompt_async(vec![user_msg("transfer me")])
.await
.unwrap();
assert_ne!(
result.stop_reason,
StopReason::Transfer,
"should not be Transfer when target agent does not exist"
);
assert!(
result.transfer_signal.is_none(),
"transfer_signal should be None when transfer failed"
);
let has_error_result = result.messages.iter().any(|msg| {
if let AgentMessage::Llm(LlmMessage::ToolResult(tr)) = msg {
tr.is_error
&& tr.content.iter().any(|b| {
matches!(b, ContentBlock::Text { text } if text.contains("not found in registry"))
})
} else {
false
}
});
assert!(
has_error_result,
"should have an error tool result for nonexistent agent"
);
}
fn make_named_transfer_agent(
name: &str,
stream_fn: Arc<MockStreamFn>,
registry: Arc<AgentRegistry>,
) -> Agent {
make_named_transfer_agent_with_chain(name, stream_fn, registry, None)
}
fn make_named_transfer_agent_with_chain(
name: &str,
stream_fn: Arc<MockStreamFn>,
registry: Arc<AgentRegistry>,
transfer_chain: Option<TransferChain>,
) -> Agent {
let transfer_tool = Arc::new(TransferToAgentTool::new(registry));
let mut opts = AgentOptions::new(
"test system prompt",
default_model(),
stream_fn,
default_convert,
)
.with_tools(vec![transfer_tool as Arc<dyn AgentTool>])
.with_retry_strategy(fast_retry())
.with_agent_name(name);
if let Some(chain) = transfer_chain {
opts = opts.with_transfer_chain(chain);
}
Agent::new(opts)
}
#[tokio::test]
async fn transfer_chain_blocks_self_transfer() {
let registry = Arc::new(AgentRegistry::new());
registry.register("support", dummy_agent());
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args("support", "self-transfer"),
),
text_only_events("I'll help you directly instead"),
]));
let mut agent = make_named_transfer_agent("support", stream_fn, registry);
let result = agent
.prompt_async(vec![user_msg("transfer me to support")])
.await
.unwrap();
assert_ne!(
result.stop_reason,
StopReason::Transfer,
"self-transfer should be blocked by TransferChain"
);
assert!(
result.transfer_signal.is_none(),
"transfer_signal should be None when self-transfer is blocked"
);
}
#[tokio::test]
async fn transfer_chain_blocks_circular_a_to_b_to_a() {
let registry = Arc::new(AgentRegistry::new());
registry.register("support", dummy_agent());
registry.register("billing", dummy_agent());
let support_stream = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
text_only_events("fallback"),
]));
let mut support_agent =
make_named_transfer_agent("support", support_stream, Arc::clone(®istry));
let first_result = support_agent
.prompt_async(vec![user_msg("transfer me to billing")])
.await
.unwrap();
assert_eq!(
first_result.stop_reason,
StopReason::Transfer,
"first transfer in chain should succeed"
);
let first_signal = first_result.transfer_signal.as_ref().unwrap();
assert_eq!(first_signal.target_agent(), "billing");
let billing_stream = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer_back",
"transfer_to_agent",
&transfer_args("support", "route back"),
),
text_only_events("I'll handle this directly"),
]));
let carried_chain = first_signal
.transfer_chain()
.expect("handoff signal should carry transfer chain")
.clone();
let mut billing_agent = make_named_transfer_agent_with_chain(
"billing",
billing_stream,
Arc::clone(®istry),
Some(carried_chain),
);
let second_result = billing_agent
.prompt_async(vec![user_msg("continue on billing and transfer back")])
.await
.unwrap();
assert_ne!(
second_result.stop_reason,
StopReason::Transfer,
"A->B->A must be blocked across handoffs"
);
assert!(
second_result.transfer_signal.is_none(),
"blocked cross-handoff transfer must not return a transfer signal"
);
}
#[tokio::test]
async fn transfer_chain_max_depth_is_enforced_in_loop() {
let registry = Arc::new(AgentRegistry::new());
registry.register("target", dummy_agent());
let mut carried_chain = TransferChain::new(2);
carried_chain.push("agent-0").unwrap();
carried_chain.push("agent-1").unwrap();
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args("target", "handoff"),
),
text_only_events("cannot transfer further"),
]));
let mut agent =
make_named_transfer_agent_with_chain("agent-1", stream_fn, registry, Some(carried_chain));
let result = agent
.prompt_async(vec![user_msg("transfer me")])
.await
.unwrap();
assert_ne!(
result.stop_reason,
StopReason::Transfer,
"transfer should be blocked when carried chain is already at max depth"
);
assert!(
result.transfer_signal.is_none(),
"blocked max-depth transfer must not return a transfer signal"
);
}
#[tokio::test]
async fn transfer_works_without_agent_name() {
let registry = registry_with_billing();
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"tc_transfer",
"transfer_to_agent",
&transfer_args("billing", "billing question"),
),
text_only_events("fallback"),
]));
let mut agent = make_transfer_agent(stream_fn, registry);
let result = agent
.prompt_async(vec![user_msg("transfer me")])
.await
.unwrap();
assert_eq!(
result.stop_reason,
StopReason::Transfer,
"transfer without agent_name should succeed"
);
assert!(result.transfer_signal.is_some());
}