use async_trait::async_trait;
use futures::StreamExt;
use reqwest::Client;
use serde::Serialize;
use crate::error::ProviderError;
use crate::provider::Provider;
use crate::types::{
ContentBlock, FinishReason, ModelId, Request, Response, Role, StreamChunk, StreamResponse,
ToolDefinition, ToolUse, Usage,
};
const API_BASE: &str = "https://api.anthropic.com";
const API_VERSION: &str = "2023-06-01";
pub struct AnthropicProvider {
client: Client,
api_key: String,
base_url: String,
model: ModelId,
}
impl AnthropicProvider {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: Client::new(),
api_key: api_key.into(),
base_url: API_BASE.to_string(),
model: ModelId::new(model),
}
}
pub fn with_base_url(
api_key: impl Into<String>,
model: impl Into<String>,
base_url: impl Into<String>,
) -> Self {
Self {
client: Client::new(),
api_key: api_key.into(),
base_url: base_url.into().trim_end_matches('/').to_string(),
model: ModelId::new(model),
}
}
fn request_builder(&self, url: &str) -> reqwest::RequestBuilder {
self.client
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", API_VERSION)
.header("content-type", "application/json")
}
}
#[async_trait]
impl Provider for AnthropicProvider {
async fn complete(&self, req: &Request) -> Result<Response, ProviderError> {
let url = format!("{}/v1/messages", self.base_url);
let body = MessagesRequest::from_request(req, &self.model, false);
let start = std::time::Instant::now();
let resp = self.request_builder(&url).json(&body).send().await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(ProviderError::Api {
status: status.as_u16(),
message: text,
});
}
let raw: serde_json::Value = resp.json().await?;
let latency = start.elapsed();
parse_messages_response(raw, &self.model, latency)
}
async fn stream(&self, req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
let url = format!("{}/v1/messages", self.base_url);
let body = MessagesRequest::from_request(req, &self.model, true);
let resp = self.request_builder(&url).json(&body).send().await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(ProviderError::Api {
status: status.as_u16(),
message: text,
});
}
let byte_stream = resp.bytes_stream();
let stream = futures::stream::unfold(
SseState {
inner: Box::pin(byte_stream),
buf: String::new(),
done: false,
},
|mut state| async move {
if state.done {
return None;
}
loop {
if let Some(chunk) = try_parse_sse_frame(&mut state.buf) {
match chunk {
SseFrame::Chunk(c) => return Some((c, state)),
SseFrame::Done(usage) => {
state.done = true;
return Some((StreamChunk::Done { usage }, state));
}
SseFrame::Skip => continue,
}
}
match state.inner.next().await {
Some(Ok(bytes)) => match std::str::from_utf8(&bytes) {
Ok(s) => state.buf.push_str(s),
Err(e) => {
state.done = true;
return Some((StreamChunk::Error(e.to_string()), state));
}
},
Some(Err(e)) => {
state.done = true;
return Some((StreamChunk::Error(e.to_string()), state));
}
None => {
state.done = true;
return Some((StreamChunk::Done { usage: None }, state));
}
}
}
},
);
Ok(Box::pin(stream))
}
fn model_id(&self) -> &ModelId {
&self.model
}
}
struct SseState {
inner:
std::pin::Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
buf: String,
done: bool,
}
#[derive(Debug)]
enum SseFrame {
Chunk(StreamChunk),
Done(Option<Usage>),
Skip,
}
fn try_parse_sse_frame(buf: &mut String) -> Option<SseFrame> {
let frame_end = buf.find("\n\n")?;
let frame: String = buf.drain(..frame_end + 2).collect();
let mut event_type = "";
let mut data = String::new();
for line in frame.lines() {
if let Some(val) = line.strip_prefix("event: ") {
event_type = val.trim();
} else if let Some(val) = line.strip_prefix("event:") {
event_type = val.trim();
} else if let Some(val) = line.strip_prefix("data: ") {
data.push_str(val);
} else if let Some(val) = line.strip_prefix("data:") {
data.push_str(val);
}
}
match event_type {
"content_block_delta" => {
let v: serde_json::Value = serde_json::from_str(&data).ok()?;
let text = v["delta"]["text"].as_str().unwrap_or("").to_string();
if text.is_empty() {
Some(SseFrame::Skip)
} else {
Some(SseFrame::Chunk(StreamChunk::Delta(text)))
}
}
"message_delta" => {
let v: serde_json::Value = serde_json::from_str(&data).ok()?;
let output_tokens = v["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32;
Some(SseFrame::Done(Some(Usage {
input_tokens: 0, output_tokens,
})))
}
"message_stop" => Some(SseFrame::Skip),
"message_start" | "content_block_start" | "content_block_stop" | "ping" => {
Some(SseFrame::Skip)
}
"error" => {
let v: serde_json::Value = serde_json::from_str(&data).ok()?;
let msg = v["error"]["message"]
.as_str()
.unwrap_or("unknown error")
.to_string();
Some(SseFrame::Chunk(StreamChunk::Error(msg)))
}
_ => Some(SseFrame::Skip),
}
}
#[derive(Serialize)]
struct MessagesRequest {
model: String,
messages: Vec<ApiMessage>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Vec::is_empty")]
stop_sequences: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<ToolDefinition>,
stream: bool,
}
#[derive(Serialize)]
struct ApiMessage {
role: String,
content: serde_json::Value,
}
fn blocks_to_anthropic(content: &[ContentBlock]) -> serde_json::Value {
let arr: Vec<serde_json::Value> = content
.iter()
.map(|b| match b {
ContentBlock::Text { text } => serde_json::json!({ "type": "text", "text": text }),
ContentBlock::ToolUse { id, name, input } => serde_json::json!({
"type": "tool_use",
"id": id,
"name": name,
"input": input,
}),
ContentBlock::ToolResult {
tool_use_id,
content,
is_error,
} => {
let mut v = serde_json::json!({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
});
if *is_error {
v["is_error"] = serde_json::Value::Bool(true);
}
v
}
})
.collect();
serde_json::Value::Array(arr)
}
impl MessagesRequest {
fn from_request(req: &Request, default_model: &ModelId, stream: bool) -> Self {
let model = if req.model.as_str() == "default" {
default_model.as_str().to_string()
} else {
req.model.as_str().to_string()
};
let max_tokens = req.max_tokens.unwrap_or(4096);
let mut system_parts: Vec<String> = Vec::new();
if let Some(s) = &req.system {
system_parts.push(s.clone());
}
let mut messages: Vec<ApiMessage> = Vec::new();
for m in &req.messages {
match m.role {
Role::System => system_parts.push(m.text()),
Role::User => messages.push(ApiMessage {
role: "user".into(),
content: blocks_to_anthropic(&m.content),
}),
Role::Assistant => messages.push(ApiMessage {
role: "assistant".into(),
content: blocks_to_anthropic(&m.content),
}),
}
}
let system = if system_parts.is_empty() {
None
} else {
Some(system_parts.join("\n"))
};
Self {
model,
messages,
max_tokens,
system,
temperature: req.temperature,
stop_sequences: req.stop.clone(),
tools: req.tools.clone(),
stream,
}
}
}
fn parse_messages_response(
raw: serde_json::Value,
default_model: &ModelId,
latency: std::time::Duration,
) -> Result<Response, ProviderError> {
let mut content = String::new();
let mut tool_calls: Vec<ToolUse> = Vec::new();
if let Some(blocks) = raw["content"].as_array() {
for b in blocks {
match b["type"].as_str() {
Some("text") => content.push_str(b["text"].as_str().unwrap_or("")),
Some("tool_use") => tool_calls.push(ToolUse {
id: b["id"].as_str().unwrap_or("").to_string(),
name: b["name"].as_str().unwrap_or("").to_string(),
input: b["input"].clone(),
}),
_ => {}
}
}
}
let stop_reason = raw["stop_reason"].as_str().unwrap_or("end_turn");
let finish_reason = match stop_reason {
"end_turn" | "stop_sequence" => FinishReason::Stop,
"max_tokens" => FinishReason::MaxTokens,
"tool_use" => FinishReason::ToolUse,
other => FinishReason::Other(other.into()),
};
let model_str = raw["model"].as_str().unwrap_or(default_model.as_str());
let usage = Usage {
input_tokens: raw["usage"]["input_tokens"].as_u64().unwrap_or(0) as u32,
output_tokens: raw["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32,
};
Ok(Response {
content,
tool_calls,
usage,
model: ModelId::new(model_str),
finish_reason,
latency,
raw,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
use futures::StreamExt;
use std::time::Duration;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn parse_response_full() {
let raw = serde_json::json!({
"content": [
{ "type": "text", "text": "Hello!" }
],
"model": "claude-sonnet-4-20250514",
"stop_reason": "end_turn",
"usage": { "input_tokens": 15, "output_tokens": 8 }
});
let resp = parse_messages_response(raw, &ModelId::new("fallback"), Duration::from_secs(2))
.unwrap();
assert_eq!(resp.content, "Hello!");
assert_eq!(resp.model.as_str(), "claude-sonnet-4-20250514");
assert_eq!(resp.finish_reason, FinishReason::Stop);
assert_eq!(resp.usage.input_tokens, 15);
assert_eq!(resp.usage.output_tokens, 8);
assert_eq!(resp.latency, Duration::from_secs(2));
}
#[test]
fn parse_response_stop_sequence() {
let raw = serde_json::json!({ "stop_reason": "stop_sequence", "content": [] });
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.finish_reason, FinishReason::Stop);
}
#[test]
fn parse_response_max_tokens() {
let raw = serde_json::json!({ "stop_reason": "max_tokens", "content": [] });
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.finish_reason, FinishReason::MaxTokens);
}
#[test]
fn parse_response_other_stop() {
let raw = serde_json::json!({ "stop_reason": "custom", "content": [] });
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.finish_reason, FinishReason::Other("custom".into()));
}
#[test]
fn parse_response_missing_stop_reason() {
let raw = serde_json::json!({ "content": [] });
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.finish_reason, FinishReason::Stop); }
#[test]
fn parse_response_no_content() {
let raw = serde_json::json!({});
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.content, "");
}
#[test]
fn parse_response_mixed_blocks() {
let raw = serde_json::json!({
"content": [
{ "type": "text", "text": "A" },
{ "type": "tool_use", "id": "x" },
{ "type": "text", "text": "B" },
]
});
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.content, "AB");
}
#[test]
fn parse_response_missing_model() {
let raw = serde_json::json!({ "content": [] });
let resp = parse_messages_response(raw, &ModelId::new("fallback"), Duration::ZERO).unwrap();
assert_eq!(resp.model.as_str(), "fallback");
}
#[test]
fn parse_response_missing_usage() {
let raw = serde_json::json!({ "content": [] });
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.usage.input_tokens, 0);
assert_eq!(resp.usage.output_tokens, 0);
}
#[test]
fn sse_incomplete_frame() {
let mut buf = "event: ping\ndata: {}\n".to_string(); let original_len = buf.len();
assert!(try_parse_sse_frame(&mut buf).is_none());
assert_eq!(buf.len(), original_len); }
#[test]
fn sse_content_block_delta() {
let mut buf =
"event: content_block_delta\ndata: {\"delta\":{\"text\":\"hi\"}}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Chunk(StreamChunk::Delta(t))) => assert_eq!(t, "hi"),
other => panic!("expected Chunk(Delta), got {other:?}"),
}
assert!(buf.is_empty());
}
#[test]
fn sse_content_block_delta_empty_text() {
let mut buf =
"event: content_block_delta\ndata: {\"delta\":{\"text\":\"\"}}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Skip) => {}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn sse_message_delta() {
let mut buf =
"event: message_delta\ndata: {\"usage\":{\"output_tokens\":42}}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Done(Some(usage))) => {
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 42);
}
other => panic!("expected Done, got {other:?}"),
}
}
#[test]
fn sse_message_stop() {
let mut buf = "event: message_stop\ndata: {}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Skip) => {}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn sse_message_start() {
let mut buf = "event: message_start\ndata: {}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Skip) => {}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn sse_content_block_start() {
let mut buf = "event: content_block_start\ndata: {}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Skip) => {}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn sse_content_block_stop() {
let mut buf = "event: content_block_stop\ndata: {}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Skip) => {}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn sse_ping() {
let mut buf = "event: ping\ndata: {}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Skip) => {}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn sse_error_event() {
let mut buf =
"event: error\ndata: {\"error\":{\"message\":\"overloaded\"}}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Chunk(StreamChunk::Error(msg))) => assert_eq!(msg, "overloaded"),
other => panic!("expected Error chunk, got {other:?}"),
}
}
#[test]
fn sse_error_no_message() {
let mut buf = "event: error\ndata: {\"error\":{}}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Chunk(StreamChunk::Error(msg))) => assert_eq!(msg, "unknown error"),
other => panic!("expected Error chunk with unknown, got {other:?}"),
}
}
#[test]
fn sse_unknown_event() {
let mut buf = "event: custom_thing\ndata: {}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Skip) => {}
other => panic!("expected Skip, got {other:?}"),
}
}
#[test]
fn sse_no_space_after_colon() {
let mut buf =
"event:content_block_delta\ndata:{\"delta\":{\"text\":\"x\"}}\n\n".to_string();
match try_parse_sse_frame(&mut buf) {
Some(SseFrame::Chunk(StreamChunk::Delta(t))) => assert_eq!(t, "x"),
other => panic!("expected Delta, got {other:?}"),
}
}
#[test]
fn sse_invalid_json_returns_none() {
let mut buf = "event: content_block_delta\ndata: not-json\n\n".to_string();
assert!(try_parse_sse_frame(&mut buf).is_none());
}
#[test]
fn sse_drains_buffer() {
let mut buf = "event: ping\ndata: {}\n\nevent: message_stop\ndata: {}\n\n".to_string();
try_parse_sse_frame(&mut buf); assert!(buf.starts_with("event: message_stop"));
try_parse_sse_frame(&mut buf); assert!(buf.is_empty());
}
#[test]
fn msg_request_default_model() {
let req = Request::default();
let mr = MessagesRequest::from_request(&req, &ModelId::new("claude-3"), false);
assert_eq!(mr.model, "claude-3");
}
#[test]
fn msg_request_explicit_model() {
let req = Request {
model: ModelId::new("claude-opus"),
..Default::default()
};
let mr = MessagesRequest::from_request(&req, &ModelId::new("claude-3"), false);
assert_eq!(mr.model, "claude-opus");
}
#[test]
fn msg_request_max_tokens_default() {
let req = Request::default();
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
assert_eq!(mr.max_tokens, 4096);
}
#[test]
fn msg_request_max_tokens_explicit() {
let req = Request {
max_tokens: Some(1000),
..Default::default()
};
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
assert_eq!(mr.max_tokens, 1000);
}
#[test]
fn msg_request_system_combined() {
let req = Request {
system: Some("A".into()),
messages: vec![Message::system("B"), Message::user("hi")],
..Default::default()
};
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
assert_eq!(mr.system, Some("A\nB".into()));
assert_eq!(mr.messages.len(), 1);
assert_eq!(mr.messages[0].role, "user");
}
#[test]
fn msg_request_no_system() {
let req = Request {
messages: vec![Message::user("hi")],
..Default::default()
};
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
assert!(mr.system.is_none());
}
#[test]
fn msg_request_system_messages_filtered() {
let req = Request {
messages: vec![
Message::system("sys"),
Message::user("usr"),
Message::assistant("ast"),
],
..Default::default()
};
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
assert_eq!(mr.messages.len(), 2);
assert_eq!(mr.messages[0].role, "user");
assert_eq!(mr.messages[1].role, "assistant");
assert_eq!(mr.system, Some("sys".into()));
}
#[test]
fn msg_request_stream_flag() {
let req = Request::default();
assert!(MessagesRequest::from_request(&req, &ModelId::new("m"), true).stream);
assert!(!MessagesRequest::from_request(&req, &ModelId::new("m"), false).stream);
}
#[test]
fn msg_request_serializes_tools() {
let req = Request {
tools: vec![crate::types::ToolDefinition::new(
"get_weather",
"Get weather",
serde_json::json!({"type": "object", "properties": {"city": {"type": "string"}}}),
)],
messages: vec![Message::user("hi")],
..Default::default()
};
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
let body = serde_json::to_value(&mr).unwrap();
assert_eq!(body["tools"][0]["name"], "get_weather");
assert_eq!(body["tools"][0]["description"], "Get weather");
assert_eq!(body["tools"][0]["input_schema"]["type"], "object");
assert_eq!(body["messages"][0]["content"][0]["type"], "text");
assert_eq!(body["messages"][0]["content"][0]["text"], "hi");
}
#[test]
fn msg_request_no_tools_field_when_empty() {
let req = Request::default();
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
let body = serde_json::to_value(&mr).unwrap();
assert!(body.get("tools").is_none());
}
#[test]
fn msg_request_serializes_tool_result_block() {
let req = Request {
messages: vec![Message::tool_result("tu_1", "42", false)],
..Default::default()
};
let mr = MessagesRequest::from_request(&req, &ModelId::new("m"), false);
let body = serde_json::to_value(&mr).unwrap();
let block = &body["messages"][0]["content"][0];
assert_eq!(block["type"], "tool_result");
assert_eq!(block["tool_use_id"], "tu_1");
assert_eq!(block["content"], "42");
assert!(block.get("is_error").is_none());
}
#[test]
fn parse_response_tool_use() {
let raw = serde_json::json!({
"content": [
{ "type": "text", "text": "Let me check." },
{ "type": "tool_use", "id": "tu_9", "name": "get_weather", "input": {"city": "SF"} },
],
"stop_reason": "tool_use",
"usage": { "input_tokens": 5, "output_tokens": 3 }
});
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert_eq!(resp.content, "Let me check.");
assert_eq!(resp.finish_reason, FinishReason::ToolUse);
assert_eq!(resp.tool_calls.len(), 1);
assert_eq!(resp.tool_calls[0].id, "tu_9");
assert_eq!(resp.tool_calls[0].name, "get_weather");
assert_eq!(resp.tool_calls[0].input["city"], "SF");
}
#[test]
fn parse_response_no_tool_calls_when_text_only() {
let raw = serde_json::json!({
"content": [{ "type": "text", "text": "hi" }],
"stop_reason": "end_turn"
});
let resp = parse_messages_response(raw, &ModelId::new("f"), Duration::ZERO).unwrap();
assert!(resp.tool_calls.is_empty());
}
#[test]
fn new_default_base_url() {
let p = AnthropicProvider::new("key", "model");
assert_eq!(p.base_url, "https://api.anthropic.com");
}
#[test]
fn with_base_url_trims_slash() {
let p = AnthropicProvider::with_base_url("k", "m", "http://host/");
assert_eq!(p.base_url, "http://host");
}
#[test]
fn model_id_returns_configured() {
let p = AnthropicProvider::new("key", "claude-3");
assert_eq!(p.model_id().as_str(), "claude-3");
}
fn anthropic_response_json() -> serde_json::Value {
serde_json::json!({
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [{ "type": "text", "text": "Hello!" }],
"model": "claude-sonnet-4-20250514",
"stop_reason": "end_turn",
"usage": { "input_tokens": 12, "output_tokens": 6 }
})
}
#[tokio::test]
async fn complete_success() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.and(header("x-api-key", "test-key"))
.and(header("anthropic-version", "2023-06-01"))
.respond_with(ResponseTemplate::new(200).set_body_json(anthropic_response_json()))
.mount(&server)
.await;
let provider =
AnthropicProvider::with_base_url("test-key", "claude-sonnet-4-20250514", server.uri());
let resp = provider
.complete(&Request {
messages: vec![Message::user("hi")],
..Default::default()
})
.await
.unwrap();
assert_eq!(resp.content, "Hello!");
assert_eq!(resp.usage.input_tokens, 12);
assert_eq!(resp.usage.output_tokens, 6);
assert_eq!(resp.model.as_str(), "claude-sonnet-4-20250514");
assert!(resp.latency > Duration::ZERO);
}
#[tokio::test]
async fn complete_api_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(429).set_body_string("rate limited"))
.mount(&server)
.await;
let provider = AnthropicProvider::with_base_url("key", "model", server.uri());
let err = provider.complete(&Request::default()).await.unwrap_err();
match err {
ProviderError::Api { status, message } => {
assert_eq!(status, 429);
assert!(message.contains("rate limited"));
}
other => panic!("expected Api error, got {other}"),
}
}
#[tokio::test]
async fn stream_success() {
let server = MockServer::start().await;
let sse_body = [
"event: message_start\ndata: {\"type\":\"message_start\"}\n\n",
"event: content_block_start\ndata: {\"type\":\"content_block_start\"}\n\n",
"event: content_block_delta\ndata: {\"delta\":{\"text\":\"Hello\"}}\n\n",
"event: content_block_delta\ndata: {\"delta\":{\"text\":\" world\"}}\n\n",
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\"}\n\n",
"event: message_delta\ndata: {\"usage\":{\"output_tokens\":5}}\n\n",
"event: message_stop\ndata: {}\n\n",
]
.join("");
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(200).set_body_string(sse_body))
.mount(&server)
.await;
let provider = AnthropicProvider::with_base_url("key", "model", server.uri());
let mut stream = provider
.stream(&Request {
messages: vec![Message::user("hi")],
..Default::default()
})
.await
.unwrap();
let mut text = String::new();
let mut got_done = false;
while let Some(chunk) = stream.next().await {
match chunk {
StreamChunk::Delta(t) => text.push_str(&t),
StreamChunk::Done { usage } => {
got_done = true;
let u = usage.unwrap();
assert_eq!(u.output_tokens, 5);
}
StreamChunk::Error(e) => panic!("unexpected error: {e}"),
}
}
assert_eq!(text, "Hello world");
assert!(got_done);
}
#[tokio::test]
async fn stream_api_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(500).set_body_string("server error"))
.mount(&server)
.await;
let provider = AnthropicProvider::with_base_url("key", "model", server.uri());
match provider.stream(&Request::default()).await {
Err(ProviderError::Api { status, .. }) => assert_eq!(status, 500),
Err(other) => panic!("expected Api error, got {other}"),
Ok(_) => panic!("expected error"),
}
}
#[tokio::test]
async fn stream_ends_without_stop() {
let server = MockServer::start().await;
let sse_body = "event: content_block_delta\ndata: {\"delta\":{\"text\":\"hi\"}}\n\n";
Mock::given(method("POST"))
.and(path("/v1/messages"))
.respond_with(ResponseTemplate::new(200).set_body_string(sse_body))
.mount(&server)
.await;
let provider = AnthropicProvider::with_base_url("key", "model", server.uri());
let mut stream = provider
.stream(&Request {
messages: vec![Message::user("hi")],
..Default::default()
})
.await
.unwrap();
let mut chunks = Vec::new();
while let Some(chunk) = stream.next().await {
chunks.push(chunk);
}
assert!(chunks.len() >= 2);
match &chunks[0] {
StreamChunk::Delta(t) => assert_eq!(t, "hi"),
other => panic!("expected Delta, got {other:?}"),
}
match chunks.last().unwrap() {
StreamChunk::Done { usage } => assert!(usage.is_none()),
other => panic!("expected Done, got {other:?}"),
}
}
}