cckit 0.1.0

Code Kit Written by rust for Claude model Switch, Support 智普LLM, MiniMax, Kimi 提供的 Claude model
use crate::models::{ModelProvider, ProviderTestResult, ModelInfo};
use anyhow::{Context, Result};
use reqwest::Client;
use serde_json::{json, Value};
use std::time::Instant;
use colored::*;
use reqwest::header::{HeaderName, HeaderValue};

/// Provider configuration structure
#[derive(Debug, Clone)]
pub struct ProviderConfig {
    pub provider_type: &'static str,
    pub display_name: &'static str,
    pub default_model: &'static str,
    pub base_url: Option<&'static str>,
    pub auth_header_name: &'static str,
    pub auth_prefix: &'static str,
    pub endpoint: &'static str,
    pub api_version_header: Option<(&'static str, &'static str)>,
    pub capabilities: &'static [&'static str],
}

/// Predefined provider configurations
pub(crate) const PROVIDER_CONFIGS: &[ProviderConfig] = &[
    ProviderConfig {
        provider_type: "zhipu",
        display_name: "智普LLM (Zhipu)",
        default_model: "GLM-4.6",
        base_url: Some("https://open.bigmodel.cn/api/anthropic"),
        auth_header_name: "Authorization",
        auth_prefix: "Bearer ",
        endpoint: "/v1/messages",
        api_version_header: None,
        capabilities: &["Chat", "Code Generation", "Chinese Support", "Multi-modal"],
    },
    ProviderConfig {
        provider_type: "minimax",
        display_name: "MiniMax",
        default_model: "MiniMax-M2",
        base_url: Some("https://api.minimaxi.com/anthropic"),
        auth_header_name: "Authorization",
        auth_prefix: "Bearer ",
        endpoint: "/v1/messages",
        api_version_header: None,
        capabilities: &["Chat", "Code Generation", "Multi-language", "Voice Synthesis"],
    },
    ProviderConfig {
        provider_type: "kimi",
        display_name: "Kimi (Moonshot)",
        default_model: "kimi-for-coding",
        base_url: Some("https://api.kimi.com/coding/"),
        auth_header_name: "Authorization",
        auth_prefix: "Bearer ",
        endpoint: "/v1/messages",
        api_version_header: None,
        capabilities: &["Chat", "Long Context", "Code Generation", "Document Processing"],
    },
    ProviderConfig {
        provider_type: "claude",
        display_name: "Claude (Official)",
        default_model: "claude-3-5-sonnet-20241022",
        base_url: Some("https://api.anthropic.com"),
        auth_header_name: "x-api-key",
        auth_prefix: "",
        endpoint: "/v1/messages",
        api_version_header: Some(("anthropic-version", "2023-06-01")),
        capabilities: &["Chat", "Code Generation", "Analysis", "Multimodal", "Reasoning"],
    },
];

/// Find provider configuration by type
pub(crate) fn find_provider_config(provider_type: &str) -> Option<&ProviderConfig> {
    PROVIDER_CONFIGS.iter().find(|config| config.provider_type == provider_type)
}

