1use super::error::TransportError;
2use async_trait::async_trait;
3use backoff::{future::retry, ExponentialBackoff};
4use reqwest::{Client, Method, Proxy, Response};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::env;
8use std::time::Duration;
9
10#[async_trait]
17pub trait HttpClient: Send + Sync {
18 async fn request<T, R>(
20 &self,
21 method: Method,
22 url: &str,
23 headers: Option<HashMap<String, String>>,
24 body: Option<&T>,
25 ) -> Result<R, TransportError>
26 where
27 T: Serialize + Send + Sync,
28 R: for<'de> Deserialize<'de>;
29
30 async fn request_with_retry<T, R>(
32 &self,
33 method: Method,
34 url: &str,
35 headers: Option<HashMap<String, String>>,
36 body: Option<&T>,
37 _max_retries: u32,
38 ) -> Result<R, TransportError>
39 where
40 T: Serialize + Send + Sync + Clone,
41 R: for<'de> Deserialize<'de>;
42
43 async fn get<R>(
45 &self,
46 url: &str,
47 headers: Option<HashMap<String, String>>,
48 ) -> Result<R, TransportError>
49 where
50 R: for<'de> Deserialize<'de>,
51 {
52 self.request(Method::GET, url, headers, None::<&()>).await
53 }
54
55 async fn post<T, R>(
57 &self,
58 url: &str,
59 headers: Option<HashMap<String, String>>,
60 body: &T,
61 ) -> Result<R, TransportError>
62 where
63 T: Serialize + Send + Sync,
64 R: for<'de> Deserialize<'de>,
65 {
66 self.request(Method::POST, url, headers, Some(body)).await
67 }
68
69 async fn put<T, R>(
71 &self,
72 url: &str,
73 headers: Option<HashMap<String, String>>,
74 body: &T,
75 ) -> Result<R, TransportError>
76 where
77 T: Serialize + Send + Sync,
78 R: for<'de> Deserialize<'de>,
79 {
80 self.request(Method::PUT, url, headers, Some(body)).await
81 }
82}
83
84pub struct HttpTransport {
90 client: Client,
91 timeout: Duration,
92}
93
94pub struct HttpTransportConfig {
96 pub timeout: Duration,
97 pub proxy: Option<String>,
98 pub pool_max_idle_per_host: Option<usize>,
100 pub pool_idle_timeout: Option<Duration>,
102}
103
104impl HttpTransport {
105 pub fn new() -> Self {
112 let timeout_secs = env::var("AI_HTTP_TIMEOUT_SECS")
114 .ok()
115 .and_then(|s| s.parse::<u64>().ok())
116 .or_else(|| env::var("AI_TIMEOUT_SECS").ok().and_then(|s| s.parse::<u64>().ok()))
117 .unwrap_or(30);
118 Self::with_timeout(Duration::from_secs(timeout_secs))
119 }
120
121 pub fn new_without_proxy() -> Self {
126 let timeout_secs = env::var("AI_HTTP_TIMEOUT_SECS")
127 .ok()
128 .and_then(|s| s.parse::<u64>().ok())
129 .or_else(|| env::var("AI_TIMEOUT_SECS").ok().and_then(|s| s.parse::<u64>().ok()))
130 .unwrap_or(30);
131 Self::with_timeout_without_proxy(Duration::from_secs(timeout_secs))
132 }
133
134 pub fn with_timeout(timeout: Duration) -> Self {
138 let mut client_builder = Client::builder().timeout(timeout);
139
140 if let Ok(v) = env::var("AI_HTTP_POOL_MAX_IDLE_PER_HOST") {
142 if let Ok(n) = v.parse::<usize>() {
143 client_builder = client_builder.pool_max_idle_per_host(n);
144 }
145 }
146 if let Ok(v) = env::var("AI_HTTP_POOL_IDLE_TIMEOUT_MS") {
147 if let Ok(ms) = v.parse::<u64>() {
148 client_builder = client_builder.pool_idle_timeout(Duration::from_millis(ms));
149 }
150 }
151
152 if let Ok(proxy_url) = env::var("AI_PROXY_URL") {
154 match Proxy::all(&proxy_url) {
155 Ok(proxy) => {
156 client_builder = client_builder.proxy(proxy);
157 }
158 Err(_) => {
159 }
161 }
162 }
163
164 let client = client_builder
165 .build()
166 .expect("Failed to create HTTP client");
167
168 Self { client, timeout }
169 }
170
171 pub fn with_timeout_without_proxy(timeout: Duration) -> Self {
175 let mut client_builder = Client::builder().timeout(timeout);
176
177 if let Ok(v) = env::var("AI_HTTP_POOL_MAX_IDLE_PER_HOST") {
179 if let Ok(n) = v.parse::<usize>() {
180 client_builder = client_builder.pool_max_idle_per_host(n);
181 }
182 }
183 if let Ok(v) = env::var("AI_HTTP_POOL_IDLE_TIMEOUT_MS") {
184 if let Ok(ms) = v.parse::<u64>() {
185 client_builder = client_builder.pool_idle_timeout(Duration::from_millis(ms));
186 }
187 }
188
189 let client = client_builder
190 .build()
191 .expect("Failed to create HTTP client");
192
193 Self { client, timeout }
194 }
195
196 pub fn with_client(client: Client, timeout: Duration) -> Self {
198 Self { client, timeout }
199 }
200
201 pub fn with_reqwest_client(client: Client, timeout: Duration) -> Self {
206 Self::with_client(client, timeout)
207 }
208
209 pub fn new_with_config(config: HttpTransportConfig) -> Result<Self, TransportError> {
211 let mut client_builder = Client::builder().timeout(config.timeout);
212
213 if let Some(max_idle) = config.pool_max_idle_per_host {
215 client_builder = client_builder.pool_max_idle_per_host(max_idle);
216 }
217 if let Some(idle_timeout) = config.pool_idle_timeout {
218 client_builder = client_builder.pool_idle_timeout(idle_timeout);
219 }
220
221 if let Some(proxy_url) = config.proxy {
222 if let Ok(proxy) = Proxy::all(&proxy_url) {
223 client_builder = client_builder.proxy(proxy);
224 }
225 }
226
227 let client = client_builder
228 .build()
229 .map_err(|e| TransportError::HttpError(e.to_string()))?;
230 Ok(Self {
231 client,
232 timeout: config.timeout,
233 })
234 }
235
236 pub fn with_proxy(timeout: Duration, proxy_url: Option<&str>) -> Result<Self, TransportError> {
238 let mut client_builder = Client::builder().timeout(timeout);
239
240 if let Some(url) = proxy_url {
241 let proxy = Proxy::all(url)
242 .map_err(|e| TransportError::InvalidUrl(format!("Invalid proxy URL: {}", e)))?;
243 client_builder = client_builder.proxy(proxy);
244 }
245
246 let client = client_builder
247 .build()
248 .map_err(|e| TransportError::HttpError(e.to_string()))?;
249
250 Ok(Self { client, timeout })
251 }
252
253 pub fn timeout(&self) -> Duration {
255 self.timeout
256 }
257
258 async fn execute_request<T, R>(
260 &self,
261 method: Method,
262 url: &str,
263 headers: Option<HashMap<String, String>>,
264 body: Option<&T>,
265 ) -> Result<R, TransportError>
266 where
267 T: Serialize + Send + Sync,
268 R: for<'de> Deserialize<'de>,
269 {
270 let mut request_builder = self.client.request(method, url);
271
272 if let Some(headers) = headers {
274 for (key, value) in headers {
275 request_builder = request_builder.header(key, value);
276 }
277 }
278
279 if let Some(body) = body {
281 request_builder = request_builder.json(body);
282 }
283
284 let response = request_builder.send().await?;
286
287 Self::handle_response(response).await
289 }
290
291 fn is_retryable_error(&self, error: &TransportError) -> bool {
293 match error {
294 TransportError::HttpError(err_msg) => {
295 err_msg.contains("timeout") || err_msg.contains("connection")
296 }
297 TransportError::ClientError { status, .. } => {
298 *status == 429 || *status == 502 || *status == 503 || *status == 504
299 }
300 TransportError::ServerError { .. } => true,
301 _ => false,
302 }
303 }
304
305 async fn handle_response<R>(response: Response) -> Result<R, TransportError>
307 where
308 R: for<'de> Deserialize<'de>,
309 {
310 let status = response.status();
311
312 if status.is_success() {
313 let json_text = response.text().await?;
314 let result: R = serde_json::from_str(&json_text)?;
315 Ok(result)
316 } else {
317 let error_text = response.text().await.unwrap_or_default();
318 Err(TransportError::from_status(status.as_u16(), error_text))
319 }
320 }
321}
322
323#[async_trait]
324impl HttpClient for HttpTransport {
325 async fn request<T, R>(
326 &self,
327 method: Method,
328 url: &str,
329 headers: Option<HashMap<String, String>>,
330 body: Option<&T>,
331 ) -> Result<R, TransportError>
332 where
333 T: Serialize + Send + Sync,
334 R: for<'de> Deserialize<'de>,
335 {
336 self.execute_request(method, url, headers, body).await
337 }
338
339 async fn request_with_retry<T, R>(
340 &self,
341 method: Method,
342 url: &str,
343 headers: Option<HashMap<String, String>>,
344 body: Option<&T>,
345 _max_retries: u32,
346 ) -> Result<R, TransportError>
347 where
348 T: Serialize + Send + Sync + Clone,
349 R: for<'de> Deserialize<'de>,
350 {
351 let backoff = ExponentialBackoff {
352 max_elapsed_time: Some(Duration::from_secs(60)),
353 max_interval: Duration::from_secs(10),
354 ..Default::default()
355 };
356
357 let headers_clone = headers.clone();
358 let body_clone = body.cloned();
359 let url_clone = url.to_string();
360
361 retry(backoff, || async {
362 match self
363 .execute_request(
364 method.clone(),
365 &url_clone,
366 headers_clone.clone(),
367 body_clone.as_ref(),
368 )
369 .await
370 {
371 Ok(result) => Ok(result),
372 Err(e) => {
373 if self.is_retryable_error(&e) {
374 Err(backoff::Error::transient(e))
375 } else {
376 Err(backoff::Error::permanent(e))
377 }
378 }
379 }
380 })
381 .await
382 }
383}
384
385impl Default for HttpTransport {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391pub struct HttpTransportBoxed {
393 inner: HttpTransport,
394}
395
396impl HttpTransportBoxed {
397 pub fn new(inner: HttpTransport) -> Self {
398 Self { inner }
399 }
400}
401
402use crate::transport::dyn_transport::{DynHttpTransport, DynHttpTransportRef};
403use bytes::Bytes;
404use futures::{Stream, StreamExt};
405use std::pin::Pin;
406use std::sync::Arc;
407
408impl DynHttpTransport for HttpTransportBoxed {
409 fn get_json<'a>(
410 &'a self,
411 url: &'a str,
412 headers: Option<HashMap<String, String>>,
413 ) -> futures::future::BoxFuture<'a, Result<serde_json::Value, crate::types::AiLibError>> {
414 Box::pin(async move {
415 let res: Result<serde_json::Value, TransportError> = self.inner.get(url, headers).await;
416 match res {
417 Ok(v) => Ok(v),
418 Err(e) => Err(map_transport_error_to_ailib(e)),
419 }
420 })
421 }
422
423 fn post_json<'a>(
424 &'a self,
425 url: &'a str,
426 headers: Option<HashMap<String, String>>,
427 body: serde_json::Value,
428 ) -> futures::future::BoxFuture<'a, Result<serde_json::Value, crate::types::AiLibError>> {
429 Box::pin(async move {
430 let res: Result<serde_json::Value, TransportError> =
431 self.inner.post(url, headers, &body).await;
432 match res {
433 Ok(v) => Ok(v),
434 Err(e) => Err(map_transport_error_to_ailib(e)),
435 }
436 })
437 }
438
439 fn post_stream<'a>(
440 &'a self,
441 _url: &'a str,
442 _headers: Option<HashMap<String, String>>,
443 _body: serde_json::Value,
444 ) -> futures::future::BoxFuture<
445 'a,
446 Result<
447 Pin<Box<dyn Stream<Item = Result<Bytes, crate::types::AiLibError>> + Send>>,
448 crate::types::AiLibError,
449 >,
450 > {
451 Box::pin(async move {
452 let mut req = self.inner.client.post(_url).json(&_body);
454 if let Some(h) = _headers {
456 for (k, v) in h.into_iter() {
457 req = req.header(k, v);
458 }
459 }
460 req = req.header("Accept", "text/event-stream");
462
463 let resp = req.send().await.map_err(|e| {
464 if e.is_timeout() {
465 crate::types::AiLibError::TimeoutError(format!("Stream request timeout: {}", e))
466 } else {
467 crate::types::AiLibError::NetworkError(format!("Stream request failed: {}", e))
468 }
469 })?;
470 if !resp.status().is_success() {
471 let status = resp.status();
472 let text = resp.text().await.unwrap_or_default();
473 return Err(map_status_to_ailib(status.as_u16(), text));
474 }
475
476 let byte_stream = resp.bytes_stream().map(|res| match res {
477 Ok(b) => Ok(b),
478 Err(e) => {
479 if e.is_timeout() {
480 Err(crate::types::AiLibError::TimeoutError(format!(
481 "Stream chunk timeout: {}",
482 e
483 )))
484 } else {
485 Err(crate::types::AiLibError::NetworkError(format!(
486 "Stream chunk error: {}",
487 e
488 )))
489 }
490 }
491 });
492
493 let boxed_stream: Pin<
494 Box<dyn Stream<Item = Result<Bytes, crate::types::AiLibError>> + Send>,
495 > = Box::pin(byte_stream);
496 Ok(boxed_stream)
497 })
498 }
499
500 fn upload_multipart<'a>(
501 &'a self,
502 url: &'a str,
503 headers: Option<HashMap<String, String>>,
504 field_name: &'a str,
505 file_name: &'a str,
506 bytes: Vec<u8>,
507 ) -> Pin<
508 Box<
509 dyn futures::Future<Output = Result<serde_json::Value, crate::types::AiLibError>>
510 + Send
511 + 'a,
512 >,
513 > {
514 Box::pin(async move {
515 let part = reqwest::multipart::Part::bytes(bytes).file_name(file_name.to_string());
517 let form = reqwest::multipart::Form::new().part(field_name.to_string(), part);
518
519 let mut req = self.inner.client.post(url).multipart(form);
520 if let Some(h) = headers {
521 for (k, v) in h.into_iter() {
522 req = req.header(k, v);
523 }
524 }
525
526 let resp = req.send().await.map_err(|e| {
527 if e.is_timeout() {
528 crate::types::AiLibError::TimeoutError(format!("upload request timeout: {}", e))
529 } else {
530 crate::types::AiLibError::NetworkError(format!("upload request failed: {}", e))
531 }
532 })?;
533 if !resp.status().is_success() {
534 let status = resp.status();
535 let text = resp.text().await.unwrap_or_default();
536 return Err(map_status_to_ailib(status.as_u16(), text));
537 }
538 let j: serde_json::Value = resp.json().await.map_err(|e| {
539 crate::types::AiLibError::DeserializationError(format!(
540 "parse upload response: {}",
541 e
542 ))
543 })?;
544 Ok(j)
545 })
546 }
547}
548
549impl HttpTransport {
550 pub fn boxed(self) -> DynHttpTransportRef {
552 Arc::new(HttpTransportBoxed::new(self))
553 }
554}
555
556fn map_transport_error_to_ailib(e: TransportError) -> crate::types::AiLibError {
558 use crate::types::AiLibError;
559 match e {
560 TransportError::AuthenticationError(msg) => AiLibError::AuthenticationError(msg),
561 TransportError::RateLimitExceeded => {
562 AiLibError::RateLimitExceeded("rate limited".to_string())
563 }
564 TransportError::Timeout(msg) => AiLibError::TimeoutError(msg),
565 TransportError::ServerError { status, message } => {
566 AiLibError::NetworkError(format!("server {}: {}", status, message))
568 }
569 TransportError::ClientError { status, message } => match status {
570 401 | 403 => AiLibError::AuthenticationError(message),
571 408 => AiLibError::TimeoutError(message),
572 409 | 425 | 429 => AiLibError::RateLimitExceeded(message),
573 _ => AiLibError::InvalidRequest(format!("client {}: {}", status, message)),
574 },
575 TransportError::HttpError(msg) => {
576 if msg.contains("timeout") {
578 AiLibError::TimeoutError(msg)
579 } else {
580 AiLibError::NetworkError(msg)
581 }
582 }
583 TransportError::JsonError(msg) => AiLibError::DeserializationError(msg),
584 TransportError::InvalidUrl(msg) => AiLibError::ConfigurationError(msg),
585 }
586}
587
588fn map_status_to_ailib(status: u16, body: String) -> crate::types::AiLibError {
589 use crate::types::AiLibError;
590 match status {
591 401 | 403 => AiLibError::AuthenticationError(body),
592 408 => AiLibError::TimeoutError(body),
593 409 | 425 | 429 => AiLibError::RateLimitExceeded(body),
594 500..=599 => AiLibError::NetworkError(format!("server {}: {}", status, body)),
595 400 => AiLibError::InvalidRequest(body),
596 _ => AiLibError::ProviderError(format!("http {}: {}", status, body)),
597 }
598}