use crate::core::{GenericProvider, HttpClient, Protocol};
use crate::error::LlmConnectorError;
use crate::types::{ChatRequest, ChatResponse, Role};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct ZhipuProtocol {
api_key: String,
use_openai_format: bool,
}
impl ZhipuProtocol {
pub fn new(api_key: &str) -> Self {
Self {
api_key: api_key.to_string(),
use_openai_format: false,
}
}
pub fn new_openai_compatible(api_key: &str) -> Self {
Self {
api_key: api_key.to_string(),
use_openai_format: true,
}
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn is_openai_compatible(&self) -> bool {
self.use_openai_format
}
}
impl Protocol for ZhipuProtocol {
type Request = ZhipuRequest;
type Response = ZhipuResponse;
fn name(&self) -> &str {
"zhipu"
}
fn chat_endpoint(&self, base_url: &str) -> String {
format!("{}/api/paas/v4/chat/completions", base_url)
}
fn auth_headers(&self) -> Vec<(String, String)> {
vec![
("Authorization".to_string(), format!("Bearer {}", self.api_key)),
("Content-Type".to_string(), "application/json".to_string()),
]
}
fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError> {
let messages: Vec<ZhipuMessage> = request.messages.iter().map(|msg| {
ZhipuMessage {
role: match msg.role {
Role::System => "system".to_string(),
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::Tool => "tool".to_string(),
},
content: msg.content.clone(),
}
}).collect();
Ok(ZhipuRequest {
model: request.model.clone(),
messages,
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
})
}
fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
let parsed: ZhipuResponse = serde_json::from_str(response)
.map_err(|e| LlmConnectorError::InvalidRequest(format!("Failed to parse response: {}", e)))?;
if let Some(choices) = parsed.choices {
if let Some(first_choice) = choices.first() {
return Ok(ChatResponse {
content: first_choice.message.content.clone(),
model: parsed.model.unwrap_or_else(|| "unknown".to_string()),
..Default::default()
});
}
}
Err(LlmConnectorError::InvalidRequest("Empty or invalid response".to_string()))
}
fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
LlmConnectorError::from_status_code(status, format!("Zhipu API error: {}", body))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuRequest {
pub model: String,
pub messages: Vec<ZhipuMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuResponse {
pub model: Option<String>,
pub choices: Option<Vec<ZhipuChoice>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuChoice {
pub message: ZhipuMessage,
}
pub type ZhipuProvider = GenericProvider<ZhipuProtocol>;
pub fn zhipu(api_key: &str) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, false, None, None, None)
}
pub fn zhipu_openai_compatible(api_key: &str) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, true, None, None, None)
}
pub fn zhipu_with_config(
api_key: &str,
openai_compatible: bool,
base_url: Option<&str>,
timeout_secs: Option<u64>,
proxy: Option<&str>,
) -> Result<ZhipuProvider, LlmConnectorError> {
let protocol = if openai_compatible {
ZhipuProtocol::new_openai_compatible(api_key)
} else {
ZhipuProtocol::new(api_key)
};
let client = HttpClient::with_config(
base_url.unwrap_or("https://open.bigmodel.cn"),
timeout_secs,
proxy,
)?;
let auth_headers: HashMap<String, String> = protocol.auth_headers().into_iter().collect();
let client = client.with_headers(auth_headers);
Ok(GenericProvider::new(protocol, client))
}
pub fn zhipu_default(api_key: &str) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_openai_compatible(api_key)
}
pub fn zhipu_with_timeout(
api_key: &str,
timeout_secs: u64,
) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, true, None, Some(timeout_secs), None)
}
pub fn zhipu_enterprise(
api_key: &str,
enterprise_endpoint: &str,
) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, true, Some(enterprise_endpoint), None, None)
}
pub fn validate_zhipu_key(api_key: &str) -> bool {
!api_key.is_empty() && api_key.len() > 10
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zhipu_provider_creation() {
let provider = zhipu("test-key");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.protocol().name(), "zhipu");
}
#[test]
fn test_zhipu_openai_compatible() {
let provider = zhipu_openai_compatible("test-key");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.protocol().name(), "zhipu");
assert!(provider.protocol().is_openai_compatible());
}
#[test]
fn test_zhipu_with_config() {
let provider = zhipu_with_config(
"test-key",
true,
Some("https://custom.bigmodel.cn"),
Some(60),
None
);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.client().base_url(), "https://custom.bigmodel.cn");
assert!(provider.protocol().is_openai_compatible());
}
#[test]
fn test_zhipu_default() {
let provider = zhipu_default("test-key");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert!(provider.protocol().is_openai_compatible());
}
#[test]
fn test_zhipu_with_timeout() {
let provider = zhipu_with_timeout("test-key", 120);
assert!(provider.is_ok());
}
#[test]
fn test_zhipu_enterprise() {
let provider = zhipu_enterprise("test-key", "https://enterprise.bigmodel.cn");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.client().base_url(), "https://enterprise.bigmodel.cn");
}
#[test]
fn test_validate_zhipu_key() {
assert!(validate_zhipu_key("valid-test-key"));
assert!(validate_zhipu_key("another-valid-key-12345"));
assert!(!validate_zhipu_key("short"));
assert!(!validate_zhipu_key(""));
}
}