#![cfg(feature = "testkit")]
mod common;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use common::{
MockContextCapturingStreamFn, MockTool, default_exhausted_fallback, default_model,
next_response, text_only_events, tool_call_events, user_msg,
};
use futures::Stream;
use futures::stream::StreamExt;
use tokio_util::sync::CancellationToken;
use swink_agent::{
AgentContext, AgentEvent, AgentLoopConfig, AgentMessage, AssistantMessageEvent, ContentBlock,
ContextTransformer, DefaultRetryStrategy, LlmMessage, ModelSpec, SlidingWindowTransformer,
StopReason, StreamFn, StreamOptions, UserMessage, agent_loop,
};
struct MockMessageCapturingStreamFn {
responses: Mutex<Vec<Vec<AssistantMessageEvent>>>,
captured_messages: Arc<Mutex<Vec<Vec<LlmMessage>>>>,
}
impl StreamFn for MockMessageCapturingStreamFn {
fn stream<'a>(
&'a self,
_model: &'a ModelSpec,
context: &'a AgentContext,
_options: &'a StreamOptions,
_cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
let llm_msgs: Vec<LlmMessage> = context
.messages
.iter()
.filter_map(|m| match m {
AgentMessage::Llm(llm) => Some(llm.clone()),
AgentMessage::Custom(_) => None,
})
.collect();
self.captured_messages.lock().unwrap().push(llm_msgs);
let events = next_response(&self.responses, default_exhausted_fallback());
Box::pin(futures::stream::iter(events))
}
}
fn overflow_error_events() -> Vec<AssistantMessageEvent> {
vec![AssistantMessageEvent::Error {
stop_reason: StopReason::Error,
error_message: "context_length_exceeded: too many tokens".to_string(),
error_kind: None,
usage: None,
}]
}
type ConvertToLlmBoxed = Box<dyn Fn(&AgentMessage) -> Option<LlmMessage> + Send + Sync>;
fn default_convert_to_llm() -> ConvertToLlmBoxed {
Box::new(|msg: &AgentMessage| match msg {
AgentMessage::Llm(llm) => Some(llm.clone()),
AgentMessage::Custom(_) => None,
})
}
fn default_config(stream_fn: Arc<dyn StreamFn>) -> AgentLoopConfig {
AgentLoopConfig {
agent_name: None,
transfer_chain: None,
model: default_model(),
stream_options: StreamOptions::default(),
retry_strategy: Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
),
stream_fn,
tools: vec![],
convert_to_llm: default_convert_to_llm(),
transform_context: None,
get_api_key: None,
message_provider: None,
pending_message_snapshot: Arc::default(),
loop_context_snapshot: Arc::default(),
approve_tool: None,
approval_mode: swink_agent::ApprovalMode::default(),
pre_turn_policies: vec![],
pre_dispatch_policies: vec![],
post_turn_policies: vec![],
post_loop_policies: vec![],
async_transform_context: None,
metrics_collector: None,
fallback: None,
tool_execution_policy: swink_agent::ToolExecutionPolicy::default(),
session_state: std::sync::Arc::new(
std::sync::RwLock::new(swink_agent::SessionState::new()),
),
credential_resolver: None,
cache_config: None,
cache_state: std::sync::Mutex::new(swink_agent::CacheState::default()),
dynamic_system_prompt: None,
}
}
fn large_user_msg(label: &str, token_count: usize) -> AgentMessage {
let padding = "x".repeat(token_count * 4);
let text = format!("{label}:{padding}");
AgentMessage::Llm(LlmMessage::User(UserMessage {
content: vec![ContentBlock::Text { text }],
timestamp: 0,
cache_hint: None,
}))
}
async fn collect_events(stream: Pin<Box<dyn Stream<Item = AgentEvent> + Send>>) -> Vec<AgentEvent> {
stream.collect().await
}
fn has_event(events: &[AgentEvent], name: &str) -> bool {
events.iter().any(|e| common::event_variant_name(e) == name)
}
#[tokio::test]
async fn overflow_triggers_compaction() {
let capturing_fn = Arc::new(MockContextCapturingStreamFn::new(vec![
overflow_error_events(),
text_only_events("recovered"),
]));
let stream_fn: Arc<dyn StreamFn> = Arc::clone(&capturing_fn) as Arc<dyn StreamFn>;
let overflow_flags: Arc<Mutex<Vec<bool>>> = Arc::new(Mutex::new(Vec::new()));
let flags_clone = Arc::clone(&overflow_flags);
let mut config = default_config(stream_fn);
let compact = SlidingWindowTransformer::new(10_000, 200, 1);
config.transform_context = Some(Arc::new(
move |msgs: &mut Vec<AgentMessage>, overflow: bool| {
flags_clone.lock().unwrap().push(overflow);
compact.transform(msgs, overflow);
},
));
let mut initial_messages = Vec::new();
for i in 0..10 {
initial_messages.push(large_user_msg(&format!("msg{i}"), 100));
}
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"), "loop should complete");
let flags: Vec<bool> = overflow_flags.lock().unwrap().clone();
assert!(flags.len() >= 2, "transform_context called at least twice");
assert!(!flags[0], "first call should not have overflow");
assert!(flags[1], "second call should have overflow=true");
let counts: Vec<usize> = capturing_fn.captured_message_counts.lock().unwrap().clone();
assert!(
counts.len() >= 2,
"stream should be called at least twice, got {}",
counts.len()
);
assert!(
counts[1] < counts[0],
"context should be smaller after compaction: first={}, second={}",
counts[0],
counts[1]
);
}
#[tokio::test]
async fn compacted_context_preserves_anchors() {
let captured_messages: Arc<Mutex<Vec<Vec<LlmMessage>>>> = Arc::new(Mutex::new(Vec::new()));
let stream_fn = Arc::new(MockMessageCapturingStreamFn {
responses: Mutex::new(vec![overflow_error_events(), text_only_events("ok")]),
captured_messages: Arc::clone(&captured_messages),
});
let anchor_count = 2;
let compact = SlidingWindowTransformer::new(10_000, 300, anchor_count);
let mut config = default_config(stream_fn as Arc<dyn StreamFn>);
config.transform_context = Some(Arc::new(
move |msgs: &mut Vec<AgentMessage>, overflow: bool| {
compact.transform(msgs, overflow);
},
));
let mut initial_messages = vec![user_msg("ANCHOR_ONE"), user_msg("ANCHOR_TWO")];
for i in 0..8 {
initial_messages.push(large_user_msg(&format!("filler{i}"), 100));
}
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"));
let all_captured: Vec<Vec<LlmMessage>> = captured_messages.lock().unwrap().clone();
assert!(
all_captured.len() >= 2,
"should have at least 2 stream calls"
);
let post_overflow = &all_captured[1];
assert!(
post_overflow.len() >= anchor_count,
"post-overflow context should have at least {anchor_count} messages, got {}",
post_overflow.len()
);
let first_text = match &post_overflow[0] {
LlmMessage::User(u) => ContentBlock::extract_text(&u.content),
LlmMessage::Assistant(_) | LlmMessage::ToolResult(_) => String::new(),
};
let second_text = match &post_overflow[1] {
LlmMessage::User(u) => ContentBlock::extract_text(&u.content),
LlmMessage::Assistant(_) | LlmMessage::ToolResult(_) => String::new(),
};
assert!(
first_text.contains("ANCHOR_ONE"),
"first anchor should survive compaction, got: {first_text}"
);
assert!(
second_text.contains("ANCHOR_TWO"),
"second anchor should survive compaction, got: {second_text}"
);
assert!(
post_overflow.len() < 10,
"context should be compacted, got {} messages",
post_overflow.len()
);
}
#[tokio::test]
async fn compacted_context_preserves_tool_pairs() {
let captured_messages: Arc<Mutex<Vec<Vec<LlmMessage>>>> = Arc::new(Mutex::new(Vec::new()));
let stream_fn = Arc::new(MockMessageCapturingStreamFn {
responses: Mutex::new(vec![
tool_call_events("tc_1", "mock_tool", "{}"),
overflow_error_events(),
text_only_events("done"),
]),
captured_messages: Arc::clone(&captured_messages),
});
let tool = Arc::new(MockTool::new("mock_tool"));
let compact = SlidingWindowTransformer::new(10_000, 500, 1);
let mut config = default_config(stream_fn as Arc<dyn StreamFn>);
config.tools = vec![tool];
config.transform_context = Some(Arc::new(
move |msgs: &mut Vec<AgentMessage>, overflow: bool| {
compact.transform(msgs, overflow);
},
));
let mut initial_messages = Vec::new();
for i in 0..6 {
initial_messages.push(large_user_msg(&format!("filler{i}"), 100));
}
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"));
let post_overflow = {
let captured = captured_messages.lock().unwrap();
assert!(
captured.len() >= 3,
"should have at least 3 stream calls, got {}",
captured.len()
);
captured[2].clone()
};
let has_tool_result = post_overflow
.iter()
.any(|m| matches!(m, LlmMessage::ToolResult(_)));
let has_tool_call = post_overflow.iter().any(|m| {
matches!(m, LlmMessage::Assistant(a)
if a.content.iter().any(|b| matches!(b, ContentBlock::ToolCall { .. })))
});
if has_tool_result {
assert!(
has_tool_call,
"tool result survived compaction but its tool call did not"
);
}
}
#[tokio::test]
async fn double_overflow_surfaces_error() {
let capturing_fn = Arc::new(MockContextCapturingStreamFn::new(vec![
overflow_error_events(),
overflow_error_events(),
text_only_events("should not reach this"),
]));
let stream_fn: Arc<dyn StreamFn> = Arc::clone(&capturing_fn) as Arc<dyn StreamFn>;
let overflow_count: Arc<Mutex<usize>> = Arc::new(Mutex::new(0));
let overflow_clone = Arc::clone(&overflow_count);
let mut config = default_config(stream_fn);
config.transform_context = Some(Arc::new(
move |msgs: &mut Vec<AgentMessage>, overflow: bool| {
if overflow {
{
let mut count = overflow_clone.lock().unwrap();
*count += 1;
}
let keep = msgs.len().saturating_sub(2).max(1);
if keep < msgs.len() {
let tail: Vec<AgentMessage> = msgs.drain(keep..).collect();
msgs.clear();
msgs.extend(tail);
}
}
},
));
let mut initial_messages = Vec::new();
for i in 0..10 {
initial_messages.push(large_user_msg(&format!("msg{i}"), 50));
}
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(has_event(&events, "AgentEnd"));
let counts: Vec<usize> = capturing_fn.captured_message_counts.lock().unwrap().clone();
assert_eq!(
counts.len(),
2,
"should have exactly 2 stream calls (initial + retry), got {}",
counts.len()
);
assert!(
counts[1] < counts[0],
"second call should have fewer messages than first: {} vs {}",
counts[1],
counts[0]
);
}
#[tokio::test]
async fn overflow_with_no_compaction_surfaces_error() {
let capturing_fn = Arc::new(MockContextCapturingStreamFn::new(vec![
overflow_error_events(),
text_only_events("should not reach this"),
]));
let stream_fn: Arc<dyn StreamFn> = Arc::clone(&capturing_fn) as Arc<dyn StreamFn>;
let compact = SlidingWindowTransformer::new(10_000, 10, 1);
let mut config = default_config(stream_fn);
config.transform_context = Some(Arc::new(
move |msgs: &mut Vec<AgentMessage>, overflow: bool| {
compact.transform(msgs, overflow);
},
));
let initial_messages = vec![large_user_msg("huge", 500)];
let events = collect_events(agent_loop(
initial_messages,
"system".to_string(),
config,
CancellationToken::new(),
))
.await;
assert!(
has_event(&events, "AgentEnd"),
"loop should complete even when compaction cannot help"
);
let counts = capturing_fn.captured_message_counts.lock().unwrap().clone();
assert_eq!(
counts.len(),
1,
"should have exactly 1 stream call (no retry when no compaction), got {}",
counts.len()
);
}