1use crate::VERSION;
4use crate::error::{Error, Result, StatusCodeExt};
5use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
6use reqwest::{Method, Response};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
9use retry_policies::Jitter;
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12use std::time::Duration;
13
14const DEFAULT_BASE_URL: &str = "https://api.replicate.com";
16
17#[derive(Debug, Clone)]
19pub struct RetryConfig {
20 pub max_retries: u32,
21 pub min_delay: Duration,
22 pub max_delay: Duration,
23 pub base_multiplier: u32,
24}
25
26impl Default for RetryConfig {
27 fn default() -> Self {
28 Self {
29 max_retries: 3,
30 min_delay: Duration::from_millis(500),
31 max_delay: Duration::from_secs(30),
32 base_multiplier: 2,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct TimeoutConfig {
40 pub connect_timeout: Option<Duration>,
41 pub request_timeout: Option<Duration>,
42}
43
44impl Default for TimeoutConfig {
45 fn default() -> Self {
46 Self {
47 connect_timeout: Some(Duration::from_secs(30)),
48 request_timeout: Some(Duration::from_secs(60)),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Default)]
55pub struct HttpConfig {
56 pub retry: RetryConfig,
57 pub timeout: TimeoutConfig,
58}
59
60#[derive(Debug, Clone)]
62pub struct HttpClient {
63 client: ClientWithMiddleware,
64 base_url: String,
65 api_token: String,
66 http_config: HttpConfig,
67}
68
69impl HttpClient {
70 pub fn new(api_token: impl Into<String>) -> Result<Self> {
72 Self::with_retry_config(api_token, RetryConfig::default())
73 }
74
75 pub fn with_retry_config(
77 api_token: impl Into<String>,
78 retry_config: RetryConfig,
79 ) -> Result<Self> {
80 let http_config = HttpConfig {
81 retry: retry_config,
82 timeout: TimeoutConfig::default(),
83 };
84 Self::with_http_config(api_token, http_config)
85 }
86
87 pub fn with_http_config(api_token: impl Into<String>, http_config: HttpConfig) -> Result<Self> {
89 let api_token = api_token.into();
90 if api_token.is_empty() {
91 return Err(Error::auth_error("API token cannot be empty"));
92 }
93
94 let client = Self::build_client_with_config(&http_config)?;
95
96 Ok(Self {
97 client,
98 base_url: DEFAULT_BASE_URL.to_string(),
99 api_token,
100 http_config,
101 })
102 }
103
104 fn build_client_with_config(http_config: &HttpConfig) -> Result<ClientWithMiddleware> {
106 let retry_policy = ExponentialBackoff::builder()
108 .retry_bounds(http_config.retry.min_delay, http_config.retry.max_delay)
109 .jitter(Jitter::Bounded)
110 .base(http_config.retry.base_multiplier)
111 .build_with_max_retries(http_config.retry.max_retries);
112
113 let mut client_builder =
115 reqwest::Client::builder().user_agent(format!("replicate-rs/{}", crate::VERSION));
116
117 if let Some(connect_timeout) = http_config.timeout.connect_timeout {
118 client_builder = client_builder.connect_timeout(connect_timeout);
119 }
120
121 if let Some(request_timeout) = http_config.timeout.request_timeout {
122 client_builder = client_builder.timeout(request_timeout);
123 }
124
125 let reqwest_client = client_builder.build()?;
126
127 let client = ClientBuilder::new(reqwest_client)
129 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
130 .build();
131
132 Ok(client)
133 }
134
135 pub fn with_base_url(
137 api_token: impl Into<String>,
138 base_url: impl Into<String>,
139 ) -> Result<Self> {
140 let mut client = Self::new(api_token)?;
141 client.base_url = base_url.into();
142 Ok(client)
143 }
144
145 pub fn with_base_url_and_retry(
147 api_token: impl Into<String>,
148 base_url: impl Into<String>,
149 retry_config: RetryConfig,
150 ) -> Result<Self> {
151 let mut client = Self::with_retry_config(api_token, retry_config)?;
152 client.base_url = base_url.into();
153 Ok(client)
154 }
155
156 pub fn with_base_url_and_http_config(
158 api_token: impl Into<String>,
159 base_url: impl Into<String>,
160 http_config: HttpConfig,
161 ) -> Result<Self> {
162 let mut client = Self::with_http_config(api_token, http_config)?;
163 client.base_url = base_url.into();
164 Ok(client)
165 }
166
167 pub fn inner(&self) -> &ClientWithMiddleware {
169 &self.client
170 }
171
172 fn build_url(&self, path: &str) -> String {
174 let path = path.strip_prefix('/').unwrap_or(path);
175 format!("{}/{}", self.base_url.trim_end_matches('/'), path)
176 }
177
178 async fn execute_request(&self, method: Method, path: &str) -> Result<Response> {
180 let url = self.build_url(path);
181 let response = self
182 .client
183 .request(method, &url)
184 .header("Authorization", format!("Token {}", self.api_token))
185 .header("Content-Type", "application/json")
186 .send()
187 .await?;
188
189 if response.status().is_success() {
190 Ok(response)
191 } else {
192 let status = response.status();
193 let body = response.text().await.unwrap_or_default();
194 Err(status.to_replicate_error(body))
195 }
196 }
197
198 async fn execute_request_with_json<T: Serialize>(
200 &self,
201 method: Method,
202 path: &str,
203 body: &T,
204 ) -> Result<Response> {
205 let url = self.build_url(path);
206 let json_body = serde_json::to_vec(body)?;
207 let response = self
208 .client
209 .request(method, &url)
210 .header("Authorization", format!("Token {}", self.api_token))
211 .header("Content-Type", "application/json")
212 .body(json_body)
213 .send()
214 .await?;
215
216 if response.status().is_success() {
217 Ok(response)
218 } else {
219 let status = response.status();
220 let body = response.text().await.unwrap_or_default();
221 Err(status.to_replicate_error(body))
222 }
223 }
224
225 pub async fn get(&self, path: &str) -> Result<Response> {
227 self.execute_request(Method::GET, path).await
228 }
229
230 pub async fn post<T: Serialize>(&self, path: &str, body: &T) -> Result<Response> {
232 self.execute_request_with_json(Method::POST, path, body)
233 .await
234 }
235
236 pub async fn post_empty(&self, path: &str) -> Result<Response> {
238 self.execute_request(Method::POST, path).await
239 }
240
241 pub async fn put<T: Serialize>(&self, path: &str, body: &T) -> Result<Response> {
243 self.execute_request_with_json(Method::PUT, path, body)
244 .await
245 }
246
247 pub async fn delete(&self, path: &str) -> Result<Response> {
249 self.execute_request(Method::DELETE, path).await
250 }
251
252 pub async fn get_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
254 let response = self.get(path).await?;
255 let json = response.json().await?;
256 Ok(json)
257 }
258
259 pub async fn post_json<B: Serialize, T: for<'de> Deserialize<'de>>(
261 &self,
262 path: &str,
263 body: &B,
264 ) -> Result<T> {
265 let response = self.post(path, body).await?;
266 let json = response.json().await?;
267 Ok(json)
268 }
269
270 pub async fn post_empty_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
272 let response = self.post_empty(path).await?;
273 let json = response.json().await?;
274 Ok(json)
275 }
276
277 pub fn configure_retries(
300 &mut self,
301 max_retries: u32,
302 min_delay: Duration,
303 max_delay: Duration,
304 ) -> Result<()> {
305 self.configure_retries_advanced(max_retries, min_delay, max_delay, 2)
306 }
307
308 pub fn configure_retries_advanced(
319 &mut self,
320 max_retries: u32,
321 min_delay: Duration,
322 max_delay: Duration,
323 base_multiplier: u32,
324 ) -> Result<()> {
325 let new_retry_config = RetryConfig {
326 max_retries,
327 min_delay,
328 max_delay,
329 base_multiplier,
330 };
331
332 let new_http_config = HttpConfig {
333 retry: new_retry_config,
334 timeout: self.http_config.timeout.clone(),
335 };
336
337 let new_client = Self::build_client_with_config(&new_http_config)?;
339
340 self.client = new_client;
342 self.http_config = new_http_config;
343
344 Ok(())
345 }
346
347 pub fn configure_timeouts(
356 &mut self,
357 connect_timeout: Option<Duration>,
358 request_timeout: Option<Duration>,
359 ) -> Result<()> {
360 let new_timeout_config = TimeoutConfig {
361 connect_timeout,
362 request_timeout,
363 };
364
365 let new_http_config = HttpConfig {
366 retry: self.http_config.retry.clone(),
367 timeout: new_timeout_config,
368 };
369
370 let new_client = Self::build_client_with_config(&new_http_config)?;
372
373 self.client = new_client;
375 self.http_config = new_http_config;
376
377 Ok(())
378 }
379
380 pub fn retry_config(&self) -> &RetryConfig {
382 &self.http_config.retry
383 }
384
385 pub fn timeout_config(&self) -> &TimeoutConfig {
387 &self.http_config.timeout
388 }
389
390 pub fn http_config(&self) -> &HttpConfig {
392 &self.http_config
393 }
394
395 async fn execute_multipart_request(
397 &self,
398 method: Method,
399 path: &str,
400 form: reqwest::multipart::Form,
401 ) -> Result<Response> {
402 let url = self.build_url(path);
403
404 let mut headers = HeaderMap::new();
405 headers.insert(
406 AUTHORIZATION,
407 HeaderValue::from_str(&format!("Token {}", self.api_token))
408 .map_err(|_| Error::auth_error("Invalid API token format"))?,
409 );
410 headers.insert(
411 USER_AGENT,
412 HeaderValue::from_str(&format!("replicate-rs/{}", VERSION))
413 .map_err(|_| Error::InvalidInput("Invalid user agent format".to_string()))?,
414 );
415
416 let inner_client = reqwest::Client::new();
419 let request = inner_client
420 .request(method, &url)
421 .headers(headers)
422 .multipart(form);
423
424 let response = request.send().await?;
425
426 if response.status().is_success() {
427 Ok(response)
428 } else {
429 let status = response.status().as_u16();
430 let text = response.text().await.unwrap_or_default();
431
432 if let Ok(api_error) = serde_json::from_str::<serde_json::Value>(&text) {
434 let message = api_error
435 .get("detail")
436 .and_then(|v| v.as_str())
437 .unwrap_or("Unknown API error");
438
439 Err(Error::Api {
440 status,
441 message: message.to_string(),
442 detail: Some(text),
443 })
444 } else {
445 Err(Error::Api {
446 status,
447 message: text,
448 detail: None,
449 })
450 }
451 }
452 }
453
454 pub async fn post_multipart(
456 &self,
457 path: &str,
458 form: reqwest::multipart::Form,
459 ) -> Result<Response> {
460 self.execute_multipart_request(Method::POST, path, form)
461 .await
462 }
463
464 pub async fn post_multipart_json<T: for<'de> serde::Deserialize<'de>>(
466 &self,
467 path: &str,
468 form: reqwest::multipart::Form,
469 ) -> Result<T> {
470 let response = self.post_multipart(path, form).await?;
471 let text = response.text().await?;
472 serde_json::from_str(&text).map_err(Into::into)
473 }
474
475 pub async fn create_file_form(
477 file_content: &[u8],
478 filename: Option<&str>,
479 content_type: Option<&str>,
480 metadata: Option<&std::collections::HashMap<String, serde_json::Value>>,
481 ) -> Result<reqwest::multipart::Form> {
482 let filename = filename.unwrap_or("file").to_string();
483 let content_type = content_type
484 .unwrap_or("application/octet-stream")
485 .to_string();
486
487 let file_part = reqwest::multipart::Part::bytes(file_content.to_vec())
488 .file_name(filename)
489 .mime_str(&content_type)
490 .map_err(|e| Error::InvalidInput(format!("Invalid content type: {}", e)))?;
491
492 let mut form = reqwest::multipart::Form::new().part("content", file_part);
493
494 if let Some(metadata) = metadata {
496 let metadata_json = serde_json::to_string(metadata)?;
497 form = form.text("metadata", metadata_json);
498 }
499
500 Ok(form)
501 }
502
503 pub async fn create_file_form_from_path(
505 file_path: &Path,
506 metadata: Option<&std::collections::HashMap<String, serde_json::Value>>,
507 ) -> Result<reqwest::multipart::Form> {
508 let file_content = tokio::fs::read(file_path).await?;
510
511 let filename = file_path
513 .file_name()
514 .and_then(|n| n.to_str())
515 .unwrap_or("file");
516
517 let content_type = mime_guess::from_path(file_path)
518 .first_or_octet_stream()
519 .to_string();
520
521 Self::create_file_form(&file_content, Some(filename), Some(&content_type), metadata).await
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn test_build_url() {
531 let client = HttpClient::new("test-token").unwrap();
532
533 assert_eq!(
534 client.build_url("/v1/predictions"),
535 "https://api.replicate.com/v1/predictions"
536 );
537
538 assert_eq!(
539 client.build_url("v1/predictions"),
540 "https://api.replicate.com/v1/predictions"
541 );
542 }
543
544 #[test]
545 fn test_empty_token_error() {
546 let result = HttpClient::new("");
547 assert!(result.is_err());
548 assert!(matches!(result.unwrap_err(), Error::Auth(_)));
549 }
550
551 #[test]
552 fn test_client_creation_with_retry() {
553 let client = HttpClient::new("test-token");
554 assert!(client.is_ok());
555
556 let client = client.unwrap();
558 let _inner = client.inner(); let retry_config = client.retry_config();
562 assert_eq!(retry_config.max_retries, 3);
563 assert_eq!(retry_config.min_delay, Duration::from_millis(500));
564 assert_eq!(retry_config.max_delay, Duration::from_secs(30));
565 assert_eq!(retry_config.base_multiplier, 2);
566 }
567
568 #[test]
569 fn test_retry_configuration() {
570 let mut client = HttpClient::new("test-token").unwrap();
571
572 let initial_config = client.retry_config();
574 assert_eq!(initial_config.max_retries, 3);
575
576 let result =
578 client.configure_retries(5, Duration::from_millis(100), Duration::from_secs(60));
579 assert!(result.is_ok());
580
581 let new_config = client.retry_config();
583 assert_eq!(new_config.max_retries, 5);
584 assert_eq!(new_config.min_delay, Duration::from_millis(100));
585 assert_eq!(new_config.max_delay, Duration::from_secs(60));
586 assert_eq!(new_config.base_multiplier, 2);
587 }
588
589 #[test]
590 fn test_custom_retry_config() {
591 let custom_config = RetryConfig {
592 max_retries: 2,
593 min_delay: Duration::from_millis(200),
594 max_delay: Duration::from_secs(10),
595 base_multiplier: 3,
596 };
597
598 let client = HttpClient::with_retry_config("test-token", custom_config.clone());
599 assert!(client.is_ok());
600
601 let client = client.unwrap();
602 let actual_config = client.retry_config();
603 assert_eq!(actual_config.max_retries, custom_config.max_retries);
604 assert_eq!(actual_config.min_delay, custom_config.min_delay);
605 assert_eq!(actual_config.max_delay, custom_config.max_delay);
606 assert_eq!(actual_config.base_multiplier, custom_config.base_multiplier);
607 }
608
609 #[test]
610 fn test_timeout_configuration() {
611 let timeout_config = TimeoutConfig {
612 connect_timeout: Some(Duration::from_secs(15)),
613 request_timeout: Some(Duration::from_secs(90)),
614 };
615
616 let http_config = HttpConfig {
617 retry: RetryConfig::default(),
618 timeout: timeout_config,
619 };
620
621 let client = HttpClient::with_http_config("test-token", http_config);
622 assert!(client.is_ok());
623
624 let client = client.unwrap();
625 let returned_timeout_config = client.timeout_config();
626 assert_eq!(
627 returned_timeout_config.connect_timeout,
628 Some(Duration::from_secs(15))
629 );
630 assert_eq!(
631 returned_timeout_config.request_timeout,
632 Some(Duration::from_secs(90))
633 );
634 }
635
636 #[test]
637 fn test_timeout_reconfiguration() {
638 let mut client = HttpClient::new("test-token").unwrap();
639
640 let initial_config = client.timeout_config();
642 assert_eq!(
643 initial_config.connect_timeout,
644 Some(Duration::from_secs(30))
645 );
646 assert_eq!(
647 initial_config.request_timeout,
648 Some(Duration::from_secs(60))
649 );
650
651 let result =
653 client.configure_timeouts(Some(Duration::from_secs(5)), Some(Duration::from_secs(120)));
654 assert!(result.is_ok());
655
656 let updated_config = client.timeout_config();
657 assert_eq!(updated_config.connect_timeout, Some(Duration::from_secs(5)));
658 assert_eq!(
659 updated_config.request_timeout,
660 Some(Duration::from_secs(120))
661 );
662 }
663
664 #[test]
665 fn test_timeout_disable() {
666 let mut client = HttpClient::new("test-token").unwrap();
667
668 let result = client.configure_timeouts(None, None);
670 assert!(result.is_ok());
671
672 let config = client.timeout_config();
673 assert_eq!(config.connect_timeout, None);
674 assert_eq!(config.request_timeout, None);
675 }
676
677 #[test]
678 fn test_http_config_accessors() {
679 let http_config = HttpConfig {
680 retry: RetryConfig {
681 max_retries: 2,
682 min_delay: Duration::from_millis(100),
683 max_delay: Duration::from_secs(20),
684 base_multiplier: 4,
685 },
686 timeout: TimeoutConfig {
687 connect_timeout: Some(Duration::from_secs(10)),
688 request_timeout: Some(Duration::from_secs(45)),
689 },
690 };
691
692 let client = HttpClient::with_http_config("test-token", http_config);
693 assert!(client.is_ok());
694
695 let client = client.unwrap();
696 let returned_config = client.http_config();
697 assert_eq!(returned_config.retry.max_retries, 2);
698 assert_eq!(returned_config.retry.min_delay, Duration::from_millis(100));
699 assert_eq!(returned_config.retry.max_delay, Duration::from_secs(20));
700 assert_eq!(returned_config.retry.base_multiplier, 4);
701 assert_eq!(
702 returned_config.timeout.connect_timeout,
703 Some(Duration::from_secs(10))
704 );
705 assert_eq!(
706 returned_config.timeout.request_timeout,
707 Some(Duration::from_secs(45))
708 );
709 }
710}