use std::time::Duration;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use secrecy::{ExposeSecret, SecretString};
use crate::config::Config;
use crate::error::{Error, Result};
use super::MAX_RESPONSE_BYTES;
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com/v1";
const API_VERSION: &str = "2023-06-01";
pub struct AnthropicProvider {
client: Client,
base_url: String,
model: String,
api_key: SecretString,
temperature: f32,
max_tokens: u32,
}
#[derive(Serialize)]
struct MessagesRequest {
model: String,
system: String,
messages: Vec<Message>,
temperature: f32,
max_tokens: u32,
stream: bool,
}
#[derive(Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Deserialize)]
struct StreamEvent {
#[serde(rename = "type")]
event_type: String,
delta: Option<ContentDelta>,
}
#[derive(Deserialize)]
struct ContentDelta {
text: Option<String>,
}
impl AnthropicProvider {
pub fn new(config: &Config) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| Error::Provider {
provider: "anthropic".into(),
message: format!("failed to build HTTP client: {e}"),
})?;
Ok(Self {
client,
base_url: config
.anthropic_base_url
.clone()
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.trim_end_matches('/')
.to_string(),
model: config.model.clone(),
api_key: config.api_key.clone().unwrap_or_default(),
temperature: config.temperature,
max_tokens: config.num_predict,
})
}
pub async fn verify_connection(&self) -> Result<()> {
if self.api_key.expose_secret().is_empty() {
return Err(Error::Provider {
provider: "anthropic".into(),
message: "API key not configured".into(),
});
}
Ok(())
}
pub async fn generate(
&self,
prompt: &str,
system_prompt: &str,
token_tx: mpsc::Sender<String>,
cancel: CancellationToken,
) -> Result<String> {
let url = format!("{}/messages", self.base_url);
let response = self
.client
.post(&url)
.header("x-api-key", self.api_key.expose_secret())
.header("anthropic-version", API_VERSION)
.header("content-type", "application/json")
.json(&MessagesRequest {
model: self.model.clone(),
system: system_prompt.into(),
messages: vec![Message {
role: "user".into(),
content: prompt.to_string(),
}],
temperature: self.temperature,
max_tokens: self.max_tokens,
stream: true,
})
.send()
.await
.map_err(|e| {
if e.is_timeout() {
Error::Provider {
provider: "anthropic".into(),
message: "request timed out".into(),
}
} else {
Error::Provider {
provider: "anthropic".into(),
message: e.without_url().to_string(),
}
}
})?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|e| format!("(failed to read body: {e})"));
return Err(Error::Provider {
provider: "anthropic".into(),
message: format!("HTTP {status}: {body}"),
});
}
let mut stream = response.bytes_stream();
let mut full_response = String::new();
let mut line_buffer = String::new();
loop {
tokio::select! {
_ = cancel.cancelled() => {
return Err(Error::Cancelled);
}
chunk = stream.next() => {
let Some(chunk) = chunk else { break };
let chunk = chunk.map_err(|e| Error::Provider {
provider: "anthropic".into(),
message: e.without_url().to_string(),
})?;
line_buffer.push_str(&String::from_utf8_lossy(&chunk));
if line_buffer.len() > MAX_RESPONSE_BYTES {
return Err(Error::Provider {
provider: "anthropic".into(),
message: "line buffer exceeded 1 MB limit".into(),
});
}
while let Some(newline_pos) = line_buffer.find('\n') {
let result = {
let line = line_buffer[..newline_pos].trim();
if line.is_empty() || line.starts_with("event:") {
None
} else if let Some(data) = line.strip_prefix("data: ") {
serde_json::from_str::<StreamEvent>(data).ok()
} else {
None
}
};
line_buffer.drain(..=newline_pos);
if let Some(event) = result {
match event.event_type.as_str() {
"content_block_delta" => {
if let Some(delta) = &event.delta
&& let Some(text) = &delta.text
{
let _ = token_tx.send(text.clone()).await;
full_response.push_str(text);
}
if full_response.len() > MAX_RESPONSE_BYTES {
return Err(Error::Provider {
provider: "anthropic".into(),
message: "response exceeded 1 MB limit".into(),
});
}
}
"message_stop" => {
return Ok(full_response.trim().to_string());
}
_ => {}
}
}
}
}
}
}
if !line_buffer.is_empty() {
let line = line_buffer.trim();
if !line.is_empty()
&& !line.starts_with("event:")
&& let Some(data) = line.strip_prefix("data: ")
&& let Ok(event) = serde_json::from_str::<StreamEvent>(data)
&& event.event_type == "content_block_delta"
&& let Some(delta) = &event.delta
&& let Some(text) = &delta.text
{
full_response.push_str(text);
}
}
Ok(full_response.trim().to_string())
}
pub fn name(&self) -> &str {
"anthropic"
}
}