use futures::stream::{BoxStream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::error::{AgentError, Result};
use crate::models::LLM;
use crate::types::{File, GenerationChunk, GenerationResponse, Message, Role};
pub struct AnthropicLLM {
client: Client,
api_key: String,
model: String,
max_tokens: u32,
}
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
messages: Vec<AnthropicMessage>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct AnthropicMessage {
role: String,
content: AnthropicContent,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { source: ImageSource },
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct ImageSource {
#[serde(rename = "type")]
source_type: String,
media_type: String,
data: String,
}
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
content: Vec<ContentBlock>,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum AnthropicStreamEvent {
#[serde(rename = "content_block_delta")]
ContentBlockDelta { delta: StreamDelta },
#[serde(rename = "message_start")]
MessageStart,
#[serde(rename = "content_block_start")]
ContentBlockStart,
#[serde(rename = "content_block_stop")]
ContentBlockStop,
#[serde(rename = "message_delta")]
MessageDelta,
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: AnthropicError },
}
#[derive(Debug, Deserialize)]
struct AnthropicError {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum StreamDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(other)]
Other,
}
impl AnthropicLLM {
pub fn new(model: impl Into<String>) -> Result<Self> {
let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| {
AgentError::ConfigError("ANTHROPIC_API_KEY environment variable not set".to_string())
})?;
Ok(Self {
client: Client::new(),
api_key,
model: model.into(),
max_tokens: 4096,
})
}
pub fn with_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
client: Client::new(),
api_key: api_key.into(),
model: model.into(),
max_tokens: 4096,
}
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
fn convert_role(role: &Role) -> String {
match role {
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::System => "user".to_string(), Role::Tool => "user".to_string(),
}
}
fn prepare_request(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
stream: bool,
) -> AnthropicRequest {
let system_prompt = messages
.iter()
.find(|m| matches!(m.role, Role::System))
.map(|m| m.content.clone());
let mut anthropic_messages: Vec<AnthropicMessage> = messages
.into_iter()
.filter(|m| !matches!(m.role, Role::System))
.map(|m| AnthropicMessage {
role: Self::convert_role(&m.role),
content: AnthropicContent::Text(m.content),
})
.collect();
if let Some(files) = files {
if let Some(last_msg) = anthropic_messages.last_mut() {
let mut blocks = vec![ContentBlock::Text {
text: match &last_msg.content {
AnthropicContent::Text(t) => t.clone(),
_ => String::new(),
},
}];
for file in files {
if file.mime_type.starts_with("image/") {
blocks.push(ContentBlock::Image {
source: ImageSource {
source_type: "base64".to_string(),
media_type: file.mime_type,
data: base64::engine::general_purpose::STANDARD.encode(&file.data),
},
});
}
}
last_msg.content = AnthropicContent::Blocks(blocks);
}
}
AnthropicRequest {
model: self.model.clone(),
messages: anthropic_messages,
max_tokens: self.max_tokens,
stream: if stream { Some(true) } else { None },
system: system_prompt,
}
}
}
#[async_trait]
impl LLM for AnthropicLLM {
async fn generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<GenerationResponse> {
let request = self.prepare_request(messages, files, false);
let response = self
.client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| AgentError::ModelError(format!("Anthropic request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentError::ModelError(format!(
"Anthropic API error {}: {}",
status, text
)));
}
let anthropic_response: AnthropicResponse = response
.json()
.await
.map_err(|e| AgentError::ModelError(format!("Failed to parse response: {}", e)))?;
let content = anthropic_response
.content
.into_iter()
.filter_map(|block| match block {
ContentBlock::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
Ok(GenerationResponse {
content,
metadata: None,
})
}
async fn stream_generate(
&self,
messages: Vec<Message>,
files: Option<Vec<File>>,
) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
let request = self.prepare_request(messages, files, true);
let response = self
.client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| AgentError::ModelError(format!("Anthropic request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentError::ModelError(format!(
"Anthropic API error {}: {}",
status, text
)));
}
let stream = response.bytes_stream();
let buffer = Vec::new();
let s = futures::stream::unfold((stream, buffer), |(mut stream, mut buffer)| async move {
loop {
if let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
let line = buffer.drain(0..=pos).collect::<Vec<u8>>();
let s = String::from_utf8_lossy(&line);
let trimmed = s.trim();
if trimmed.starts_with("data: ") {
let json_str = trimmed.trim_start_matches("data: ").trim();
if let Ok(event) = serde_json::from_str::<AnthropicStreamEvent>(json_str) {
match event {
AnthropicStreamEvent::ContentBlockDelta {
delta: StreamDelta::TextDelta { text },
} => {
return Some((
Ok(GenerationChunk {
content: text,
metadata: None,
}),
(stream, buffer),
));
}
AnthropicStreamEvent::Error { error } => {
return Some((
Err(AgentError::ModelError(error.message)),
(stream, buffer),
));
}
_ => {} }
}
}
continue;
}
match stream.next().await {
Some(Ok(chunk)) => {
buffer.extend_from_slice(&chunk);
}
Some(Err(e)) => {
return Some((
Err(AgentError::ModelError(e.to_string())),
(stream, buffer),
))
}
None => {
return None;
}
}
}
});
Ok(Box::pin(s))
}
fn model_name(&self) -> &str {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_anthropic_generate() {
let llm = AnthropicLLM::new("claude-3-5-sonnet-20241022").unwrap();
let messages = vec![Message {
role: Role::User,
content: "Say 'Hello' and nothing else.".to_string(),
metadata: None,
}];
let response = llm.generate(messages, None).await.unwrap();
assert!(response.content.contains("Hello"));
}
}