#![cfg(feature = "testkit")]
mod common;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use common::{
MockStreamFn, default_convert, default_model, error_events, text_only_events, tool_call_events,
user_msg,
};
use futures::stream::StreamExt;
use serde_json::json;
use swink_agent::{
Agent, AgentError, AgentMessage, AgentOptions, CustomMessage, DefaultRetryStrategy, StopReason,
StreamFn,
};
fn make_agent(stream_fn: Arc<dyn StreamFn>) -> Agent {
Agent::new(
AgentOptions::new(
"test system prompt",
default_model(),
stream_fn,
default_convert,
)
.with_retry_strategy(Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
)),
)
}
#[tokio::test]
async fn subscribe_receives_events() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("subscribed")]));
let mut agent = make_agent(stream_fn);
let events_received = Arc::new(AtomicU32::new(0));
let events_clone = Arc::clone(&events_received);
let id = agent.subscribe(move |_event| {
events_clone.fetch_add(1, Ordering::SeqCst);
});
let _ = id;
let _result = agent.prompt_async(vec![user_msg("Hi")]).await.unwrap();
let count = events_received.load(Ordering::SeqCst);
assert!(
count > 0,
"subscriber should have received events, got {count}"
);
}
#[tokio::test]
async fn unsubscribe_removes_listener() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
text_only_events("first"),
text_only_events("second"),
]));
let mut agent = make_agent(stream_fn);
let events_received = Arc::new(AtomicU32::new(0));
let events_clone = Arc::clone(&events_received);
let id = agent.subscribe(move |_event| {
events_clone.fetch_add(1, Ordering::SeqCst);
});
agent.follow_up(user_msg("follow up"));
let _result = agent.prompt_async(vec![user_msg("Hi")]).await.unwrap();
let count_after_first = events_received.load(Ordering::SeqCst);
assert!(count_after_first > 0, "should have received events");
let removed = agent.unsubscribe(id);
assert!(removed, "unsubscribe should return true for existing id");
let removed_again = agent.unsubscribe(id);
assert!(!removed_again, "second unsubscribe should return false");
}
#[tokio::test]
async fn subscriber_panic_does_not_crash() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("safe")]));
let mut agent = make_agent(stream_fn);
let _panic_id = agent.subscribe(|_event| {
panic!("subscriber panic test");
});
let good_events = Arc::new(AtomicU32::new(0));
let good_clone = Arc::clone(&good_events);
let _good_id = agent.subscribe(move |_event| {
good_clone.fetch_add(1, Ordering::SeqCst);
});
let result = agent.prompt_async(vec![user_msg("Hi")]).await.unwrap();
assert_eq!(result.stop_reason, StopReason::Stop);
let good_count = good_events.load(Ordering::SeqCst);
assert!(
good_count > 0,
"good subscriber should still receive events despite panicking sibling"
);
}
#[tokio::test]
async fn structured_output_valid() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "integer" }
},
"required": ["name", "age"]
});
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events(
"so_1",
"__structured_output",
r#"{"name": "Alice", "age": 30}"#,
),
text_only_events("done"),
]));
let mut agent = make_agent(stream_fn);
let value = agent
.structured_output("Extract name and age".to_string(), schema)
.await
.unwrap();
assert_eq!(value["name"], "Alice");
assert_eq!(value["age"], 30);
}
#[tokio::test]
async fn structured_output_retries() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events("so_1", "__structured_output", r"{}"),
text_only_events("done"),
tool_call_events("so_2", "__structured_output", r#"{"name": "Bob"}"#),
text_only_events("done"),
]));
let mut agent = Agent::new(
AgentOptions::new("test", default_model(), stream_fn, default_convert)
.with_retry_strategy(Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
))
.with_structured_output_max_retries(3),
);
let value = agent
.structured_output("Extract name".to_string(), schema)
.await
.unwrap();
assert_eq!(value["name"], "Bob");
}
#[tokio::test]
async fn structured_output_fails_after_max_retries() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let stream_fn = Arc::new(MockStreamFn::new(vec![
tool_call_events("so_1", "__structured_output", r"{}"),
text_only_events("done"),
tool_call_events("so_2", "__structured_output", r"{}"),
text_only_events("done"),
tool_call_events("so_3", "__structured_output", r"{}"),
text_only_events("done"),
]));
let mut agent = Agent::new(
AgentOptions::new("test", default_model(), stream_fn, default_convert)
.with_retry_strategy(Box::new(
DefaultRetryStrategy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(1)),
))
.with_structured_output_max_retries(2),
);
let err = agent
.structured_output("Extract name".to_string(), schema)
.await
.unwrap_err();
assert!(
matches!(err, AgentError::StructuredOutputFailed { attempts, .. } if attempts == 3),
"expected StructuredOutputFailed with 3 attempts, got {err:?}"
);
assert!(
agent
.state()
.tools
.iter()
.all(|tool| tool.name() != "__structured_output"),
"synthetic structured output tool should always be removed after failure"
);
}
#[tokio::test]
async fn multi_turn_via_prompt_stream_and_handle_stream_event() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
text_only_events("first response"),
text_only_events("second response"),
]));
let mut agent = make_agent(stream_fn);
{
let stream = agent.prompt_stream(vec![user_msg("hello")]).unwrap();
let mut stream = std::pin::pin!(stream);
while let Some(event) = stream.next().await {
agent.handle_stream_event(&event);
}
}
assert!(
!agent.state().is_running,
"agent should be idle after stream is fully consumed"
);
assert!(
!agent.state().messages.is_empty(),
"agent state should have messages after first turn"
);
{
let stream = agent
.prompt_stream(vec![user_msg("follow up")])
.expect("second prompt_stream should not return AlreadyRunning");
let mut stream = std::pin::pin!(stream);
while let Some(event) = stream.next().await {
agent.handle_stream_event(&event);
}
}
assert!(
!agent.state().is_running,
"agent should be idle after second turn"
);
assert!(
agent.state().messages.len() >= 4,
"state should have messages from both turns (2 user + 2 assistant), got {}",
agent.state().messages.len()
);
}
#[tokio::test]
async fn handle_stream_event_dispatches_to_subscribers() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("hello")]));
let mut agent = make_agent(stream_fn);
let event_names: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let names_clone = Arc::clone(&event_names);
let _id = agent.subscribe(move |event| {
let name = format!("{event:?}");
let prefix = name.split([' ', '{', '(']).next().unwrap_or("").to_string();
names_clone.lock().unwrap().push(prefix);
});
let stream = agent.prompt_stream(vec![user_msg("hi")]).unwrap();
let mut stream = std::pin::pin!(stream);
while let Some(event) = stream.next().await {
agent.handle_stream_event(&event);
}
let collected = event_names.lock().unwrap().clone();
assert!(
collected.contains(&"AgentStart".to_string()),
"subscriber should receive AgentStart"
);
assert!(
collected.contains(&"AgentEnd".to_string()),
"subscriber should receive AgentEnd"
);
}
#[tokio::test]
async fn prompt_stream_without_handle_stream_event_stays_running() {
let stream_fn = Arc::new(MockStreamFn::new(vec![
text_only_events("first"),
text_only_events("second"),
]));
let mut agent = make_agent(stream_fn);
let stream = agent.prompt_stream(vec![user_msg("hello")]).unwrap();
let mut stream = std::pin::pin!(stream);
while let Some(_event) = stream.next().await {}
assert!(
agent.state().is_running,
"agent should still think it is running"
);
let err = agent.prompt_stream(vec![user_msg("follow up")]);
assert!(
matches!(err, Err(AgentError::AlreadyRunning)),
"second prompt should fail with AlreadyRunning"
);
}
#[tokio::test]
async fn handle_stream_event_preserves_terminal_error_through_agent_end() {
let stream_fn = Arc::new(MockStreamFn::new(vec![error_events("fatal error", None)]));
let mut agent = make_agent(stream_fn);
let stream = agent
.prompt_stream(vec![user_msg("trigger error")])
.unwrap();
let mut stream = std::pin::pin!(stream);
while let Some(event) = stream.next().await {
agent.handle_stream_event(&event);
}
assert!(
!agent.state().is_running,
"agent should be idle after error stream completes"
);
assert_eq!(
agent.state().error.as_deref(),
Some("stream error: fatal error"),
"terminal error must survive through AgentEnd"
);
}
#[derive(Debug, Clone)]
struct CloneableCustomMsg {
label: String,
}
impl CustomMessage for CloneableCustomMsg {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Option<Box<dyn CustomMessage>> {
Some(Box::new(self.clone()))
}
}
#[tokio::test]
async fn handle_stream_event_preserves_custom_messages() {
let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("response")]));
let mut agent = make_agent(stream_fn);
let input = vec![
user_msg("hello"),
AgentMessage::Custom(Box::new(CloneableCustomMsg {
label: "my-custom".to_string(),
})),
];
let stream = agent.prompt_stream(input).unwrap();
let mut stream = std::pin::pin!(stream);
while let Some(event) = stream.next().await {
agent.handle_stream_event(&event);
}
assert!(!agent.state().is_running, "agent should be idle");
let custom_count = agent
.state()
.messages
.iter()
.filter(|m| matches!(m, AgentMessage::Custom(_)))
.count();
assert_eq!(
custom_count, 1,
"custom message must survive stream-driven state rebuild (got {custom_count})"
);
let custom = agent
.state()
.messages
.iter()
.find_map(|m| match m {
AgentMessage::Custom(cm) => cm.as_any().downcast_ref::<CloneableCustomMsg>(),
_ => None,
})
.expect("should find CloneableCustomMsg in rebuilt state");
assert_eq!(custom.label, "my-custom");
}