#![allow(dead_code)]
use super::{AIClient, AIConfig, AIResponse};
use anyhow::Result;
use futures_util::stream::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use tracing::debug;
pub struct ClaudeClient {
client: reqwest::Client,
api_key: String,
base_url: String,
model: String,
max_tokens: u32,
temperature: f32,
}
#[derive(Debug, Serialize)]
struct ClaudeRequest {
model: String,
max_tokens: u32,
temperature: f32,
messages: Vec<ClaudeMessage>,
system: Option<String>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
stream: bool,
}
#[derive(Debug, Deserialize)]
struct StreamDeltaEvent {
#[serde(rename = "type")]
event_type: String,
delta: Option<StreamDeltaInner>,
}
#[derive(Debug, Deserialize)]
struct StreamDeltaInner {
#[serde(rename = "type")]
delta_type: String,
#[serde(default)]
text: String,
}
#[derive(Debug, Deserialize)]
struct StreamErrorEvent {
#[serde(rename = "type")]
event_type: String,
error: StreamErrorInner,
}
#[derive(Debug, Deserialize)]
struct StreamErrorInner {
message: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct ClaudeMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ClaudeResponse {
id: String,
content: Vec<ContentBlock>,
model: String,
stop_reason: Option<String>,
usage: Usage,
}
#[derive(Debug, Deserialize)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
text: String,
}
#[derive(Debug, Deserialize)]
struct Usage {
input_tokens: u32,
output_tokens: u32,
}
impl ClaudeClient {
pub fn new(config: &AIConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()?;
let base_url = std::env::var("ANTHROPIC_BASE_URL")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "https://api.anthropic.com".to_string())
.trim_end_matches('/')
.to_string();
Ok(Self {
client,
api_key: config.api_key.clone(),
base_url,
model: config.model.clone(),
max_tokens: config.max_tokens,
temperature: config.temperature,
})
}
fn build_headers(&self) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
let auth_header = format!("{}", self.api_key);
headers.insert(
"x-api-key",
HeaderValue::from_str(&auth_header)?,
);
headers.insert(
"anthropic-version",
HeaderValue::from_static("2023-06-01"),
);
Ok(headers)
}
}
#[async_trait::async_trait]
impl AIClient for ClaudeClient {
async fn complete(
&self,
prompt: &str,
_context: Option<&super::AIContext>,
) -> Result<AIResponse> {
debug!("Sending completion request to Claude");
let url = format!("{}/v1/messages", self.base_url);
let messages = vec![ClaudeMessage {
role: "user".to_string(),
content: prompt.to_string(),
}];
let request = ClaudeRequest {
model: self.model.clone(),
max_tokens: self.max_tokens,
temperature: self.temperature,
messages,
system: Some("You are a helpful AI assistant for developers.".to_string()),
stream: false,
};
let response = self
.client
.post(url)
.headers(self.build_headers()?)
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
anyhow::bail!("Claude API error: {}", error_text);
}
let completion: ClaudeResponse = response.json().await?;
let content = completion
.content
.get(0)
.map(|c| c.text.clone())
.unwrap_or_default();
let tokens_used = Some(completion.usage.input_tokens + completion.usage.output_tokens);
Ok(AIResponse {
content,
tokens_used,
model: completion.model,
finish_reason: completion.stop_reason,
cost_estimate: None,
})
}
async fn complete_stream(
&self,
prompt: &str,
_context: Option<&super::AIContext>,
) -> Result<tokio::sync::mpsc::Receiver<Result<String>>> {
debug!("Sending streaming completion request to Claude");
let (tx, rx) = tokio::sync::mpsc::channel(100);
let url = format!("{}/v1/messages", self.base_url);
let headers = self.build_headers()?;
let client = self.client.clone();
let model = self.model.clone();
let max_tokens = self.max_tokens;
let temperature = self.temperature;
let prompt = prompt.to_string();
tokio::spawn(async move {
let request = ClaudeRequest {
model,
max_tokens,
temperature,
messages: vec![ClaudeMessage {
role: "user".to_string(),
content: prompt,
}],
system: Some("You are a helpful AI assistant for developers.".to_string()),
stream: true,
};
let response = match client.post(url).headers(headers).json(&request).send().await {
Ok(r) => r,
Err(e) => {
let _ = tx.send(Err(e.into())).await;
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let _ = tx
.send(Err(anyhow::anyhow!("Claude API error {}: {}", status, body)))
.await;
return;
}
let mut stream = response.bytes_stream();
let mut buf = String::new();
while let Some(chunk) = stream.next().await {
let bytes = match chunk {
Ok(b) => b,
Err(e) => {
let _ = tx.send(Err(anyhow::anyhow!("stream error: {}", e))).await;
return;
}
};
buf.push_str(&String::from_utf8_lossy(&bytes));
while let Some(idx) = buf.find("\n\n") {
let raw = buf[..idx].to_string();
buf.drain(..idx + 2);
for line in raw.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return;
}
if let Ok(ev) = serde_json::from_str::<StreamDeltaEvent>(data) {
if ev.event_type == "content_block_delta" {
if let Some(delta) = ev.delta {
if delta.delta_type == "text_delta" && !delta.text.is_empty() {
if tx.send(Ok(delta.text)).await.is_err() {
return;
}
}
}
} else if ev.event_type == "message_stop" {
return;
}
} else if let Ok(err) = serde_json::from_str::<StreamErrorEvent>(data) {
if err.event_type == "error" {
let _ = tx
.send(Err(anyhow::anyhow!(
"Claude stream error: {}",
err.error.message
)))
.await;
return;
}
}
}
}
}
}
});
Ok(rx)
}
fn name(&self) -> &str {
"Claude"
}
fn model_info(&self) -> super::ModelConfig {
super::ModelConfig {
name: self.model.clone(),
context_window: 200_000,
supports_functions: true,
supports_vision: false,
cost_per_1k_input: 0.015,
cost_per_1k_output: 0.075,
}
}
}