#![cfg(feature = "testkit")]
mod common;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use common::{
MockContextCapturingStreamFn, MockStreamFn, MockTool, default_model, text_only_events,
tool_call_events, user_msg,
};
use futures::Stream;
use futures::stream::StreamExt;
use tokio_util::sync::CancellationToken;
use swink_agent::{
AgentEvent, AgentLoopConfig, AgentMessage, AssistantMessageEvent, ContentBlock,
DefaultRetryStrategy, LlmMessage, StreamFn, StreamOptions, UserMessage, agent_loop,
};
fn overflow_error_events() -> Vec<AssistantMessageEvent> {
vec![AssistantMessageEvent::error_context_overflow(
"context_length_exceeded: too many tokens",
)]
}
type ConvertToLlmBoxed = Box<dyn Fn(&AgentMessage) -> Option<LlmMessage> + Send + Sync>;
fn default_convert_to_llm() -> ConvertToLlmBoxed {
Box::new(|msg: &AgentMessage| match msg {
AgentMessage::Llm(llm) => Some(llm.clone()),
AgentMessage::Custom(_) => None,
})
}
fn default_config(stream_fn: Arc<dyn StreamFn>) -> AgentLoopConfig {
AgentLoopConfig {
agent_name: None,
transfer_chain: None,
model: default_model(),
stream_options: StreamOptions::default(),
retry_strategy: Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
),
stream_fn,
tools: vec![],
convert_to_llm: default_convert_to_llm(),
transform_context: None,
get_api_key: None,
message_provider: None,
pending_message_snapshot: Arc::default(),
loop_context_snapshot: Arc::default(),
approve_tool: None,
approval_mode: swink_agent::ApprovalMode::default(),
pre_turn_policies: vec![],
pre_dispatch_policies: vec![],
post_turn_policies: vec![],
post_loop_policies: vec![],
async_transform_context: None,
metrics_collector: None,
fallback: None,
tool_execution_policy: swink_agent::ToolExecutionPolicy::default(),
session_state: Arc::new(std::sync::RwLock::new(swink_agent::SessionState::new())),
credential_resolver: None,
cache_config: None,
cache_state: std::sync::Mutex::new(swink_agent::CacheState::default()),
dynamic_system_prompt: None,
}
}
fn large_user_msg(label: &str, token_count: usize) -> AgentMessage {
let padding = "x".repeat(token_count * 4);
let text = format!("{label}:{padding}");
AgentMessage::Llm(LlmMessage::User(UserMessage {
content: vec![ContentBlock::Text { text }],
timestamp: 0,
cache_hint: None,
}))
}
async fn collect_events(stream: Pin<Box<dyn Stream<Item = AgentEvent> + Send>>) -> Vec<AgentEvent> {
stream.collect().await
}
fn has_event(events: &[AgentEvent], name: &str) -> bool {
events.iter().any(|e| common::event_variant_name(e) == name)
}
fn count_events(events: &[AgentEvent], name: &str) -> usize {
events
.iter()
.filter(|e| common::event_variant_name(e) == name)
.count()
}
struct MockAsyncTransformer {
overflow_flags: Arc<Mutex<Vec<bool>>>,
compact: bool,
}
impl swink_agent::AsyncContextTransformer for MockAsyncTransformer {
fn transform<'a>(
&'a self,
messages: &'a mut Vec<AgentMessage>,
overflow: bool,
) -> Pin<Box<dyn Future<Output = Option<swink_agent::CompactionReport>> + Send + 'a>> {
self.overflow_flags.lock().unwrap().push(overflow);
let compact = self.compact;
Box::pin(async move {
if overflow && compact && messages.len() > 1 {
let removed = messages.len() - 1;
messages.truncate(1);
Some(swink_agent::CompactionReport {
dropped_count: removed,
tokens_before: removed * 100,
tokens_after: 100,
overflow: true,
dropped_messages: Vec::new(),
})
} else {
None
}
})
}
}
#[tokio::test]
async fn emergency_overflow_recovery() {
let capturing_fn = Arc::new(MockContextCapturingStreamFn::new(vec![
overflow_error_events(),
text_only_events("recovered after compaction"),
]));
let stream_fn: Arc<dyn StreamFn> = Arc::clone(&capturing_fn) as Arc<dyn StreamFn>;
let async_flags: Arc<Mutex<Vec<bool>>> = Arc::new(Mutex::new(Vec::new()));
let sync_flags: Arc<Mutex<Vec<bool>>> = Arc::new(Mutex::new(Vec::new()));
let sync_flags_clone = Arc::clone(&sync_flags);
let mut config = default_config(stream_fn);
config.async_transform_context = Some(Arc::new(MockAsyncTransformer {
overflow_flags: Arc::clone(&async_flags),
compact: true,
}));
config.transform_context = Some(Arc::new(
move |_msgs: &mut Vec<AgentMessage>, overflow: bool| {
sync_flags_clone.lock().unwrap().push(overflow);
},
));
let mut initial_messages = Vec::new();
for i in 0..5 {
initial_messages.push(large_user_msg(&format!("msg{i}"), 100));
}
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"), "loop should complete");
let af = async_flags.lock().unwrap().clone();
assert!(af.len() >= 2, "async transformer called at least twice");
assert!(!af[0], "first call: overflow=false (pre-turn)");
assert!(af[1], "second call: overflow=true (recovery)");
let sf = sync_flags.lock().unwrap().clone();
assert!(sf.len() >= 2, "sync transformer called at least twice");
assert!(!sf[0], "first call: overflow=false (pre-turn)");
assert!(sf[1], "second call: overflow=true (recovery)");
assert!(
count_events(&events, "ContextCompacted") >= 1,
"should emit at least one ContextCompacted event"
);
let counts = capturing_fn.captured_message_counts.lock().unwrap().clone();
assert_eq!(counts.len(), 2, "exactly 2 stream calls");
assert!(
counts[1] < counts[0],
"retry should see fewer messages: first={}, second={}",
counts[0],
counts[1]
);
}
#[tokio::test]
async fn double_overflow_surfaces_error() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
overflow_error_events(),
overflow_error_events(),
]));
let mut config = default_config(stream_fn as Arc<dyn StreamFn>);
config.transform_context = Some(Arc::new(|msgs: &mut Vec<AgentMessage>, overflow: bool| {
if overflow && msgs.len() > 1 {
msgs.truncate(1);
}
}));
let mut initial_messages = Vec::new();
for i in 0..5 {
initial_messages.push(user_msg(&format!("msg{i}")));
}
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"));
assert!(
events.iter().any(|e| {
matches!(
e,
AgentEvent::TurnEnd { reason, .. }
if *reason == swink_agent::TurnEndReason::Error
)
}),
"should have a TurnEnd with Error reason"
);
let message_end_count = events
.iter()
.filter(|e| common::event_variant_name(e) == "MessageEnd")
.count();
assert_eq!(
message_end_count, 1,
"unrecoverable overflow should emit exactly one MessageEnd"
);
}
#[tokio::test]
async fn no_transformer_overflow_surfaces_error() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
overflow_error_events(),
text_only_events("should not reach this"),
]));
let config = default_config(stream_fn as Arc<dyn StreamFn>);
let events = collect_events(agent_loop(
vec![user_msg("hello")],
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"));
assert_eq!(
count_events(&events, "ContextCompacted"),
0,
"no ContextCompacted when no transformer"
);
assert!(
events.iter().any(|e| {
matches!(
e,
AgentEvent::TurnEnd { reason, .. }
if *reason == swink_agent::TurnEndReason::Error
)
}),
"should have TurnEnd with Error reason"
);
}
#[tokio::test]
async fn overflow_recovery_resets_per_turn() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events("tc_1", "mock_tool", "{}"),
overflow_error_events(),
text_only_events("recovered in turn 2"),
]));
let tool = Arc::new(MockTool::new("mock_tool"));
let mut config = default_config(stream_fn as Arc<dyn StreamFn>);
config.tools = vec![tool];
config.transform_context = Some(Arc::new(|msgs: &mut Vec<AgentMessage>, overflow: bool| {
if overflow && msgs.len() > 1 {
msgs.truncate(1);
}
}));
let mut initial_messages = Vec::new();
for i in 0..5 {
initial_messages.push(user_msg(&format!("msg{i}")));
}
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"), "loop should complete");
assert!(
count_events(&events, "TurnStart") >= 2,
"should have at least 2 turns"
);
assert!(
count_events(&events, "ContextCompacted") >= 1,
"turn 2 should recover from overflow"
);
assert!(
!events.iter().any(|e| {
matches!(
e,
AgentEvent::TurnEnd { reason, .. }
if *reason == swink_agent::TurnEndReason::Error
)
}),
"should not have error TurnEnd — recovery should succeed"
);
}
#[tokio::test]
async fn no_compaction_skip_surfaces_error() {
let capturing_fn = Arc::new(MockContextCapturingStreamFn::new(vec![
overflow_error_events(),
text_only_events("should not reach this"),
]));
let stream_fn: Arc<dyn StreamFn> = Arc::clone(&capturing_fn) as Arc<dyn StreamFn>;
let mut config = default_config(stream_fn);
config.async_transform_context = Some(Arc::new(MockAsyncTransformer {
overflow_flags: Arc::new(Mutex::new(Vec::new())),
compact: false, }));
config.transform_context = Some(Arc::new(
|_msgs: &mut Vec<AgentMessage>, _overflow: bool| {
},
));
let events = collect_events(agent_loop(
vec![user_msg("hello")],
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"));
let counts = capturing_fn.captured_message_counts.lock().unwrap().clone();
assert_eq!(
counts.len(),
1,
"should have exactly 1 stream call (no retry when no compaction), got {}",
counts.len()
);
assert!(
events.iter().any(|e| {
matches!(
e,
AgentEvent::TurnEnd { reason, .. }
if *reason == swink_agent::TurnEndReason::Error
)
}),
"should have TurnEnd with Error reason"
);
}
#[tokio::test]
async fn cancellation_during_recovery_aborts() {
let cancel_token = CancellationToken::new();
let stream_fn = Arc::new(MockStreamFn::new(vec![
overflow_error_events(),
text_only_events("should not reach — cancelled"),
]));
let mut config = default_config(stream_fn);
let cancel_clone = cancel_token.clone();
config.transform_context = Some(Arc::new(
move |msgs: &mut Vec<AgentMessage>, overflow: bool| {
if overflow && msgs.len() > 1 {
cancel_clone.cancel();
msgs.truncate(1);
}
},
));
let events = collect_events(agent_loop(
vec![user_msg("msg1"), user_msg("msg2"), user_msg("msg3")],
"system".to_string(),
config,
cancel_token,
))
.await;
assert!(has_event(&events, "AgentEnd"), "loop should complete");
assert_eq!(
count_events(&events, "TurnStart"),
1,
"cancellation during overflow recovery must not emit a duplicate TurnStart"
);
assert!(
events.iter().any(|e| {
matches!(
e,
AgentEvent::TurnEnd { reason, .. }
if matches!(reason, swink_agent::TurnEndReason::Cancelled | swink_agent::TurnEndReason::Aborted)
)
}),
"should have TurnEnd with Cancelled/Aborted reason"
);
}