1use std::pin::Pin;
2use std::time::Duration;
3
4use async_stream::stream;
5use futures::Stream;
6use http_body_util::{BodyExt, Full};
7use hyper::body::Bytes;
8use hyper::{Method, Request, StatusCode};
9use hyper_rustls::HttpsConnectorBuilder;
10use hyper_util::client::legacy::Client;
11use hyper_util::rt::TokioExecutor;
12
13use crate::client::error::LlmError;
14
15type HttpsClient =
16 Client<hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>, Full<Bytes>>;
17
18const MAX_RETRIES: u32 = 5;
20
21const BASE_DELAY_MS: u64 = 1000;
23
24const MAX_DELAY_MS: u64 = 60000;
26
27#[derive(Clone)]
29pub struct HttpClient {
30 client: HttpsClient,
31}
32
33fn calculate_backoff_delay(attempt: u32, response_text: &str) -> Duration {
36 if let Some(seconds) = extract_retry_after(response_text) {
39 return Duration::from_secs(seconds);
40 }
41
42 let exponential_delay = BASE_DELAY_MS * (1 << attempt);
44 let capped_delay = exponential_delay.min(MAX_DELAY_MS);
45
46 let jitter = (capped_delay as f64 * 0.25 * rand_factor()) as u64;
48 Duration::from_millis(capped_delay + jitter)
49}
50
51fn extract_retry_after(response_text: &str) -> Option<u64> {
53 let lower = response_text.to_lowercase();
55
56 if let Some(pos) = lower.find("retry after ") {
58 let after_pos = pos + "retry after ".len();
59 let remaining = &lower[after_pos..];
60 if let Some(space_pos) = remaining.find(' ') {
61 if let Ok(seconds) = remaining[..space_pos].trim().parse::<u64>() {
62 return Some(seconds);
63 }
64 }
65 }
66
67 if let Some(pos) = lower.find("\"retry_after\":") {
69 let after_pos = pos + "\"retry_after\":".len();
70 let remaining = &lower[after_pos..];
71 let trimmed = remaining.trim_start();
73 let num_str: String = trimmed.chars().take_while(|c| c.is_ascii_digit()).collect();
75 if let Ok(seconds) = num_str.parse::<u64>() {
76 return Some(seconds);
77 }
78 }
79
80 None
81}
82
83fn rand_factor() -> f64 {
86 use std::time::SystemTime;
87 let nanos = SystemTime::now()
88 .duration_since(SystemTime::UNIX_EPOCH)
89 .map(|d| d.subsec_nanos())
90 .unwrap_or(0);
91 (nanos % 1000) as f64 / 1000.0
92}
93
94impl HttpClient {
95 pub fn new() -> Result<Self, LlmError> {
97 let https = HttpsConnectorBuilder::new()
98 .with_native_roots()
99 .map_err(|e| {
100 LlmError::new(
101 "TLS_INIT_FAILED",
102 format!("failed to load native TLS roots: {}", e),
103 )
104 })?
105 .https_or_http()
106 .enable_http1()
107 .build();
108
109 let client = Client::builder(TokioExecutor::new()).build(https);
110 Ok(Self { client })
111 }
112
113 pub async fn get(&self, uri: &str) -> Result<String, LlmError> {
115 let uri: hyper::Uri = uri
116 .parse()
117 .map_err(|e| LlmError::new("HTTP_INVALID_URI", format!("{}", e)))?;
118
119 let request = Request::builder()
120 .method(Method::GET)
121 .uri(uri)
122 .body(Full::new(Bytes::new()))
123 .map_err(|e| LlmError::new("HTTP_REQUEST_BUILD", format!("{}", e)))?;
124
125 let res = self
126 .client
127 .request(request)
128 .await
129 .map_err(|e| LlmError::new("HTTP_REQUEST_FAILED", format!("{}", e)))?;
130
131 let body = res
132 .collect()
133 .await
134 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
135 .to_bytes();
136
137 String::from_utf8(body.to_vec())
138 .map_err(|e| LlmError::new("HTTP_INVALID_UTF8", format!("{}", e)))
139 }
140
141 pub async fn post(
145 &self,
146 uri: &str,
147 headers: &[(&str, &str)],
148 body: &str,
149 ) -> Result<String, LlmError> {
150 let parsed_uri: hyper::Uri = uri
151 .parse()
152 .map_err(|e| LlmError::new("HTTP_INVALID_URI", format!("{}", e)))?;
153
154 let mut last_error = None;
155
156 for attempt in 0..=MAX_RETRIES {
157 let mut builder = Request::builder()
158 .method(Method::POST)
159 .uri(parsed_uri.clone());
160
161 for (key, value) in headers {
162 builder = builder.header(*key, *value);
163 }
164
165 let request = builder
166 .body(Full::new(Bytes::from(body.to_string())))
167 .map_err(|e| LlmError::new("HTTP_REQUEST_BUILD", format!("{}", e)))?;
168
169 let res = self
170 .client
171 .request(request)
172 .await
173 .map_err(|e| LlmError::new("HTTP_REQUEST_FAILED", format!("{}", e)))?;
174
175 let status = res.status();
176
177 let response_body = res
178 .collect()
179 .await
180 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
181 .to_bytes();
182
183 let response_text = String::from_utf8(response_body.to_vec())
184 .map_err(|e| LlmError::new("HTTP_INVALID_UTF8", format!("{}", e)))?;
185
186 if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() == 529 {
188 if attempt < MAX_RETRIES {
189 let delay = calculate_backoff_delay(attempt, &response_text);
190 tracing::warn!(
191 status = %status,
192 attempt = attempt + 1,
193 max_retries = MAX_RETRIES,
194 delay_ms = delay.as_millis(),
195 "Rate limited, retrying after delay"
196 );
197 tokio::time::sleep(delay).await;
198 last_error = Some(LlmError::new(
199 format!("HTTP_{}", status.as_u16()),
200 response_text,
201 ));
202 continue;
203 }
204 }
205
206 return Ok(response_text);
208 }
209
210 Err(last_error.unwrap_or_else(|| {
212 LlmError::new("RATE_LIMIT_EXHAUSTED", "Rate limit retries exhausted")
213 }))
214 }
215
216 pub async fn post_stream(
218 &self,
219 uri: &str,
220 headers: &[(&str, &str)],
221 body: &str,
222 ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, LlmError>> + Send>>, LlmError> {
223 let parsed_uri: hyper::Uri = uri
224 .parse()
225 .map_err(|e| LlmError::new("HTTP_INVALID_URI", format!("{}", e)))?;
226
227 let mut last_error = None;
228
229 for attempt in 0..=MAX_RETRIES {
230 let mut builder = Request::builder()
231 .method(Method::POST)
232 .uri(parsed_uri.clone());
233
234 for (key, value) in headers {
235 builder = builder.header(*key, *value);
236 }
237
238 let request = builder
239 .body(Full::new(Bytes::from(body.to_string())))
240 .map_err(|e| LlmError::new("HTTP_REQUEST_BUILD", format!("{}", e)))?;
241
242 let res = self
243 .client
244 .request(request)
245 .await
246 .map_err(|e| LlmError::new("HTTP_REQUEST_FAILED", format!("{}", e)))?;
247
248 let status = res.status();
249
250 if status == StatusCode::TOO_MANY_REQUESTS || status.as_u16() == 529 {
252 let error_body = res
253 .collect()
254 .await
255 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
256 .to_bytes();
257 let error_text = String::from_utf8_lossy(&error_body).to_string();
258
259 if attempt < MAX_RETRIES {
260 let delay = calculate_backoff_delay(attempt, &error_text);
261 tracing::warn!(
262 status = %status,
263 attempt = attempt + 1,
264 max_retries = MAX_RETRIES,
265 delay_ms = delay.as_millis(),
266 "Rate limited on stream request, retrying after delay"
267 );
268 tokio::time::sleep(delay).await;
269 last_error = Some(LlmError::new(
270 format!("HTTP_{}", status.as_u16()),
271 error_text,
272 ));
273 continue;
274 }
275
276 return Err(LlmError::new(
278 format!("HTTP_{}", status.as_u16()),
279 error_text,
280 ));
281 }
282
283 if !status.is_success() {
285 let error_body = res
286 .collect()
287 .await
288 .map_err(|e| LlmError::new("HTTP_BODY_READ", format!("{}", e)))?
289 .to_bytes();
290 let error_text = String::from_utf8_lossy(&error_body);
291 return Err(LlmError::new(
292 format!("HTTP_{}", status.as_u16()),
293 error_text.to_string(),
294 ));
295 }
296
297 let response_body = res.into_body();
299 let byte_stream = stream! {
300 use http_body_util::BodyExt;
301 let mut body = response_body;
302 while let Some(frame_result) = body.frame().await {
303 match frame_result {
304 Ok(frame) => {
305 if let Some(data) = frame.data_ref() {
306 yield Ok(data.clone());
307 }
308 }
309 Err(e) => {
310 yield Err(LlmError::new("HTTP_STREAM_ERROR", format!("{}", e)));
311 break;
312 }
313 }
314 }
315 };
316
317 return Ok(Box::pin(byte_stream)
318 as Pin<Box<dyn Stream<Item = Result<Bytes, LlmError>> + Send>>);
319 }
320
321 Err(last_error.unwrap_or_else(|| {
323 LlmError::new("RATE_LIMIT_EXHAUSTED", "Rate limit retries exhausted")
324 }))
325 }
326}