#![cfg(feature = "proxy")]
mod common;
use std::sync::{Arc, Mutex};
use futures::StreamExt;
use swink_agent::{
AssistantMessageEvent, ContentBlock, ModelSpec, StopReason, StreamFn, StreamOptions,
accumulate_message,
};
use swink_agent_adapters::ProxyStreamFn;
use tokio_util::sync::CancellationToken;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use common::{sse_response, test_context};
fn test_model() -> ModelSpec {
ModelSpec::new("test", "test-model")
}
fn text_only_sse_body() -> String {
[
r#"data: {"type":"start"}"#,
"",
r#"data: {"type":"text_start","content_index":0}"#,
"",
r#"data: {"type":"text_delta","content_index":0,"delta":"hello"}"#,
"",
r#"data: {"type":"text_end","content_index":0}"#,
"",
r#"data: {"type":"done","stop_reason":"stop","usage":{"input":10,"output":20,"cache_read":0,"cache_write":0,"total":30},"cost":{"input":0.01,"output":0.02,"cache_read":0.0,"cache_write":0.0,"total":0.03}}"#,
"",
"",
]
.join("\n")
}
fn text_and_tool_call_sse_body() -> String {
[
r#"data: {"type":"start"}"#,
"",
r#"data: {"type":"text_start","content_index":0}"#,
"",
r#"data: {"type":"text_delta","content_index":0,"delta":"Let me read that."}"#,
"",
r#"data: {"type":"text_end","content_index":0}"#,
"",
r#"data: {"type":"tool_call_start","content_index":1,"id":"tc_1","name":"read_file"}"#,
"",
r#"data: {"type":"tool_call_delta","content_index":1,"delta":"{\"path\":"}"#,
"",
r#"data: {"type":"tool_call_delta","content_index":1,"delta":"\"foo.rs\"}"}"#,
"",
r#"data: {"type":"tool_call_end","content_index":1}"#,
"",
r#"data: {"type":"done","stop_reason":"tool_use","usage":{"input":15,"output":25,"cache_read":0,"cache_write":0,"total":40},"cost":{"input":0.015,"output":0.025,"cache_read":0.0,"cache_write":0.0,"total":0.04}}"#,
"",
"",
]
.join("\n")
}
async fn collect_events(proxy: &ProxyStreamFn) -> Vec<AssistantMessageEvent> {
collect_events_with_options(proxy, StreamOptions::default()).await
}
async fn collect_events_with_options(
proxy: &ProxyStreamFn,
options: StreamOptions,
) -> Vec<AssistantMessageEvent> {
let model = test_model();
let context = test_context();
let token = CancellationToken::new();
let stream = proxy.stream(&model, &context, &options, token);
stream.collect::<Vec<_>>().await
}
#[tokio::test]
async fn successful_stream_reconstructs_message() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(sse_response(&text_only_sse_body()))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "test-token");
let events = collect_events(&proxy).await;
let msg = accumulate_message(events, "test", "test-model").expect("accumulate should succeed");
assert_eq!(msg.content.len(), 1);
assert_eq!(
msg.content[0],
ContentBlock::Text {
text: "hello".into(),
}
);
assert_eq!(msg.stop_reason, StopReason::Stop);
assert_eq!(msg.usage.input, 10);
assert_eq!(msg.usage.output, 20);
assert!((msg.cost.total - 0.03).abs() < f64::EPSILON);
}
#[tokio::test]
async fn text_and_tool_call_reconstruct_correctly() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(sse_response(&text_and_tool_call_sse_body()))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "test-token");
let events = collect_events(&proxy).await;
let msg = accumulate_message(events, "test", "test-model").expect("accumulate should succeed");
assert_eq!(msg.content.len(), 2);
assert_eq!(
msg.content[0],
ContentBlock::Text {
text: "Let me read that.".into(),
}
);
match &msg.content[1] {
ContentBlock::ToolCall {
id,
name,
arguments,
..
} => {
assert_eq!(id, "tc_1");
assert_eq!(name, "read_file");
assert_eq!(arguments["path"], "foo.rs");
}
other => panic!("expected ToolCall, got {other:?}"),
}
assert_eq!(msg.stop_reason, StopReason::ToolUse);
}
#[tokio::test]
async fn bearer_token_sent_in_auth_header() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.and(header("Authorization", "Bearer test-token"))
.respond_with(sse_response(&text_only_sse_body()))
.expect(1)
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "test-token");
let events = collect_events(&proxy).await;
let has_start = events
.iter()
.any(|e| matches!(e, AssistantMessageEvent::Start));
assert!(has_start, "expected Start event from authenticated request");
}
#[tokio::test]
async fn connection_failure_produces_network_error() {
let proxy = ProxyStreamFn::new("http://127.0.0.1:1", "token");
let events = collect_events(&proxy).await;
assert_eq!(events.len(), 2);
assert!(matches!(events[0], AssistantMessageEvent::Start));
match &events[1] {
AssistantMessageEvent::Error {
error_message,
stop_reason,
..
} => {
assert_eq!(*stop_reason, StopReason::Error);
assert!(
error_message.contains("network error"),
"expected 'network error', got: {error_message}"
);
}
other => panic!("expected Error event, got {other:?}"),
}
}
#[tokio::test]
async fn http_401_produces_auth_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(ResponseTemplate::new(401))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "bad-token");
let events = collect_events(&proxy).await;
assert_eq!(events.len(), 2);
assert!(matches!(events[0], AssistantMessageEvent::Start));
match &events[1] {
AssistantMessageEvent::Error {
error_message,
stop_reason,
..
} => {
assert_eq!(*stop_reason, StopReason::Error);
assert!(
error_message.contains("auth error"),
"expected 'auth error', got: {error_message}"
);
}
other => panic!("expected Error event, got {other:?}"),
}
}
#[tokio::test]
async fn http_429_produces_rate_limit_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(ResponseTemplate::new(429))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "token");
let events = collect_events(&proxy).await;
assert_eq!(events.len(), 2);
assert!(matches!(events[0], AssistantMessageEvent::Start));
match &events[1] {
AssistantMessageEvent::Error {
error_message,
stop_reason,
..
} => {
assert_eq!(*stop_reason, StopReason::Error);
assert!(
error_message.contains("rate limit"),
"expected 'rate limit', got: {error_message}"
);
assert!(
error_message.contains("429"),
"expected '429', got: {error_message}"
);
}
other => panic!("expected Error event, got {other:?}"),
}
}
#[tokio::test]
async fn proxy_on_raw_payload_observes_runtime_sse_lines() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(sse_response(&text_only_sse_body()))
.mount(&server)
.await;
let observed = Arc::new(Mutex::new(Vec::<String>::new()));
let callback_lines = Arc::clone(&observed);
let options = StreamOptions {
on_raw_payload: Some(Arc::new(move |line| {
callback_lines
.lock()
.expect("callback buffer poisoned")
.push(line.to_owned());
})),
..StreamOptions::default()
};
let proxy = ProxyStreamFn::new(server.uri(), "test-token");
let events = collect_events_with_options(&proxy, options).await;
let observed = observed.lock().expect("callback buffer poisoned").clone();
assert!(
events
.iter()
.any(|event| matches!(event, AssistantMessageEvent::Done { .. })),
"expected runtime stream to complete successfully"
);
assert_eq!(
observed,
vec![
r#"{"type":"start"}"#.to_string(),
r#"{"type":"text_start","content_index":0}"#.to_string(),
r#"{"type":"text_delta","content_index":0,"delta":"hello"}"#.to_string(),
r#"{"type":"text_end","content_index":0}"#.to_string(),
r#"{"type":"done","stop_reason":"stop","usage":{"input":10,"output":20,"cache_read":0,"cache_write":0,"total":30},"cost":{"input":0.01,"output":0.02,"cache_read":0.0,"cache_write":0.0,"total":0.03}}"#.to_string(),
]
);
}
#[tokio::test]
async fn http_504_produces_network_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(ResponseTemplate::new(504))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "token");
let events = collect_events(&proxy).await;
assert_eq!(events.len(), 2);
assert!(matches!(events[0], AssistantMessageEvent::Start));
match &events[1] {
AssistantMessageEvent::Error {
error_message,
stop_reason,
..
} => {
assert_eq!(*stop_reason, StopReason::Error);
assert!(
error_message.contains("server error"),
"expected 'server error', got: {error_message}"
);
}
other => panic!("expected Error event, got {other:?}"),
}
}
#[tokio::test]
async fn malformed_sse_event_produces_stream_error() {
let body = [
r#"data: {"type":"start"}"#,
"",
"data: {not valid json at all!!!}",
"",
"",
]
.join("\n");
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(sse_response(&body))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "token");
let events = collect_events(&proxy).await;
let error = events
.iter()
.find(|event| matches!(event, AssistantMessageEvent::Error { .. }))
.expect("expected terminal error");
match error {
AssistantMessageEvent::Error {
error_message,
error_kind,
..
} => {
assert!(
error_message.contains("malformed SSE event JSON"),
"expected malformed-JSON diagnostic, got: {error_message}"
);
assert_eq!(
*error_kind, None,
"malformed proxy payloads must not be marked retryable network errors"
);
}
other => panic!("expected Error event, got {other:?}"),
}
}
#[tokio::test]
async fn mid_stream_disconnect_produces_network_error() {
let body = [
r#"data: {"type":"start"}"#,
"",
r#"data: {"type":"text_start","content_index":0}"#,
"",
r#"data: {"type":"text_delta","content_index":0,"delta":"partial"}"#,
"",
"",
]
.join("\n");
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(sse_response(&body))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "token");
let events = collect_events(&proxy).await;
let last = events.last().expect("should have at least one event");
match last {
AssistantMessageEvent::Error {
error_message,
stop_reason,
..
} => {
assert_eq!(*stop_reason, StopReason::Error);
assert!(
error_message.contains("network error"),
"expected 'network error', got: {error_message}"
);
}
other => panic!("expected terminal Error event, got {other:?}"),
}
}
#[tokio::test]
async fn done_sentinel_without_protocol_terminal_produces_stream_error() {
let body = [
r#"data: {"type":"start"}"#,
"",
r#"data: {"type":"text_start","content_index":0}"#,
"",
r#"data: {"type":"text_delta","content_index":0,"delta":"partial"}"#,
"",
r#"data: {"type":"text_end","content_index":0}"#,
"",
"data: [DONE]",
"",
"",
]
.join("\n");
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(sse_response(&body))
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "token");
let events = collect_events(&proxy).await;
assert!(
!events
.iter()
.any(|event| matches!(event, AssistantMessageEvent::Done { .. })),
"expected missing protocol terminal to avoid emitting Done: {events:?}"
);
let last = events.last().expect("should have at least one event");
match last {
AssistantMessageEvent::Error {
error_message,
stop_reason,
error_kind,
..
} => {
assert_eq!(*stop_reason, StopReason::Error);
assert!(
error_message.contains("protocol terminal event"),
"expected terminal-event diagnostic, got: {error_message}"
);
assert_eq!(
*error_kind, None,
"protocol faults must not be marked retryable network errors"
);
}
other => panic!("expected terminal Error event, got {other:?}"),
}
}
#[tokio::test]
async fn cancellation_yields_aborted() {
let body = text_only_sse_body();
let slow_response = sse_response(&body).set_delay(std::time::Duration::from_secs(30));
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/stream"))
.respond_with(slow_response)
.mount(&server)
.await;
let proxy = ProxyStreamFn::new(server.uri(), "token");
let model = test_model();
let context = test_context();
let options = StreamOptions::default();
let token = CancellationToken::new();
let cancel_token = token.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
cancel_token.cancel();
});
let stream = proxy.stream(&model, &context, &options, token);
let events: Vec<_> = stream.collect().await;
let has_aborted = events.iter().any(|e| match e {
AssistantMessageEvent::Error { stop_reason, .. } => *stop_reason == StopReason::Aborted,
_ => false,
});
assert!(
has_aborted,
"expected an Aborted event after cancellation, got: {events:?}"
);
}