llm_connector/core/
traits.rs

1//! 统一的trait定义 - V2架构核心
2//!
3//! 这个模块定义了V2架构的核心trait,提供清晰、统一的抽象层。
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::any::Any;
8
9// 重用现有类型,保持兼容性
10use crate::types::{ChatRequest, ChatResponse};
11use crate::error::LlmConnectorError;
12
13#[cfg(feature = "streaming")]
14use crate::types::ChatStream;
15
16/// 协议trait - 定义纯API规范
17/// 
18/// 这个trait代表一个LLM API的协议规范,如OpenAI API、Anthropic API等。
19/// 它只关注API的格式转换,不涉及具体的网络通信。
20#[async_trait]
21pub trait Protocol: Send + Sync + Clone + 'static {
22    /// 协议特定的请求类型
23    type Request: Serialize + Send + Sync;
24    
25    /// 协议特定的响应类型  
26    type Response: for<'de> Deserialize<'de> + Send + Sync;
27    
28    /// 协议名称 (如 "openai", "anthropic")
29    fn name(&self) -> &str;
30    
31    /// 获取聊天完成的端点URL
32    fn chat_endpoint(&self, base_url: &str) -> String;
33    
34    /// 获取模型列表的端点URL (可选)
35    fn models_endpoint(&self, _base_url: &str) -> Option<String> {
36        None
37    }
38    
39    /// 构建协议特定的请求
40    fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError>;
41    
42    /// 解析协议特定的响应
43    fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError>;
44    
45    /// 解析模型列表响应
46    fn parse_models(&self, _response: &str) -> Result<Vec<String>, LlmConnectorError> {
47        Err(LlmConnectorError::UnsupportedOperation(
48            format!("{} does not support model listing", self.name())
49        ))
50    }
51    
52    /// 映射HTTP错误到统一错误类型
53    fn map_error(&self, status: u16, body: &str) -> LlmConnectorError;
54    
55    /// 获取认证头 (可选)
56    fn auth_headers(&self) -> Vec<(String, String)> {
57        Vec::new()
58    }
59
60    /// 解析流式响应 (可选)
61    #[cfg(feature = "streaming")]
62    async fn parse_stream_response(&self, response: reqwest::Response) -> Result<ChatStream, LlmConnectorError> {
63        // 默认使用通用SSE流解析
64        Ok(crate::sse::sse_to_streaming_response(response))
65    }
66}
67
68/// 服务提供商trait - 定义统一的服务接口
69/// 
70/// 这个trait代表一个具体的LLM服务提供商,提供完整的服务功能。
71/// 它是用户直接交互的接口。
72#[async_trait]
73pub trait Provider: Send + Sync {
74    /// 提供商名称 (如 "openai", "aliyun", "ollama")
75    fn name(&self) -> &str;
76    
77    /// 聊天完成
78    async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError>;
79    
80    /// 流式聊天完成
81    #[cfg(feature = "streaming")]
82    async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError>;
83    
84    /// 获取可用模型列表
85    async fn models(&self) -> Result<Vec<String>, LlmConnectorError>;
86    
87    /// 类型转换支持 (用于特殊功能访问)
88    fn as_any(&self) -> &dyn Any;
89}
90
91/// 通用提供商实现
92/// 
93/// 这个结构体为大多数标准LLM API提供通用实现。
94/// 它使用Protocol trait来处理API特定的格式转换,
95/// 使用HttpClient来处理网络通信。
96pub struct GenericProvider<P: Protocol> {
97    protocol: P,
98    client: super::HttpClient,
99}
100
101impl<P: Protocol> GenericProvider<P> {
102    /// 创建新的通用提供商
103    pub fn new(protocol: P, client: super::HttpClient) -> Self {
104        Self { protocol, client }
105    }
106    
107    /// 获取协议引用
108    pub fn protocol(&self) -> &P {
109        &self.protocol
110    }
111    
112    /// 获取客户端引用
113    pub fn client(&self) -> &super::HttpClient {
114        &self.client
115    }
116}
117
118impl<P: Protocol> Clone for GenericProvider<P> {
119    fn clone(&self) -> Self {
120        Self {
121            protocol: self.protocol.clone(),
122            client: self.client.clone(),
123        }
124    }
125}
126
127#[async_trait]
128impl<P: Protocol> Provider for GenericProvider<P> {
129    fn name(&self) -> &str {
130        self.protocol.name()
131    }
132    
133    async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
134        // 构建协议特定的请求
135        let protocol_request = self.protocol.build_request(request)?;
136        
137        // 获取端点URL
138        let url = self.protocol.chat_endpoint(self.client.base_url());
139        
140        // 发送请求
141        let response = self.client.post(&url, &protocol_request).await?;
142        let status = response.status();
143        let text = response.text().await
144            .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
145            
146        // 检查HTTP状态
147        if !status.is_success() {
148            return Err(self.protocol.map_error(status.as_u16(), &text));
149        }
150        
151        // 解析响应
152        self.protocol.parse_response(&text)
153    }
154    
155    #[cfg(feature = "streaming")]
156    async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
157        let mut streaming_request = request.clone();
158        streaming_request.stream = Some(true);
159
160        let protocol_request = self.protocol.build_request(&streaming_request)?;
161        let url = self.protocol.chat_endpoint(self.client.base_url());
162
163        let response = self.client.stream(&url, &protocol_request).await?;
164        let status = response.status();
165
166        if !status.is_success() {
167            let text = response.text().await
168                .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
169            return Err(self.protocol.map_error(status.as_u16(), &text));
170        }
171
172        self.protocol.parse_stream_response(response).await
173    }
174    
175    async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
176        let endpoint = self.protocol.models_endpoint(self.client.base_url())
177            .ok_or_else(|| LlmConnectorError::UnsupportedOperation(
178                format!("{} does not support model listing", self.protocol.name())
179            ))?;
180            
181        let response = self.client.get(&endpoint).await?;
182        let status = response.status();
183        let text = response.text().await
184            .map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
185            
186        if !status.is_success() {
187            return Err(self.protocol.map_error(status.as_u16(), &text));
188        }
189        
190        self.protocol.parse_models(&text)
191    }
192    
193    fn as_any(&self) -> &dyn Any {
194        self
195    }
196}