use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use secrecy::ExposeSecret;
use serde_json::{json, Value};
use crate::client::{ChatRequest, ChatResponse, ChatStream, ChatStreamEvent};
use crate::config::Config;
use crate::error::{redact, AiError};
use crate::message::{Citation, ContentBlock, Message, Role, Usage};
use crate::thinking::ThinkingMode;
pub(crate) const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
pub(crate) const ANTHROPIC_VERSION: &str = "2023-06-01";
fn messages_url(config: &Config) -> String {
let base = config.base_url.as_ref().map_or_else(
|| DEFAULT_BASE_URL.to_string(),
|u| u.as_str().trim_end_matches('/').to_string(),
);
format!("{base}/v1/messages")
}
#[doc(hidden)]
#[must_use]
pub fn build_request_body(config: &Config, req: &ChatRequest) -> Value {
let mut body = json!({
"model": config.model,
"max_tokens": req.max_tokens.unwrap_or(1024),
"messages": serialise_messages(&req.messages, req.cache_control),
});
if let Some(system) = &req.system {
body["system"] = serialise_system(system, req.cache_control);
}
if let Some(t) = req.temperature {
body["temperature"] = json!(t);
}
if let Some(thinking) = req.thinking {
body["thinking"] = serialise_thinking(thinking);
}
body
}
fn serialise_system(system: &str, cache_control: bool) -> Value {
let mut block = json!({ "type": "text", "text": system });
if cache_control {
block["cache_control"] = json!({ "type": "ephemeral" });
}
Value::Array(vec![block])
}
fn serialise_messages(messages: &[Message], cache_control: bool) -> Value {
let mut out = Vec::with_capacity(messages.len());
for (idx, msg) in messages.iter().enumerate() {
let role = match msg.role {
Role::Assistant => "assistant",
Role::User | Role::System => "user",
};
let mut blocks = Vec::with_capacity(msg.content.len());
for (block_idx, block) in msg.content.iter().enumerate() {
let ContentBlock::Text(text) = block;
let mut json_block = json!({ "type": "text", "text": text });
if cache_control && idx == 0 && block_idx == 0 {
json_block["cache_control"] = json!({ "type": "ephemeral" });
}
blocks.push(json_block);
}
out.push(json!({ "role": role, "content": blocks }));
}
Value::Array(out)
}
fn serialise_thinking(thinking: ThinkingMode) -> Value {
#[allow(clippy::infallible_destructuring_match)]
let max_tokens = match thinking {
ThinkingMode::Budget { max_tokens } => max_tokens,
};
json!({ "type": "enabled", "budget_tokens": max_tokens })
}
pub(crate) async fn chat(
client: &reqwest::Client,
config: &Config,
req: ChatRequest,
) -> Result<ChatResponse, AiError> {
let body = build_request_body(config, &req);
let resp = client
.post(messages_url(config))
.header("x-api-key", config.api_key.expose_secret())
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| AiError::Transport(redact(&e.to_string())))?;
let resp = map_status(resp, host_for(config)).await?;
let body: Value = resp.json().await.map_err(|e| AiError::Transport(redact(&e.to_string())))?;
parse_chat_response(&body)
}
pub(crate) async fn chat_stream(
client: &reqwest::Client,
config: &Config,
req: ChatRequest,
) -> Result<ChatStream, AiError> {
let mut body = build_request_body(config, &req);
body["stream"] = json!(true);
let resp = client
.post(messages_url(config))
.header("x-api-key", config.api_key.expose_secret())
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.json(&body)
.send()
.await
.map_err(|e| AiError::Transport(redact(&e.to_string())))?;
let resp = map_status(resp, host_for(config)).await?;
let inner = resp.bytes_stream();
Ok(ChatStream::new(Box::pin(SseEventStream::new(inner))))
}
#[doc(hidden)]
pub fn parse_chat_response(body: &Value) -> Result<ChatResponse, AiError> {
let content_arr = body
.get("content")
.and_then(Value::as_array)
.ok_or_else(|| AiError::Provider(redact("missing `content` array on response")))?;
let mut text_out = String::new();
let mut citations = Vec::new();
for block in content_arr {
if block.get("type").and_then(Value::as_str) == Some("text") {
if let Some(t) = block.get("text").and_then(Value::as_str) {
text_out.push_str(t);
}
if let Some(arr) = block.get("citations").and_then(Value::as_array) {
for c in arr {
citations.push(parse_citation(c));
}
}
}
}
let usage = parse_usage(body.get("usage"));
Ok(ChatResponse {
message: Message { role: Role::Assistant, content: vec![ContentBlock::Text(text_out)] },
usage,
citations,
})
}
fn parse_citation(c: &Value) -> Citation {
Citation {
cited_text: c.get("cited_text").and_then(Value::as_str).unwrap_or_default().to_string(),
source: c
.get("document_title")
.or_else(|| c.get("file_path"))
.or_else(|| c.get("source"))
.and_then(Value::as_str)
.unwrap_or_default()
.to_string(),
start_index: c.get("start_char_index").and_then(Value::as_u64).map(|n| n as u32),
end_index: c.get("end_char_index").and_then(Value::as_u64).map(|n| n as u32),
}
}
fn parse_usage(value: Option<&Value>) -> Usage {
let Some(v) = value else { return Usage::default() };
Usage {
input_tokens: v.get("input_tokens").and_then(Value::as_u64).unwrap_or(0) as u32,
output_tokens: v.get("output_tokens").and_then(Value::as_u64).unwrap_or(0) as u32,
cache_creation_input_tokens: v
.get("cache_creation_input_tokens")
.and_then(Value::as_u64)
.unwrap_or(0) as u32,
cache_read_input_tokens: v
.get("cache_read_input_tokens")
.and_then(Value::as_u64)
.unwrap_or(0) as u32,
}
}
async fn map_status(resp: reqwest::Response, host: String) -> Result<reqwest::Response, AiError> {
let status = resp.status();
if status.is_success() {
return Ok(resp);
}
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
let retry_after = resp
.headers()
.get("retry-after")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(std::time::Duration::from_secs);
return Err(AiError::RateLimited { host, retry_after });
}
let detail = provider_error_detail(&resp.text().await.unwrap_or_default());
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
return Err(AiError::Auth(redact(&format!("{status}: {detail}"))));
}
Err(AiError::Provider(redact(&format!("status {status} from {host}: {detail}"))))
}
fn provider_error_detail(body: &str) -> String {
let trimmed = body.trim();
if trimmed.is_empty() {
return "(empty response body)".to_string();
}
if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
if let Some(msg) = v.get("error").and_then(|e| e.get("message")).and_then(Value::as_str) {
return msg.to_string();
}
if let Some(msg) = v.get("message").and_then(Value::as_str) {
return msg.to_string();
}
}
let capped: String = trimmed.chars().take(500).collect();
if capped.len() < trimmed.len() {
format!("{capped}…")
} else {
capped
}
}
fn host_for(config: &Config) -> String {
config
.base_url
.as_ref()
.and_then(|u| u.host_str().map(String::from))
.unwrap_or_else(|| "api.anthropic.com".to_string())
}
pub(crate) struct SseEventStream<S> {
inner: S,
buffer: Vec<u8>,
}
impl<S> SseEventStream<S>
where
S: Stream<Item = reqwest::Result<Bytes>> + Send + Unpin,
{
pub(crate) fn new(inner: S) -> Self {
Self { inner, buffer: Vec::with_capacity(4096) }
}
fn drain_event(&mut self) -> Option<ChatStreamEvent> {
let pos = self.buffer.windows(2).position(|w| w == b"\n\n")?;
let event_bytes = self.buffer.drain(..pos + 2).collect::<Vec<_>>();
let event = std::str::from_utf8(&event_bytes[..pos]).ok()?;
let mut data = String::new();
for line in event.split('\n') {
if let Some(rest) = line.strip_prefix("data: ") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(rest);
}
}
if data.is_empty() {
return None;
}
parse_sse_data(&data)
}
}
fn parse_sse_data(data: &str) -> Option<ChatStreamEvent> {
let v: Value = serde_json::from_str(data).ok()?;
let kind = v.get("type").and_then(Value::as_str)?;
match kind {
"content_block_delta" => {
let delta = v.get("delta")?;
match delta.get("type").and_then(Value::as_str)? {
"text_delta" => {
let token = delta.get("text").and_then(Value::as_str)?.to_string();
Some(ChatStreamEvent::Token(token))
}
"thinking_delta" => {
let token = delta.get("thinking").and_then(Value::as_str)?.to_string();
Some(ChatStreamEvent::ThinkingToken(token))
}
_ => None,
}
}
"message_delta" | "message_start" => {
None
}
"message_stop" => {
let usage = parse_usage(v.get("usage").or_else(|| v.pointer("/message/usage")));
Some(ChatStreamEvent::Done(usage))
}
"error" => {
let msg = v
.get("error")
.and_then(|e| e.get("message"))
.and_then(Value::as_str)
.unwrap_or("unknown SSE error")
.to_string();
Some(ChatStreamEvent::Error(AiError::Provider(redact(&msg))))
}
_ => None,
}
}
impl<S> Stream for SseEventStream<S>
where
S: Stream<Item = reqwest::Result<Bytes>> + Send + Unpin,
{
type Item = ChatStreamEvent;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
loop {
if !has_complete_event(&self.buffer) {
break;
}
if let Some(event) = self.drain_event() {
return Poll::Ready(Some(event));
}
}
match self.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(chunk))) => {
self.buffer.extend_from_slice(&chunk);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(ChatStreamEvent::Error(AiError::Transport(redact(
&e.to_string(),
)))));
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
fn has_complete_event(buf: &[u8]) -> bool {
buf.windows(2).any(|w| w == b"\n\n")
}
#[async_trait]
pub(crate) trait AnthropicTransport: Send + Sync {
async fn chat(&self, config: &Config, req: ChatRequest) -> Result<ChatResponse, AiError>;
async fn chat_stream(&self, config: &Config, req: ChatRequest) -> Result<ChatStream, AiError>;
}
pub(crate) struct ReqwestAnthropic {
client: Arc<reqwest::Client>,
}
impl ReqwestAnthropic {
pub(crate) const fn new(client: Arc<reqwest::Client>) -> Self {
Self { client }
}
}
#[async_trait]
impl AnthropicTransport for ReqwestAnthropic {
async fn chat(&self, config: &Config, req: ChatRequest) -> Result<ChatResponse, AiError> {
chat(&self.client, config, req).await
}
async fn chat_stream(&self, config: &Config, req: ChatRequest) -> Result<ChatStream, AiError> {
chat_stream(&self.client, config, req).await
}
}
#[cfg(test)]
mod map_status_tests {
use super::provider_error_detail;
#[test]
fn extracts_anthropic_error_message() {
let body = r#"{"type":"error","error":{"type":"invalid_request_error","message":"Your credit balance is too low to access the Anthropic API."}}"#;
assert_eq!(
provider_error_detail(body),
"Your credit balance is too low to access the Anthropic API."
);
}
#[test]
fn falls_back_to_message_field() {
assert_eq!(provider_error_detail(r#"{"message":"model not found"}"#), "model not found");
}
#[test]
fn falls_back_to_raw_body_for_non_json() {
assert_eq!(provider_error_detail("upstream 502 bad gateway"), "upstream 502 bad gateway");
}
#[test]
fn handles_empty_body() {
assert_eq!(provider_error_detail(" "), "(empty response body)");
}
#[test]
fn caps_long_non_json_body_on_char_boundary() {
let body = "é".repeat(600); let out = provider_error_detail(&body);
assert!(out.ends_with('…'));
assert_eq!(out.chars().count(), 501); }
}