llm_connector/core/
client.rs1use crate::error::LlmConnectorError;
6use reqwest::Client;
7use serde::Serialize;
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Clone)]
15pub struct HttpClient {
16 client: Client,
17 base_url: String,
18 headers: HashMap<String, String>,
19}
20
21impl HttpClient {
22 pub fn new(base_url: &str) -> Result<Self, LlmConnectorError> {
24 let client = Client::builder()
25 .timeout(Duration::from_secs(30))
26 .build()
27 .map_err(|e| LlmConnectorError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
28
29 Ok(Self {
30 client,
31 base_url: base_url.trim_end_matches('/').to_string(),
32 headers: HashMap::new(),
33 })
34 }
35
36 pub fn with_config(
38 base_url: &str,
39 timeout_secs: Option<u64>,
40 proxy: Option<&str>,
41 ) -> Result<Self, LlmConnectorError> {
42 let mut builder = Client::builder();
43
44 if let Some(timeout) = timeout_secs {
46 builder = builder.timeout(Duration::from_secs(timeout));
47 } else {
48 builder = builder.timeout(Duration::from_secs(30));
49 }
50
51 if let Some(proxy_url) = proxy {
53 let proxy = reqwest::Proxy::all(proxy_url)
54 .map_err(|e| LlmConnectorError::ConfigError(format!("Invalid proxy URL: {}", e)))?;
55 builder = builder.proxy(proxy);
56 }
57
58 let client = builder.build()
59 .map_err(|e| LlmConnectorError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
60
61 Ok(Self {
62 client,
63 base_url: base_url.trim_end_matches('/').to_string(),
64 headers: HashMap::new(),
65 })
66 }
67
68 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
70 self.headers.extend(headers);
71 self
72 }
73
74 pub fn with_header(mut self, key: String, value: String) -> Self {
76 self.headers.insert(key, value);
77 self
78 }
79
80 pub fn base_url(&self) -> &str {
82 &self.base_url
83 }
84
85 pub async fn get(&self, url: &str) -> Result<reqwest::Response, LlmConnectorError> {
87 let mut request = self.client.get(url);
88
89 for (key, value) in &self.headers {
91 request = request.header(key, value);
92 }
93
94 request.send().await
95 .map_err(|e| {
96 if e.is_timeout() {
97 LlmConnectorError::TimeoutError(format!("GET request timeout: {}", e))
98 } else if e.is_connect() {
99 LlmConnectorError::ConnectionError(format!("GET connection failed: {}", e))
100 } else {
101 LlmConnectorError::NetworkError(format!("GET request failed: {}", e))
102 }
103 })
104 }
105
106 pub async fn post<T: Serialize>(
108 &self,
109 url: &str,
110 body: &T
111 ) -> Result<reqwest::Response, LlmConnectorError> {
112 let mut request = self.client.post(url).json(body);
113
114 for (key, value) in &self.headers {
116 request = request.header(key, value);
117 }
118
119 request.send().await
120 .map_err(|e| {
121 if e.is_timeout() {
122 LlmConnectorError::TimeoutError(format!("POST request timeout: {}", e))
123 } else if e.is_connect() {
124 LlmConnectorError::ConnectionError(format!("POST connection failed: {}", e))
125 } else {
126 LlmConnectorError::NetworkError(format!("POST request failed: {}", e))
127 }
128 })
129 }
130
131 #[cfg(feature = "streaming")]
133 pub async fn stream<T: Serialize>(
134 &self,
135 url: &str,
136 body: &T,
137 ) -> Result<reqwest::Response, LlmConnectorError> {
138 let mut request = self.client.post(url).json(body);
139
140 for (key, value) in &self.headers {
142 request = request.header(key, value);
143 }
144
145 request.send().await
146 .map_err(|e| {
147 if e.is_timeout() {
148 LlmConnectorError::TimeoutError(format!("Stream request timeout: {}", e))
149 } else if e.is_connect() {
150 LlmConnectorError::ConnectionError(format!("Stream connection failed: {}", e))
151 } else {
152 LlmConnectorError::NetworkError(format!("Stream request failed: {}", e))
153 }
154 })
155 }
156
157 pub async fn post_with_custom_headers<T: Serialize>(
159 &self,
160 url: &str,
161 body: &T,
162 custom_headers: &HashMap<String, String>,
163 ) -> Result<reqwest::Response, LlmConnectorError> {
164 let mut request = self.client.post(url).json(body);
165
166 for (key, value) in custom_headers {
168 request = request.header(key, value);
169 }
170
171 for (key, value) in &self.headers {
173 request = request.header(key, value);
174 }
175
176 request.send().await
177 .map_err(|e| {
178 if e.is_timeout() {
179 LlmConnectorError::TimeoutError(format!("POST request timeout: {}", e))
180 } else if e.is_connect() {
181 LlmConnectorError::ConnectionError(format!("POST connection failed: {}", e))
182 } else {
183 LlmConnectorError::NetworkError(format!("POST request failed: {}", e))
184 }
185 })
186 }
187}
188
189impl std::fmt::Debug for HttpClient {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 f.debug_struct("HttpClient")
192 .field("base_url", &self.base_url)
193 .field("headers_count", &self.headers.len())
194 .finish()
195 }
196}