#![forbid(unsafe_code)]
use std::future::Future;
use tokio_util::sync::CancellationToken;
const MAX_ERROR_BODY_BYTES: usize = 64 * 1024;
#[allow(dead_code)]
pub struct AdapterBase {
pub base_url: String,
pub api_key: String,
pub client: reqwest::Client,
}
impl AdapterBase {
#[allow(dead_code)]
pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
Self {
base_url: base_url.into().trim_end_matches('/').to_string(),
api_key: api_key.into(),
client: reqwest::Client::new(),
}
}
}
impl std::fmt::Debug for AdapterBase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AdapterBase")
.field("base_url", &self.base_url)
.field("api_key", &"[REDACTED]")
.finish_non_exhaustive()
}
}
#[must_use]
pub const fn pre_stream_error(
event: swink_agent::AssistantMessageEvent,
) -> [swink_agent::AssistantMessageEvent; 2] {
[swink_agent::AssistantMessageEvent::Start, event]
}
#[must_use]
pub fn cancelled_error(message: impl Into<String>) -> swink_agent::AssistantMessageEvent {
swink_agent::AssistantMessageEvent::Error {
stop_reason: swink_agent::StopReason::Aborted,
error_message: message.into(),
usage: None,
error_kind: None,
}
}
pub async fn race_pre_stream_cancellation<T, F>(
cancellation_token: &CancellationToken,
cancelled_message: &'static str,
operation: F,
) -> Result<T, swink_agent::AssistantMessageEvent>
where
F: Future<Output = Result<T, swink_agent::AssistantMessageEvent>>,
{
if cancellation_token.is_cancelled() {
return Err(cancelled_error(cancelled_message));
}
tokio::select! {
() = cancellation_token.cancelled() => Err(cancelled_error(cancelled_message)),
result = operation => result,
}
}
pub async fn read_error_body_or_cancelled(
mut response: reqwest::Response,
cancellation_token: &CancellationToken,
cancelled_message: &'static str,
) -> Result<String, swink_agent::AssistantMessageEvent> {
let mut bytes = Vec::new();
let mut truncated = false;
loop {
tokio::select! {
biased;
() = cancellation_token.cancelled() => {
return Err(cancelled_error(cancelled_message));
}
chunk = response.chunk() => {
match chunk {
Ok(Some(chunk)) => {
let remaining = MAX_ERROR_BODY_BYTES.saturating_sub(bytes.len());
if remaining == 0 {
truncated = true;
break;
}
if remaining > 0 {
let take = remaining.min(chunk.len());
bytes.extend_from_slice(&chunk[..take]);
}
if chunk.len() > remaining {
truncated = true;
break;
}
}
Ok(None) | Err(_) => break,
}
}
}
}
let mut body = String::from_utf8_lossy(&bytes).into_owned();
if truncated {
body.push_str("...[truncated]");
}
Ok(body)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
use tokio::net::TcpListener;
#[test]
fn trailing_slash_stripped() {
let base = AdapterBase::new("https://api.example.com/", "key");
assert_eq!(base.base_url, "https://api.example.com");
}
#[test]
fn multiple_trailing_slashes_stripped() {
let base = AdapterBase::new("https://api.example.com///", "key");
assert_eq!(base.base_url, "https://api.example.com");
}
#[test]
fn no_trailing_slash_unchanged() {
let base = AdapterBase::new("https://api.example.com", "key");
assert_eq!(base.base_url, "https://api.example.com");
}
#[test]
fn pre_stream_error_prefixes_start() {
let events = pre_stream_error(swink_agent::AssistantMessageEvent::error("boom"));
assert!(matches!(
events,
[
swink_agent::AssistantMessageEvent::Start,
swink_agent::AssistantMessageEvent::Error { .. }
]
));
}
#[test]
fn cancelled_error_uses_aborted_stop_reason() {
let event = cancelled_error("cancelled");
assert!(matches!(
event,
swink_agent::AssistantMessageEvent::Error {
stop_reason: swink_agent::StopReason::Aborted,
..
}
));
}
#[tokio::test]
async fn race_pre_stream_cancellation_short_circuits() {
let token = CancellationToken::new();
token.cancel();
let result =
race_pre_stream_cancellation(&token, "cancelled", async { Ok::<_, _>("ok") }).await;
assert!(matches!(
result,
Err(swink_agent::AssistantMessageEvent::Error {
stop_reason: swink_agent::StopReason::Aborted,
..
})
));
}
#[tokio::test]
async fn read_error_body_returns_aborted_when_cancelled_mid_body() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
if let Ok((mut socket, _)) = listener.accept().await {
let mut request = [0u8; 1024];
let _ = socket.read(&mut request).await;
let response = concat!(
"HTTP/1.1 500 Internal Server Error\r\n",
"Content-Length: 128\r\n\r\n",
"partial",
);
let _ = socket.write_all(response.as_bytes()).await;
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
}
});
let response = reqwest::Client::new()
.get(format!("http://{addr}/"))
.send()
.await
.unwrap();
let token = CancellationToken::new();
let cancel = token.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
cancel.cancel();
});
let result = read_error_body_or_cancelled(response, &token, "cancelled").await;
assert!(matches!(
result,
Err(swink_agent::AssistantMessageEvent::Error {
stop_reason: swink_agent::StopReason::Aborted,
..
})
));
}
#[tokio::test]
async fn read_error_body_is_size_bounded() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let body = "x".repeat(MAX_ERROR_BODY_BYTES + 16);
tokio::spawn(async move {
if let Ok((mut socket, _)) = listener.accept().await {
let mut request = [0u8; 1024];
let _ = socket.read(&mut request).await;
let header = format!(
"HTTP/1.1 500 Internal Server Error\r\nContent-Length: {}\r\n\r\n",
body.len()
);
let _ = socket.write_all(header.as_bytes()).await;
let _ = socket.write_all(body.as_bytes()).await;
}
});
let response = reqwest::Client::new()
.get(format!("http://{addr}/"))
.send()
.await
.unwrap();
let token = CancellationToken::new();
let body = read_error_body_or_cancelled(response, &token, "cancelled")
.await
.unwrap();
assert_eq!(body.len(), MAX_ERROR_BODY_BYTES + "...[truncated]".len());
assert!(body.ends_with("...[truncated]"));
}
}