use crate::models::{ModelProvider, ProviderTestResult, ModelInfo};
use anyhow::{Context, Result};
use reqwest::Client;
use serde_json::{json, Value};
use std::time::Instant;
use colored::*;
use reqwest::header::{HeaderName, HeaderValue};
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub provider_type: &'static str,
pub display_name: &'static str,
pub default_model: &'static str,
pub base_url: Option<&'static str>,
pub auth_header_name: &'static str,
pub auth_prefix: &'static str,
pub endpoint: &'static str,
pub api_version_header: Option<(&'static str, &'static str)>,
pub capabilities: &'static [&'static str],
}
pub(crate) const PROVIDER_CONFIGS: &[ProviderConfig] = &[
ProviderConfig {
provider_type: "zhipu",
display_name: "智普LLM (Zhipu)",
default_model: "GLM-4.6",
base_url: Some("https://open.bigmodel.cn/api/anthropic"),
auth_header_name: "Authorization",
auth_prefix: "Bearer ",
endpoint: "/v1/messages",
api_version_header: None,
capabilities: &["Chat", "Code Generation", "Chinese Support", "Multi-modal"],
},
ProviderConfig {
provider_type: "minimax",
display_name: "MiniMax",
default_model: "MiniMax-M2",
base_url: Some("https://api.minimaxi.com/anthropic"),
auth_header_name: "Authorization",
auth_prefix: "Bearer ",
endpoint: "/v1/messages",
api_version_header: None,
capabilities: &["Chat", "Code Generation", "Multi-language", "Voice Synthesis"],
},
ProviderConfig {
provider_type: "kimi",
display_name: "Kimi (Moonshot)",
default_model: "kimi-for-coding",
base_url: Some("https://api.kimi.com/coding/"),
auth_header_name: "Authorization",
auth_prefix: "Bearer ",
endpoint: "/v1/messages",
api_version_header: None,
capabilities: &["Chat", "Long Context", "Code Generation", "Document Processing"],
},
ProviderConfig {
provider_type: "claude",
display_name: "Claude (Official)",
default_model: "claude-3-5-sonnet-20241022",
base_url: Some("https://api.anthropic.com"),
auth_header_name: "x-api-key",
auth_prefix: "",
endpoint: "/v1/messages",
api_version_header: Some(("anthropic-version", "2023-06-01")),
capabilities: &["Chat", "Code Generation", "Analysis", "Multimodal", "Reasoning"],
},
];
pub(crate) fn find_provider_config(provider_type: &str) -> Option<&ProviderConfig> {
PROVIDER_CONFIGS.iter().find(|config| config.provider_type == provider_type)
}
async fn test_provider_generic(provider: &ModelProvider, config: &ProviderConfig) -> Result<()> {
let client = Client::new();
let start_time = Instant::now();
let model_name = provider
.model
.as_ref()
.map(|s| s.as_str())
.unwrap_or(config.default_model);
let base_url = provider
.base_url
.as_ref()
.map(|s| s.as_str())
.unwrap_or(config.base_url.unwrap_or(""));
let url = format!("{}{}", base_url, config.endpoint);
let request_body = if config.provider_type == "claude" {
json!({
"model": model_name,
"max_tokens": 50,
"messages": [
{
"role": "user",
"content": "Hello, this is a test message. Please respond with 'Connection successful'."
}
]
})
} else {
json!({
"model": model_name,
"messages": [
{
"role": "user",
"content": "Hello, this is a test message. Please respond with 'Connection successful'."
}
],
"max_tokens": 50,
"temperature": 0.1
})
};
let auth_header_name = HeaderName::from_bytes(config.auth_header_name.as_bytes())
.expect("Invalid header name");
let auth_header_value = HeaderValue::from_str(&format!("{}{}", config.auth_prefix, provider.api_key))
.expect("Invalid header value");
let mut request_builder = client
.post(&url)
.header(auth_header_name, auth_header_value)
.header("Content-Type", "application/json");
if let Some((header_name, header_value)) = config.api_version_header {
let header_name = HeaderName::from_bytes(header_name.as_bytes())
.expect("Invalid header name");
let header_value = HeaderValue::from_str(header_value)
.expect("Invalid header value");
request_builder = request_builder.header(header_name, header_value);
}
let response = request_builder
.json(&request_body)
.send()
.await
.context(format!("Failed to send request to {}", config.display_name))?;
let response_time = start_time.elapsed().as_millis() as u64;
if response.status().is_success() {
let response_json: Value = response.json().await
.context(format!("Failed to parse response from {}", config.display_name))?;
let content = if config.provider_type == "claude" {
response_json["content"][0]["text"].as_str()
} else {
response_json["choices"][0]["message"]["content"].as_str()
};
if let Some(content_str) = content {
println!("{} {}", "✓ Connection successful!".green(), format!("({}ms)", response_time).dimmed());
println!(" Response: {}", content_str.truncate(100));
let model_info = ModelInfo {
model_name: model_name.to_string(),
provider_name: config.display_name.to_string(),
capabilities: config.capabilities.iter().copied().map(String::from).collect(),
};
let _test_result = ProviderTestResult {
success: true,
message: "Connection successful".to_string(),
model_info: Some(model_info.clone()),
response_time_ms: response_time,
};
println!(" Provider info: {} ({})", model_info.provider_name, model_info.model_name);
} else {
println!("{}", "✗ Invalid response format".red());
}
} else {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
println!("{} {}", "✗ Connection failed:".red(), status);
println!(" Error: {}", error_text);
}
Ok(())
}
pub async fn test_zhipu_provider(provider: &ModelProvider) -> Result<()> {
let config = find_provider_config("zhipu")
.expect("Zhipu provider config not found");
println!("{}", "Testing 智普LLM (Zhipu) connection...".yellow());
test_provider_generic(provider, config).await
}
pub async fn test_minimax_provider(provider: &ModelProvider) -> Result<()> {
let config = find_provider_config("minimax")
.expect("MiniMax provider config not found");
println!("{}", "Testing MiniMax connection...".cyan());
test_provider_generic(provider, config).await
}
pub async fn test_kimi_provider(provider: &ModelProvider) -> Result<()> {
let config = find_provider_config("kimi")
.expect("Kimi provider config not found");
println!("{}", "Testing Kimi connection...".magenta());
test_provider_generic(provider, config).await
}
pub async fn test_claude_provider(provider: &ModelProvider) -> Result<()> {
let config = find_provider_config("claude")
.expect("Claude provider config not found");
println!("{}", "Testing Claude (Official) connection...".white());
test_provider_generic(provider, config).await
}
#[allow(dead_code)]
pub fn get_provider_info(provider_type: &str) -> (&'static str, &'static str, &'static [&'static str]) {
match find_provider_config(provider_type) {
Some(config) => (
config.display_name,
config.base_url.unwrap_or(""),
config.capabilities,
),
None => (
"Unknown Provider",
"",
&[],
),
}
}
#[allow(dead_code)]
pub fn get_supported_providers() -> Vec<&'static str> {
PROVIDER_CONFIGS.iter().map(|config| config.provider_type).collect()
}
trait StringExt {
fn truncate(&self, limit: usize) -> &str;
}
impl StringExt for str {
fn truncate(&self, limit: usize) -> &str {
if self.len() <= limit {
self
} else {
&self[..limit]
}
}
}