use crate::auth::unmark_anthropic_oauth_bearer_token;
use crate::error::{Error, Result};
use crate::http::client::Client;
use crate::model::{
AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
ThinkingLevel, ToolCall, Usage, UserContent,
};
use crate::models::CompatConfig;
use crate::provider::{CacheRetention, Context, Provider, StreamOptions, ToolDef};
use crate::provider_metadata::canonical_provider_id;
use crate::sse::SseStream;
use async_trait::async_trait;
use futures::StreamExt;
use futures::stream::{self, Stream};
use serde::{Deserialize, Serialize};
use std::fs;
use std::pin::Pin;
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
const ANTHROPIC_API_VERSION: &str = "2023-06-01";
const DEFAULT_MAX_TOKENS: u32 = 8192;
const ANTHROPIC_OAUTH_TOKEN_PREFIX: &str = "sk-ant-oat";
const ANTHROPIC_OAUTH_BETA_FLAGS: &str = "claude-code-20250219,oauth-2025-04-20";
const KIMI_SHARE_DIR_ENV_KEY: &str = "KIMI_SHARE_DIR";
#[inline]
fn is_anthropic_oauth_token(token: &str) -> bool {
token.contains(ANTHROPIC_OAUTH_TOKEN_PREFIX)
}
#[inline]
fn is_anthropic_provider(provider: &str) -> bool {
canonical_provider_id(provider).unwrap_or(provider) == "anthropic"
}
#[inline]
fn is_anthropic_bearer_token(provider: &str, token: &str) -> bool {
if !is_anthropic_provider(provider) {
return false;
}
let token = token.trim();
if token.is_empty() {
return false;
}
if is_anthropic_oauth_token(token) {
return true;
}
!token.starts_with("sk-ant-")
}
#[inline]
fn is_kimi_coding_provider(provider: &str) -> bool {
canonical_provider_id(provider).unwrap_or(provider) == "kimi-for-coding"
}
#[inline]
fn is_kimi_oauth_token(provider: &str, token: &str) -> bool {
is_kimi_coding_provider(provider) && !token.starts_with("sk-")
}
fn sanitize_ascii_header_value(value: &str, fallback: &str) -> String {
if value.is_ascii() && !value.trim().is_empty() {
return value.to_string();
}
let sanitized = value
.chars()
.filter(char::is_ascii)
.collect::<String>()
.trim()
.to_string();
if sanitized.is_empty() {
fallback.to_string()
} else {
sanitized
}
}
fn home_dir_with_env_lookup<F>(env_lookup: F) -> Option<std::path::PathBuf>
where
F: Fn(&str) -> Option<String>,
{
env_lookup("HOME")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.map(std::path::PathBuf::from)
.or_else(|| {
env_lookup("USERPROFILE")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.map(std::path::PathBuf::from)
})
.or_else(|| {
let drive = env_lookup("HOMEDRIVE")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())?;
let path = env_lookup("HOMEPATH")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())?;
if path.starts_with('\\') || path.starts_with('/') {
Some(std::path::PathBuf::from(format!("{drive}{path}")))
} else {
let mut combined = std::path::PathBuf::from(drive);
combined.push(path);
Some(combined)
}
})
}
fn home_dir() -> Option<std::path::PathBuf> {
home_dir_with_env_lookup(|key| std::env::var(key).ok())
}
fn kimi_share_dir_with_env_lookup<F>(env_lookup: F) -> Option<std::path::PathBuf>
where
F: Fn(&str) -> Option<String>,
{
env_lookup(KIMI_SHARE_DIR_ENV_KEY)
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.map(std::path::PathBuf::from)
.or_else(|| home_dir_with_env_lookup(env_lookup).map(|home| home.join(".kimi")))
}
fn kimi_share_dir() -> Option<std::path::PathBuf> {
kimi_share_dir_with_env_lookup(|key| std::env::var(key).ok())
}
fn kimi_device_id_paths() -> Option<(std::path::PathBuf, std::path::PathBuf)> {
let primary = kimi_share_dir()?.join("device_id");
let legacy = home_dir().map_or_else(
|| primary.clone(),
|home| home.join(".pi").join("agent").join("kimi-device-id"),
);
Some((primary, legacy))
}
fn kimi_device_id() -> String {
let generated = uuid::Uuid::new_v4().simple().to_string();
let Some((primary, legacy)) = kimi_device_id_paths() else {
return generated;
};
for path in [&primary, &legacy] {
if let Ok(existing) = fs::read_to_string(path) {
let existing = existing.trim();
if !existing.is_empty() {
return existing.to_string();
}
}
}
if let Some(parent) = primary.parent() {
let _ = fs::create_dir_all(parent);
}
let mut options = fs::OpenOptions::new();
options.write(true).create(true).truncate(true);
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
options.mode(0o600);
}
if let Ok(mut file) = options.open(&primary) {
use std::io::Write;
let _ = file.write_all(generated.as_bytes());
}
generated
}
fn kimi_common_headers() -> Vec<(String, String)> {
let device_name = std::env::var("HOSTNAME")
.ok()
.or_else(|| std::env::var("COMPUTERNAME").ok())
.unwrap_or_else(|| "unknown".to_string());
let device_model = format!("{} {}", std::env::consts::OS, std::env::consts::ARCH);
let os_version = std::env::consts::OS.to_string();
vec![
(
"X-Msh-Platform".to_string(),
sanitize_ascii_header_value("kimi_cli", "unknown"),
),
(
"X-Msh-Version".to_string(),
sanitize_ascii_header_value(env!("CARGO_PKG_VERSION"), "unknown"),
),
(
"X-Msh-Device-Name".to_string(),
sanitize_ascii_header_value(&device_name, "unknown"),
),
(
"X-Msh-Device-Model".to_string(),
sanitize_ascii_header_value(&device_model, "unknown"),
),
(
"X-Msh-Os-Version".to_string(),
sanitize_ascii_header_value(&os_version, "unknown"),
),
(
"X-Msh-Device-Id".to_string(),
sanitize_ascii_header_value(&kimi_device_id(), "unknown"),
),
]
}
pub struct AnthropicProvider {
client: Client,
model: String,
base_url: String,
provider: String,
compat: Option<CompatConfig>,
}
impl AnthropicProvider {
pub fn new(model: impl Into<String>) -> Self {
Self {
client: Client::new(),
model: model.into(),
base_url: ANTHROPIC_API_URL.to_string(),
provider: "anthropic".to_string(),
compat: None,
}
}
#[must_use]
pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
self.provider = provider.into();
self
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
#[must_use]
pub fn with_client(mut self, client: Client) -> Self {
self.client = client;
self
}
#[must_use]
pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
self.compat = compat;
self
}
pub fn build_request<'a>(
&'a self,
context: &'a Context<'_>,
options: &StreamOptions,
) -> AnthropicRequest<'a> {
let messages = context
.messages
.iter()
.map(convert_message_to_anthropic)
.collect();
let tools: Option<Vec<AnthropicTool<'_>>> = if context.tools.is_empty() {
None
} else {
Some(
context
.tools
.iter()
.map(convert_tool_to_anthropic)
.collect(),
)
};
let thinking = options.thinking_level.and_then(|level| {
if level == ThinkingLevel::Off {
None
} else {
let budget = options.thinking_budgets.as_ref().map_or_else(
|| level.default_budget(),
|b| match level {
ThinkingLevel::Off => 0,
ThinkingLevel::Minimal => b.minimal,
ThinkingLevel::Low => b.low,
ThinkingLevel::Medium => b.medium,
ThinkingLevel::High => b.high,
ThinkingLevel::XHigh => b.xhigh,
},
);
Some(AnthropicThinking {
r#type: "enabled",
budget_tokens: budget,
})
}
});
let mut max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
if let Some(t) = &thinking {
if max_tokens <= t.budget_tokens {
max_tokens = t.budget_tokens + 4096;
}
}
AnthropicRequest {
model: &self.model,
messages,
system: context.system_prompt.as_deref(),
max_tokens,
temperature: options.temperature,
tools,
stream: true,
thinking,
}
}
}
#[async_trait]
impl Provider for AnthropicProvider {
fn name(&self) -> &str {
&self.provider
}
fn api(&self) -> &'static str {
"anthropic-messages"
}
fn model_id(&self) -> &str {
&self.model
}
#[allow(clippy::too_many_lines)]
async fn stream(
&self,
context: &Context<'_>,
options: &StreamOptions,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
let raw_auth_value = options
.api_key
.clone()
.or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
.ok_or_else(|| {
Error::provider(
self.name(),
"Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var.",
)
})?;
let forced_bearer_token = if is_anthropic_provider(&self.provider) {
unmark_anthropic_oauth_bearer_token(&raw_auth_value).map(ToString::to_string)
} else {
None
};
let force_bearer = forced_bearer_token.is_some();
let auth_value = forced_bearer_token.unwrap_or(raw_auth_value);
let request_body = self.build_request(context, options);
let anthropic_bearer_token =
force_bearer || is_anthropic_bearer_token(&self.provider, &auth_value);
let kimi_oauth_token = is_kimi_oauth_token(&self.provider, &auth_value);
let mut request = self
.client
.post(&self.base_url)
.header("Accept", "text/event-stream")
.header("anthropic-version", ANTHROPIC_API_VERSION);
if anthropic_bearer_token {
request = request
.header("Authorization", format!("Bearer {auth_value}"))
.header("anthropic-dangerous-direct-browser-access", "true")
.header("x-app", "cli")
.header(
"user-agent",
format!(
"pi_agent_rust/{} (external, cli)",
env!("CARGO_PKG_VERSION")
),
);
} else if kimi_oauth_token {
request = request
.header("Authorization", format!("Bearer {auth_value}"))
.header(
"user-agent",
format!(
"pi_agent_rust/{} (kimi-oauth, cli)",
env!("CARGO_PKG_VERSION")
),
);
for (name, value) in kimi_common_headers() {
request = request.header(name, value);
}
} else {
request = request.header("X-API-Key", &auth_value);
}
let mut beta_flags: Vec<&str> = Vec::new();
if anthropic_bearer_token {
beta_flags.push(ANTHROPIC_OAUTH_BETA_FLAGS);
}
if options.cache_retention != CacheRetention::None {
beta_flags.push("prompt-caching-2024-07-31");
}
if !beta_flags.is_empty() {
request = request.header("anthropic-beta", beta_flags.join(","));
}
if let Some(compat) = &self.compat {
if let Some(custom_headers) = &compat.custom_headers {
for (key, value) in custom_headers {
request = request.header(key, value);
}
}
}
for (key, value) in &options.headers {
request = request.header(key, value);
}
let request = request.json(&request_body)?;
let response = Box::pin(request.send()).await?;
let status = response.status();
if !(200..300).contains(&status) {
let body = response
.text()
.await
.unwrap_or_else(|e| format!("<failed to read body: {e}>"));
return Err(Error::provider(
self.name(),
format!("Anthropic API error (HTTP {status}): {body}"),
));
}
let event_source = SseStream::new(response.bytes_stream());
let model = self.model.clone();
let api = self.api().to_string();
let provider = self.name().to_string();
let stream = stream::unfold(
StreamState::new(event_source, model, api, provider),
|mut state| async move {
if state.done {
return None;
}
loop {
match state.event_source.next().await {
Some(Ok(msg)) => {
if msg.event == "ping" {
} else {
match state.process_event(&msg.data) {
Ok(Some(event)) => {
if matches!(
&event,
StreamEvent::Done { .. } | StreamEvent::Error { .. }
) {
state.done = true;
}
return Some((Ok(event), state));
}
Ok(None) => {}
Err(e) => {
state.done = true;
return Some((Err(e), state));
}
}
}
}
Some(Err(e)) => {
state.done = true;
let err = Error::api(format!("SSE error: {e}"));
return Some((Err(err), state));
}
None => {
state.done = true;
let reason = state.partial.stop_reason;
let message = std::mem::take(&mut state.partial);
return Some((Ok(StreamEvent::Done { reason, message }), state));
}
}
}
},
);
Ok(Box::pin(stream))
}
}
struct StreamState<S>
where
S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
{
event_source: SseStream<S>,
partial: AssistantMessage,
current_tool_json: String,
current_tool_id: Option<String>,
current_tool_name: Option<String>,
done: bool,
}
impl<S> StreamState<S>
where
S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
{
const fn recompute_total_tokens(&mut self) {
self.partial.usage.total_tokens = self
.partial
.usage
.input
.saturating_add(self.partial.usage.output)
.saturating_add(self.partial.usage.cache_read)
.saturating_add(self.partial.usage.cache_write);
}
fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
Self {
event_source,
partial: AssistantMessage {
content: Vec::new(),
api,
provider,
model,
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: chrono::Utc::now().timestamp_millis(),
},
current_tool_json: String::new(),
current_tool_id: None,
current_tool_name: None,
done: false,
}
}
#[allow(clippy::too_many_lines)]
fn process_event(&mut self, data: &str) -> Result<Option<StreamEvent>> {
let event: AnthropicStreamEvent =
serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
match event {
AnthropicStreamEvent::MessageStart { message } => {
Ok(Some(self.handle_message_start(message)))
}
AnthropicStreamEvent::ContentBlockStart {
index,
content_block,
} => Ok(Some(self.handle_content_block_start(index, content_block))),
AnthropicStreamEvent::ContentBlockDelta { index, delta } => {
Ok(self.handle_content_block_delta(index, delta))
}
AnthropicStreamEvent::ContentBlockStop { index } => {
Ok(self.handle_content_block_stop(index))
}
AnthropicStreamEvent::MessageDelta { delta, usage } => {
self.handle_message_delta(&delta, usage);
Ok(None)
}
AnthropicStreamEvent::MessageStop => {
let reason = self.partial.stop_reason;
Ok(Some(StreamEvent::Done {
reason,
message: std::mem::take(&mut self.partial),
}))
}
AnthropicStreamEvent::Error { error } => {
self.partial.stop_reason = StopReason::Error;
self.partial.error_message = Some(error.message);
Ok(Some(StreamEvent::Error {
reason: StopReason::Error,
error: std::mem::take(&mut self.partial),
}))
}
AnthropicStreamEvent::Ping => Ok(None),
}
}
fn handle_message_start(&mut self, message: AnthropicMessageStart) -> StreamEvent {
if let Some(usage) = message.usage {
self.partial.usage.input = usage.input;
self.partial.usage.cache_read = usage.cache_read.unwrap_or(0);
self.partial.usage.cache_write = usage.cache_write.unwrap_or(0);
self.recompute_total_tokens();
}
StreamEvent::Start {
partial: self.partial.clone(),
}
}
fn handle_content_block_start(
&mut self,
index: u32,
content_block: AnthropicContentBlock,
) -> StreamEvent {
let content_index = index as usize;
match content_block {
AnthropicContentBlock::Text => {
self.partial
.content
.push(ContentBlock::Text(TextContent::new("")));
StreamEvent::TextStart { content_index }
}
AnthropicContentBlock::Thinking => {
self.partial
.content
.push(ContentBlock::Thinking(ThinkingContent {
thinking: String::new(),
thinking_signature: None,
}));
StreamEvent::ThinkingStart { content_index }
}
AnthropicContentBlock::ToolUse { id, name } => {
self.current_tool_json.clear();
self.current_tool_id = id;
self.current_tool_name = name;
self.partial.content.push(ContentBlock::ToolCall(ToolCall {
id: self.current_tool_id.clone().unwrap_or_default(),
name: self.current_tool_name.clone().unwrap_or_default(),
arguments: serde_json::Value::Null,
thought_signature: None,
}));
StreamEvent::ToolCallStart { content_index }
}
}
}
fn handle_content_block_delta(
&mut self,
index: u32,
delta: AnthropicDelta,
) -> Option<StreamEvent> {
let idx = index as usize;
match delta {
AnthropicDelta::TextDelta { text } => {
if let Some(text) = text {
if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(idx) {
t.text.push_str(&text);
}
Some(StreamEvent::TextDelta {
content_index: idx,
delta: text,
})
} else {
None
}
}
AnthropicDelta::ThinkingDelta { thinking } => {
if let Some(thinking) = thinking {
if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(idx) {
t.thinking.push_str(&thinking);
}
Some(StreamEvent::ThinkingDelta {
content_index: idx,
delta: thinking,
})
} else {
None
}
}
AnthropicDelta::InputJsonDelta { partial_json } => {
if let Some(partial_json) = partial_json {
self.current_tool_json.push_str(&partial_json);
Some(StreamEvent::ToolCallDelta {
content_index: idx,
delta: partial_json,
})
} else {
None
}
}
AnthropicDelta::SignatureDelta { signature } => {
if let Some(sig) = signature {
if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(idx) {
t.thinking_signature = Some(sig);
}
}
None
}
}
}
fn handle_content_block_stop(&mut self, index: u32) -> Option<StreamEvent> {
let idx = index as usize;
match self.partial.content.get_mut(idx) {
Some(ContentBlock::Text(t)) => {
let content = t.text.clone();
Some(StreamEvent::TextEnd {
content_index: idx,
content,
})
}
Some(ContentBlock::Thinking(t)) => {
let content = t.thinking.clone();
Some(StreamEvent::ThinkingEnd {
content_index: idx,
content,
})
}
Some(ContentBlock::ToolCall(tc)) => {
let arguments: serde_json::Value =
match serde_json::from_str(&self.current_tool_json) {
Ok(args) => args,
Err(e) => {
tracing::warn!(
error = %e,
raw = %self.current_tool_json,
"Failed to parse tool arguments as JSON"
);
serde_json::Value::Null
}
};
let tool_call = ToolCall {
id: self.current_tool_id.take().unwrap_or_default(),
name: self.current_tool_name.take().unwrap_or_default(),
arguments: arguments.clone(),
thought_signature: None,
};
tc.arguments = arguments;
self.current_tool_json.clear();
Some(StreamEvent::ToolCallEnd {
content_index: idx,
tool_call,
})
}
_ => None,
}
}
#[allow(clippy::missing_const_for_fn)]
fn handle_message_delta(
&mut self,
delta: &AnthropicMessageDelta,
usage: Option<AnthropicDeltaUsage>,
) {
if let Some(stop_reason) = delta.stop_reason {
self.partial.stop_reason = match stop_reason {
AnthropicStopReason::MaxTokens => StopReason::Length,
AnthropicStopReason::ToolUse => StopReason::ToolUse,
AnthropicStopReason::EndTurn | AnthropicStopReason::StopSequence => {
StopReason::Stop
}
};
}
if let Some(u) = usage {
self.partial.usage.output = u.output_tokens;
self.recompute_total_tokens();
}
}
}
#[derive(Debug, Serialize)]
pub struct AnthropicRequest<'a> {
model: &'a str,
messages: Vec<AnthropicMessage<'a>>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<&'a str>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool<'a>>>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<AnthropicThinking>,
}
#[derive(Debug, Serialize)]
struct AnthropicThinking {
r#type: &'static str,
budget_tokens: u32,
}
#[derive(Debug, Serialize)]
struct AnthropicMessage<'a> {
role: &'static str,
content: Vec<AnthropicContent<'a>>,
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContent<'a> {
Text {
text: &'a str,
},
Thinking {
thinking: &'a str,
signature: &'a str,
},
Image {
source: AnthropicImageSource<'a>,
},
ToolUse {
id: &'a str,
name: &'a str,
input: &'a serde_json::Value,
},
ToolResult {
tool_use_id: &'a str,
content: Vec<AnthropicToolResultContent<'a>>,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
}
#[derive(Debug, Serialize)]
struct AnthropicImageSource<'a> {
r#type: &'static str,
media_type: &'a str,
data: &'a str,
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicToolResultContent<'a> {
Text { text: &'a str },
Image { source: AnthropicImageSource<'a> },
}
#[derive(Debug, Serialize)]
struct AnthropicTool<'a> {
name: &'a str,
description: &'a str,
input_schema: &'a serde_json::Value,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicStreamEvent {
MessageStart {
message: AnthropicMessageStart,
},
ContentBlockStart {
index: u32,
content_block: AnthropicContentBlock,
},
ContentBlockDelta {
index: u32,
delta: AnthropicDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: AnthropicMessageDelta,
#[serde(default)]
usage: Option<AnthropicDeltaUsage>,
},
MessageStop,
Error {
error: AnthropicError,
},
Ping,
}
#[derive(Debug, Deserialize)]
struct AnthropicMessageStart {
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Debug, Deserialize)]
#[allow(clippy::struct_field_names)]
struct AnthropicUsage {
#[serde(rename = "input_tokens")]
input: u64,
#[serde(default, rename = "cache_read_input_tokens")]
cache_read: Option<u64>,
#[serde(default, rename = "cache_creation_input_tokens")]
cache_write: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct AnthropicDeltaUsage {
output_tokens: u64,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContentBlock {
Text,
Thinking,
ToolUse {
#[serde(default)]
id: Option<String>,
#[serde(default)]
name: Option<String>,
},
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[allow(clippy::enum_variant_names)] enum AnthropicDelta {
TextDelta {
#[serde(default)]
text: Option<String>,
},
ThinkingDelta {
#[serde(default)]
thinking: Option<String>,
},
InputJsonDelta {
#[serde(default)]
partial_json: Option<String>,
},
SignatureDelta {
#[serde(default)]
signature: Option<String>,
},
}
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "snake_case")]
enum AnthropicStopReason {
EndTurn,
MaxTokens,
ToolUse,
StopSequence,
}
#[derive(Debug, Deserialize)]
struct AnthropicMessageDelta {
#[serde(default)]
stop_reason: Option<AnthropicStopReason>,
}
#[derive(Debug, Deserialize)]
struct AnthropicError {
message: String,
}
fn convert_message_to_anthropic(message: &Message) -> AnthropicMessage<'_> {
match message {
Message::User(user) => AnthropicMessage {
role: "user",
content: convert_user_content(&user.content),
},
Message::Custom(custom) => AnthropicMessage {
role: "user",
content: vec![AnthropicContent::Text {
text: &custom.content,
}],
},
Message::Assistant(assistant) => AnthropicMessage {
role: "assistant",
content: assistant
.content
.iter()
.filter_map(convert_content_block_to_anthropic)
.collect(),
},
Message::ToolResult(result) => AnthropicMessage {
role: "user",
content: vec![AnthropicContent::ToolResult {
tool_use_id: &result.tool_call_id,
content: result
.content
.iter()
.filter_map(|block| match block {
ContentBlock::Text(t) => {
Some(AnthropicToolResultContent::Text { text: &t.text })
}
ContentBlock::Image(img) => Some(AnthropicToolResultContent::Image {
source: AnthropicImageSource {
r#type: "base64",
media_type: &img.mime_type,
data: &img.data,
},
}),
_ => None,
})
.collect(),
is_error: if result.is_error { Some(true) } else { None },
}],
},
}
}
fn convert_user_content(content: &UserContent) -> Vec<AnthropicContent<'_>> {
match content {
UserContent::Text(text) => vec![AnthropicContent::Text { text }],
UserContent::Blocks(blocks) => blocks
.iter()
.filter_map(|block| match block {
ContentBlock::Text(t) => Some(AnthropicContent::Text { text: &t.text }),
ContentBlock::Image(img) => Some(AnthropicContent::Image {
source: AnthropicImageSource {
r#type: "base64",
media_type: &img.mime_type,
data: &img.data,
},
}),
_ => None,
})
.collect(),
}
}
fn convert_content_block_to_anthropic(block: &ContentBlock) -> Option<AnthropicContent<'_>> {
match block {
ContentBlock::Text(t) => Some(AnthropicContent::Text { text: &t.text }),
ContentBlock::ToolCall(tc) => Some(AnthropicContent::ToolUse {
id: &tc.id,
name: &tc.name,
input: &tc.arguments,
}),
ContentBlock::Thinking(t) => {
t.thinking_signature
.as_ref()
.map(|sig| AnthropicContent::Thinking {
thinking: &t.thinking,
signature: sig,
})
}
ContentBlock::Image(_) => None,
}
}
fn convert_tool_to_anthropic(tool: &ToolDef) -> AnthropicTool<'_> {
AnthropicTool {
name: &tool.name,
description: &tool.description,
input_schema: &tool.parameters,
}
}
#[cfg(test)]
mod tests {
use super::*;
use asupersync::runtime::RuntimeBuilder;
use futures::{StreamExt, stream};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::json;
use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::TcpListener;
use std::path::PathBuf;
use std::sync::mpsc;
use std::time::Duration;
#[test]
fn home_dir_lookup_falls_back_to_userprofile() {
let home = home_dir_with_env_lookup(|key| match key {
"USERPROFILE" => Some("C:\\Users\\Ada".to_string()),
_ => None,
});
assert_eq!(home, Some(PathBuf::from("C:\\Users\\Ada")));
}
#[test]
fn home_dir_lookup_falls_back_to_homedrive_homepath() {
let home = home_dir_with_env_lookup(|key| match key {
"HOMEDRIVE" => Some("D:".to_string()),
"HOMEPATH" => Some("\\Users\\Grace".to_string()),
_ => None,
});
assert_eq!(home, Some(PathBuf::from("D:\\Users\\Grace")));
}
#[test]
fn test_convert_user_text_message() {
let message = Message::User(crate::model::UserMessage {
content: UserContent::Text("Hello".to_string()),
timestamp: 0,
});
let converted = convert_message_to_anthropic(&message);
assert_eq!(converted.role, "user");
assert_eq!(converted.content.len(), 1);
}
#[test]
fn test_thinking_budget() {
assert_eq!(ThinkingLevel::Minimal.default_budget(), 1024);
assert_eq!(ThinkingLevel::Low.default_budget(), 2048);
assert_eq!(ThinkingLevel::Medium.default_budget(), 8192);
assert_eq!(ThinkingLevel::High.default_budget(), 16384);
}
#[test]
fn test_build_request_includes_system_tools_and_thinking() {
let provider = AnthropicProvider::new("claude-test");
let context = Context {
system_prompt: Some("System prompt".to_string().into()),
messages: vec![Message::User(crate::model::UserMessage {
content: UserContent::Text("Ping".to_string()),
timestamp: 0,
})]
.into(),
tools: vec![ToolDef {
name: "echo".to_string(),
description: "Echo a string.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"text": { "type": "string" }
},
"required": ["text"]
}),
}]
.into(),
};
let options = StreamOptions {
max_tokens: Some(128),
temperature: Some(0.2),
thinking_level: Some(ThinkingLevel::Medium),
thinking_budgets: Some(crate::provider::ThinkingBudgets {
minimal: 1024,
low: 2048,
medium: 9000,
high: 16384,
xhigh: 32768,
}),
..Default::default()
};
let request = provider.build_request(&context, &options);
assert_eq!(request.model, "claude-test");
assert_eq!(request.system, Some("System prompt"));
assert_eq!(request.temperature, Some(0.2));
assert!(request.stream);
assert_eq!(request.max_tokens, 13_096);
let thinking = request.thinking.expect("thinking config");
assert_eq!(thinking.r#type, "enabled");
assert_eq!(thinking.budget_tokens, 9000);
assert_eq!(request.messages.len(), 1);
assert_eq!(request.messages[0].role, "user");
assert_eq!(request.messages[0].content.len(), 1);
match &request.messages[0].content[0] {
AnthropicContent::Text { text } => assert_eq!(*text, "Ping"),
other => panic!("expected text content, got {other:?}"),
}
let tools = request.tools.expect("tools");
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "echo");
assert_eq!(tools[0].description, "Echo a string.");
assert_eq!(
*tools[0].input_schema,
json!({
"type": "object",
"properties": {
"text": { "type": "string" }
},
"required": ["text"]
})
);
}
#[test]
fn test_build_request_omits_optional_fields_by_default() {
let provider = AnthropicProvider::new("claude-test");
let context = Context::default();
let options = StreamOptions::default();
let request = provider.build_request(&context, &options);
assert_eq!(request.model, "claude-test");
assert_eq!(request.system, None);
assert!(request.tools.is_none());
assert!(request.thinking.is_none());
assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
assert!(request.stream);
}
#[test]
#[allow(clippy::too_many_lines)]
fn test_stream_parses_thinking_and_tool_call_events() {
let events = vec![
json!({
"type": "message_start",
"message": { "usage": { "input_tokens": 3 } }
}),
json!({
"type": "content_block_start",
"index": 0,
"content_block": { "type": "thinking" }
}),
json!({
"type": "content_block_delta",
"index": 0,
"delta": { "type": "thinking_delta", "thinking": "step 1" }
}),
json!({
"type": "content_block_stop",
"index": 0
}),
json!({
"type": "content_block_start",
"index": 1,
"content_block": { "type": "tool_use", "id": "tool_123", "name": "search" }
}),
json!({
"type": "content_block_delta",
"index": 1,
"delta": { "type": "input_json_delta", "partial_json": "{\"q\":\"ru" }
}),
json!({
"type": "content_block_delta",
"index": 1,
"delta": { "type": "input_json_delta", "partial_json": "st\"}" }
}),
json!({
"type": "content_block_stop",
"index": 1
}),
json!({
"type": "content_block_start",
"index": 2,
"content_block": { "type": "text" }
}),
json!({
"type": "content_block_delta",
"index": 2,
"delta": { "type": "text_delta", "text": "done" }
}),
json!({
"type": "content_block_stop",
"index": 2
}),
json!({
"type": "message_delta",
"delta": { "stop_reason": "tool_use" },
"usage": { "output_tokens": 5 }
}),
json!({
"type": "message_stop"
}),
];
let out = collect_events(&events);
assert_eq!(out.len(), 12, "expected full stream event sequence");
assert!(matches!(&out[0], StreamEvent::Start { .. }));
assert!(matches!(
&out[1],
StreamEvent::ThinkingStart {
content_index: 0,
..
}
));
assert!(matches!(
&out[2],
StreamEvent::ThinkingDelta {
content_index: 0,
delta,
..
} if delta == "step 1"
));
assert!(matches!(
&out[3],
StreamEvent::ThinkingEnd {
content_index: 0,
content,
..
} if content == "step 1"
));
assert!(matches!(
&out[4],
StreamEvent::ToolCallStart {
content_index: 1,
..
}
));
assert!(matches!(
&out[5],
StreamEvent::ToolCallDelta {
content_index: 1,
delta,
..
} if delta == "{\"q\":\"ru"
));
assert!(matches!(
&out[6],
StreamEvent::ToolCallDelta {
content_index: 1,
delta,
..
} if delta == "st\"}"
));
if let StreamEvent::ToolCallEnd {
content_index,
tool_call,
..
} = &out[7]
{
assert_eq!(*content_index, 1);
assert_eq!(tool_call.id, "tool_123");
assert_eq!(tool_call.name, "search");
assert_eq!(tool_call.arguments, json!({ "q": "rust" }));
} else {
panic!("expected ToolCallEnd event, got {:?}", out[7]);
}
assert!(matches!(
&out[8],
StreamEvent::TextStart {
content_index: 2,
..
}
));
assert!(matches!(
&out[9],
StreamEvent::TextDelta {
content_index: 2,
delta,
..
} if delta == "done"
));
assert!(matches!(
&out[10],
StreamEvent::TextEnd {
content_index: 2,
content,
..
} if content == "done"
));
if let StreamEvent::Done { reason, message } = &out[11] {
assert_eq!(*reason, StopReason::ToolUse);
assert_eq!(message.stop_reason, StopReason::ToolUse);
} else {
panic!("expected Done event, got {:?}", out[11]);
}
}
#[test]
fn test_message_delta_sets_length_stop_reason_and_usage() {
let events = vec![
json!({
"type": "message_start",
"message": { "usage": { "input_tokens": 5 } }
}),
json!({
"type": "message_delta",
"delta": { "stop_reason": "max_tokens" },
"usage": { "output_tokens": 7 }
}),
json!({
"type": "message_stop"
}),
];
let out = collect_events(&events);
assert_eq!(out.len(), 2);
if let StreamEvent::Done { reason, message } = &out[1] {
assert_eq!(*reason, StopReason::Length);
assert_eq!(message.stop_reason, StopReason::Length);
assert_eq!(message.usage.input, 5);
assert_eq!(message.usage.output, 7);
assert_eq!(message.usage.total_tokens, 12);
} else {
panic!("expected Done event, got {:?}", out[1]);
}
}
#[test]
fn test_usage_total_tokens_saturates_on_large_values() {
let events = vec![
json!({
"type": "message_start",
"message": {
"usage": {
"input_tokens": u64::MAX,
"cache_read_input_tokens": 1,
"cache_creation_input_tokens": 1
}
}
}),
json!({
"type": "message_delta",
"delta": { "stop_reason": "end_turn" },
"usage": { "output_tokens": 1 }
}),
json!({
"type": "message_stop"
}),
];
let out = collect_events(&events);
assert_eq!(out.len(), 2);
if let StreamEvent::Done { message, .. } = &out[1] {
assert_eq!(message.usage.total_tokens, u64::MAX);
} else {
panic!("expected Done event, got {:?}", out[1]);
}
}
#[derive(Debug, Deserialize)]
struct ProviderFixture {
cases: Vec<ProviderCase>,
}
#[derive(Debug, Deserialize)]
struct ProviderCase {
name: String,
events: Vec<Value>,
expected: Vec<EventSummary>,
}
#[derive(Debug, Deserialize, Serialize, PartialEq)]
struct EventSummary {
kind: String,
#[serde(default)]
content_index: Option<usize>,
#[serde(default)]
delta: Option<String>,
#[serde(default)]
content: Option<String>,
#[serde(default)]
reason: Option<String>,
}
#[test]
fn test_stream_fixtures() {
let fixture = load_fixture("anthropic_stream.json");
for case in fixture.cases {
let events = collect_events(&case.events);
let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
assert_eq!(summaries, case.expected, "case {}", case.name);
}
}
#[test]
fn test_stream_error_event_maps_to_stop_reason_error() {
let events = vec![json!({
"type": "error",
"error": { "message": "nope" }
})];
let out = collect_events(&events);
assert_eq!(out.len(), 1);
assert!(
matches!(&out[0], StreamEvent::Error { .. }),
"expected StreamEvent::Error, got {:?}",
out[0]
);
if let StreamEvent::Error { reason, error } = &out[0] {
assert_eq!(*reason, StopReason::Error);
assert_eq!(error.stop_reason, StopReason::Error);
assert_eq!(error.error_message.as_deref(), Some("nope"));
}
}
#[test]
fn test_stream_emits_single_done_when_transport_ends_after_message_stop() {
let out = collect_stream_items_from_body(&success_sse_body());
let done_count = out
.iter()
.filter(|item| matches!(item, Ok(StreamEvent::Done { .. })))
.count();
assert_eq!(done_count, 1, "expected exactly one terminal Done event");
}
#[test]
fn test_stream_error_event_is_terminal() {
let body = [
r#"data: {"type":"error","error":{"message":"boom"}}"#,
"",
r#"data: {"type":"message_stop"}"#,
"",
]
.join("\n");
let out = collect_stream_items_from_body(&body);
assert_eq!(out.len(), 1, "Error should terminate the stream");
assert!(matches!(out[0], Ok(StreamEvent::Error { .. })));
}
#[test]
fn test_stream_parse_error_is_terminal() {
let body = [
r#"data: {"type":"message_start","message":{"usage":{"input_tokens":1}}}"#,
"",
r"data: {invalid-json}",
"",
r#"data: {"type":"message_stop"}"#,
"",
]
.join("\n");
let out = collect_stream_items_from_body(&body);
assert_eq!(out.len(), 2, "parse error should stop further events");
assert!(matches!(out[0], Ok(StreamEvent::Start { .. })));
match &out[1] {
Ok(event) => panic!("expected parse error item, got event: {event:?}"),
Err(err) => assert!(err.to_string().contains("JSON parse error")),
}
}
#[test]
fn test_stream_sets_required_headers() {
let captured = run_stream_and_capture_headers(CacheRetention::None)
.expect("captured request for required headers");
assert_eq!(
captured.headers.get("x-api-key").map(String::as_str),
Some("sk-ant-test-key")
);
assert_eq!(
captured
.headers
.get("anthropic-version")
.map(String::as_str),
Some(ANTHROPIC_API_VERSION)
);
assert!(!captured.headers.contains_key("anthropic-beta"));
assert!(captured.body.contains("\"stream\":true"));
}
#[test]
fn test_stream_adds_prompt_caching_beta_header_when_enabled() {
let captured = run_stream_and_capture_headers(CacheRetention::Short)
.expect("captured request for beta header");
assert_eq!(
captured.headers.get("anthropic-beta").map(String::as_str),
Some("prompt-caching-2024-07-31")
);
}
#[test]
fn test_stream_uses_oauth_bearer_auth_headers() {
let captured =
run_stream_and_capture_headers_with_api_key(CacheRetention::None, "sk-ant-oat-test")
.expect("captured request for oauth headers");
assert_eq!(
captured.headers.get("authorization").map(String::as_str),
Some("Bearer sk-ant-oat-test")
);
assert!(!captured.headers.contains_key("x-api-key"));
assert_eq!(
captured
.headers
.get("anthropic-dangerous-direct-browser-access")
.map(String::as_str),
Some("true")
);
assert_eq!(
captured.headers.get("x-app").map(String::as_str),
Some("cli")
);
assert!(
captured
.headers
.get("anthropic-beta")
.is_some_and(|value| value.contains("oauth-2025-04-20"))
);
assert!(
captured
.headers
.get("user-agent")
.is_some_and(|value| value.contains("pi_agent_rust/"))
);
}
#[test]
fn test_stream_uses_bearer_headers_for_marked_anthropic_oauth_token() {
let marked = "__pi_anthropic_oauth_bearer__:sk-ant-api-like-token";
let captured = run_stream_and_capture_headers_with_api_key(CacheRetention::None, marked)
.expect("captured request for marked oauth headers");
assert_eq!(
captured.headers.get("authorization").map(String::as_str),
Some("Bearer sk-ant-api-like-token")
);
assert!(!captured.headers.contains_key("x-api-key"));
assert!(
captured
.headers
.get("anthropic-beta")
.is_some_and(|value| value.contains("oauth-2025-04-20"))
);
}
#[test]
fn test_stream_claude_style_non_sk_token_uses_bearer_auth_headers() {
let captured =
run_stream_and_capture_headers_with_api_key(CacheRetention::None, "claude-oauth-token")
.expect("captured request for claude bearer headers");
assert_eq!(
captured.headers.get("authorization").map(String::as_str),
Some("Bearer claude-oauth-token")
);
assert!(!captured.headers.contains_key("x-api-key"));
}
#[test]
fn test_stream_kimi_oauth_uses_bearer_and_kimi_headers() {
let captured = run_stream_and_capture_headers_for_provider_with_api_key(
CacheRetention::None,
"kimi-for-coding",
"kimi-oauth-token",
)
.expect("captured request for kimi oauth headers");
assert_eq!(
captured.headers.get("authorization").map(String::as_str),
Some("Bearer kimi-oauth-token")
);
assert!(!captured.headers.contains_key("x-api-key"));
assert!(
!captured
.headers
.contains_key("anthropic-dangerous-direct-browser-access")
);
assert!(!captured.headers.contains_key("anthropic-beta"));
assert_eq!(
captured.headers.get("x-msh-platform").map(String::as_str),
Some("kimi_cli")
);
assert!(captured.headers.contains_key("x-msh-version"));
assert!(captured.headers.contains_key("x-msh-device-name"));
assert!(captured.headers.contains_key("x-msh-device-model"));
assert!(captured.headers.contains_key("x-msh-os-version"));
assert!(captured.headers.contains_key("x-msh-device-id"));
}
#[test]
fn test_stream_kimi_api_key_uses_x_api_key_header() {
let captured = run_stream_and_capture_headers_for_provider_with_api_key(
CacheRetention::None,
"kimi-for-coding",
"sk-kimi-api-key",
)
.expect("captured request for kimi api-key headers");
assert_eq!(
captured.headers.get("x-api-key").map(String::as_str),
Some("sk-kimi-api-key")
);
assert!(!captured.headers.contains_key("authorization"));
assert!(!captured.headers.contains_key("x-msh-platform"));
}
#[test]
fn test_stream_oauth_beta_header_includes_prompt_caching_when_enabled() {
let captured =
run_stream_and_capture_headers_with_api_key(CacheRetention::Short, "sk-ant-oat-test")
.expect("captured request for oauth + cache beta header");
let beta = captured
.headers
.get("anthropic-beta")
.expect("anthropic-beta header");
assert!(beta.contains("oauth-2025-04-20"));
assert!(beta.contains("prompt-caching-2024-07-31"));
}
#[test]
fn test_stream_http_error_includes_status_and_body_message() {
let (base_url, _rx) = spawn_test_server(
401,
"application/json",
r#"{"type":"error","error":{"type":"authentication_error","message":"Invalid API key"}}"#,
);
let provider = AnthropicProvider::new("claude-test").with_base_url(base_url);
let context = Context {
system_prompt: None,
messages: vec![Message::User(crate::model::UserMessage {
content: UserContent::Text("ping".to_string()),
timestamp: 0,
})]
.into(),
tools: Vec::new().into(),
};
let options = StreamOptions {
api_key: Some("test-key".to_string()),
..Default::default()
};
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let result = runtime.block_on(async { provider.stream(&context, &options).await });
let Err(err) = result else {
panic!("expected HTTP error");
};
let message = err.to_string();
assert!(message.contains("Anthropic API error (HTTP 401)"));
assert!(message.contains("Invalid API key"));
}
#[test]
fn test_provider_name_reflects_override() {
let provider = AnthropicProvider::new("claude-test").with_provider_name("kimi-for-coding");
assert_eq!(provider.name(), "kimi-for-coding");
}
#[derive(Debug)]
struct CapturedRequest {
headers: HashMap<String, String>,
body: String,
}
fn run_stream_and_capture_headers(cache_retention: CacheRetention) -> Option<CapturedRequest> {
run_stream_and_capture_headers_with_api_key(cache_retention, "sk-ant-test-key")
}
fn run_stream_and_capture_headers_with_api_key(
cache_retention: CacheRetention,
api_key: &str,
) -> Option<CapturedRequest> {
run_stream_and_capture_headers_for_provider_with_api_key(
cache_retention,
"anthropic",
api_key,
)
}
fn run_stream_and_capture_headers_for_provider_with_api_key(
cache_retention: CacheRetention,
provider_name: &str,
api_key: &str,
) -> Option<CapturedRequest> {
let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
let provider = AnthropicProvider::new("claude-test")
.with_provider_name(provider_name)
.with_base_url(base_url);
let context = Context {
system_prompt: Some("test system".to_string().into()),
messages: vec![Message::User(crate::model::UserMessage {
content: UserContent::Text("ping".to_string()),
timestamp: 0,
})]
.into(),
tools: Vec::new().into(),
};
let options = StreamOptions {
api_key: Some(api_key.to_string()),
cache_retention,
..Default::default()
};
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let mut stream = provider.stream(&context, &options).await.expect("stream");
while let Some(event) = stream.next().await {
if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
break;
}
}
});
rx.recv_timeout(Duration::from_secs(2)).ok()
}
fn collect_stream_items_from_body(body: &str) -> Vec<Result<StreamEvent>> {
let (base_url, _rx) = spawn_test_server(200, "text/event-stream", body);
let provider = AnthropicProvider::new("claude-test").with_base_url(base_url);
let context = Context {
system_prompt: Some("test system".to_string().into()),
messages: vec![Message::User(crate::model::UserMessage {
content: UserContent::Text("ping".to_string()),
timestamp: 0,
})]
.into(),
tools: Vec::new().into(),
};
let options = StreamOptions {
api_key: Some("sk-ant-test-key".to_string()),
..Default::default()
};
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let mut stream = provider.stream(&context, &options).await.expect("stream");
let mut items = Vec::new();
while let Some(item) = stream.next().await {
items.push(item);
}
items
})
}
fn success_sse_body() -> String {
[
r#"data: {"type":"message_start","message":{"usage":{"input_tokens":1}}}"#,
"",
r#"data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":1}}"#,
"",
r#"data: {"type":"message_stop"}"#,
"",
]
.join("\n")
}
fn spawn_test_server(
status_code: u16,
content_type: &str,
body: &str,
) -> (String, mpsc::Receiver<CapturedRequest>) {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
let addr = listener.local_addr().expect("local addr");
let (tx, rx) = mpsc::channel();
let body = body.to_string();
let content_type = content_type.to_string();
std::thread::spawn(move || {
let (mut socket, _) = listener.accept().expect("accept");
socket
.set_read_timeout(Some(Duration::from_secs(2)))
.expect("set read timeout");
let mut bytes = Vec::new();
let mut chunk = [0_u8; 4096];
loop {
match socket.read(&mut chunk) {
Ok(0) => break,
Ok(n) => {
bytes.extend_from_slice(&chunk[..n]);
if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
Err(err)
if err.kind() == std::io::ErrorKind::WouldBlock
|| err.kind() == std::io::ErrorKind::TimedOut =>
{
break;
}
Err(err) => panic!("read request failed: {err}"),
}
}
let header_end = bytes
.windows(4)
.position(|window| window == b"\r\n\r\n")
.expect("request header boundary");
let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
let headers = parse_headers(&header_text);
let mut request_body = bytes[header_end + 4..].to_vec();
let content_length = headers
.get("content-length")
.and_then(|value| value.parse::<usize>().ok())
.unwrap_or(0);
while request_body.len() < content_length {
match socket.read(&mut chunk) {
Ok(0) => break,
Ok(n) => request_body.extend_from_slice(&chunk[..n]),
Err(err)
if err.kind() == std::io::ErrorKind::WouldBlock
|| err.kind() == std::io::ErrorKind::TimedOut =>
{
break;
}
Err(err) => panic!("read request body failed: {err}"),
}
}
let captured = CapturedRequest {
headers,
body: String::from_utf8_lossy(&request_body).to_string(),
};
tx.send(captured).expect("send captured request");
let reason = match status_code {
401 => "Unauthorized",
500 => "Internal Server Error",
_ => "OK",
};
let response = format!(
"HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
socket
.write_all(response.as_bytes())
.expect("write response");
socket.flush().expect("flush response");
});
(format!("http://{addr}/messages"), rx)
}
fn parse_headers(header_text: &str) -> HashMap<String, String> {
let mut headers = HashMap::new();
for line in header_text.lines().skip(1) {
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
}
}
headers
}
fn load_fixture(file_name: &str) -> ProviderFixture {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/provider_responses")
.join(file_name);
let raw = std::fs::read_to_string(path).expect("fixture read");
serde_json::from_str(&raw).expect("fixture parse")
}
fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async move {
let byte_stream = stream::iter(
events
.iter()
.map(|event| {
let data = match event {
Value::String(text) => text.clone(),
_ => serde_json::to_string(event).expect("serialize event"),
};
format!("data: {data}\n\n").into_bytes()
})
.map(Ok),
);
let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
let mut state = StreamState::new(
event_source,
"claude-test".to_string(),
"anthropic-messages".to_string(),
"anthropic".to_string(),
);
let mut out = Vec::new();
while let Some(item) = state.event_source.next().await {
let msg = item.expect("SSE event");
if msg.event == "ping" {
continue;
}
if let Some(event) = state.process_event(&msg.data).expect("process_event") {
out.push(event);
}
}
out
})
}
fn summarize_event(event: &StreamEvent) -> EventSummary {
match event {
StreamEvent::Start { .. } => EventSummary {
kind: "start".to_string(),
content_index: None,
delta: None,
content: None,
reason: None,
},
StreamEvent::TextStart { content_index, .. } => EventSummary {
kind: "text_start".to_string(),
content_index: Some(*content_index),
delta: None,
content: None,
reason: None,
},
StreamEvent::TextDelta {
content_index,
delta,
..
} => EventSummary {
kind: "text_delta".to_string(),
content_index: Some(*content_index),
delta: Some(delta.clone()),
content: None,
reason: None,
},
StreamEvent::TextEnd {
content_index,
content,
..
} => EventSummary {
kind: "text_end".to_string(),
content_index: Some(*content_index),
delta: None,
content: Some(content.clone()),
reason: None,
},
StreamEvent::Done { reason, .. } => EventSummary {
kind: "done".to_string(),
content_index: None,
delta: None,
content: None,
reason: Some(reason_to_string(*reason)),
},
StreamEvent::Error { reason, .. } => EventSummary {
kind: "error".to_string(),
content_index: None,
delta: None,
content: None,
reason: Some(reason_to_string(*reason)),
},
_ => EventSummary {
kind: "other".to_string(),
content_index: None,
delta: None,
content: None,
reason: None,
},
}
}
fn reason_to_string(reason: StopReason) -> String {
match reason {
StopReason::Stop => "stop",
StopReason::Length => "length",
StopReason::ToolUse => "tool_use",
StopReason::Error => "error",
StopReason::Aborted => "aborted",
}
.to_string()
}
#[test]
fn test_compat_custom_headers_injected_into_request() {
let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
let mut custom = HashMap::new();
custom.insert("X-Custom-Tag".to_string(), "anthropic-override".to_string());
custom.insert("X-Routing-Hint".to_string(), "us-east-1".to_string());
let compat = crate::models::CompatConfig {
custom_headers: Some(custom),
..Default::default()
};
let provider = AnthropicProvider::new("claude-test")
.with_base_url(base_url)
.with_compat(Some(compat));
let context = Context {
system_prompt: Some("test".to_string().into()),
messages: vec![Message::User(crate::model::UserMessage {
content: UserContent::Text("hi".to_string()),
timestamp: 0,
})]
.into(),
tools: Vec::new().into(),
};
let options = StreamOptions {
api_key: Some("sk-ant-test-key".to_string()),
..Default::default()
};
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let mut stream = provider.stream(&context, &options).await.expect("stream");
while let Some(event) = stream.next().await {
if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
break;
}
}
});
let captured = rx
.recv_timeout(Duration::from_secs(2))
.expect("captured request");
assert_eq!(
captured.headers.get("x-custom-tag").map(String::as_str),
Some("anthropic-override"),
"compat custom header X-Custom-Tag missing"
);
assert_eq!(
captured.headers.get("x-routing-hint").map(String::as_str),
Some("us-east-1"),
"compat custom header X-Routing-Hint missing"
);
assert_eq!(
captured.headers.get("x-api-key").map(String::as_str),
Some("sk-ant-test-key"),
);
}
#[test]
fn test_compat_none_does_not_affect_headers() {
let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
let provider = AnthropicProvider::new("claude-test")
.with_base_url(base_url)
.with_compat(None);
let context = Context {
system_prompt: Some("test".to_string().into()),
messages: vec![Message::User(crate::model::UserMessage {
content: UserContent::Text("hi".to_string()),
timestamp: 0,
})]
.into(),
tools: Vec::new().into(),
};
let options = StreamOptions {
api_key: Some("sk-ant-test-key".to_string()),
..Default::default()
};
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let mut stream = provider.stream(&context, &options).await.expect("stream");
while let Some(event) = stream.next().await {
if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
break;
}
}
});
let captured = rx
.recv_timeout(Duration::from_secs(2))
.expect("captured request");
assert_eq!(
captured.headers.get("x-api-key").map(String::as_str),
Some("sk-ant-test-key"),
);
assert!(
!captured.headers.contains_key("x-custom-tag"),
"No custom headers should be present with compat=None"
);
}
mod proptest_process_event {
use super::*;
use proptest::prelude::*;
fn make_state()
-> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
{
let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
let sse = crate::sse::SseStream::new(Box::pin(empty));
StreamState::new(
sse,
"claude-test".into(),
"anthropic-messages".into(),
"anthropic".into(),
)
}
fn small_string() -> impl Strategy<Value = String> {
prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
}
fn optional_string() -> impl Strategy<Value = Option<String>> {
prop_oneof![Just(None), small_string().prop_map(Some),]
}
fn token_count() -> impl Strategy<Value = u64> {
prop_oneof![
5 => 0u64..10_000u64,
2 => Just(0u64),
1 => Just(u64::MAX),
1 => (u64::MAX - 100)..=u64::MAX,
]
}
fn block_type() -> impl Strategy<Value = String> {
prop_oneof![
Just("text".to_string()),
Just("thinking".to_string()),
Just("tool_use".to_string()),
Just("unknown_block_type".to_string()),
"[a-z_]{1,12}",
]
}
fn delta_type() -> impl Strategy<Value = String> {
prop_oneof![
Just("text_delta".to_string()),
Just("thinking_delta".to_string()),
Just("input_json_delta".to_string()),
Just("signature_delta".to_string()),
Just("unknown_delta".to_string()),
"[a-z_]{1,16}",
]
}
fn content_index() -> impl Strategy<Value = u32> {
prop_oneof![
5 => 0u32..5u32,
2 => Just(0u32),
1 => Just(u32::MAX),
1 => 1000u32..2000u32,
]
}
fn stop_reason_str() -> impl Strategy<Value = String> {
prop_oneof![
Just("end_turn".to_string()),
Just("max_tokens".to_string()),
Just("tool_use".to_string()),
Just("stop_sequence".to_string()),
Just("unknown_reason".to_string()),
"[a-z_]{1,12}",
]
}
fn anthropic_event_json() -> impl Strategy<Value = String> {
prop_oneof![
3 => token_count().prop_flat_map(|input| {
(Just(input), token_count(), token_count()).prop_map(
move |(cache_read, cache_write, _)| {
serde_json::json!({
"type": "message_start",
"message": {
"usage": {
"input_tokens": input,
"cache_read_input_tokens": cache_read,
"cache_creation_input_tokens": cache_write
}
}
})
.to_string()
},
)
}),
1 => Just(r#"{"type":"message_start","message":{}}"#.to_string()),
3 => (content_index(), block_type(), optional_string(), optional_string())
.prop_map(|(idx, bt, id, name)| {
let mut block = serde_json::json!({"type": bt});
if let Some(id) = id {
block["id"] = serde_json::Value::String(id);
}
if let Some(name) = name {
block["name"] = serde_json::Value::String(name);
}
serde_json::json!({
"type": "content_block_start",
"index": idx,
"content_block": block
})
.to_string()
}),
3 => (content_index(), delta_type(), optional_string(), optional_string(), optional_string(), optional_string())
.prop_map(|(idx, dt, text, thinking, partial_json, sig)| {
let mut delta = serde_json::json!({"type": dt});
if let Some(t) = text { delta["text"] = serde_json::Value::String(t); }
if let Some(t) = thinking { delta["thinking"] = serde_json::Value::String(t); }
if let Some(p) = partial_json { delta["partial_json"] = serde_json::Value::String(p); }
if let Some(s) = sig { delta["signature"] = serde_json::Value::String(s); }
serde_json::json!({
"type": "content_block_delta",
"index": idx,
"delta": delta
})
.to_string()
}),
2 => content_index().prop_map(|idx| {
serde_json::json!({"type": "content_block_stop", "index": idx}).to_string()
}),
2 => (stop_reason_str(), token_count()).prop_map(|(sr, out)| {
serde_json::json!({
"type": "message_delta",
"delta": {"stop_reason": sr},
"usage": {"output_tokens": out}
})
.to_string()
}),
1 => stop_reason_str().prop_map(|sr| {
serde_json::json!({
"type": "message_delta",
"delta": {"stop_reason": sr}
})
.to_string()
}),
2 => Just(r#"{"type":"message_stop"}"#.to_string()),
2 => small_string().prop_map(|msg| {
serde_json::json!({"type": "error", "error": {"message": msg}}).to_string()
}),
2 => Just(r#"{"type":"ping"}"#.to_string()),
]
}
fn chaos_json() -> impl Strategy<Value = String> {
prop_oneof![
Just(String::new()),
Just("{}".to_string()),
Just("[]".to_string()),
Just("null".to_string()),
Just("true".to_string()),
Just("42".to_string()),
Just("{".to_string()),
Just(r#"{"type":}"#.to_string()),
Just(r#"{"type":null}"#.to_string()),
"[a-z_]{1,20}".prop_map(|t| format!(r#"{{"type":"{t}"}}"#)),
"[ -~]{0,64}",
Just(r#"{"type":"message_start"}"#.to_string()),
Just(r#"{"type":"content_block_delta"}"#.to_string()),
Just(r#"{"type":"error"}"#.to_string()),
]
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 256,
max_shrink_iters: 100,
.. ProptestConfig::default()
})]
#[test]
fn process_event_valid_never_panics(data in anthropic_event_json()) {
let mut state = make_state();
let _ = state.process_event(&data);
}
#[test]
fn process_event_chaos_never_panics(data in chaos_json()) {
let mut state = make_state();
let _ = state.process_event(&data);
}
#[test]
fn process_event_sequence_never_panics(
events in prop::collection::vec(anthropic_event_json(), 1..8)
) {
let mut state = make_state();
for event in &events {
let _ = state.process_event(event);
}
}
}
}
}
#[cfg(feature = "fuzzing")]
pub mod fuzz {
use super::*;
use futures::stream;
use std::pin::Pin;
type FuzzStream =
Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
pub struct Processor(StreamState<FuzzStream>);
impl Default for Processor {
fn default() -> Self {
Self::new()
}
}
impl Processor {
pub fn new() -> Self {
let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
Self(StreamState::new(
crate::sse::SseStream::new(Box::pin(empty)),
"claude-fuzz".into(),
"anthropic-messages".into(),
"anthropic".into(),
))
}
pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
Ok(self
.0
.process_event(data)?
.map_or_else(Vec::new, |event| vec![event]))
}
}
}