1use std::pin::Pin;
2use std::time::Duration;
3
4use async_stream::stream;
5use futures::Stream;
6use http_body_util::{BodyExt, Full};
7use hyper::body::Bytes;
8use hyper::{Method, Request, StatusCode};
9use hyper_rustls::HttpsConnectorBuilder;
10use hyper_util::client::legacy::Client;
11use hyper_util::rt::TokioExecutor;
12
13use crate::client::error::LlmError;
14
15type HttpsClient =
16 Client<hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>, Full<Bytes>>;
17
18const MAX_RETRIES: u32 = 5;
20
21const BASE_DELAY_MS: u64 = 1000;
23
24const MAX_DELAY_MS: u64 = 60000;
26
27#[derive(Clone)]
28pub struct HttpClient {
29 client: HttpsClient,
30}
31
32fn calculate_backoff_delay(attempt: u32, response_text: &str) -> Duration {
35 if let Some(seconds) = extract_retry_after(response_text) {
38 return Duration::from_secs(seconds);
39 }
40
41 let exponential_delay = BASE_DELAY_MS * (1 << attempt);
43 let capped_delay = exponential_delay.min(MAX_DELAY_MS);
44
45 let jitter = (capped_delay as f64 * 0.25 * rand_factor()) as u64;
47 Duration::from_millis(capped_delay + jitter)
48}
49
50fn extract_retry_after(response_text: &str) -> Option<u64> {
52 let lower = response_text.to_lowercase();
54
55 if let Some(pos) = lower.find("retry after ") {
57 let after_pos = pos + "retry after ".len();
58 let remaining = &lower[after_pos..];
59 if let Some(space_pos) = remaining.find(' ') {
60 if let Ok(seconds) = remaining[..space_pos].trim().parse::<u64>() {
61 return Some(seconds);
62 }
63 }
64 }
65
66 if let Some(pos) = lower.find("\"retry_after\":") {
68 let after_pos = pos + "\"retry_after\":".len();
69 let remaining = &lower[after_pos..];
70 let trimmed = remaining.trim_start();
72 let num_str: String = trimmed.chars().take_while(|c| c.is_ascii_digit()).collect();
74 if let Ok(seconds) = num_str.parse::<u64>() {
75 return Some(seconds);
76 }
77 }
78
79 None
80}
81
82fn rand_factor() -> f64 {
85 use std::time::SystemTime;
86 let nanos = SystemTime::now()
87 .duration_since(SystemTime::UNIX_EPOCH)
88 .map(|d| d.subsec_nanos())
89 .unwrap_or(0);
90 (nanos % 1000) as f64 / 1000.0
91}
92
93impl HttpClient {
94 pub fn new() -> Result<Self, LlmError> {
95 let https = HttpsConnectorBuilder::new()
96 .with_native_roots()
97 .map_err(|e| {
98 LlmError::new(
99 "TLS_INIT_FAILED",
100 format!("failed to load native TLS roots: {}", e),
101 )
102 })?
103 .https_or_http()
104 .enable_http1()
105 .build();
106
107 let client = Client::builder(TokioExecutor::new()).build(https);
108 Ok(Self { client })
109 }
110
111 pub async fn get(&self, uri: &str) -> Result<String, LlmError> {
112 let uri: hyper::Uri = uri
113 .parse()
114 .map_err(|e| LlmError::new("HTTP_INVALID_URI", format!("{}", e)))?;
115
116 let request = Request::builder()
117 .method(Method::GET)
118 .uri(uri)
119 .body(Full::new(Bytes::new()))
120 .map_err(|e| LlmError::new("HTTP_REQUEST_BUILD", format!("{}", e)))?;
121
122 let res = self
123 .client
124 .request(request)
125 .await
126 .map_err(|e| LlmError::new("HTTP_REQUEST_FAILED", format!("{}", e)))?;
127
128 let body = res
129 .collect()
130 .await
131 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
132 .to_bytes();
133
134 String::from_utf8(body.to_vec())
135 .map_err(|e| LlmError::new("HTTP_INVALID_UTF8", format!("{}", e)))
136 }
137
138 pub async fn post(
139 &self,
140 uri: &str,
141 headers: &[(&str, &str)],
142 body: &str,
143 ) -> Result<String, LlmError> {
144 let parsed_uri: hyper::Uri = uri
145 .parse()
146 .map_err(|e| LlmError::new("HTTP_INVALID_URI", format!("{}", e)))?;
147
148 let mut last_error = None;
149
150 for attempt in 0..=MAX_RETRIES {
151 let mut builder = Request::builder()
152 .method(Method::POST)
153 .uri(parsed_uri.clone());
154
155 for (key, value) in headers {
156 builder = builder.header(*key, *value);
157 }
158
159 let request = builder
160 .body(Full::new(Bytes::from(body.to_string())))
161 .map_err(|e| LlmError::new("HTTP_REQUEST_BUILD", format!("{}", e)))?;
162
163 let res = self
164 .client
165 .request(request)
166 .await
167 .map_err(|e| LlmError::new("HTTP_REQUEST_FAILED", format!("{}", e)))?;
168
169 let status = res.status();
170
171 let response_body = res
172 .collect()
173 .await
174 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
175 .to_bytes();
176
177 let response_text = String::from_utf8(response_body.to_vec())
178 .map_err(|e| LlmError::new("HTTP_INVALID_UTF8", format!("{}", e)))?;
179
180 if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() == 529 {
182 if attempt < MAX_RETRIES {
183 let delay = calculate_backoff_delay(attempt, &response_text);
184 tracing::warn!(
185 status = %status,
186 attempt = attempt + 1,
187 max_retries = MAX_RETRIES,
188 delay_ms = delay.as_millis(),
189 "Rate limited, retrying after delay"
190 );
191 tokio::time::sleep(delay).await;
192 last_error = Some(LlmError::new(
193 format!("HTTP_{}", status.as_u16()),
194 response_text,
195 ));
196 continue;
197 }
198 }
199
200 return Ok(response_text);
202 }
203
204 Err(last_error.unwrap_or_else(|| {
206 LlmError::new("RATE_LIMIT_EXHAUSTED", "Rate limit retries exhausted")
207 }))
208 }
209
210 pub async fn post_stream(
212 &self,
213 uri: &str,
214 headers: &[(&str, &str)],
215 body: &str,
216 ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, LlmError>> + Send>>, LlmError> {
217 let parsed_uri: hyper::Uri = uri
218 .parse()
219 .map_err(|e| LlmError::new("HTTP_INVALID_URI", format!("{}", e)))?;
220
221 let mut last_error = None;
222
223 for attempt in 0..=MAX_RETRIES {
224 let mut builder = Request::builder()
225 .method(Method::POST)
226 .uri(parsed_uri.clone());
227
228 for (key, value) in headers {
229 builder = builder.header(*key, *value);
230 }
231
232 let request = builder
233 .body(Full::new(Bytes::from(body.to_string())))
234 .map_err(|e| LlmError::new("HTTP_REQUEST_BUILD", format!("{}", e)))?;
235
236 let res = self
237 .client
238 .request(request)
239 .await
240 .map_err(|e| LlmError::new("HTTP_REQUEST_FAILED", format!("{}", e)))?;
241
242 let status = res.status();
243
244 if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() == 529 {
246 let error_body = res
247 .collect()
248 .await
249 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
250 .to_bytes();
251 let error_text = String::from_utf8_lossy(&error_body).to_string();
252
253 if attempt < MAX_RETRIES {
254 let delay = calculate_backoff_delay(attempt, &error_text);
255 tracing::warn!(
256 status = %status,
257 attempt = attempt + 1,
258 max_retries = MAX_RETRIES,
259 delay_ms = delay.as_millis(),
260 "Rate limited on stream request, retrying after delay"
261 );
262 tokio::time::sleep(delay).await;
263 last_error = Some(LlmError::new(
264 format!("HTTP_{}", status.as_u16()),
265 error_text,
266 ));
267 continue;
268 }
269
270 return Err(LlmError::new(
272 format!("HTTP_{}", status.as_u16()),
273 error_text,
274 ));
275 }
276
277 if !status.is_success() {
279 let error_body = res
280 .collect()
281 .await
282 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
283 .to_bytes();
284 let error_text = String::from_utf8_lossy(&error_body);
285 return Err(LlmError::new(
286 format!("HTTP_{}", status.as_u16()),
287 error_text.to_string(),
288 ));
289 }
290
291 let response_body = res.into_body();
293 let byte_stream = stream! {
294 use http_body_util::BodyExt;
295 let mut body = response_body;
296 while let Some(frame_result) = body.frame().await {
297 match frame_result {
298 Ok(frame) => {
299 if let Some(data) = frame.data_ref() {
300 yield Ok(data.clone());
301 }
302 }
303 Err(e) => {
304 yield Err(LlmError::new("HTTP_STREAM_ERROR", format!("{}", e)));
305 break;
306 }
307 }
308 }
309 };
310
311 return Ok(Box::pin(byte_stream)
312 as Pin<Box<dyn Stream<Item = Result<Bytes, LlmError>> + Send>>);
313 }
314
315 Err(last_error.unwrap_or_else(|| {
317 LlmError::new("RATE_LIMIT_EXHAUSTED", "Rate limit retries exhausted")
318 }))
319 }
320}