use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use log::{debug, info};
use crate::config::Config;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
pub max_tokens: u32,
pub temperature: f64,
pub messages: Vec<Message>,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub content: Vec<ContentBlock>,
pub model: String,
pub role: String,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub r#type: String,
pub usage: Usage,
}
#[derive(Debug, Deserialize)]
pub struct ContentBlock {
pub text: String,
pub r#type: String,
}
#[derive(Debug, Deserialize)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
}
pub struct ClaudeClient {
client: Client,
config: Config,
}
impl ClaudeClient {
pub fn new(config: Config) -> Result<Self> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.context("Failed to create HTTP client")?;
Ok(Self { client, config })
}
pub async fn chat(&self, messages: Vec<Message>) -> Result<String> {
let request = ChatRequest {
model: self.config.model.version.clone(),
max_tokens: self.config.model.max_tokens,
temperature: self.config.model.temperature,
messages,
};
debug!("Sending request to Claude API: {:?}", request);
let response = self
.client
.post(&format!("{}/v1/chat/completions", self.config.anthropic.base_url))
.header("Content-Type", "application/json")
.header("Authorization", &format!("Bearer {}", self.config.anthropic.api_key))
.json(&request)
.send()
.await
.context("Failed to send request to Claude API")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Claude API request failed: {}", error_text);
}
let response_data: ChatResponse = response
.json()
.await
.context("Failed to parse Claude API response")?;
debug!("Received response from Claude API: {:?}", response_data);
info!(
"Claude API usage - Input tokens: {}, Output tokens: {}",
response_data.usage.input_tokens, response_data.usage.output_tokens
);
if let Some(content_block) = response_data.content.first() {
Ok(content_block.text.clone())
} else {
anyhow::bail!("No content in Claude API response");
}
}
pub async fn chat_with_tools(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
) -> Result<ChatResponseWithTools> {
let mut request_data = serde_json::json!({
"model": self.config.model.version,
"max_tokens": self.config.model.max_tokens,
"temperature": self.config.model.temperature,
"messages": messages
});
if !tools.is_empty() {
request_data["tools"] = serde_json::to_value(tools)?;
}
debug!("Sending request to Claude API with tools: {:?}", request_data);
let response = self
.client
.post(&format!("{}/v1/chat/completions", self.config.anthropic.base_url))
.header("Content-Type", "application/json")
.header("Authorization", &format!("Bearer {}", self.config.anthropic.api_key))
.json(&request_data)
.send()
.await
.context("Failed to send request to Claude API")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Claude API request failed: {}", error_text);
}
let response_text = response.text().await.context("Failed to read response text")?;
debug!("Raw response: {}", response_text);
match serde_json::from_str::<OpenAIResponse>(&response_text) {
Ok(openai_response) => {
debug!("Parsed as OpenAI response: {:?}", openai_response);
let content = if let Some(choice) = openai_response.choices.first() {
let mut content_blocks = Vec::new();
if let Some(content) = &choice.message.content {
content_blocks.push(ContentBlockWithTools::Text { text: content.clone() });
}
if let Some(tool_calls) = &choice.message.tool_calls {
for tool_call in tool_calls {
content_blocks.push(ContentBlockWithTools::ToolUse {
id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
input: serde_json::from_str(&tool_call.function.arguments).unwrap_or_default(),
});
}
}
content_blocks
} else {
Vec::new()
};
Ok(ChatResponseWithTools {
id: openai_response.id,
content,
model: openai_response.model,
role: "assistant".to_string(),
stop_reason: openai_response.choices.first().and_then(|c| c.finish_reason.clone()),
stop_sequence: None,
r#type: "message".to_string(),
usage: Usage {
input_tokens: openai_response.usage.prompt_tokens,
output_tokens: openai_response.usage.completion_tokens,
},
})
}
Err(e) => {
debug!("Failed to parse as OpenAI format: {}", e);
let response_data: ChatResponseWithTools = serde_json::from_str(&response_text)
.context("Failed to parse Claude API response")?;
Ok(response_data)
}
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ToolDefinition {
pub r#type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponseWithTools {
pub id: String,
pub content: Vec<ContentBlockWithTools>,
pub model: String,
pub role: String,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub r#type: String,
pub usage: Usage,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<OpenAIChoice>,
pub usage: OpenAIUsage,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIChoice {
pub index: u32,
pub message: OpenAIMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIMessage {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Vec<OpenAIToolCall>>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIToolCall {
pub id: String,
pub r#type: String,
pub function: OpenAIFunction,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlockWithTools {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
}