use async_trait::async_trait;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::Deserialize;
use serde_json::Value as JsonValue;
use std::pin::Pin;
use crate::{
error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
ProviderEvent, StopReason, StreamOptions, Usage,
};
#[derive(Clone)]
pub struct AnthropicProvider {
client: Client,
api_key: Option<String>,
}
impl AnthropicProvider {
pub fn new() -> Self {
Self {
client: Client::new(),
api_key: std::env::var("ANTHROPIC_API_KEY").ok(),
}
}
#[allow(dead_code)]
pub fn with_api_key(api_key: impl Into<String>) -> Self {
Self {
client: Client::new(),
api_key: Some(api_key.into()),
}
}
}
impl Default for AnthropicProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for AnthropicProvider {
async fn stream(
&self,
model: &Model,
context: &Context,
options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
let options = options.unwrap_or_default();
let url = format!("{}/v1/messages", model.base_url);
let api_key = options
.api_key
.as_ref()
.or(self.api_key.as_ref())
.ok_or_else(|| ProviderError::MissingApiKey)?;
let messages = build_anthropic_messages(context)?;
let mut body = serde_json::json!({
"model": model.id,
"messages": messages,
"stream": true,
});
if let Some(ref prompt) = context.system_prompt {
body["system"] = serde_json::json!(prompt);
}
if let Some(temp) = options.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max) = options.max_tokens {
body["max_tokens"] = serde_json::json!(max);
}
if !context.tools.is_empty() {
body["tools"] = build_anthropic_tools(&context.tools)?;
}
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-api-key", api_key.parse().unwrap());
headers.insert("content-type", "application/json".parse().unwrap());
headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
for (k, v) in &options.headers {
if let (Ok(name), Ok(value)) = (
k.parse::<reqwest::header::HeaderName>(),
v.parse::<reqwest::header::HeaderValue>(),
) {
headers.insert(name, value);
}
}
let response = self
.client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(ProviderError::RequestFailed)?;
if !response.status().is_success() {
let status = response.status();
let body: String = response.text().await.unwrap_or_default();
return Err(ProviderError::HttpError(status.as_u16(), body));
}
let model_name = model.id.clone();
let stream = response.bytes_stream().flat_map(move |chunk| match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
futures::stream::iter(parse_anthropic_events(&text, &model_name))
}
Err(e) => futures::stream::iter(vec![ProviderEvent::Error {
reason: StopReason::Error,
error: create_error_message(&e.to_string()),
}]),
});
Ok(Box::pin(stream))
}
fn name(&self) -> &str {
"anthropic"
}
}
fn build_anthropic_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
let mut messages = Vec::new();
for msg in &context.messages {
match msg {
crate::Message::User(u) => {
let content = match &u.content {
crate::MessageContent::Text(s) => vec![serde_json::json!({
"type": "text",
"text": s,
})],
crate::MessageContent::Blocks(blocks) => blocks_to_anthropic_content(blocks)?,
};
messages.push(serde_json::json!({
"role": "user",
"content": content,
}));
}
crate::Message::Assistant(a) => {
let content = blocks_to_anthropic_content(&a.content)?;
messages.push(serde_json::json!({
"role": "assistant",
"content": content,
}));
}
crate::Message::ToolResult(t) => {
let content = blocks_to_anthropic_content(&t.content)?;
messages.push(serde_json::json!({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": t.tool_call_id,
"content": content,
}],
}));
}
}
}
Ok(messages)
}
fn blocks_to_anthropic_content(blocks: &[ContentBlock]) -> Result<Vec<JsonValue>, ProviderError> {
let mut items = Vec::new();
for block in blocks {
match block {
ContentBlock::Text(t) => {
items.push(serde_json::json!({
"type": "text",
"text": t.text,
}));
}
ContentBlock::ToolCall(tc) => {
items.push(serde_json::json!({
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments,
}));
}
ContentBlock::Thinking(th) => {
items.push(serde_json::json!({
"type": "thinking",
"thinking": th.thinking,
}));
}
ContentBlock::Image(img) => {
items.push(serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": img.mime_type,
"data": img.data,
},
}));
}
ContentBlock::Unknown(_) => {
}
}
}
Ok(items)
}
fn build_anthropic_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
let items: Vec<_> = tools
.iter()
.map(|tool| {
serde_json::json!({
"name": tool.name,
"description": tool.description,
"input_schema": tool.parameters,
})
})
.collect();
Ok(serde_json::json!(items))
}
fn parse_anthropic_events(text: &str, model_id: &str) -> Vec<ProviderEvent> {
let mut events = Vec::new();
let partial_message = AssistantMessage::new(Api::AnthropicMessages, "anthropic", model_id);
let estimated = text.split('\n').filter(|l| l.starts_with("data: ")).count();
events.reserve(estimated);
let mut accumulated_usage = Usage::default();
for line in text.split('\n') {
let line = line.trim_end_matches('\r');
if line.is_empty() {
continue;
}
if !line.starts_with("data: ") {
continue;
}
let data = &line[6..];
if data == "[DONE]" || data.is_empty() {
continue;
}
let event = match serde_json::from_str::<AnthropicEvent>(data) {
Ok(e) => e,
Err(_) => continue,
};
let event_type = event.type_.as_deref();
match event_type {
Some("message_start") => {
events.push(ProviderEvent::Start {
partial: partial_message.clone(),
});
}
Some("content_block_start") => {
if let Some(block) = &event.content_block {
match block.type_.as_deref() {
Some("text") => {
events.push(ProviderEvent::TextStart {
content_index: block.index.unwrap_or(0),
partial: partial_message.clone(),
});
}
Some("thinking") => {
events.push(ProviderEvent::ThinkingStart {
content_index: block.index.unwrap_or(0),
partial: partial_message.clone(),
});
}
Some("tool_use") => {
events.push(ProviderEvent::ToolCallStart {
content_index: block.index.unwrap_or(0),
partial: partial_message.clone(),
});
}
_ => {}
}
}
}
Some("content_block_delta") => {
if let Some(delta) = &event.delta {
match delta.type_.as_deref() {
Some("text_delta") => {
if let Some(text) = &delta.text {
events.push(ProviderEvent::TextDelta {
content_index: event.index.unwrap_or(0),
delta: text.clone(),
partial: partial_message.clone(),
});
}
}
Some("thinking_delta") => {
if let Some(text) = &delta.thinking {
events.push(ProviderEvent::ThinkingDelta {
content_index: event.index.unwrap_or(0),
delta: text.clone(),
partial: partial_message.clone(),
});
}
}
Some("input_json_delta") => {
if let Some(args) = &delta.partial_json {
events.push(ProviderEvent::ToolCallDelta {
content_index: event.index.unwrap_or(0),
delta: args.clone(),
partial: partial_message.clone(),
});
}
}
_ => {}
}
}
}
Some("message_delta") => {
if let Some(delta) = &event.delta {
let reason = match delta.stop_reason.as_deref() {
Some("end_turn") => StopReason::Stop,
Some("max_tokens") => StopReason::Length,
Some("stop_sequence") => StopReason::Stop,
_ => StopReason::Stop,
};
let mut done_msg = partial_message.clone();
done_msg.usage = accumulated_usage.clone();
events.push(ProviderEvent::Done {
reason,
message: done_msg,
});
}
}
Some("message_stop") => {
}
_ => {}
}
if let Some(usage) = event.usage {
accumulated_usage.input = usage.input_tokens;
accumulated_usage.output = usage.output_tokens;
accumulated_usage.cache_read = usage.cache_read;
accumulated_usage.cache_write = usage.cache_creation;
accumulated_usage.total_tokens = usage.input_tokens + usage.output_tokens;
}
}
events
}
fn create_error_message(msg: &str) -> AssistantMessage {
let mut message = AssistantMessage::new(Api::AnthropicMessages, "anthropic", "unknown");
message.stop_reason = StopReason::Error;
message.error_message = Some(msg.to_string());
message
}
#[derive(Debug, Deserialize)]
struct AnthropicEvent {
#[serde(rename = "type")]
type_: Option<String>,
#[serde(rename = "index")]
index: Option<usize>,
content_block: Option<ContentBlockStart>,
delta: Option<Delta>,
usage: Option<AnthropicUsage>,
}
#[derive(Debug, Deserialize)]
struct ContentBlockStart {
#[serde(rename = "type")]
type_: Option<String>,
index: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct Delta {
#[serde(rename = "type")]
type_: Option<String>,
text: Option<String>,
thinking: Option<String>,
partial_json: Option<String>,
#[serde(rename = "stop_reason")]
stop_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AnthropicUsage {
#[serde(rename = "input_tokens")]
input_tokens: usize,
#[serde(rename = "output_tokens")]
output_tokens: usize,
#[serde(rename = "cache_read")]
cache_read: usize,
#[serde(rename = "cache_creation")]
cache_creation: usize,
}