use crate::agent::messages::{Content, Message, Role};
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::{self, BoxStream, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::time::{sleep, Duration};
use crate::tools::Tool;
#[derive(Debug, Clone, PartialEq)]
pub enum ResponseChunk {
TextDelta(String),
ToolUseInputDelta { id: String, input_json: String },
MessageDone(Message),
}
#[async_trait]
pub trait Provider: Send + Sync {
async fn stream_messages(
&self,
messages: &[Message],
tools: &[Box<dyn Tool>],
) -> Result<BoxStream<'static, Result<ResponseChunk>>>;
fn is_static(&self) -> bool {
false
}
#[allow(dead_code)]
async fn send_messages(
&self,
messages: &[Message],
tools: &[Box<dyn Tool>],
) -> Result<Message> {
let mut attempts = 0;
let max_attempts = 3;
loop {
attempts += 1;
match self.stream_messages(messages, tools).await {
Ok(mut stream) => {
let mut last_message = None;
let mut full_text = String::new();
let role = Role::Assistant;
while let Some(chunk_result) = stream.next().await {
match chunk_result? {
ResponseChunk::MessageDone(msg) => {
last_message = Some(msg);
}
ResponseChunk::TextDelta(t) => {
full_text.push_str(&t);
}
_ => {}
}
}
if let Some(msg) = last_message {
return Ok(msg);
}
if !full_text.is_empty() {
return Ok(Message {
role,
content: vec![Content::Text { text: full_text }],
});
}
return Err(anyhow::anyhow!(
"Stream ended without MessageDone or content"
));
}
Err(e) if attempts < max_attempts && e.to_string().contains("429") => {
let wait_secs = 2_u64.pow(attempts as u32);
sleep(Duration::from_secs(wait_secs)).await;
continue;
}
Err(e) => return Err(e),
}
}
}
}
const MAX_SSE_BUFFER_BYTES: usize = 8 * 1024 * 1024;
fn parse_tool_input(acc: &str) -> Result<serde_json::Value, String> {
if acc.trim().is_empty() {
return Ok(serde_json::Value::Object(serde_json::Map::new()));
}
serde_json::from_str(acc).map_err(|e| e.to_string())
}
fn drain_sse_events(buffer: &mut Vec<u8>) -> Vec<String> {
let mut blocks = Vec::new();
while let Some(pos) = buffer.windows(2).position(|w| w == b"\n\n") {
let block: Vec<u8> = buffer.drain(..pos + 2).collect();
blocks.push(String::from_utf8_lossy(&block).into_owned());
}
blocks
}
fn finalize_tool(tool: Option<(String, String, String)>, full_content: &mut Vec<Content>) {
if let Some((id, name, acc)) = tool {
let input = parse_tool_input(&acc).unwrap_or_else(|e| {
eprintln!(
"WARNING: malformed tool_use input JSON for tool '{}' (id {}): {}; using empty object",
name, id, e
);
serde_json::Value::Object(serde_json::Map::new())
});
full_content.push(Content::ToolUse { id, name, input });
}
}
pub struct StaticProvider;
#[async_trait]
impl Provider for StaticProvider {
async fn stream_messages(
&self,
_messages: &[Message],
_tools: &[Box<dyn Tool>],
) -> Result<BoxStream<'static, Result<ResponseChunk>>> {
let msg = Message::assistant("I am a Rust-powered assistant. How can I help you today?");
let chunks = vec![
Ok(ResponseChunk::TextDelta(
"I am a Rust-powered assistant. ".to_string(),
)),
Ok(ResponseChunk::TextDelta(
"How can I help you today?".to_string(),
)),
Ok(ResponseChunk::MessageDone(msg)),
];
Ok(Box::pin(stream::iter(chunks)))
}
fn is_static(&self) -> bool {
true
}
}
#[derive(Debug, Serialize)]
struct AnthropicTool {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
messages: Vec<Message>,
max_tokens: u32,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<AnthropicTool>,
stream: bool,
}
#[derive(Debug, Deserialize)]
struct AnthropicErrorResponse {
error: AnthropicErrorDetail,
}
#[derive(Debug, Deserialize)]
struct AnthropicErrorDetail {
#[serde(rename = "type")]
error_type: String,
message: String,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicSseEvent {
MessageStart {
message: AnthropicMessageStart,
},
ContentBlockStart {
index: usize,
content_block: serde_json::Value,
},
ContentBlockDelta {
index: usize,
delta: AnthropicDelta,
},
ContentBlockStop {
index: usize,
},
MessageDelta {
delta: serde_json::Value,
usage: serde_json::Value,
},
MessageStop,
Ping,
Error {
error: AnthropicErrorDetail,
},
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct AnthropicMessageStart {
id: String,
role: Role,
model: String,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicDelta {
TextDelta {
text: String,
},
#[serde(rename = "input_json_delta")]
InputDelta {
partial_json: String,
},
}
#[derive(Debug, Serialize)]
struct OpenAiRequest {
model: String,
messages: Vec<OpenAiMessage>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<OpenAiTool>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct OpenAiMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OpenAiToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Debug, Serialize)]
struct OpenAiToolCall {
id: String,
#[serde(rename = "type")]
kind: String,
function: OpenAiToolCallFn,
}
#[derive(Debug, Serialize)]
struct OpenAiToolCallFn {
name: String,
arguments: String,
}
#[derive(Debug, Serialize)]
struct OpenAiTool {
#[serde(rename = "type")]
kind: String,
function: OpenAiFunction,
}
#[derive(Debug, Serialize)]
struct OpenAiFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
fn map_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
let mut out = Vec::new();
for m in messages {
match m.role {
Role::Assistant => {
let mut text: Option<String> = None;
let mut calls: Vec<OpenAiToolCall> = Vec::new();
for c in &m.content {
match c {
Content::Text { text: t } => {
text.get_or_insert_with(String::new).push_str(t);
}
Content::ToolUse { id, name, input } => calls.push(OpenAiToolCall {
id: id.clone(),
kind: "function".into(),
function: OpenAiToolCallFn {
name: name.clone(),
arguments: input.to_string(),
},
}),
Content::ToolResult { .. } => {} }
}
if text.is_some() || !calls.is_empty() {
out.push(OpenAiMessage {
role: "assistant".into(),
content: text,
tool_calls: if calls.is_empty() { None } else { Some(calls) },
tool_call_id: None,
});
}
}
Role::User => {
for c in &m.content {
match c {
Content::Text { text: t } => out.push(OpenAiMessage {
role: "user".into(),
content: Some(t.clone()),
tool_calls: None,
tool_call_id: None,
}),
Content::ToolResult {
tool_use_id,
content,
..
} => out.push(OpenAiMessage {
role: "tool".into(),
content: Some(content.clone()),
tool_calls: None,
tool_call_id: Some(tool_use_id.clone()),
}),
Content::ToolUse { .. } => {} }
}
}
}
}
out
}
fn map_tools(tools: &[Box<dyn Tool>]) -> Vec<OpenAiTool> {
tools
.iter()
.map(|t| OpenAiTool {
kind: "function".into(),
function: OpenAiFunction {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.input_schema(),
},
})
.collect()
}
#[derive(Debug, Deserialize)]
struct OpenAiStreamChunk {
#[serde(default)]
choices: Vec<OpenAiChoice>,
}
#[derive(Debug, Deserialize)]
struct OpenAiChoice {
#[serde(default)]
delta: OpenAiStreamDelta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Default, Deserialize)]
struct OpenAiStreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
}
#[derive(Debug, Deserialize)]
struct OpenAiToolCallDelta {
index: usize,
#[serde(default)]
id: Option<String>,
#[serde(default)]
function: Option<OpenAiFnDelta>,
}
#[derive(Debug, Default, Deserialize)]
struct OpenAiFnDelta {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: Option<String>,
}
pub struct OpenAiSettings {
pub base_url: String,
pub api_key: String,
pub model: String,
}
const MAX_TOOL_CALL_SLOTS: usize = 64;
pub struct OpenAiCompatibleProvider {
client: reqwest::Client,
base_url: String,
api_key: String,
model: String,
}
impl OpenAiCompatibleProvider {
pub fn new(s: OpenAiSettings) -> Self {
Self {
client: reqwest::Client::new(),
base_url: s.base_url,
api_key: s.api_key,
model: s.model,
}
}
}
struct OaiState {
src: BoxStream<'static, Result<Vec<u8>>>,
buffer: Vec<u8>,
full_content: Vec<Content>,
tool_accs: Vec<(String, String, String)>,
pending: std::collections::VecDeque<Result<ResponseChunk>>,
done: bool,
src_done: bool,
}
impl OaiState {
fn finalize(&mut self) {
if self.done {
return;
}
for acc in std::mem::take(&mut self.tool_accs) {
if acc.0.is_empty() && acc.1.is_empty() {
continue;
}
finalize_tool(Some(acc), &mut self.full_content);
}
self.pending
.push_back(Ok(ResponseChunk::MessageDone(Message {
role: Role::Assistant,
content: self.full_content.clone(),
})));
self.done = true;
}
fn process_buffer(&mut self) {
if self.done {
return;
}
for block in drain_sse_events(&mut self.buffer) {
if self.done {
continue;
}
for line in block.lines() {
let Some(rest) = line.strip_prefix("data:") else {
continue;
};
let data = rest.trim_start();
if data.trim() == "[DONE]" {
self.finalize();
continue;
}
let Ok(parsed) = serde_json::from_str::<OpenAiStreamChunk>(data) else {
continue;
};
let Some(choice) = parsed.choices.into_iter().next() else {
continue;
};
if let Some(text) = choice.delta.content {
if let Some(Content::Text { text: existing }) = self.full_content.last_mut() {
existing.push_str(&text);
} else {
self.full_content.push(Content::Text { text: text.clone() });
}
self.pending.push_back(Ok(ResponseChunk::TextDelta(text)));
}
if let Some(tcs) = choice.delta.tool_calls {
for tc in tcs {
if tc.index >= MAX_TOOL_CALL_SLOTS {
eprintln!(
"WARNING: tool_call index {} exceeds cap {}; dropping",
tc.index, MAX_TOOL_CALL_SLOTS
);
continue;
}
if self.tool_accs.len() <= tc.index {
self.tool_accs.resize(
tc.index + 1,
(String::new(), String::new(), String::new()),
);
}
let slot = &mut self.tool_accs[tc.index];
if let Some(id) = tc.id {
if !id.is_empty() {
slot.0 = id;
}
}
if let Some(f) = tc.function {
if let Some(name) = f.name {
if !name.is_empty() {
slot.1 = name;
}
}
if let Some(args) = f.arguments {
if slot.0.is_empty() && slot.1.is_empty() {
eprintln!(
"WARNING: tool_call args fragment arrived at slot {} before id/name; skipping",
tc.index
);
continue;
}
slot.2.push_str(&args);
self.pending.push_back(Ok(ResponseChunk::ToolUseInputDelta {
id: slot.0.clone(),
input_json: args,
}));
}
}
}
}
if choice.finish_reason.is_some() {
self.finalize();
}
}
}
}
}
#[async_trait]
impl Provider for OpenAiCompatibleProvider {
async fn stream_messages(
&self,
messages: &[Message],
tools: &[Box<dyn Tool>],
) -> Result<BoxStream<'static, Result<ResponseChunk>>> {
let url = format!("{}/chat/completions", self.base_url);
let request = OpenAiRequest {
model: self.model.clone(),
messages: map_messages(messages),
tools: map_tools(tools),
stream: true,
};
let response = self
.client
.post(&url)
.header("authorization", format!("Bearer {}", self.api_key))
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!("OpenAI API Error [{}]: {}", status, body));
}
let st = OaiState {
src: response
.bytes_stream()
.map(|r| {
r.map(|b| b.to_vec())
.map_err(|e| anyhow::anyhow!("Network error: {}", e))
})
.boxed(),
buffer: Vec::new(),
full_content: Vec::new(),
tool_accs: Vec::new(),
pending: std::collections::VecDeque::new(),
done: false,
src_done: false,
};
let out = stream::unfold(st, |mut st| async move {
loop {
if let Some(item) = st.pending.pop_front() {
return Some((item, st));
}
if st.src_done {
return None;
}
match st.src.next().await {
Some(Ok(chunk)) => {
if st.buffer.len() + chunk.len() > MAX_SSE_BUFFER_BYTES {
st.pending.push_back(Err(anyhow::anyhow!(
"SSE buffer would exceed {} bytes without an event boundary; aborting to avoid OOM (limit: 8 MiB)",
MAX_SSE_BUFFER_BYTES
)));
st.src_done = true;
} else {
st.buffer.extend_from_slice(&chunk);
st.process_buffer();
}
}
Some(Err(e)) => {
st.pending.push_back(Err(e));
st.src_done = true;
}
None => {
st.src_done = true;
if !st.buffer.is_empty() {
st.buffer.extend_from_slice(b"\n\n");
st.process_buffer();
}
st.finalize();
}
}
}
});
Ok(Box::pin(out))
}
}
pub struct AnthropicProvider {
client: reqwest::Client,
api_key: String,
model: String,
base_url: String,
}
impl AnthropicProvider {
pub fn new(api_key: String, model: String) -> Self {
Self {
client: reqwest::Client::new(),
api_key,
model,
base_url: "https://api.anthropic.com/v1".to_string(),
}
}
#[cfg(test)]
pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
Self {
client: reqwest::Client::new(),
api_key,
model,
base_url,
}
}
}
#[async_trait]
impl Provider for AnthropicProvider {
async fn stream_messages(
&self,
messages: &[Message],
tools: &[Box<dyn Tool>],
) -> Result<BoxStream<'static, Result<ResponseChunk>>> {
let url = format!("{}/messages", self.base_url);
let anthropic_tools: Vec<AnthropicTool> = tools
.iter()
.map(|t| AnthropicTool {
name: t.name().to_string(),
description: t.description().to_string(),
input_schema: t.input_schema(),
})
.collect();
let request = AnthropicRequest {
model: self.model.clone(),
messages: messages.to_vec(),
max_tokens: 4096,
tools: anthropic_tools,
stream: true,
};
let response = self
.client
.post(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await?;
if let Ok(error_res) = serde_json::from_str::<AnthropicErrorResponse>(&body) {
return Err(anyhow::anyhow!(
"Anthropic API Error [{}] ({}): {}",
status,
error_res.error.error_type,
error_res.error.message
));
}
return Err(anyhow::anyhow!(
"Anthropic API Error [{}]: Raw Body: {}",
status,
body
));
}
let bytes_stream = response.bytes_stream();
let mut buffer: Vec<u8> = Vec::new();
let mut full_content: Vec<Content> = Vec::new();
let mut current_role = Role::Assistant;
let mut current_tool: Option<(String, String, String)> = None;
let output_stream = bytes_stream.flat_map(move |chunk_res| {
let chunk = match chunk_res {
Ok(c) => c,
Err(e) => {
return stream::iter(vec![Err(anyhow::anyhow!("Network error: {}", e))]).boxed()
}
};
if buffer.len() + chunk.len() > MAX_SSE_BUFFER_BYTES {
return stream::iter(vec![Err(anyhow::anyhow!(
"SSE buffer would exceed {} bytes without an event boundary; aborting to avoid OOM (limit: 8 MiB)",
MAX_SSE_BUFFER_BYTES
))])
.boxed();
}
buffer.extend_from_slice(&chunk);
let mut chunks = Vec::new();
for block in drain_sse_events(&mut buffer) {
for line in block.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if let Ok(event) = serde_json::from_str::<AnthropicSseEvent>(data) {
match event {
AnthropicSseEvent::MessageStart { message } => {
current_role = message.role;
}
AnthropicSseEvent::ContentBlockStart {
content_block, ..
} => {
finalize_tool(current_tool.take(), &mut full_content);
if content_block
.get("type")
.and_then(|t| t.as_str())
== Some("tool_use")
{
let id = content_block
.get("id")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let name = content_block
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
current_tool = Some((id, name, String::new()));
}
}
AnthropicSseEvent::ContentBlockDelta { delta, .. } => match delta {
AnthropicDelta::TextDelta { text } => {
if let Some(Content::Text { text: existing }) =
full_content.last_mut()
{
existing.push_str(&text);
} else {
full_content.push(Content::Text { text: text.clone() });
}
chunks.push(Ok(ResponseChunk::TextDelta(text)));
}
AnthropicDelta::InputDelta { partial_json } => {
let id = if let Some((id, _, acc)) = current_tool.as_mut() {
acc.push_str(&partial_json);
id.clone()
} else {
String::new()
};
chunks.push(Ok(ResponseChunk::ToolUseInputDelta {
id,
input_json: partial_json,
}));
}
},
AnthropicSseEvent::ContentBlockStop { .. } => {
finalize_tool(current_tool.take(), &mut full_content);
}
AnthropicSseEvent::MessageStop => {
finalize_tool(current_tool.take(), &mut full_content);
let msg = Message {
role: current_role.clone(),
content: full_content.clone(),
};
chunks.push(Ok(ResponseChunk::MessageDone(msg)));
}
_ => {}
}
}
}
}
}
stream::iter(chunks).boxed()
});
Ok(Box::pin(output_stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::messages::{Content, Role};
use mockito::Server;
use serde_json::json;
#[test]
fn test_parse_tool_input_empty_is_object() {
assert_eq!(parse_tool_input("").unwrap(), json!({}));
assert_eq!(parse_tool_input(" ").unwrap(), json!({}));
}
#[test]
fn test_parse_tool_input_valid_json() {
assert_eq!(
parse_tool_input(r#"{"path":"."}"#).unwrap(),
json!({"path":"."})
);
}
#[test]
fn test_parse_tool_input_malformed_is_err() {
assert!(parse_tool_input(r#"{"path":"#).is_err());
}
#[test]
fn test_drain_sse_events_handles_multibyte_split_across_chunks() {
let mut buf: Vec<u8> = Vec::new();
buf.extend_from_slice("data: caf".as_bytes());
buf.push(0xC3);
assert!(
drain_sse_events(&mut buf).is_empty(),
"no event before the boundary"
);
buf.push(0xA9);
buf.extend_from_slice(b"\n\n");
assert_eq!(
drain_sse_events(&mut buf),
vec!["data: café\n\n".to_string()]
);
assert!(buf.is_empty());
}
#[test]
fn test_drain_sse_events_multiple_events_and_remainder() {
let mut buf: Vec<u8> = b"event: a\n\nevent: b\n\nevent: c-incomplete".to_vec();
assert_eq!(
drain_sse_events(&mut buf),
vec!["event: a\n\n".to_string(), "event: b\n\n".to_string()]
);
assert_eq!(buf, b"event: c-incomplete".to_vec());
}
#[test]
fn test_finalize_tool_pushes_parsed_tooluse() {
let mut content: Vec<Content> = Vec::new();
finalize_tool(
Some(("id1".into(), "ls".into(), r#"{"path":"."}"#.into())),
&mut content,
);
assert_eq!(
content,
vec![Content::ToolUse {
id: "id1".into(),
name: "ls".into(),
input: json!({"path":"."}),
}]
);
}
#[test]
fn test_finalize_tool_empty_input_is_object() {
let mut content: Vec<Content> = Vec::new();
finalize_tool(Some(("id".into(), "n".into(), String::new())), &mut content);
assert_eq!(
content,
vec![Content::ToolUse {
id: "id".into(),
name: "n".into(),
input: json!({}),
}]
);
}
#[test]
fn test_finalize_tool_none_is_noop() {
let mut content: Vec<Content> = Vec::new();
finalize_tool(None, &mut content);
assert!(content.is_empty());
}
#[tokio::test]
async fn test_anthropic_provider_simple_response() {
let mut server = Server::new_async().await;
let url = server.url();
let sse_body =
"event: message_start\ndata: {\"type\": \"message_start\", \"message\": {\"id\": \"msg_123\", \"role\": \"assistant\", \"model\": \"claude-3-5-sonnet\"}}\n\n\
event: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Hello from Mockito!\"}}\n\n\
event: message_stop\ndata: {\"type\": \"message_stop\"}\n\n";
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url(
"test_key".to_string(),
"claude-3-5-sonnet".to_string(),
url,
);
let messages = vec![Message::user("Hi")];
let response = provider.send_messages(&messages, &[]).await.unwrap();
assert_eq!(response.role, Role::Assistant);
if let Content::Text { text } = &response.content[0] {
assert_eq!(text, "Hello from Mockito!");
} else {
panic!("Expected text content");
}
}
#[tokio::test]
async fn test_anthropic_provider_tool_use() {
let mut server = Server::new_async().await;
let url = server.url();
let sse_body =
"event: message_start\ndata: {\"type\": \"message_start\", \"message\": {\"id\": \"msg_tool_1\", \"role\": \"assistant\", \"model\": \"claude-3-5-sonnet\"}}\n\n\
event: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Listing \"}}\n\n\
event: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \"files in .\"}}\n\n\
event: message_stop\ndata: {\"type\": \"message_stop\"}\n\n";
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url(
"test_key".to_string(),
"claude-3-5-sonnet".to_string(),
url,
);
let messages = vec![Message::user("List files")];
let response = provider.send_messages(&messages, &[]).await.unwrap();
assert_eq!(response.role, Role::Assistant);
assert_eq!(response.content.len(), 1);
if let Content::Text { text } = &response.content[0] {
assert_eq!(text, "Listing files in .");
} else {
panic!("Expected text content, got {:?}", response.content[0]);
}
}
#[tokio::test]
async fn test_anthropic_provider_invalid_key_error() {
let mut server = Server::new_async().await;
let url = server.url();
let _m = server
.mock("POST", "/messages")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(
json!({
"type": "error",
"error": {
"type": "authentication_error",
"message": "invalid x-api-key"
}
})
.to_string(),
)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url(
"invalid_key".to_string(),
"claude-3-5-sonnet".to_string(),
url,
);
let result = provider.send_messages(&[], &[]).await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("authentication_error"),
"Error should mention auth error type"
);
assert!(
err_msg.contains("invalid x-api-key"),
"Error should contain the specific API message"
);
assert!(
err_msg.contains("401"),
"Error should contain the status code"
);
}
#[tokio::test]
async fn test_anthropic_provider_streaming_parsing() {
let mut server = Server::new_async().await;
let url = server.url();
let sse_body =
"event: message_start\ndata: {\"type\": \"message_start\", \"message\": {\"id\": \"msg_1\", \"role\": \"assistant\", \"content\": [], \"model\": \"claude-3\", \"stop_reason\": null, \"stop_sequence\": null, \"usage\": {\"input_tokens\": 1, \"output_tokens\": 1}}}\n\n\
event: content_block_start\ndata: {\"type\": \"content_block_start\", \"index\":0, \"content_block\": {\"type\": \"text\", \"text\": \"\"}}\n\n\
event: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\":0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Hello \"}}\n\n\
event: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\":0, \"delta\": {\"type\": \"text_delta\", \"text\": \"world!\"}}\n\n\
event: message_stop\ndata: {\"type\": \"message_stop\"}\n\n";
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url(
"test_key".to_string(),
"claude-3-5-sonnet".to_string(),
url,
);
let mut stream = provider.stream_messages(&[], &[]).await.unwrap();
let mut full_text = String::new();
while let Some(chunk_result) = stream.next().await {
if let Ok(ResponseChunk::TextDelta(delta)) = chunk_result {
full_text.push_str(&delta);
}
}
assert_eq!(full_text, "Hello world!");
}
#[tokio::test]
async fn test_anthropic_provider_malformed_sse() {
let mut server = Server::new_async().await;
let url = server.url();
let sse_body =
"event: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\":0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Valid\"}}\n\n\
event: content_block_delta\ndata: {MALFORMED_JSON}\n\n\
event: message_stop\ndata: {\"type\": \"message_stop\"}\n\n";
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url(
"test_key".to_string(),
"claude-3-5-sonnet".to_string(),
url,
);
let mut stream = provider.stream_messages(&[], &[]).await.unwrap();
let mut full_text = String::new();
while let Some(chunk_result) = stream.next().await {
if let Ok(ResponseChunk::TextDelta(delta)) = chunk_result {
full_text.push_str(&delta);
}
}
assert_eq!(full_text, "Valid");
}
#[tokio::test]
async fn test_anthropic_provider_sse_buffer_cap_aborts_without_separator() {
let mut server = Server::new_async().await;
let url = server.url();
let oversized = "a".repeat(9 * 1024 * 1024);
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(oversized)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url(
"test_key".to_string(),
"claude-3-5-sonnet".to_string(),
url,
);
let mut stream = provider.stream_messages(&[], &[]).await.unwrap();
let mut saw_error = false;
while let Some(chunk_result) = stream.next().await {
if let Err(e) = chunk_result {
let msg = e.to_string();
assert!(
msg.contains("buffer") || msg.contains("8 MiB") || msg.contains("limit"),
"error should mention the SSE buffer cap, got: {}",
msg
);
saw_error = true;
break;
}
}
assert!(
saw_error,
"oversized separator-less stream must abort with an error"
);
}
#[tokio::test]
async fn test_anthropic_provider_streaming_assembles_tool_use() {
let mut server = Server::new_async().await;
let url = server.url();
let sse_body = concat!(
"event: message_start\n",
"data: {\"type\": \"message_start\", \"message\": {\"id\": \"msg_tu\", \"role\": \"assistant\", \"model\": \"claude-3-5-sonnet\"}}\n\n",
"event: content_block_start\n",
"data: {\"type\": \"content_block_start\", \"index\": 0, \"content_block\": {\"type\": \"tool_use\", \"id\": \"toolu_abc\", \"name\": \"ls\", \"input\": {}}}\n\n",
"event: content_block_delta\n",
"data: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"input_json_delta\", \"partial_json\": \"{\\\"path\\\": \"}}\n\n",
"event: content_block_delta\n",
"data: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"input_json_delta\", \"partial_json\": \"\\\".\\\"}\"}}\n\n",
"event: content_block_stop\n",
"data: {\"type\": \"content_block_stop\", \"index\": 0}\n\n",
"event: message_stop\n",
"data: {\"type\": \"message_stop\"}\n\n",
);
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url(
"test_key".to_string(),
"claude-3-5-sonnet".to_string(),
url,
);
let response = provider
.send_messages(&[Message::user("list")], &[])
.await
.unwrap();
let tool = response.content.iter().find_map(|c| match c {
Content::ToolUse { id, name, input } => Some((id.clone(), name.clone(), input.clone())),
_ => None,
});
let (id, name, input) =
tool.expect("streaming response must assemble a ToolUse content block");
assert_eq!(id, "toolu_abc");
assert_eq!(name, "ls");
assert_eq!(input, serde_json::json!({"path": "."}));
let tool_count = response
.content
.iter()
.filter(|c| matches!(c, Content::ToolUse { .. }))
.count();
assert_eq!(
tool_count, 1,
"normal flow must assemble exactly one ToolUse (no double-push)"
);
}
#[tokio::test]
async fn test_missing_content_block_stop_does_not_drop_prior_tool() {
let mut server = Server::new_async().await;
let url = server.url();
let sse_body = concat!(
"event: message_start\n",
"data: {\"type\": \"message_start\", \"message\": {\"id\": \"m\", \"role\": \"assistant\", \"model\": \"x\"}}\n\n",
"event: content_block_start\n",
"data: {\"type\": \"content_block_start\", \"index\": 0, \"content_block\": {\"type\": \"tool_use\", \"id\": \"toolu_A\", \"name\": \"ls\", \"input\": {}}}\n\n",
"event: content_block_delta\n",
"data: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"input_json_delta\", \"partial_json\": \"{}\"}}\n\n",
"event: content_block_start\n",
"data: {\"type\": \"content_block_start\", \"index\": 1, \"content_block\": {\"type\": \"tool_use\", \"id\": \"toolu_B\", \"name\": \"view\", \"input\": {}}}\n\n",
"event: content_block_delta\n",
"data: {\"type\": \"content_block_delta\", \"index\": 1, \"delta\": {\"type\": \"input_json_delta\", \"partial_json\": \"{}\"}}\n\n",
"event: message_stop\n",
"data: {\"type\": \"message_stop\"}\n\n",
);
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url("k".to_string(), "x".to_string(), url);
let response = provider
.send_messages(&[Message::user("go")], &[])
.await
.unwrap();
let ids: Vec<String> = response
.content
.iter()
.filter_map(|c| match c {
Content::ToolUse { id, .. } => Some(id.clone()),
_ => None,
})
.collect();
assert_eq!(
ids,
vec!["toolu_A".to_string(), "toolu_B".to_string()],
"a missing content_block_stop must not drop the first tool"
);
}
#[tokio::test]
async fn test_tool_input_delta_chunk_carries_tool_id() {
let mut server = Server::new_async().await;
let url = server.url();
let sse_body = concat!(
"event: message_start\n",
"data: {\"type\": \"message_start\", \"message\": {\"id\": \"m\", \"role\": \"assistant\", \"model\": \"x\"}}\n\n",
"event: content_block_start\n",
"data: {\"type\": \"content_block_start\", \"index\": 0, \"content_block\": {\"type\": \"tool_use\", \"id\": \"toolu_x\", \"name\": \"ls\", \"input\": {}}}\n\n",
"event: content_block_delta\n",
"data: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"input_json_delta\", \"partial_json\": \"{}\"}}\n\n",
"event: message_stop\n",
"data: {\"type\": \"message_stop\"}\n\n",
);
let _m = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.create_async()
.await;
let provider = AnthropicProvider::with_base_url("k".to_string(), "x".to_string(), url);
let mut stream = provider
.stream_messages(&[Message::user("go")], &[])
.await
.unwrap();
let mut delta_ids = Vec::new();
while let Some(chunk) = stream.next().await {
if let Ok(ResponseChunk::ToolUseInputDelta { id, .. }) = chunk {
delta_ids.push(id);
}
}
assert_eq!(
delta_ids,
vec!["toolu_x".to_string()],
"ToolUseInputDelta chunk must carry the tool id"
);
}
#[test]
fn test_map_user_and_assistant_text() {
let v = serde_json::to_value(map_messages(&[
Message::user("hi"),
Message::assistant("yo"),
]))
.unwrap();
assert_eq!(v[0]["role"], "user");
assert_eq!(v[0]["content"], "hi");
assert_eq!(v[1]["role"], "assistant");
assert_eq!(v[1]["content"], "yo");
}
#[test]
fn test_map_parallel_tooluse_coalesced_into_one_assistant_message() {
let msgs = vec![Message {
role: Role::Assistant,
content: vec![
Content::ToolUse {
id: "c1".into(),
name: "ls".into(),
input: json!({"path": "."}),
},
Content::ToolUse {
id: "c2".into(),
name: "view".into(),
input: json!({"path": "a"}),
},
],
}];
let v = serde_json::to_value(map_messages(&msgs)).unwrap();
assert_eq!(
v.as_array().unwrap().len(),
1,
"parallel tool calls must be ONE assistant message"
);
assert_eq!(v[0]["role"], "assistant");
assert_eq!(v[0]["tool_calls"].as_array().unwrap().len(), 2);
assert_eq!(v[0]["tool_calls"][0]["id"], "c1");
assert_eq!(v[0]["tool_calls"][0]["function"]["name"], "ls");
assert_eq!(v[0]["tool_calls"][1]["id"], "c2");
}
#[test]
fn test_map_assistant_text_plus_tooluse_one_message() {
let msgs = vec![Message {
role: Role::Assistant,
content: vec![
Content::Text {
text: "calling".into(),
},
Content::ToolUse {
id: "c1".into(),
name: "ls".into(),
input: json!({}),
},
],
}];
let v = serde_json::to_value(map_messages(&msgs)).unwrap();
assert_eq!(v.as_array().unwrap().len(), 1);
assert_eq!(v[0]["content"], "calling");
assert_eq!(v[0]["tool_calls"][0]["id"], "c1");
}
#[test]
fn test_map_toolresult_becomes_tool_role() {
let msgs = vec![Message {
role: Role::User,
content: vec![Content::ToolResult {
tool_use_id: "c1".into(),
content: "out".into(),
is_error: false,
}],
}];
let v = serde_json::to_value(map_messages(&msgs)).unwrap();
assert_eq!(v[0]["role"], "tool");
assert_eq!(v[0]["tool_call_id"], "c1");
assert_eq!(v[0]["content"], "out");
}
#[test]
fn test_map_user_text_before_toolresult_preserves_order() {
let msgs = vec![Message {
role: Role::User,
content: vec![
Content::Text { text: "ctx".into() },
Content::ToolResult {
tool_use_id: "c1".into(),
content: "out".into(),
is_error: false,
},
],
}];
let v = serde_json::to_value(map_messages(&msgs)).unwrap();
assert_eq!(v.as_array().unwrap().len(), 2);
assert_eq!(v[0]["role"], "user");
assert_eq!(v[0]["content"], "ctx");
assert_eq!(v[1]["role"], "tool");
assert_eq!(v[1]["tool_call_id"], "c1");
}
#[tokio::test]
async fn test_anthropic_provider_retry_on_429() {
let mut server = Server::new_async().await;
let url = server.url();
let _m1 = server
.mock("POST", "/messages")
.with_status(429)
.with_header("content-type", "application/json")
.with_body(
json!({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": "Too many requests"
}
})
.to_string(),
)
.expect(1)
.create_async()
.await;
let sse_body =
"event: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\":0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Recovered!\"}}\n\n\
event: message_stop\ndata: {\"type\": \"message_stop\"}\n\n";
let _m2 = server
.mock("POST", "/messages")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse_body)
.expect(1)
.create_async()
.await;
let provider =
AnthropicProvider::with_base_url("test_key".to_string(), "test-model".to_string(), url);
let response = provider.send_messages(&[], &[]).await.unwrap();
assert_eq!(response.role, Role::Assistant);
if let Content::Text { text } = &response.content[0] {
assert_eq!(text, "Recovered!");
}
}
#[tokio::test]
async fn test_openai_streams_text_finalizes_on_done() {
let mut server = Server::new_async().await;
let url = server.url();
let sse = concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"Hello \"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"content\":\"world!\"},\"finish_reason\":null}]}\n\n",
"data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
"data: [DONE]\n\n"
);
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let r = p.send_messages(&[Message::user("hi")], &[]).await.unwrap();
assert_eq!(
r.content,
vec![Content::Text {
text: "Hello world!".into()
}]
);
}
#[tokio::test]
async fn test_openai_finalizes_without_done_sentinel() {
let mut server = Server::new_async().await;
let url = server.url();
let sse =
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"},\"finish_reason\":\"stop\"}]}\n\n";
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let r = p.send_messages(&[Message::user("hi")], &[]).await.unwrap();
assert_eq!(r.content, vec![Content::Text { text: "hi".into() }]);
}
#[tokio::test]
async fn test_openai_swallows_malformed_line() {
let mut server = Server::new_async().await;
let url = server.url();
let sse = concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"Valid\"}}]}\n\n",
"data: {MALFORMED}\n\n",
"data: [DONE]\n\n"
);
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let r = p.send_messages(&[Message::user("hi")], &[]).await.unwrap();
assert_eq!(
r.content,
vec![Content::Text {
text: "Valid".into()
}]
);
}
#[tokio::test]
async fn test_openai_http_error_surfaces() {
let mut server = Server::new_async().await;
let url = server.url();
let _m = server
.mock("POST", "/chat/completions")
.with_status(401)
.with_body("{\"error\":{\"message\":\"bad key\"}}")
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
assert!(p
.send_messages(&[Message::user("hi")], &[])
.await
.unwrap_err()
.to_string()
.contains("401"));
}
#[tokio::test]
async fn test_openai_assembles_fragmented_tool_call() {
let mut server = Server::new_async().await;
let url = server.url();
let sse = concat!(
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_x\",\"function\":{\"name\":\"ls\",\"arguments\":\"{\\\"path\\\":\"}}]}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\".\\\"}\"}}]}}]}\n\n",
"data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
"data: [DONE]\n\n"
);
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let r = p
.send_messages(&[Message::user("list")], &[])
.await
.unwrap();
let tool = r
.content
.iter()
.find_map(|c| match c {
Content::ToolUse { id, name, input } => {
Some((id.clone(), name.clone(), input.clone()))
}
_ => None,
})
.expect("must assemble a ToolUse");
assert_eq!(tool.0, "call_x");
assert_eq!(tool.1, "ls");
assert_eq!(tool.2, json!({"path":"."}));
assert_eq!(
r.content
.iter()
.filter(|c| matches!(c, Content::ToolUse { .. }))
.count(),
1
);
}
#[tokio::test]
async fn test_openai_bounds_tool_call_index() {
let mut server = Server::new_async().await;
let url = server.url();
let sse = concat!(
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":999999999,\"id\":\"x\",\"function\":{\"name\":\"ls\",\"arguments\":\"{}\"}}]}}]}\n\n",
"data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
"data: [DONE]\n\n"
);
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let r = p.send_messages(&[Message::user("x")], &[]).await.unwrap();
assert_eq!(
r.content
.iter()
.filter(|c| matches!(c, Content::ToolUse { .. }))
.count(),
0
);
}
#[tokio::test]
async fn test_openai_finalizes_on_stream_end_without_finish_or_done() {
let mut server = Server::new_async().await;
let url = server.url();
let sse = "data: {\"choices\":[{\"delta\":{\"content\":\"only\"}}]}\n\n";
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let mut stream = p
.stream_messages(&[Message::user("hi")], &[])
.await
.unwrap();
let mut done = Vec::new();
while let Some(chunk) = stream.next().await {
if let Ok(ResponseChunk::MessageDone(m)) = chunk {
done.push(m);
}
}
assert_eq!(done.len(), 1, "exactly one MessageDone on stream end");
assert_eq!(
done[0].content,
vec![Content::Text {
text: "only".into()
}]
);
}
#[tokio::test]
async fn test_openai_suppresses_post_stop_text_deltas() {
let mut server = Server::new_async().await;
let url = server.url();
let sse = concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"},\"finish_reason\":\"stop\"}]}\n\n",
"data: [DONE]\n\n",
"data: {\"choices\":[{\"delta\":{\"content\":\"ghost1\"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"content\":\"ghost2\"}}]}\n\n",
);
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let mut stream = p
.stream_messages(&[Message::user("hi")], &[])
.await
.unwrap();
#[derive(PartialEq, Eq)]
enum Kind {
Text,
Done,
}
let mut ordered: Vec<Kind> = Vec::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(ResponseChunk::TextDelta(_)) => ordered.push(Kind::Text),
Ok(ResponseChunk::MessageDone(_)) => ordered.push(Kind::Done),
_ => {}
}
}
let done_count = ordered.iter().filter(|k| **k == Kind::Done).count();
assert_eq!(done_count, 1, "exactly one MessageDone");
let done_pos = ordered
.iter()
.position(|k| *k == Kind::Done)
.expect("MessageDone present");
for (i, k) in ordered.iter().enumerate() {
if i > done_pos {
assert!(
*k != Kind::Text,
"TextDelta at index {i} arrived after MessageDone at {done_pos}"
);
}
}
}
#[tokio::test]
async fn test_openai_skips_args_fragment_when_slot_empty() {
let mut server = Server::new_async().await;
let url = server.url();
let sse = concat!(
"data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"orphan\\\":true}\"}}]}}]}\n\n",
"data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
"data: [DONE]\n\n",
);
let _m = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(sse)
.create_async()
.await;
let p = OpenAiCompatibleProvider::new(OpenAiSettings {
base_url: url,
api_key: "k".into(),
model: "m".into(),
});
let mut stream = p.stream_messages(&[Message::user("x")], &[]).await.unwrap();
let mut tool_input_delta_count = 0usize;
let mut tool_use_count = 0usize;
while let Some(chunk) = stream.next().await {
match chunk {
Ok(ResponseChunk::ToolUseInputDelta { .. }) => tool_input_delta_count += 1,
Ok(ResponseChunk::MessageDone(m)) => {
tool_use_count += m
.content
.iter()
.filter(|c| matches!(c, Content::ToolUse { .. }))
.count();
}
_ => {}
}
}
assert_eq!(
tool_input_delta_count, 0,
"no ToolUseInputDelta emitted for orphan args fragment"
);
assert_eq!(
tool_use_count, 0,
"no Content::ToolUse emitted for orphan args fragment"
);
}
}