use std::collections::HashMap;
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::ProviderError;
use crate::message::{CacheControl, Content, StopReason, Usage};
use crate::provider::{LlmProvider, Request, Response, SystemBlock};
use crate::stream::{ProviderEventStream, StreamEvent};
pub mod batch;
pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
pub(crate) const API_VERSION: &str = "2023-06-01";
pub struct Anthropic {
api_key: String,
client: reqwest::Client,
base_url: String,
}
impl Anthropic {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
client: reqwest::Client::new(),
base_url: DEFAULT_BASE_URL.to_string(),
}
}
pub fn from_env() -> Self {
let api_key =
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY env var is required");
Self::new(api_key)
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub(crate) fn messages_url(&self) -> String {
format!("{}/v1/messages", self.base_url)
}
pub(crate) fn batches_url(&self) -> String {
format!("{}/v1/messages/batches", self.base_url)
}
}
#[async_trait]
impl LlmProvider for Anthropic {
async fn stream(&self, request: Request) -> Result<ProviderEventStream, ProviderError> {
let mut body = build_request_body(&request);
body.stream = true;
let response = self
.client
.post(self.messages_url())
.header("x-api-key", &self.api_key)
.header("anthropic-version", API_VERSION)
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.json(&body)
.send()
.await?;
let status = response.status().as_u16();
if status >= 400 {
let retry_after_ms = parse_retry_after(response.headers());
let text = response.text().await.unwrap_or_default();
return Err(classify_error(status, text, retry_after_ms));
}
let event_stream = response.bytes_stream().eventsource();
Ok(Box::pin(anthropic_event_stream(event_stream)))
}
async fn complete(&self, request: Request) -> Result<Response, ProviderError> {
let body = build_request_body(&request);
let response = self
.client
.post(self.messages_url())
.header("x-api-key", &self.api_key)
.header("anthropic-version", API_VERSION)
.header("content-type", "application/json")
.json(&body)
.send()
.await?;
let status = response.status().as_u16();
if status >= 400 {
let retry_after_ms = parse_retry_after(response.headers());
let text = response.text().await.unwrap_or_default();
return Err(classify_error(status, text, retry_after_ms));
}
let api_response: ApiResponse = response.json().await?;
Ok(convert_response(api_response))
}
}
pub(crate) fn classify_error(
status: u16,
message: String,
retry_after_ms: Option<u64>,
) -> ProviderError {
match status {
429 => ProviderError::RateLimit { retry_after_ms },
529 | 503 => ProviderError::Overloaded { retry_after_ms },
500 | 502 | 504 => ProviderError::Api {
status,
message,
retryable: true,
},
s => ProviderError::Api {
status: s,
message,
retryable: (500..600).contains(&s),
},
}
}
pub(crate) fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<u64> {
let raw = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?;
raw.trim().parse::<u64>().ok().map(|s| s * 1_000)
}
#[derive(Deserialize)]
#[serde(tag = "type")]
enum StreamingPayload {
#[serde(rename = "message_start")]
MessageStart {
#[serde(default)]
message: MessageStartPayload,
},
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: ContentBlockStart,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: BlockDelta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta {
delta: MessageDeltaPayload,
#[serde(default)]
usage: Option<ApiUsage>,
},
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: ErrorPayload },
}
#[derive(Deserialize, Default)]
struct MessageStartPayload {
#[serde(default)]
usage: Option<ApiUsage>,
}
#[derive(Deserialize)]
#[serde(tag = "type")]
enum ContentBlockStart {
#[serde(rename = "text")]
Text,
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
#[serde(default)]
input: Value,
},
}
#[derive(Deserialize)]
#[serde(tag = "type")]
enum BlockDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[derive(Deserialize)]
struct MessageDeltaPayload {
#[serde(default)]
stop_reason: Option<String>,
}
#[derive(Deserialize)]
struct ErrorPayload {
#[serde(default, rename = "type")]
kind: String,
#[serde(default)]
message: String,
}
enum BlockState {
Text,
ToolUse {
id: String,
name: String,
json_buf: String,
},
}
struct StreamState<S> {
sse: S,
blocks: HashMap<usize, BlockState>,
buffer: std::collections::VecDeque<Result<StreamEvent, ProviderError>>,
usage: Usage,
}
fn anthropic_event_stream<S>(
sse: S,
) -> impl futures::Stream<Item = Result<StreamEvent, ProviderError>>
where
S: futures::Stream<
Item = Result<
eventsource_stream::Event,
eventsource_stream::EventStreamError<reqwest::Error>,
>,
> + Send
+ Unpin
+ 'static,
{
use std::collections::VecDeque;
let initial = StreamState {
sse,
blocks: HashMap::new(),
buffer: VecDeque::new(),
usage: Usage::default(),
};
futures::stream::unfold(initial, |mut state| async move {
loop {
if let Some(ev) = state.buffer.pop_front() {
return Some((ev, state));
}
let next = state.sse.next().await?;
let event = match next {
Ok(ev) => ev,
Err(e) => {
let err = ProviderError::Other(format!("SSE read error: {e}"));
return Some((Err(err), state));
}
};
let payload: StreamingPayload = match serde_json::from_str(&event.data) {
Ok(p) => p,
Err(_) => continue, };
process_payload(
payload,
&mut state.blocks,
&mut state.buffer,
&mut state.usage,
);
}
})
}
fn process_payload(
payload: StreamingPayload,
blocks: &mut HashMap<usize, BlockState>,
buffer: &mut std::collections::VecDeque<Result<StreamEvent, ProviderError>>,
running: &mut Usage,
) {
match payload {
StreamingPayload::MessageStart { message } => {
if let Some(usage) = message.usage {
running.merge_max(&usage_from_api(&usage));
buffer.push_back(Ok(StreamEvent::Usage(running.clone())));
}
}
StreamingPayload::ContentBlockStart {
index,
content_block,
} => {
let state = match content_block {
ContentBlockStart::Text => BlockState::Text,
ContentBlockStart::ToolUse { id, name, input } => BlockState::ToolUse {
id,
name,
json_buf: if input.is_null() || input == Value::Object(Default::default()) {
String::new()
} else {
serde_json::to_string(&input).unwrap_or_default()
},
},
};
blocks.insert(index, state);
}
StreamingPayload::ContentBlockDelta { index, delta } => match delta {
BlockDelta::TextDelta { text } => {
buffer.push_back(Ok(StreamEvent::ContentDelta(text)));
}
BlockDelta::InputJsonDelta { partial_json } => {
if let Some(BlockState::ToolUse { json_buf, .. }) = blocks.get_mut(&index) {
json_buf.push_str(&partial_json);
}
}
},
StreamingPayload::ContentBlockStop { index } => {
if let Some(BlockState::ToolUse { id, name, json_buf }) = blocks.remove(&index) {
let input: Value = if json_buf.trim().is_empty() {
Value::Object(Default::default())
} else {
serde_json::from_str(&json_buf)
.unwrap_or_else(|_| Value::String(json_buf.clone()))
};
buffer.push_back(Ok(StreamEvent::ToolUse { id, name, input }));
}
}
StreamingPayload::MessageDelta { delta, usage } => {
if let Some(stop) = delta.stop_reason {
let stop_reason = map_stop_reason(&stop);
buffer.push_back(Ok(StreamEvent::MessageDelta { stop_reason }));
}
if let Some(u) = usage {
running.merge_max(&usage_from_api(&u));
buffer.push_back(Ok(StreamEvent::Usage(running.clone())));
}
}
StreamingPayload::MessageStop => {
buffer.push_back(Ok(StreamEvent::Done));
}
StreamingPayload::Ping => {
}
StreamingPayload::Error { error } => {
buffer.push_back(Err(ProviderError::Other(format!(
"anthropic stream error ({}): {}",
error.kind, error.message
))));
}
}
}
fn map_stop_reason(s: &str) -> StopReason {
match s {
"end_turn" => StopReason::EndTurn,
"tool_use" => StopReason::ToolUse,
"max_tokens" => StopReason::MaxTokens,
"stop_sequence" => StopReason::StopSequence,
"pause_turn" => StopReason::PauseTurn,
_ => StopReason::EndTurn,
}
}
pub(crate) fn usage_from_api(api: &ApiUsage) -> Usage {
Usage {
input_tokens: api.input_tokens,
output_tokens: api.output_tokens,
cache_creation_input_tokens: api.cache_creation_input_tokens,
cache_read_input_tokens: api.cache_read_input_tokens,
}
}
#[derive(Serialize)]
pub(crate) struct ApiRequest {
pub(crate) model: String,
pub(crate) max_tokens: u32,
pub(crate) messages: Vec<ApiMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) system: Option<Vec<ApiSystemBlock>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub(crate) tools: Vec<ApiTool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) temperature: Option<f32>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub(crate) stream: bool,
}
#[derive(Serialize)]
pub(crate) struct ApiMessage {
pub(crate) role: String,
pub(crate) content: Vec<ApiContent>,
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
pub(crate) enum ApiContent {
#[serde(rename = "text")]
Text {
text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_error: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
#[derive(Serialize)]
pub(crate) struct ApiSystemBlock {
#[serde(rename = "type")]
pub(crate) kind: &'static str,
pub(crate) text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) cache_control: Option<CacheControl>,
}
#[derive(Serialize)]
pub(crate) struct ApiTool {
pub(crate) name: String,
pub(crate) description: String,
pub(crate) input_schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) cache_control: Option<CacheControl>,
}
#[derive(Deserialize)]
pub(crate) struct ApiResponse {
pub(crate) content: Vec<ApiContent>,
pub(crate) stop_reason: String,
pub(crate) usage: ApiUsage,
}
#[derive(Deserialize)]
pub(crate) struct ApiUsage {
pub(crate) input_tokens: u32,
pub(crate) output_tokens: u32,
#[serde(default)]
pub(crate) cache_creation_input_tokens: u32,
#[serde(default)]
pub(crate) cache_read_input_tokens: u32,
}
pub(crate) fn build_request_body(request: &Request) -> ApiRequest {
let messages = request
.messages
.iter()
.map(|msg| ApiMessage {
role: match msg.role {
crate::message::Role::User => "user".to_string(),
crate::message::Role::Assistant => "assistant".to_string(),
},
content: msg.content.iter().map(content_to_api).collect(),
})
.collect();
let tools = request
.tools
.iter()
.map(|t| ApiTool {
name: t.name.clone(),
description: t.description.clone(),
input_schema: t.input_schema.clone(),
cache_control: t.cache_control.clone(),
})
.collect();
ApiRequest {
model: request.model.clone(),
max_tokens: request.max_tokens,
messages,
system: request.system.as_ref().map(|blocks| {
blocks
.iter()
.map(|b: &SystemBlock| ApiSystemBlock {
kind: "text",
text: b.text.clone(),
cache_control: b.cache_control.clone(),
})
.collect()
}),
tools,
temperature: request.temperature,
stream: false,
}
}
fn content_to_api(content: &Content) -> ApiContent {
match content {
Content::Text {
text,
cache_control,
} => ApiContent::Text {
text: text.clone(),
cache_control: cache_control.clone(),
},
Content::ToolUse { id, name, input } => ApiContent::ToolUse {
id: id.clone(),
name: name.clone(),
input: input.clone(),
},
Content::ToolResult {
tool_use_id,
content,
is_error,
cache_control,
} => ApiContent::ToolResult {
tool_use_id: tool_use_id.clone(),
content: content.clone(),
is_error: *is_error,
cache_control: cache_control.clone(),
},
}
}
pub(crate) fn convert_response(api: ApiResponse) -> Response {
let content = api
.content
.into_iter()
.map(|c| match c {
ApiContent::Text {
text,
cache_control,
} => Content::Text {
text,
cache_control,
},
ApiContent::ToolUse { id, name, input } => Content::ToolUse { id, name, input },
ApiContent::ToolResult {
tool_use_id,
content,
is_error,
cache_control,
} => Content::ToolResult {
tool_use_id,
content,
is_error,
cache_control,
},
})
.collect();
let stop_reason = map_stop_reason(&api.stop_reason);
Response {
content,
stop_reason,
usage: usage_from_api(&api.usage),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{CacheTtl, Message};
use crate::provider::ToolDefinition;
use serde_json::json;
fn req_with_system(blocks: Vec<SystemBlock>) -> Request {
Request {
model: "claude-test".into(),
system: Some(blocks),
messages: vec![Message::user_text("hi")],
tools: vec![],
max_tokens: 100,
temperature: None,
}
}
fn system_blocks_json(blocks: Vec<SystemBlock>) -> serde_json::Value {
let req = req_with_system(blocks);
let body = build_request_body(&req);
serde_json::to_value(&body).unwrap()
}
#[test]
fn system_blocks_serialize_as_typed_array() {
let json = system_blocks_json(vec![
SystemBlock::text("base instructions"),
SystemBlock::cached("long stable context"),
]);
let arr = json["system"].as_array().expect("system should be array");
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["type"], "text");
assert_eq!(arr[1]["type"], "text");
}
#[test]
fn system_blocks_serialize_text_payloads() {
let json = system_blocks_json(vec![
SystemBlock::text("base instructions"),
SystemBlock::cached("long stable context"),
]);
assert_eq!(json["system"][0]["text"], "base instructions");
assert_eq!(json["system"][1]["text"], "long stable context");
}
#[test]
fn system_blocks_serialize_cache_control_only_when_set() {
let json = system_blocks_json(vec![
SystemBlock::text("base instructions"),
SystemBlock::cached("long stable context"),
]);
assert!(json["system"][0].get("cache_control").is_none());
assert_eq!(json["system"][1]["cache_control"]["type"], "ephemeral");
}
#[test]
fn system_blocks_with_one_hour_ttl_serialize_inline() {
let req = req_with_system(vec![SystemBlock::cached_1h("long-lived prefix")]);
let body = build_request_body(&req);
let json = serde_json::to_value(&body).unwrap();
let cc = &json["system"][0]["cache_control"];
assert_eq!(cc["type"], "ephemeral");
assert_eq!(cc["ttl"], "1h");
}
#[test]
fn tool_definition_cache_control_threads_to_api_tool() {
let req = Request {
model: "m".into(),
system: None,
messages: vec![Message::user_text("hi")],
tools: vec![
ToolDefinition {
name: "first".into(),
description: "first tool".into(),
input_schema: json!({"type":"object"}),
cache_control: None,
},
ToolDefinition {
name: "last".into(),
description: "last tool".into(),
input_schema: json!({"type":"object"}),
cache_control: Some(CacheControl::ephemeral()),
},
],
max_tokens: 10,
temperature: None,
};
let body = build_request_body(&req);
let json = serde_json::to_value(&body).unwrap();
let tools = json["tools"].as_array().unwrap();
assert!(tools[0].get("cache_control").is_none());
assert_eq!(tools[1]["cache_control"]["type"], "ephemeral");
}
#[test]
fn content_text_cache_control_threads_through() {
let req = Request {
model: "m".into(),
system: None,
messages: vec![Message::user(vec![Content::Text {
text: "stable user prefix".into(),
cache_control: Some(CacheControl::Ephemeral {
ttl: Some(CacheTtl::FiveMinutes),
}),
}])],
tools: vec![],
max_tokens: 10,
temperature: None,
};
let body = build_request_body(&req);
let json = serde_json::to_value(&body).unwrap();
let block = &json["messages"][0]["content"][0];
assert_eq!(block["type"], "text");
assert_eq!(block["cache_control"]["type"], "ephemeral");
assert_eq!(block["cache_control"]["ttl"], "5m");
}
#[test]
fn tool_result_cache_control_threads_through() {
let req = Request {
model: "m".into(),
system: None,
messages: vec![Message::user(vec![Content::ToolResult {
tool_use_id: "t1".into(),
content: "long output".into(),
is_error: false,
cache_control: Some(CacheControl::ephemeral()),
}])],
tools: vec![],
max_tokens: 10,
temperature: None,
};
let body = build_request_body(&req);
let json = serde_json::to_value(&body).unwrap();
let block = &json["messages"][0]["content"][0];
assert_eq!(block["type"], "tool_result");
assert_eq!(block["cache_control"]["type"], "ephemeral");
}
#[test]
fn response_with_cache_usage_parses_all_four_fields() {
let raw = json!({
"content": [{"type":"text","text":"ok"}],
"stop_reason": "end_turn",
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"cache_creation_input_tokens": 100,
"cache_read_input_tokens": 200
}
});
let api: ApiResponse = serde_json::from_value(raw).unwrap();
let resp = convert_response(api);
assert_eq!(resp.usage.input_tokens, 10);
assert_eq!(resp.usage.output_tokens, 5);
assert_eq!(resp.usage.cache_creation_input_tokens, 100);
assert_eq!(resp.usage.cache_read_input_tokens, 200);
}
#[test]
fn response_without_cache_usage_defaults_to_zero() {
let raw = json!({
"content": [{"type":"text","text":"ok"}],
"stop_reason": "end_turn",
"usage": {"input_tokens": 1, "output_tokens": 1}
});
let api: ApiResponse = serde_json::from_value(raw).unwrap();
let resp = convert_response(api);
assert_eq!(resp.usage.cache_creation_input_tokens, 0);
assert_eq!(resp.usage.cache_read_input_tokens, 0);
}
#[test]
fn streaming_usage_merges_cache_fields_across_message_start_and_delta() {
use std::collections::VecDeque;
let mut blocks: HashMap<usize, BlockState> = HashMap::new();
let mut buffer: VecDeque<Result<StreamEvent, ProviderError>> = VecDeque::new();
let mut running = Usage::default();
let start = StreamingPayload::MessageStart {
message: MessageStartPayload {
usage: Some(ApiUsage {
input_tokens: 50,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1000,
}),
},
};
process_payload(start, &mut blocks, &mut buffer, &mut running);
let delta = StreamingPayload::MessageDelta {
delta: MessageDeltaPayload {
stop_reason: Some("end_turn".into()),
},
usage: Some(ApiUsage {
input_tokens: 50,
output_tokens: 75,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}),
};
process_payload(delta, &mut blocks, &mut buffer, &mut running);
let usages: Vec<Usage> = buffer
.into_iter()
.filter_map(|r| match r.ok()? {
StreamEvent::Usage(u) => Some(u),
_ => None,
})
.collect();
assert_eq!(usages.len(), 2);
let last = usages.last().unwrap();
assert_eq!(last.input_tokens, 50);
assert_eq!(last.output_tokens, 75);
assert_eq!(last.cache_read_input_tokens, 1000);
}
}