tiny-trae 0.1.0

An AI coding assistant with tool integration
use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use log::{debug, info};

use crate::config::Config;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    pub role: String,
    pub content: String,
}

#[derive(Debug, Serialize)]
pub struct ChatRequest {
    pub model: String,
    pub max_tokens: u32,
    pub temperature: f64,
    pub messages: Vec<Message>,
}

#[derive(Debug, Deserialize)]
pub struct ChatResponse {
    pub id: String,
    pub content: Vec<ContentBlock>,
    pub model: String,
    pub role: String,
    pub stop_reason: Option<String>,
    pub stop_sequence: Option<String>,
    pub r#type: String,
    pub usage: Usage,
}

#[derive(Debug, Deserialize)]
pub struct ContentBlock {
    pub text: String,
    pub r#type: String,
}

#[derive(Debug, Deserialize)]
pub struct Usage {
    pub input_tokens: u32,
    pub output_tokens: u32,
}

pub struct ClaudeClient {
    client: Client,
    config: Config,
}

impl ClaudeClient {
    pub fn new(config: Config) -> Result<Self> {
        let client = Client::builder()
            .timeout(std::time::Duration::from_secs(120))
            .build()
            .context("Failed to create HTTP client")?;
        
        Ok(Self { client, config })
    }

    pub async fn chat(&self, messages: Vec<Message>) -> Result<String> {
        let request = ChatRequest {
            model: self.config.model.version.clone(),
            max_tokens: self.config.model.max_tokens,
            temperature: self.config.model.temperature,
            messages,
        };

        debug!("Sending request to Claude API: {:?}", request);

        let response = self
            .client
            .post(&format!("{}/v1/chat/completions", self.config.anthropic.base_url))
            .header("Content-Type", "application/json")
            .header("Authorization", &format!("Bearer {}", self.config.anthropic.api_key))
            .json(&request)
            .send()
            .await
            .context("Failed to send request to Claude API")?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            anyhow::bail!("Claude API request failed: {}", error_text);
        }

        let response_data: ChatResponse = response
            .json()
            .await
            .context("Failed to parse Claude API response")?;

        debug!("Received response from Claude API: {:?}", response_data);

        info!(
            "Claude API usage - Input tokens: {}, Output tokens: {}",
            response_data.usage.input_tokens, response_data.usage.output_tokens
        );

        if let Some(content_block) = response_data.content.first() {
            Ok(content_block.text.clone())
        } else {
            anyhow::bail!("No content in Claude API response");
        }
    }

    pub async fn chat_with_tools(
        &self,
        messages: Vec<Message>,
        tools: Vec<ToolDefinition>,
    ) -> Result<ChatResponseWithTools> {
        let mut request_data = serde_json::json!({
            "model": self.config.model.version,
            "max_tokens": self.config.model.max_tokens,
            "temperature": self.config.model.temperature,
            "messages": messages
        });

        if !tools.is_empty() {
            request_data["tools"] = serde_json::to_value(tools)?;
        }

        debug!("Sending request to Claude API with tools: {:?}", request_data);

        let response = self
            .client
            .post(&format!("{}/v1/chat/completions", self.config.anthropic.base_url))
            .header("Content-Type", "application/json")
            .header("Authorization", &format!("Bearer {}", self.config.anthropic.api_key))
            .json(&request_data)
            .send()
            .await
            .context("Failed to send request to Claude API")?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            anyhow::bail!("Claude API request failed: {}", error_text);
        }

        // Try to parse as OpenAI format first
        let response_text = response.text().await.context("Failed to read response text")?;
        debug!("Raw response: {}", response_text);
        
        match serde_json::from_str::<OpenAIResponse>(&response_text) {
            Ok(openai_response) => {
                debug!("Parsed as OpenAI response: {:?}", openai_response);
                
                // Convert OpenAI response to ChatResponseWithTools
                let content = if let Some(choice) = openai_response.choices.first() {
                    let mut content_blocks = Vec::new();
                    
                    if let Some(content) = &choice.message.content {
                        content_blocks.push(ContentBlockWithTools::Text { text: content.clone() });
                    }
                    
                    if let Some(tool_calls) = &choice.message.tool_calls {
                        for tool_call in tool_calls {
                            content_blocks.push(ContentBlockWithTools::ToolUse {
                                id: tool_call.id.clone(),
                                name: tool_call.function.name.clone(),
                                input: serde_json::from_str(&tool_call.function.arguments).unwrap_or_default(),
                            });
                        }
                    }
                    
                    content_blocks
                } else {
                    Vec::new()
                };

                Ok(ChatResponseWithTools {
                    id: openai_response.id,
                    content,
                    model: openai_response.model,
                    role: "assistant".to_string(),
                    stop_reason: openai_response.choices.first().and_then(|c| c.finish_reason.clone()),
                    stop_sequence: None,
                    r#type: "message".to_string(),
                    usage: Usage {
                        input_tokens: openai_response.usage.prompt_tokens,
                        output_tokens: openai_response.usage.completion_tokens,
                    },
                })
            }
            Err(e) => {
                debug!("Failed to parse as OpenAI format: {}", e);
                // Fall back to original format
                let response_data: ChatResponseWithTools = serde_json::from_str(&response_text)
                    .context("Failed to parse Claude API response")?;
                Ok(response_data)
            }
        }
    }
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ToolDefinition {
    pub r#type: String,
    pub function: FunctionDefinition,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionDefinition {
    pub name: String,
    pub description: String,
    pub parameters: Value,
}

#[derive(Debug, Deserialize)]
pub struct ChatResponseWithTools {
    pub id: String,
    pub content: Vec<ContentBlockWithTools>,
    pub model: String,
    pub role: String,
    pub stop_reason: Option<String>,
    pub stop_sequence: Option<String>,
    pub r#type: String,
    pub usage: Usage,
}

// OpenAI compatible response structures
#[derive(Debug, Deserialize)]
pub struct OpenAIResponse {
    pub id: String,
    pub object: String,
    pub created: u64,
    pub model: String,
    pub choices: Vec<OpenAIChoice>,
    pub usage: OpenAIUsage,
}

#[derive(Debug, Deserialize)]
pub struct OpenAIChoice {
    pub index: u32,
    pub message: OpenAIMessage,
    pub finish_reason: Option<String>,
}

#[derive(Debug, Deserialize)]
pub struct OpenAIMessage {
    pub role: String,
    pub content: Option<String>,
    pub tool_calls: Option<Vec<OpenAIToolCall>>,
}

#[derive(Debug, Deserialize)]
pub struct OpenAIToolCall {
    pub id: String,
    pub r#type: String,
    pub function: OpenAIFunction,
}

#[derive(Debug, Deserialize)]
pub struct OpenAIFunction {
    pub name: String,
    pub arguments: String,
}

#[derive(Debug, Deserialize)]
pub struct OpenAIUsage {
    pub prompt_tokens: u32,
    pub completion_tokens: u32,
    pub total_tokens: u32,
}

#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlockWithTools {
    #[serde(rename = "text")]
    Text { text: String },
    #[serde(rename = "tool_use")]
    ToolUse {
        id: String,
        name: String,
        input: Value,
    },
}