use anyhow::{Context, Result};
use futures::{StreamExt, TryStreamExt, future::BoxFuture};
use reqwest::Client;
use serde_json::{Value, json};
use super::{
ContentPart, LlmProvider, LlmRequest, LlmStream, Message, MessageContent, Role, StreamEvent,
TokenUsage,
};
pub const ANTHROPIC_API_BASE: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const DEFAULT_MAX_TOKENS: u32 = 8192;
pub struct AnthropicProvider {
client: Client,
api_key: String,
base_url: String,
user_agent: Option<String>,
}
impl AnthropicProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_base_url(api_key, ANTHROPIC_API_BASE)
}
pub fn with_base_url(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
Self {
client: super::http_client(),
api_key: api_key.into(),
base_url: base_url.into(),
user_agent: None,
}
}
pub fn with_user_agent(
api_key: impl Into<String>,
base_url: impl Into<String>,
user_agent: Option<String>,
) -> Self {
Self {
client: super::http_client_with_ua(user_agent.as_deref()),
api_key: api_key.into(),
base_url: base_url.into(),
user_agent,
}
}
}
impl LlmProvider for AnthropicProvider {
fn name(&self) -> &str {
"anthropic"
}
fn stream(&self, req: LlmRequest) -> BoxFuture<'_, Result<LlmStream>> {
Box::pin(async move {
super::warn_unsupported_kv_cache_mode_2(self.name(), &req);
let body = build_request_body(&req)?;
let url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
let model_for_log = req.model.clone();
let resp = self
.client
.post(&url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json")
.header(
"user-agent",
self.user_agent
.as_deref()
.unwrap_or(super::DEFAULT_USER_AGENT),
)
.json(&body)
.timeout(std::time::Duration::from_secs(120))
.send()
.await
.with_context(|| format!("Anthropic request failed (url={url})"))?;
let status = resp.status();
if !status.is_success() {
let resp_body = resp.text().await.unwrap_or_default();
let req_body_str = serde_json::to_string(&body).unwrap_or_default();
let req_body_preview = if req_body_str.len() > 4000 {
format!(
"{}...[truncated, total {} bytes]",
rsclaw_util::truncate_str(&req_body_str, 4000),
req_body_str.len()
)
} else {
req_body_str
};
tracing::warn!(
url = %url,
model = %model_for_log,
status = %status,
request_body = %req_body_preview,
response_body = %resp_body,
"Anthropic provider non-2xx response"
);
anyhow::bail!(
"Anthropic API error {status} at {url} (model={model_for_log}): {resp_body}"
);
}
let byte_stream = resp.bytes_stream();
let line_buffer = std::sync::Arc::new(tokio::sync::Mutex::new(String::new()));
let event_stream = byte_stream
.map_err(|e| anyhow::anyhow!("stream read error: {e}"))
.then(move |chunk| {
let line_buffer = line_buffer.clone();
async move { parse_sse_chunk_buffered(chunk, &line_buffer).await }
})
.flat_map(|events| futures::stream::iter(events));
let stream: LlmStream = Box::pin(event_stream);
Ok(stream)
})
}
}
fn build_request_body(req: &LlmRequest) -> Result<Value> {
let (system, messages) = split_system_messages(&req.messages, req.system.as_deref());
let mut body = json!({
"model": req.model,
"max_tokens": req.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
"stream": true,
"messages": messages,
});
if let Some(sys) = system {
body["system"] = json!(sys);
}
inject_cache_control(&mut body);
if let Some(t) = req.temperature {
body["temperature"] = super::json_f32(t);
}
if !req.tools.is_empty() {
let tools: Vec<Value> = req
.tools
.iter()
.map(|t| {
json!({
"name": crate::openai::sanitize_tool_name(&t.name),
"description": t.description,
"input_schema": t.parameters,
})
})
.collect();
body["tools"] = json!(tools);
body["tool_choice"] = json!({ "type": "auto" });
}
if let Some(budget) = req.thinking_budget
&& budget > 0
{
body["thinking"] = json!({
"type": "enabled",
"budget_tokens": budget,
});
}
Ok(body)
}
fn split_system_messages<'a>(
messages: &'a [Message],
extra_system: Option<&'a str>,
) -> (Option<String>, Vec<Value>) {
let mut system_parts: Vec<String> =
extra_system.map(|s| vec![s.to_owned()]).unwrap_or_default();
let mut conv: Vec<Value> = Vec::new();
for msg in messages {
match msg.role {
Role::System => {
if let MessageContent::Text(t) = &msg.content {
system_parts.push(t.clone());
}
}
Role::User | Role::Assistant | Role::Tool => {
conv.push(serialize_message(msg));
}
}
}
let system = if system_parts.is_empty() {
None
} else {
Some(system_parts.join("\n\n"))
};
(system, conv)
}
fn serialize_message(msg: &Message) -> Value {
let role = match msg.role {
Role::User | Role::Tool => "user",
Role::Assistant => "assistant",
Role::System => "user", };
const EMPTY_PLACEHOLDER: &str = "(empty turn)";
let content = match &msg.content {
MessageContent::Text(t) => {
if t.trim().is_empty() {
tracing::warn!(role, "empty text-content message; substituting placeholder");
json!(EMPTY_PLACEHOLDER)
} else {
json!(t)
}
}
MessageContent::Parts(parts) => {
let serialized: Vec<Value> = parts.iter().map(serialize_part).collect();
let has_meaningful_content = !serialized.is_empty()
&& serialized.iter().any(|p| {
let t = p.get("type").and_then(|v| v.as_str()).unwrap_or("");
if t == "text" {
p.get("text")
.and_then(|v| v.as_str())
.map(|s| !s.trim().is_empty())
.unwrap_or(false)
} else {
true
}
});
if has_meaningful_content {
json!(serialized)
} else {
tracing::warn!(
role,
parts_len = serialized.len(),
"all-empty parts array; substituting placeholder"
);
json!([{ "type": "text", "text": EMPTY_PLACEHOLDER }])
}
}
};
json!({ "role": role, "content": content })
}
fn serialize_part(part: &ContentPart) -> Value {
match part {
ContentPart::Text { text } => json!({ "type": "text", "text": text }),
ContentPart::Image { url } => json!({
"type": "image",
"source": { "type": "url", "url": url }
}),
ContentPart::ToolUse { id, name, input } => json!({
"type": "tool_use",
"id": id,
"name": crate::openai::sanitize_tool_name(name),
"input": input,
}),
ContentPart::ToolResult {
tool_use_id,
content,
is_error,
} => json!({
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
"is_error": is_error.unwrap_or(false),
}),
ContentPart::Reasoning { text } => json!({
"type": "text",
"text": text,
}),
}
}
fn inject_cache_control(body: &mut Value) {
let cache_marker = json!({"type": "ephemeral"});
if let Some(system_val) = body.get_mut("system") {
match system_val {
Value::String(text) => {
let block = json!([{
"type": "text",
"text": text.clone(),
"cache_control": cache_marker.clone(),
}]);
*system_val = block;
}
Value::Array(blocks) => {
if let Some(last) = blocks.last_mut() {
last["cache_control"] = cache_marker.clone();
}
}
_ => {}
}
}
if let Some(Value::Array(messages)) = body.get_mut("messages") {
let len = messages.len();
if len == 0 {
return;
}
if messages[0].get("role").and_then(|r| r.as_str()) == Some("user") {
tag_last_content_block(&mut messages[0], &cache_marker);
}
if len > 1 {
let last_idx = len - 1;
tag_last_content_block(&mut messages[last_idx], &cache_marker);
}
}
}
fn tag_last_content_block(msg: &mut Value, marker: &Value) {
let content = match msg.get_mut("content") {
Some(c) => c,
None => return,
};
match content {
Value::String(text) => {
let block = json!([{
"type": "text",
"text": text.clone(),
"cache_control": marker.clone(),
}]);
*content = block;
}
Value::Array(blocks) => {
if let Some(last) = blocks.last_mut() {
last["cache_control"] = marker.clone();
}
}
_ => {}
}
}
async fn parse_sse_chunk_buffered(
chunk: Result<bytes::Bytes>,
line_buffer: &tokio::sync::Mutex<String>,
) -> Vec<Result<StreamEvent>> {
let bytes = match chunk {
Ok(b) => b,
Err(e) => return vec![Err(e)],
};
let text = match std::str::from_utf8(&bytes) {
Ok(t) => std::borrow::Cow::Borrowed(t),
Err(e) => {
tracing::warn!(
"anthropic: UTF-8 decode error at byte {}, replacing: {}",
e.valid_up_to(),
e
);
std::borrow::Cow::Owned(String::from_utf8_lossy(&bytes).into_owned())
}
};
let mut buffer = line_buffer.lock().await;
buffer.push_str(&text);
let last_newline_pos = match buffer.rfind('\n') {
Some(pos) => pos,
None => return vec![],
};
let complete_portion = buffer[..last_newline_pos].to_owned();
let incomplete_portion = buffer[last_newline_pos + 1..].to_owned();
buffer.clear();
buffer.push_str(&incomplete_portion);
let mut events = Vec::new();
for line in complete_portion.lines() {
if let Some(data) = line
.strip_prefix("data: ")
.or_else(|| line.strip_prefix("data:"))
{
if data == "[DONE]" {
continue;
}
if let Some(event) = parse_event(data) {
events.push(Ok(event));
}
}
}
events
}
fn parse_event(data: &str) -> Option<StreamEvent> {
let v: Value = serde_json::from_str(data).ok()?;
let event_type = v["type"].as_str()?;
match event_type {
"content_block_delta" => {
let delta_type = v["delta"]["type"].as_str()?;
match delta_type {
"text_delta" => {
let text = v["delta"]["text"].as_str()?.to_owned();
Some(StreamEvent::TextDelta(text))
}
"thinking_delta" => {
let text = v["delta"]["thinking"].as_str().unwrap_or("").to_owned();
if text.is_empty() {
None
} else {
Some(StreamEvent::ReasoningDelta(text))
}
}
"input_json_delta" => {
let partial = v["delta"]["partial_json"].as_str().unwrap_or("");
if partial.is_empty() {
None
} else {
Some(StreamEvent::ToolCall {
id: String::new(),
name: String::new(),
input: Value::String(partial.to_owned()),
})
}
}
_ => None,
}
}
"content_block_start" => {
let block = &v["content_block"];
match block["type"].as_str() {
Some("tool_use") => {
Some(StreamEvent::ToolCall {
id: block["id"].as_str().unwrap_or("").to_owned(),
name: block["name"].as_str().unwrap_or("").to_owned(),
input: serde_json::Value::Object(Default::default()),
})
}
Some("thinking") => {
None
}
_ => None,
}
}
"message_delta" => {
let usage = v["usage"].as_object().map(|u| TokenUsage {
input: u.get("input_tokens").and_then(Value::as_u64).unwrap_or(0),
output: u.get("output_tokens").and_then(Value::as_u64).unwrap_or(0),
cache_creation: u
.get("cache_creation_input_tokens")
.and_then(Value::as_u64)
.unwrap_or(0),
cache_read: u
.get("cache_read_input_tokens")
.and_then(Value::as_u64)
.unwrap_or(0),
..Default::default()
});
if v["delta"]["stop_reason"].is_string() {
Some(StreamEvent::Done { usage })
} else {
None
}
}
"error" => {
let msg = v["error"]["message"]
.as_str()
.unwrap_or("unknown error")
.to_owned();
Some(StreamEvent::Error(msg))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::{
super::{LlmRequest, Message, MessageContent, Role},
*,
};
fn make_request() -> LlmRequest {
LlmRequest {
fallback_models: Vec::new(),
model: "claude-3-5-sonnet-20241022".to_owned(),
..Default::default()
}
}
#[test]
fn request_serializes_messages() {
let req = LlmRequest {
fallback_models: Vec::new(),
messages: vec![
Message {
role: Role::User,
content: MessageContent::Text("hi".to_owned()),
rsclaw_hidden: None,
},
Message {
role: Role::Assistant,
content: MessageContent::Text("hello".to_owned()),
rsclaw_hidden: None,
},
],
..make_request()
};
let body = build_request_body(&req).expect("build request body");
let msgs = body["messages"].as_array().expect("messages is array");
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0]["role"].as_str().expect("role is str"), "user");
assert_eq!(msgs[1]["role"].as_str().expect("role is str"), "assistant");
}
#[test]
fn system_field_present() {
let req = LlmRequest {
fallback_models: Vec::new(),
system: Some("hello".to_owned()),
..make_request()
};
let body = build_request_body(&req).expect("build request body");
let blocks = body["system"]
.as_array()
.expect("system should be content-block array");
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0]["text"].as_str().expect("text is str"), "hello");
assert_eq!(
blocks[0]["cache_control"]["type"]
.as_str()
.expect("cache_control type is str"),
"ephemeral"
);
}
#[test]
fn tool_choice_auto_set_when_tools_present() {
let req = LlmRequest {
fallback_models: Vec::new(),
tools: vec![crate::ToolDef {
name: "shell".to_owned(),
description: "run a shell command".to_owned(),
parameters: serde_json::json!({"type": "object"}),
}],
..make_request()
};
let body = build_request_body(&req).expect("build request body");
assert_eq!(
body["tool_choice"]["type"].as_str(),
Some("auto"),
"tool_choice must be set to auto when tools are present"
);
assert!(body["tools"].as_array().is_some_and(|t| t.len() == 1));
}
#[test]
fn tool_choice_absent_when_no_tools() {
let body = build_request_body(&make_request()).expect("build request body");
assert!(
body.get("tool_choice").is_none(),
"tool_choice must be omitted when there are no tools"
);
}
#[test]
fn cache_control_system_and_anchors() {
let req = LlmRequest {
fallback_models: Vec::new(),
system: Some("system prompt".to_owned()),
messages: vec![
Message {
role: Role::User,
content: MessageContent::Text("m1".to_owned()),
rsclaw_hidden: None,
},
Message {
role: Role::Assistant,
content: MessageContent::Text("m2".to_owned()),
rsclaw_hidden: None,
},
Message {
role: Role::User,
content: MessageContent::Text("m3".to_owned()),
rsclaw_hidden: None,
},
Message {
role: Role::Assistant,
content: MessageContent::Text("m4".to_owned()),
rsclaw_hidden: None,
},
Message {
role: Role::User,
content: MessageContent::Text("m5".to_owned()),
rsclaw_hidden: None,
},
],
..make_request()
};
let body = build_request_body(&req).expect("build request body");
let sys_blocks = body["system"].as_array().expect("system content blocks");
assert_eq!(
sys_blocks[0]["cache_control"]["type"]
.as_str()
.expect("cache_control type"),
"ephemeral"
);
let msgs = body["messages"].as_array().expect("messages is array");
assert_eq!(msgs.len(), 5);
let m1_content = &msgs[0]["content"];
assert!(
m1_content.is_array(),
"m1 should be converted to content-block array"
);
assert_eq!(
m1_content[0]["cache_control"]["type"]
.as_str()
.expect("m1 cache_control type"),
"ephemeral"
);
for i in 1..4 {
let content = &msgs[i]["content"];
if content.is_array() {
assert!(
content[0].get("cache_control").is_none(),
"message {i} should not have cache_control"
);
}
}
let m5_content = &msgs[4]["content"];
assert!(m5_content.is_array(), "m5 should be content-block array");
assert_eq!(
m5_content[0]["cache_control"]["type"]
.as_str()
.expect("m5 cache_control type"),
"ephemeral"
);
}
#[test]
fn cache_control_fewer_than_3_messages() {
let req = LlmRequest {
fallback_models: Vec::new(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("only one".to_owned()),
rsclaw_hidden: None,
}],
..make_request()
};
let body = build_request_body(&req).expect("build request body");
let msgs = body["messages"].as_array().expect("messages is array");
let content = &msgs[0]["content"];
assert!(content.is_array());
assert_eq!(
content[0]["cache_control"]["type"]
.as_str()
.expect("cache_control type"),
"ephemeral"
);
}
#[test]
fn temperature_serializes() {
let req = LlmRequest {
fallback_models: Vec::new(),
temperature: Some(0.7),
..make_request()
};
let body = build_request_body(&req).expect("build request body");
let t = body["temperature"].as_f64().expect("temperature is f64");
assert!((t - 0.7).abs() < 1e-4);
}
}