llm_connector/providers/
aliyun.rs

1//! 阿里云DashScope服务提供商实现 - V2架构
2//!
3//! 这个模块提供阿里云DashScope服务的完整实现,使用统一的V2架构。
4
5use crate::core::{HttpClient, Protocol};
6use crate::error::LlmConnectorError;
7use crate::types::{ChatRequest, ChatResponse, Role};
8
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use async_trait::async_trait;
12
13// ============================================================================
14// Aliyun Protocol Definition (Private)
15// ============================================================================
16
17/// 阿里云DashScope私有协议实现
18///
19/// 这是阿里云专用的API格式,与OpenAI和Anthropic都不同。
20/// 由于这是私有协议,定义在provider内部而不是公开的protocols模块中。
21#[derive(Debug, Clone)]
22pub struct AliyunProtocol {
23    api_key: String,
24}
25
26impl AliyunProtocol {
27    /// 创建新的阿里云协议实例
28    pub fn new(api_key: &str) -> Self {
29        Self {
30            api_key: api_key.to_string(),
31        }
32    }
33
34    /// 获取API密钥
35    pub fn api_key(&self) -> &str {
36        &self.api_key
37    }
38
39    /// 获取流式请求的额外头部
40    pub fn streaming_headers(&self) -> Vec<(String, String)> {
41        vec![
42            ("X-DashScope-SSE".to_string(), "enable".to_string()),
43        ]
44    }
45}
46
47#[async_trait]
48#[async_trait]
49impl Protocol for AliyunProtocol {
50    type Request = AliyunRequest;
51    type Response = AliyunResponse;
52
53    fn name(&self) -> &str {
54        "aliyun"
55    }
56
57    fn chat_endpoint(&self, base_url: &str) -> String {
58        format!("{}/api/v1/services/aigc/text-generation/generation", base_url)
59    }
60
61    fn auth_headers(&self) -> Vec<(String, String)> {
62        vec![
63            ("Authorization".to_string(), format!("Bearer {}", self.api_key)),
64            // 注意: Content-Type 由 HttpClient::post() 的 .json() 方法自动设置
65            // 不要在这里重复设置,否则会导致 "Content-Type application/json,application/json is not supported" 错误
66        ]
67    }
68
69    fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError> {
70        // 转换为阿里云格式
71        let aliyun_messages: Vec<AliyunMessage> = request.messages.iter().map(|msg| {
72            AliyunMessage {
73                role: match msg.role {
74                    Role::System => "system".to_string(),
75                    Role::User => "user".to_string(),
76                    Role::Assistant => "assistant".to_string(),
77                    Role::Tool => "tool".to_string(),
78                },
79                // Aliyun 使用纯文本格式
80                content: msg.content_as_text(),
81            }
82        }).collect();
83
84        Ok(AliyunRequest {
85            model: request.model.clone(),
86            input: AliyunInput {
87                messages: aliyun_messages,
88            },
89            parameters: AliyunParameters {
90                max_tokens: request.max_tokens,
91                temperature: request.temperature,
92                top_p: request.top_p,
93                result_format: "message".to_string(),
94                // 流式模式需要 incremental_output
95                incremental_output: if request.stream.unwrap_or(false) {
96                    Some(true)
97                } else {
98                    None
99                },
100                // 直接使用用户指定的值
101                enable_thinking: request.enable_thinking,
102            },
103        })
104    }
105
106    #[cfg(feature = "streaming")]
107    async fn parse_stream_response(&self, response: reqwest::Response) -> Result<crate::types::ChatStream, LlmConnectorError> {
108        use futures_util::StreamExt;
109        use crate::types::{StreamingResponse, StreamingChoice, Delta};
110
111        let stream = response.bytes_stream();
112        let mut lines_buffer = String::new();
113
114        let mapped_stream = stream.map(move |result| {
115            match result {
116                Ok(bytes) => {
117                    let text = String::from_utf8_lossy(&bytes);
118                    lines_buffer.push_str(&text);
119
120                    let mut responses = Vec::new();
121                    let lines: Vec<&str> = lines_buffer.lines().collect();
122
123                    for line in &lines {
124                        if line.starts_with("data:") {
125                            let json_str = line.trim_start_matches("data:").trim();
126                            if json_str.is_empty() {
127                                continue;
128                            }
129
130                            // 解析 Aliyun 响应
131                            if let Ok(aliyun_resp) = serde_json::from_str::<AliyunResponse>(json_str) {
132                                if let Some(choices) = aliyun_resp.output.choices {
133                                    if let Some(first_choice) = choices.first() {
134                                        // 转换为 StreamingResponse
135                                        let streaming_choice = StreamingChoice {
136                                            index: 0,
137                                            delta: Delta {
138                                                role: Some(Role::Assistant),
139                                                content: if first_choice.message.content.is_empty() {
140                                                    None
141                                                } else {
142                                                    Some(first_choice.message.content.clone())
143                                                },
144                                                tool_calls: None,
145                                                reasoning_content: None,
146                                                reasoning: None,
147                                                thought: None,
148                                                thinking: None,
149                                            },
150                                            finish_reason: if first_choice.finish_reason.as_deref() == Some("stop") {
151                                                Some("stop".to_string())
152                                            } else {
153                                                None
154                                            },
155                                            logprobs: None,
156                                        };
157
158                                        let content = first_choice.message.content.clone();
159
160                                        let streaming_response = StreamingResponse {
161                                            id: aliyun_resp.request_id.clone().unwrap_or_default(),
162                                            object: "chat.completion.chunk".to_string(),
163                                            created: 0,
164                                            model: aliyun_resp.model.clone().unwrap_or_else(|| "unknown".to_string()),
165                                            choices: vec![streaming_choice],
166                                            content,
167                                            reasoning_content: None,
168                                            usage: aliyun_resp.usage.as_ref().map(|u| crate::types::Usage {
169                                                prompt_tokens: u.input_tokens,
170                                                completion_tokens: u.output_tokens,
171                                                total_tokens: u.total_tokens,
172                                                prompt_cache_hit_tokens: None,
173                                                prompt_cache_miss_tokens: None,
174                                                prompt_tokens_details: None,
175                                                completion_tokens_details: None,
176                                            }),
177                                            system_fingerprint: None,
178                                        };
179
180                                        responses.push(Ok(streaming_response));
181                                    }
182                                }
183                            }
184                        }
185                    }
186
187                    // 清空已处理的行
188                    if let Some(last_line) = lines.last() {
189                        if !last_line.is_empty() && !last_line.starts_with("data:") {
190                            lines_buffer = last_line.to_string();
191                        } else {
192                            lines_buffer.clear();
193                        }
194                    }
195
196                    futures_util::stream::iter(responses)
197                }
198                Err(e) => {
199                    futures_util::stream::iter(vec![Err(crate::error::LlmConnectorError::NetworkError(e.to_string()))])
200                }
201            }
202        }).flatten();
203
204        Ok(Box::pin(mapped_stream))
205    }
206
207    fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
208        let parsed: AliyunResponse = serde_json::from_str(response)
209            .map_err(|e| LlmConnectorError::InvalidRequest(format!("Failed to parse response: {}", e)))?;
210
211        if let Some(aliyun_choices) = parsed.output.choices {
212            if let Some(first_choice) = aliyun_choices.first() {
213                // 构建 choices 数组(符合 OpenAI 标准格式)
214                let choices = vec![crate::types::Choice {
215                    index: 0,
216                    message: crate::types::Message {
217                        role: Role::Assistant,
218                        content: vec![crate::types::MessageBlock::text(&first_choice.message.content)],
219                        name: None,
220                        tool_calls: None,
221                        tool_call_id: None,
222                        reasoning_content: None,
223                        reasoning: None,
224                        thought: None,
225                        thinking: None,
226                    },
227                    finish_reason: first_choice.finish_reason.clone(),
228                    logprobs: None,
229                }];
230
231                // 从 choices[0] 提取 content 作为便利字段
232                let content = first_choice.message.content.clone();
233
234                // 提取 usage 信息
235                let usage = parsed.usage.map(|u| crate::types::Usage {
236                    prompt_tokens: u.input_tokens,
237                    completion_tokens: u.output_tokens,
238                    total_tokens: u.total_tokens,
239                    prompt_cache_hit_tokens: None,
240                    prompt_cache_miss_tokens: None,
241                    prompt_tokens_details: None,
242                    completion_tokens_details: None,
243                });
244
245                return Ok(ChatResponse {
246                    id: parsed.request_id.unwrap_or_default(),
247                    object: "chat.completion".to_string(),
248                    created: 0,  // Aliyun 不提供 created 时间戳
249                    model: parsed.model.unwrap_or_else(|| "unknown".to_string()),
250                    choices,
251                    content,
252                    reasoning_content: None,
253                    usage,
254                    system_fingerprint: None,
255                });
256            }
257        }
258
259        Err(LlmConnectorError::InvalidRequest("Empty or invalid response".to_string()))
260    }
261
262    fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
263        LlmConnectorError::from_status_code(status, format!("Aliyun API error: {}", body))
264    }
265}
266
267// 阿里云专用数据结构
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct AliyunRequest {
270    pub model: String,
271    pub input: AliyunInput,
272    pub parameters: AliyunParameters,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct AliyunInput {
277    pub messages: Vec<AliyunMessage>,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct AliyunMessage {
282    pub role: String,
283    pub content: String,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct AliyunParameters {
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub max_tokens: Option<u32>,
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub temperature: Option<f32>,
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub top_p: Option<f32>,
294    pub result_format: String,
295    #[serde(skip_serializing_if = "Option::is_none")]
296    pub incremental_output: Option<bool>,
297
298    /// Enable thinking/reasoning mode for hybrid models
299    ///
300    /// When enabled, hybrid models like qwen-plus will return reasoning content
301    #[serde(skip_serializing_if = "Option::is_none")]
302    pub enable_thinking: Option<bool>,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct AliyunResponse {
307    pub model: Option<String>,
308    pub output: AliyunOutput,
309    #[serde(skip_serializing_if = "Option::is_none")]
310    pub usage: Option<AliyunUsage>,
311    #[serde(skip_serializing_if = "Option::is_none")]
312    pub request_id: Option<String>,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct AliyunOutput {
317    pub choices: Option<Vec<AliyunChoice>>,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct AliyunChoice {
322    pub message: AliyunMessage,
323    #[serde(skip_serializing_if = "Option::is_none")]
324    pub finish_reason: Option<String>,
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct AliyunUsage {
329    pub input_tokens: u32,
330    pub output_tokens: u32,
331    pub total_tokens: u32,
332}
333
334// ============================================================================
335// Custom Aliyun Provider Implementation
336// ============================================================================
337
338/// 自定义 Aliyun Provider 实现
339///
340/// 需要特殊处理流式请求,因为 Aliyun 需要 X-DashScope-SSE 头部
341pub struct AliyunProviderImpl {
342    protocol: AliyunProtocol,
343    client: HttpClient,
344}
345
346impl AliyunProviderImpl {
347    /// 获取协议实例的引用
348    pub fn protocol(&self) -> &AliyunProtocol {
349        &self.protocol
350    }
351
352    /// 获取 HTTP 客户端的引用
353    pub fn client(&self) -> &HttpClient {
354        &self.client
355    }
356}
357
358#[async_trait]
359impl crate::core::Provider for AliyunProviderImpl {
360    fn name(&self) -> &str {
361        "aliyun"
362    }
363
364    fn as_any(&self) -> &dyn std::any::Any {
365        self
366    }
367
368    async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
369        // 使用标准实现
370        let protocol_request = self.protocol.build_request(request)?;
371        let url = self.protocol.chat_endpoint(self.client.base_url());
372
373        let response = self.client.post(&url, &protocol_request).await?;
374        let status = response.status();
375
376        if !status.is_success() {
377            let text = response.text().await
378                .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
379            return Err(self.protocol.map_error(status.as_u16(), &text));
380        }
381
382        let text = response.text().await
383            .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
384
385        self.protocol.parse_response(&text)
386    }
387
388    #[cfg(feature = "streaming")]
389    async fn chat_stream(&self, request: &ChatRequest) -> Result<crate::types::ChatStream, LlmConnectorError> {
390        let mut streaming_request = request.clone();
391        streaming_request.stream = Some(true);
392
393        let protocol_request = self.protocol.build_request(&streaming_request)?;
394        let url = self.protocol.chat_endpoint(self.client.base_url());
395
396        // 创建临时客户端,添加流式头部
397        let streaming_headers: HashMap<String, String> = self.protocol.streaming_headers().into_iter().collect();
398        let streaming_client = self.client.clone().with_headers(streaming_headers);
399
400        let response = streaming_client.stream(&url, &protocol_request).await?;
401        let status = response.status();
402
403        if !status.is_success() {
404            let text = response.text().await
405                .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
406            return Err(self.protocol.map_error(status.as_u16(), &text));
407        }
408
409        self.protocol.parse_stream_response(response).await
410    }
411
412    async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
413        Err(LlmConnectorError::UnsupportedOperation(
414            "Aliyun DashScope does not support model listing".to_string()
415        ))
416    }
417}
418
419// ============================================================================
420// Aliyun Provider Public API
421// ============================================================================
422
423/// 阿里云DashScope服务提供商类型
424pub type AliyunProvider = AliyunProviderImpl;
425
426/// 创建阿里云DashScope服务提供商
427/// 
428/// # 参数
429/// - `api_key`: 阿里云DashScope API密钥
430/// 
431/// # 返回
432/// 配置好的阿里云服务提供商实例
433/// 
434/// # 示例
435/// ```rust,no_run
436/// use llm_connector::providers::aliyun;
437/// 
438/// let provider = aliyun("sk-...").unwrap();
439/// ```
440pub fn aliyun(api_key: &str) -> Result<AliyunProvider, LlmConnectorError> {
441    aliyun_with_config(api_key, None, None, None)
442}
443
444/// 创建带有自定义配置的阿里云服务提供商
445///
446/// # 参数
447/// - `api_key`: API密钥
448/// - `base_url`: 自定义基础URL (可选,默认为官方端点)
449/// - `timeout_secs`: 超时时间(秒) (可选)
450/// - `proxy`: 代理URL (可选)
451///
452/// # 示例
453/// ```rust,no_run
454/// use llm_connector::providers::aliyun_with_config;
455///
456/// let provider = aliyun_with_config(
457///     "sk-...",
458///     None, // 使用默认URL
459///     Some(60), // 60秒超时
460///     Some("http://proxy:8080")
461/// ).unwrap();
462/// ```
463pub fn aliyun_with_config(
464    api_key: &str,
465    base_url: Option<&str>,
466    timeout_secs: Option<u64>,
467    proxy: Option<&str>,
468) -> Result<AliyunProvider, LlmConnectorError> {
469    // 创建协议实例
470    let protocol = AliyunProtocol::new(api_key);
471
472    // 创建HTTP客户端(不包含流式头部)
473    let client = HttpClient::with_config(
474        base_url.unwrap_or("https://dashscope.aliyuncs.com"),
475        timeout_secs,
476        proxy,
477    )?;
478
479    // 添加认证头
480    let auth_headers: HashMap<String, String> = protocol.auth_headers().into_iter().collect();
481    let client = client.with_headers(auth_headers);
482
483    // 创建自定义 Aliyun Provider(需要特殊处理流式请求)
484    Ok(AliyunProviderImpl {
485        protocol,
486        client,
487    })
488}
489
490/// 创建用于阿里云国际版的服务提供商
491/// 
492/// # 参数
493/// - `api_key`: 阿里云国际版API密钥
494/// - `region`: 区域 (如 "us-west-1", "ap-southeast-1")
495/// 
496/// # 示例
497/// ```rust,no_run
498/// use llm_connector::providers::aliyun_international;
499/// 
500/// let provider = aliyun_international("sk-...", "us-west-1").unwrap();
501/// ```
502pub fn aliyun_international(
503    api_key: &str,
504    region: &str,
505) -> Result<AliyunProvider, LlmConnectorError> {
506    let base_url = format!("https://dashscope.{}.aliyuncs.com", region);
507    aliyun_with_config(api_key, Some(&base_url), None, None)
508}
509
510/// 创建用于阿里云专有云的服务提供商
511/// 
512/// # 参数
513/// - `api_key`: API密钥
514/// - `endpoint`: 专有云端点URL
515/// 
516/// # 示例
517/// ```rust,no_run
518/// use llm_connector::providers::aliyun_private;
519/// 
520/// let provider = aliyun_private(
521///     "sk-...",
522///     "https://dashscope.your-private-cloud.com"
523/// ).unwrap();
524/// ```
525pub fn aliyun_private(
526    api_key: &str,
527    endpoint: &str,
528) -> Result<AliyunProvider, LlmConnectorError> {
529    aliyun_with_config(api_key, Some(endpoint), None, None)
530}
531
532/// 创建带有自定义超时的阿里云服务提供商
533/// 
534/// 阿里云的某些模型可能需要较长的处理时间,这个函数提供便利的超时配置。
535/// 
536/// # 参数
537/// - `api_key`: API密钥
538/// - `timeout_secs`: 超时时间(秒)
539/// 
540/// # 示例
541/// ```rust,no_run
542/// use llm_connector::providers::aliyun_with_timeout;
543/// 
544/// // 设置120秒超时,适用于长文本处理
545/// let provider = aliyun_with_timeout("sk-...", 120).unwrap();
546/// ```
547pub fn aliyun_with_timeout(
548    api_key: &str,
549    timeout_secs: u64,
550) -> Result<AliyunProvider, LlmConnectorError> {
551    aliyun_with_config(api_key, None, Some(timeout_secs), None)
552}
553
554/// 验证Aliyun API密钥格式
555pub fn validate_aliyun_key(api_key: &str) -> bool {
556    api_key.starts_with("sk-") && api_key.len() > 20
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_aliyun_provider_creation() {
565        let provider = aliyun("test-key");
566        assert!(provider.is_ok());
567
568        let provider = provider.unwrap();
569        assert_eq!(provider.protocol().name(), "aliyun");
570    }
571
572    #[test]
573    fn test_aliyun_with_config() {
574        let provider = aliyun_with_config(
575            "test-key",
576            Some("https://custom.dashscope.com"),
577            Some(60),
578            None
579        );
580        assert!(provider.is_ok());
581
582        let provider = provider.unwrap();
583        assert_eq!(provider.client().base_url(), "https://custom.dashscope.com");
584    }
585
586    #[test]
587    fn test_aliyun_international() {
588        let provider = aliyun_international("test-key", "us-west-1");
589        assert!(provider.is_ok());
590
591        let provider = provider.unwrap();
592        assert_eq!(provider.client().base_url(), "https://dashscope.us-west-1.aliyuncs.com");
593    }
594
595    #[test]
596    fn test_aliyun_private() {
597        let provider = aliyun_private("test-key", "https://private.dashscope.com");
598        assert!(provider.is_ok());
599
600        let provider = provider.unwrap();
601        assert_eq!(provider.client().base_url(), "https://private.dashscope.com");
602    }
603
604    #[test]
605    fn test_aliyun_with_timeout() {
606        let provider = aliyun_with_timeout("test-key", 120);
607        assert!(provider.is_ok());
608    }
609
610    #[test]
611    fn test_enable_thinking_explicit_control() {
612        use crate::types::{ChatRequest, Message, Role};
613
614        let protocol = AliyunProtocol::new("test-key");
615
616        // 测试显式启用
617        let request = ChatRequest {
618            model: "qwen-plus".to_string(),
619            messages: vec![Message::text(Role::User, "test")],
620            enable_thinking: Some(true),  // 显式启用
621            ..Default::default()
622        };
623
624        let aliyun_request = protocol.build_request(&request).unwrap();
625        assert_eq!(aliyun_request.parameters.enable_thinking, Some(true));
626
627        // 测试显式禁用
628        let request = ChatRequest {
629            model: "qwen-plus".to_string(),
630            messages: vec![Message::text(Role::User, "test")],
631            enable_thinking: Some(false),  // 显式禁用
632            ..Default::default()
633        };
634
635        let aliyun_request = protocol.build_request(&request).unwrap();
636        assert_eq!(aliyun_request.parameters.enable_thinking, Some(false));
637
638        // 测试未指定(默认不启用)
639        let request = ChatRequest {
640            model: "qwen-plus".to_string(),
641            messages: vec![Message::text(Role::User, "test")],
642            // enable_thinking 未指定
643            ..Default::default()
644        };
645
646        let aliyun_request = protocol.build_request(&request).unwrap();
647        assert_eq!(aliyun_request.parameters.enable_thinking, None);
648    }
649}