llm_connector/core/
client.rs

1//! HTTP客户端实现 - V2架构
2//!
3//! 提供统一的HTTP通信层,支持标准请求和流式请求。
4
5use crate::error::LlmConnectorError;
6use reqwest::Client;
7use serde::Serialize;
8use std::collections::HashMap;
9use std::time::Duration;
10
11/// HTTP客户端
12/// 
13/// 封装了HTTP通信的所有细节,包括认证、超时、代理等配置。
14#[derive(Clone)]
15pub struct HttpClient {
16    client: Client,
17    base_url: String,
18    headers: HashMap<String, String>,
19}
20
21impl HttpClient {
22    /// 创建新的HTTP客户端
23    pub fn new(base_url: &str) -> Result<Self, LlmConnectorError> {
24        let client = Client::builder()
25            .timeout(Duration::from_secs(30))
26            .build()
27            .map_err(|e| LlmConnectorError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
28            
29        Ok(Self {
30            client,
31            base_url: base_url.trim_end_matches('/').to_string(),
32            headers: HashMap::new(),
33        })
34    }
35    
36    /// 创建带有自定义配置的HTTP客户端
37    pub fn with_config(
38        base_url: &str,
39        timeout_secs: Option<u64>,
40        proxy: Option<&str>,
41    ) -> Result<Self, LlmConnectorError> {
42        let mut builder = Client::builder();
43        
44        // 设置超时
45        if let Some(timeout) = timeout_secs {
46            builder = builder.timeout(Duration::from_secs(timeout));
47        } else {
48            builder = builder.timeout(Duration::from_secs(30));
49        }
50        
51        // 设置代理
52        if let Some(proxy_url) = proxy {
53            let proxy = reqwest::Proxy::all(proxy_url)
54                .map_err(|e| LlmConnectorError::ConfigError(format!("Invalid proxy URL: {}", e)))?;
55            builder = builder.proxy(proxy);
56        }
57        
58        let client = builder.build()
59            .map_err(|e| LlmConnectorError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
60            
61        Ok(Self {
62            client,
63            base_url: base_url.trim_end_matches('/').to_string(),
64            headers: HashMap::new(),
65        })
66    }
67    
68    /// 添加请求头
69    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
70        self.headers.extend(headers);
71        self
72    }
73    
74    /// 添加单个请求头
75    pub fn with_header(mut self, key: String, value: String) -> Self {
76        self.headers.insert(key, value);
77        self
78    }
79    
80    /// 获取基础URL
81    pub fn base_url(&self) -> &str {
82        &self.base_url
83    }
84    
85    /// 发送GET请求
86    pub async fn get(&self, url: &str) -> Result<reqwest::Response, LlmConnectorError> {
87        let mut request = self.client.get(url);
88        
89        // 添加所有配置的请求头
90        for (key, value) in &self.headers {
91            request = request.header(key, value);
92        }
93        
94        request.send().await
95            .map_err(|e| {
96                if e.is_timeout() {
97                    LlmConnectorError::TimeoutError(format!("GET request timeout: {}", e))
98                } else if e.is_connect() {
99                    LlmConnectorError::ConnectionError(format!("GET connection failed: {}", e))
100                } else {
101                    LlmConnectorError::NetworkError(format!("GET request failed: {}", e))
102                }
103            })
104    }
105    
106    /// 发送POST请求
107    pub async fn post<T: Serialize>(
108        &self, 
109        url: &str, 
110        body: &T
111    ) -> Result<reqwest::Response, LlmConnectorError> {
112        let mut request = self.client.post(url).json(body);
113        
114        // 添加所有配置的请求头
115        for (key, value) in &self.headers {
116            request = request.header(key, value);
117        }
118        
119        request.send().await
120            .map_err(|e| {
121                if e.is_timeout() {
122                    LlmConnectorError::TimeoutError(format!("POST request timeout: {}", e))
123                } else if e.is_connect() {
124                    LlmConnectorError::ConnectionError(format!("POST connection failed: {}", e))
125                } else {
126                    LlmConnectorError::NetworkError(format!("POST request failed: {}", e))
127                }
128            })
129    }
130    
131    /// 发送流式POST请求
132    #[cfg(feature = "streaming")]
133    pub async fn stream<T: Serialize>(
134        &self,
135        url: &str,
136        body: &T,
137    ) -> Result<reqwest::Response, LlmConnectorError> {
138        let mut request = self.client.post(url).json(body);
139        
140        // 添加所有配置的请求头
141        for (key, value) in &self.headers {
142            request = request.header(key, value);
143        }
144        
145        request.send().await
146            .map_err(|e| {
147                if e.is_timeout() {
148                    LlmConnectorError::TimeoutError(format!("Stream request timeout: {}", e))
149                } else if e.is_connect() {
150                    LlmConnectorError::ConnectionError(format!("Stream connection failed: {}", e))
151                } else {
152                    LlmConnectorError::NetworkError(format!("Stream request failed: {}", e))
153                }
154            })
155    }
156    
157    /// 发送带有自定义头的POST请求
158    pub async fn post_with_custom_headers<T: Serialize>(
159        &self,
160        url: &str,
161        body: &T,
162        custom_headers: &HashMap<String, String>,
163    ) -> Result<reqwest::Response, LlmConnectorError> {
164        let mut request = self.client.post(url).json(body);
165        
166        // 先添加自定义头
167        for (key, value) in custom_headers {
168            request = request.header(key, value);
169        }
170        
171        // 再添加配置的请求头 (可能会覆盖自定义头)
172        for (key, value) in &self.headers {
173            request = request.header(key, value);
174        }
175        
176        request.send().await
177            .map_err(|e| {
178                if e.is_timeout() {
179                    LlmConnectorError::TimeoutError(format!("POST request timeout: {}", e))
180                } else if e.is_connect() {
181                    LlmConnectorError::ConnectionError(format!("POST connection failed: {}", e))
182                } else {
183                    LlmConnectorError::NetworkError(format!("POST request failed: {}", e))
184                }
185            })
186    }
187}
188
189impl std::fmt::Debug for HttpClient {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        f.debug_struct("HttpClient")
192            .field("base_url", &self.base_url)
193            .field("headers_count", &self.headers.len())
194            .finish()
195    }
196}