/// Generic test function for any provider
async fn test_provider_generic(provider: &ModelProvider, config: &ProviderConfig) -> Result<()> {
    let client = Client::new();
    let start_time = Instant::now();

    // Get the model name (use default if not set)
    let model_name = provider
        .model
        .as_ref()
        .map(|s| s.as_str())
        .unwrap_or(config.default_model);

    // Build the request URL
    let base_url = provider
        .base_url
        .as_ref()
        .map(|s| s.as_str())
        .unwrap_or(config.base_url.unwrap_or(""));

    let url = format!("{}{}", base_url, config.endpoint);

    // Build request body (handle different provider formats)
    let request_body = if config.provider_type == "claude" {
        json!({
            "model": model_name,
            "max_tokens": 50,
            "messages": [
                {
                    "role": "user",
                    "content": "Hello, this is a test message. Please respond with 'Connection successful'."
                }
            ]
        })
    } else {
        json!({
            "model": model_name,
            "messages": [
                {
                    "role": "user",
                    "content": "Hello, this is a test message. Please respond with 'Connection successful'."
                }
            ],
            "max_tokens": 50,
            "temperature": 0.1
        })
    };

    // Build the request
    let auth_header_name = HeaderName::from_bytes(config.auth_header_name.as_bytes())
        .expect("Invalid header name");
    let auth_header_value = HeaderValue::from_str(&format!("{}{}", config.auth_prefix, provider.api_key))
        .expect("Invalid header value");

    let mut request_builder = client
        .post(&url)
        .header(auth_header_name, auth_header_value)
        .header("Content-Type", "application/json");

    // Add API version header if needed
    if let Some((header_name, header_value)) = config.api_version_header {
        let header_name = HeaderName::from_bytes(header_name.as_bytes())
            .expect("Invalid header name");
        let header_value = HeaderValue::from_str(header_value)
            .expect("Invalid header value");
        request_builder = request_builder.header(header_name, header_value);
    }

    // Send the request
    let response = request_builder
        .json(&request_body)
        .send()
        .await
        .context(format!("Failed to send request to {}", config.display_name))?;

    let response_time = start_time.elapsed().as_millis() as u64;

    // Handle successful response
    if response.status().is_success() {
        let response_json: Value = response.json().await
            .context(format!("Failed to parse response from {}", config.display_name))?;

        // Extract content based on provider (Claude uses different JSON path)
        let content = if config.provider_type == "claude" {
            response_json["content"][0]["text"].as_str()
        } else {
            response_json["choices"][0]["message"]["content"].as_str()
        };

        if let Some(content_str) = content {
            println!("{} {}", "✓ Connection successful!".green(), format!("({}ms)", response_time).dimmed());
            println!("  Response: {}", content_str.truncate(100));

            let model_info = ModelInfo {
                model_name: model_name.to_string(),
                provider_name: config.display_name.to_string(),
                capabilities: config.capabilities.iter().copied().map(String::from).collect(),
            };

            // Store test result (currently unused but ready for future use)
            let _test_result = ProviderTestResult {
                success: true,
                message: "Connection successful".to_string(),
                model_info: Some(model_info.clone()),
                response_time_ms: response_time,
            };

            println!("  Provider info: {} ({})", model_info.provider_name, model_info.model_name);
        } else {
            println!("{}", "✗ Invalid response format".red());
        }
    } else {
        // Handle error response
        let status = response.status();
        let error_text = response.text().await.unwrap_or_default();
        println!("{} {}", "✗ Connection failed:".red(), status);
        println!("  Error: {}", error_text);
    }

    Ok(())
}

// Provider-specific test functions (now just wrappers)

pub async fn test_zhipu_provider(provider: &ModelProvider) -> Result<()> {
    let config = find_provider_config("zhipu")
        .expect("Zhipu provider config not found");
    println!("{}", "Testing 智普LLM (Zhipu) connection...".yellow());
    test_provider_generic(provider, config).await
}

pub async fn test_minimax_provider(provider: &ModelProvider) -> Result<()> {
    let config = find_provider_config("minimax")
        .expect("MiniMax provider config not found");
    println!("{}", "Testing MiniMax connection...".cyan());
    test_provider_generic(provider, config).await
}

pub async fn test_kimi_provider(provider: &ModelProvider) -> Result<()> {
    let config = find_provider_config("kimi")
        .expect("Kimi provider config not found");
    println!("{}", "Testing Kimi connection...".magenta());
    test_provider_generic(provider, config).await
}

pub async fn test_claude_provider(provider: &ModelProvider) -> Result<()> {
    let config = find_provider_config("claude")
        .expect("Claude provider config not found");
    println!("{}", "Testing Claude (Official) connection...".white());
    test_provider_generic(provider, config).await
}

#[allow(dead_code)]
pub fn get_provider_info(provider_type: &str) -> (&'static str, &'static str, &'static [&'static str]) {
    match find_provider_config(provider_type) {
        Some(config) => (
            config.display_name,
            config.base_url.unwrap_or(""),
            config.capabilities,
        ),
        None => (
            "Unknown Provider",
            "",
            &[],
        ),
    }
}

/// Get all supported provider types
#[allow(dead_code)]
pub fn get_supported_providers() -> Vec<&'static str> {
    PROVIDER_CONFIGS.iter().map(|config| config.provider_type).collect()
}

// Trait implementation for string truncation
trait StringExt {
    fn truncate(&self, limit: usize) -> &str;
}

impl StringExt for str {
    fn truncate(&self, limit: usize) -> &str {
        if self.len() <= limit {
            self
        } else {
            &self[..limit]
        }
    }
}