1use super::error::TransportError;
2use async_trait::async_trait;
3use backoff::{future::retry, ExponentialBackoff};
4use reqwest::{Client, Method, Proxy, Response};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::env;
8use std::time::Duration;
9
10#[async_trait]
17pub trait HttpClient: Send + Sync {
18 async fn request<T, R>(
20 &self,
21 method: Method,
22 url: &str,
23 headers: Option<HashMap<String, String>>,
24 body: Option<&T>,
25 ) -> Result<R, TransportError>
26 where
27 T: Serialize + Send + Sync,
28 R: for<'de> Deserialize<'de>;
29
30 async fn request_with_retry<T, R>(
32 &self,
33 method: Method,
34 url: &str,
35 headers: Option<HashMap<String, String>>,
36 body: Option<&T>,
37 _max_retries: u32,
38 ) -> Result<R, TransportError>
39 where
40 T: Serialize + Send + Sync + Clone,
41 R: for<'de> Deserialize<'de>;
42
43 async fn get<R>(
45 &self,
46 url: &str,
47 headers: Option<HashMap<String, String>>,
48 ) -> Result<R, TransportError>
49 where
50 R: for<'de> Deserialize<'de>,
51 {
52 self.request(Method::GET, url, headers, None::<&()>).await
53 }
54
55 async fn post<T, R>(
57 &self,
58 url: &str,
59 headers: Option<HashMap<String, String>>,
60 body: &T,
61 ) -> Result<R, TransportError>
62 where
63 T: Serialize + Send + Sync,
64 R: for<'de> Deserialize<'de>,
65 {
66 self.request(Method::POST, url, headers, Some(body)).await
67 }
68
69 async fn put<T, R>(
71 &self,
72 url: &str,
73 headers: Option<HashMap<String, String>>,
74 body: &T,
75 ) -> Result<R, TransportError>
76 where
77 T: Serialize + Send + Sync,
78 R: for<'de> Deserialize<'de>,
79 {
80 self.request(Method::PUT, url, headers, Some(body)).await
81 }
82}
83
84pub struct HttpTransport {
90 client: Client,
91 timeout: Duration,
92}
93
94pub struct HttpTransportConfig {
96 pub timeout: Duration,
97 pub proxy: Option<String>,
98 pub pool_max_idle_per_host: Option<usize>,
100 pub pool_idle_timeout: Option<Duration>,
102}
103
104impl HttpTransport {
105 pub fn new() -> Self {
112 let timeout_secs = env::var("AI_HTTP_TIMEOUT_SECS")
114 .ok()
115 .and_then(|s| s.parse::<u64>().ok())
116 .or_else(|| {
117 env::var("AI_TIMEOUT_SECS")
118 .ok()
119 .and_then(|s| s.parse::<u64>().ok())
120 })
121 .unwrap_or(30);
122 Self::with_timeout(Duration::from_secs(timeout_secs))
123 }
124
125 pub fn new_without_proxy() -> Self {
130 let timeout_secs = env::var("AI_HTTP_TIMEOUT_SECS")
131 .ok()
132 .and_then(|s| s.parse::<u64>().ok())
133 .or_else(|| {
134 env::var("AI_TIMEOUT_SECS")
135 .ok()
136 .and_then(|s| s.parse::<u64>().ok())
137 })
138 .unwrap_or(30);
139 Self::with_timeout_without_proxy(Duration::from_secs(timeout_secs))
140 }
141
142 pub fn with_timeout(timeout: Duration) -> Self {
146 let mut client_builder = Client::builder().timeout(timeout);
147
148 if let Ok(v) = env::var("AI_HTTP_POOL_MAX_IDLE_PER_HOST") {
150 if let Ok(n) = v.parse::<usize>() {
151 client_builder = client_builder.pool_max_idle_per_host(n);
152 }
153 }
154 if let Ok(v) = env::var("AI_HTTP_POOL_IDLE_TIMEOUT_MS") {
155 if let Ok(ms) = v.parse::<u64>() {
156 client_builder = client_builder.pool_idle_timeout(Duration::from_millis(ms));
157 }
158 }
159
160 if let Ok(proxy_url) = env::var("AI_PROXY_URL") {
162 match Proxy::all(&proxy_url) {
163 Ok(proxy) => {
164 client_builder = client_builder.proxy(proxy);
165 }
166 Err(_) => {
167 }
169 }
170 }
171
172 let client = client_builder
173 .build()
174 .expect("Failed to create HTTP client");
175
176 Self { client, timeout }
177 }
178
179 pub fn with_timeout_without_proxy(timeout: Duration) -> Self {
183 let mut client_builder = Client::builder().timeout(timeout);
184
185 if let Ok(v) = env::var("AI_HTTP_POOL_MAX_IDLE_PER_HOST") {
187 if let Ok(n) = v.parse::<usize>() {
188 client_builder = client_builder.pool_max_idle_per_host(n);
189 }
190 }
191 if let Ok(v) = env::var("AI_HTTP_POOL_IDLE_TIMEOUT_MS") {
192 if let Ok(ms) = v.parse::<u64>() {
193 client_builder = client_builder.pool_idle_timeout(Duration::from_millis(ms));
194 }
195 }
196
197 let client = client_builder
198 .build()
199 .expect("Failed to create HTTP client");
200
201 Self { client, timeout }
202 }
203
204 pub fn with_client(client: Client, timeout: Duration) -> Self {
206 Self { client, timeout }
207 }
208
209 pub fn with_reqwest_client(client: Client, timeout: Duration) -> Self {
214 Self::with_client(client, timeout)
215 }
216
217 pub fn new_with_config(config: HttpTransportConfig) -> Result<Self, TransportError> {
219 let mut client_builder = Client::builder().timeout(config.timeout);
220
221 if let Some(max_idle) = config.pool_max_idle_per_host {
223 client_builder = client_builder.pool_max_idle_per_host(max_idle);
224 }
225 if let Some(idle_timeout) = config.pool_idle_timeout {
226 client_builder = client_builder.pool_idle_timeout(idle_timeout);
227 }
228
229 if let Some(proxy_url) = config.proxy {
230 if let Ok(proxy) = Proxy::all(&proxy_url) {
231 client_builder = client_builder.proxy(proxy);
232 }
233 }
234
235 let client = client_builder
236 .build()
237 .map_err(|e| TransportError::HttpError(e.to_string()))?;
238 Ok(Self {
239 client,
240 timeout: config.timeout,
241 })
242 }
243
244 pub fn with_proxy(timeout: Duration, proxy_url: Option<&str>) -> Result<Self, TransportError> {
246 let mut client_builder = Client::builder().timeout(timeout);
247
248 if let Some(url) = proxy_url {
249 let proxy = Proxy::all(url)
250 .map_err(|e| TransportError::InvalidUrl(format!("Invalid proxy URL: {}", e)))?;
251 client_builder = client_builder.proxy(proxy);
252 }
253
254 let client = client_builder
255 .build()
256 .map_err(|e| TransportError::HttpError(e.to_string()))?;
257
258 Ok(Self { client, timeout })
259 }
260
261 pub fn timeout(&self) -> Duration {
263 self.timeout
264 }
265
266 async fn execute_request<T, R>(
268 &self,
269 method: Method,
270 url: &str,
271 headers: Option<HashMap<String, String>>,
272 body: Option<&T>,
273 ) -> Result<R, TransportError>
274 where
275 T: Serialize + Send + Sync,
276 R: for<'de> Deserialize<'de>,
277 {
278 let mut request_builder = self.client.request(method, url);
279
280 if let Some(headers) = headers {
282 for (key, value) in headers {
283 request_builder = request_builder.header(key, value);
284 }
285 }
286
287 if let Some(body) = body {
289 request_builder = request_builder.json(body);
290 }
291
292 let response = request_builder.send().await?;
294
295 Self::handle_response(response).await
297 }
298
299 fn is_retryable_error(&self, error: &TransportError) -> bool {
301 match error {
302 TransportError::HttpError(err_msg) => {
303 err_msg.contains("timeout") || err_msg.contains("connection")
304 }
305 TransportError::ClientError { status, .. } => {
306 *status == 429 || *status == 502 || *status == 503 || *status == 504
307 }
308 TransportError::ServerError { .. } => true,
309 _ => false,
310 }
311 }
312
313 async fn handle_response<R>(response: Response) -> Result<R, TransportError>
315 where
316 R: for<'de> Deserialize<'de>,
317 {
318 let status = response.status();
319
320 if status.is_success() {
321 let json_text = response.text().await?;
322 let result: R = serde_json::from_str(&json_text)?;
323 Ok(result)
324 } else {
325 let error_text = response.text().await.unwrap_or_default();
326 Err(TransportError::from_status(status.as_u16(), error_text))
327 }
328 }
329}
330
331#[async_trait]
332impl HttpClient for HttpTransport {
333 async fn request<T, R>(
334 &self,
335 method: Method,
336 url: &str,
337 headers: Option<HashMap<String, String>>,
338 body: Option<&T>,
339 ) -> Result<R, TransportError>
340 where
341 T: Serialize + Send + Sync,
342 R: for<'de> Deserialize<'de>,
343 {
344 self.execute_request(method, url, headers, body).await
345 }
346
347 async fn request_with_retry<T, R>(
348 &self,
349 method: Method,
350 url: &str,
351 headers: Option<HashMap<String, String>>,
352 body: Option<&T>,
353 _max_retries: u32,
354 ) -> Result<R, TransportError>
355 where
356 T: Serialize + Send + Sync + Clone,
357 R: for<'de> Deserialize<'de>,
358 {
359 let backoff = ExponentialBackoff {
360 max_elapsed_time: Some(Duration::from_secs(60)),
361 max_interval: Duration::from_secs(10),
362 ..Default::default()
363 };
364
365 let headers_clone = headers.clone();
366 let body_clone = body.cloned();
367 let url_clone = url.to_string();
368
369 retry(backoff, || async {
370 match self
371 .execute_request(
372 method.clone(),
373 &url_clone,
374 headers_clone.clone(),
375 body_clone.as_ref(),
376 )
377 .await
378 {
379 Ok(result) => Ok(result),
380 Err(e) => {
381 if self.is_retryable_error(&e) {
382 Err(backoff::Error::transient(e))
383 } else {
384 Err(backoff::Error::permanent(e))
385 }
386 }
387 }
388 })
389 .await
390 }
391}
392
393impl Default for HttpTransport {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399pub struct HttpTransportBoxed {
401 inner: HttpTransport,
402}
403
404impl HttpTransportBoxed {
405 pub fn new(inner: HttpTransport) -> Self {
406 Self { inner }
407 }
408}
409
410use crate::transport::dyn_transport::{DynHttpTransport, DynHttpTransportRef};
411use bytes::Bytes;
412use futures::{Stream, StreamExt};
413use std::pin::Pin;
414use std::sync::Arc;
415
416impl DynHttpTransport for HttpTransportBoxed {
417 fn get_json<'a>(
418 &'a self,
419 url: &'a str,
420 headers: Option<HashMap<String, String>>,
421 ) -> futures::future::BoxFuture<'a, Result<serde_json::Value, crate::types::AiLibError>> {
422 Box::pin(async move {
423 let res: Result<serde_json::Value, TransportError> = self.inner.get(url, headers).await;
424 match res {
425 Ok(v) => Ok(v),
426 Err(e) => Err(map_transport_error_to_ailib(e)),
427 }
428 })
429 }
430
431 fn post_json<'a>(
432 &'a self,
433 url: &'a str,
434 headers: Option<HashMap<String, String>>,
435 body: serde_json::Value,
436 ) -> futures::future::BoxFuture<'a, Result<serde_json::Value, crate::types::AiLibError>> {
437 Box::pin(async move {
438 let res: Result<serde_json::Value, TransportError> =
439 self.inner.post(url, headers, &body).await;
440 match res {
441 Ok(v) => Ok(v),
442 Err(e) => Err(map_transport_error_to_ailib(e)),
443 }
444 })
445 }
446
447 fn post_stream<'a>(
448 &'a self,
449 _url: &'a str,
450 _headers: Option<HashMap<String, String>>,
451 _body: serde_json::Value,
452 ) -> futures::future::BoxFuture<
453 'a,
454 Result<
455 Pin<Box<dyn Stream<Item = Result<Bytes, crate::types::AiLibError>> + Send>>,
456 crate::types::AiLibError,
457 >,
458 > {
459 Box::pin(async move {
460 let mut req = self.inner.client.post(_url).json(&_body);
462 if let Some(h) = _headers {
464 for (k, v) in h.into_iter() {
465 req = req.header(k, v);
466 }
467 }
468 req = req.header("Accept", "text/event-stream");
470
471 let resp = req.send().await.map_err(|e| {
472 if e.is_timeout() {
473 crate::types::AiLibError::TimeoutError(format!("Stream request timeout: {}", e))
474 } else {
475 crate::types::AiLibError::NetworkError(format!("Stream request failed: {}", e))
476 }
477 })?;
478 if !resp.status().is_success() {
479 let status = resp.status();
480 let text = resp.text().await.unwrap_or_default();
481 return Err(map_status_to_ailib(status.as_u16(), text));
482 }
483
484 let byte_stream = resp.bytes_stream().map(|res| match res {
485 Ok(b) => Ok(b),
486 Err(e) => {
487 if e.is_timeout() {
488 Err(crate::types::AiLibError::TimeoutError(format!(
489 "Stream chunk timeout: {}",
490 e
491 )))
492 } else {
493 Err(crate::types::AiLibError::NetworkError(format!(
494 "Stream chunk error: {}",
495 e
496 )))
497 }
498 }
499 });
500
501 let boxed_stream: Pin<
502 Box<dyn Stream<Item = Result<Bytes, crate::types::AiLibError>> + Send>,
503 > = Box::pin(byte_stream);
504 Ok(boxed_stream)
505 })
506 }
507
508 fn upload_multipart<'a>(
509 &'a self,
510 url: &'a str,
511 headers: Option<HashMap<String, String>>,
512 field_name: &'a str,
513 file_name: &'a str,
514 bytes: Vec<u8>,
515 ) -> Pin<
516 Box<
517 dyn futures::Future<Output = Result<serde_json::Value, crate::types::AiLibError>>
518 + Send
519 + 'a,
520 >,
521 > {
522 Box::pin(async move {
523 let part = reqwest::multipart::Part::bytes(bytes).file_name(file_name.to_string());
525 let form = reqwest::multipart::Form::new().part(field_name.to_string(), part);
526
527 let mut req = self.inner.client.post(url).multipart(form);
528 if let Some(h) = headers {
529 for (k, v) in h.into_iter() {
530 req = req.header(k, v);
531 }
532 }
533
534 let resp = req.send().await.map_err(|e| {
535 if e.is_timeout() {
536 crate::types::AiLibError::TimeoutError(format!("upload request timeout: {}", e))
537 } else {
538 crate::types::AiLibError::NetworkError(format!("upload request failed: {}", e))
539 }
540 })?;
541 if !resp.status().is_success() {
542 let status = resp.status();
543 let text = resp.text().await.unwrap_or_default();
544 return Err(map_status_to_ailib(status.as_u16(), text));
545 }
546 let j: serde_json::Value = resp.json().await.map_err(|e| {
547 crate::types::AiLibError::DeserializationError(format!(
548 "parse upload response: {}",
549 e
550 ))
551 })?;
552 Ok(j)
553 })
554 }
555}
556
557impl HttpTransport {
558 pub fn boxed(self) -> DynHttpTransportRef {
560 Arc::new(HttpTransportBoxed::new(self))
561 }
562}
563
564fn map_transport_error_to_ailib(e: TransportError) -> crate::types::AiLibError {
566 use crate::types::AiLibError;
567 match e {
568 TransportError::AuthenticationError(msg) => AiLibError::AuthenticationError(msg),
569 TransportError::RateLimitExceeded => {
570 AiLibError::RateLimitExceeded("rate limited".to_string())
571 }
572 TransportError::Timeout(msg) => AiLibError::TimeoutError(msg),
573 TransportError::ServerError { status, message } => {
574 AiLibError::NetworkError(format!("server {}: {}", status, message))
576 }
577 TransportError::ClientError { status, message } => match status {
578 401 | 403 => AiLibError::AuthenticationError(message),
579 408 => AiLibError::TimeoutError(message),
580 409 | 425 | 429 => AiLibError::RateLimitExceeded(message),
581 _ => AiLibError::InvalidRequest(format!("client {}: {}", status, message)),
582 },
583 TransportError::HttpError(msg) => {
584 if msg.contains("timeout") {
586 AiLibError::TimeoutError(msg)
587 } else {
588 AiLibError::NetworkError(msg)
589 }
590 }
591 TransportError::JsonError(msg) => AiLibError::DeserializationError(msg),
592 TransportError::InvalidUrl(msg) => AiLibError::ConfigurationError(msg),
593 }
594}
595
596fn map_status_to_ailib(status: u16, body: String) -> crate::types::AiLibError {
597 use crate::types::AiLibError;
598 match status {
599 401 | 403 => AiLibError::AuthenticationError(body),
600 408 => AiLibError::TimeoutError(body),
601 409 | 425 | 429 => AiLibError::RateLimitExceeded(body),
602 500..=599 => AiLibError::NetworkError(format!("server {}: {}", status, body)),
603 400 => AiLibError::InvalidRequest(body),
604 _ => AiLibError::ProviderError(format!("http {}: {}", status, body)),
605 }
606}