replicate_client/http/
client.rs

1//! HTTP client implementation for the Replicate API with retry logic.
2
3use crate::VERSION;
4use crate::error::{Error, Result, StatusCodeExt};
5use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
6use reqwest::{Method, Response};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
9use retry_policies::Jitter;
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12use std::time::Duration;
13
14/// Base URL for the Replicate API.
15const DEFAULT_BASE_URL: &str = "https://api.replicate.com";
16
17/// Configuration for retry behavior.
18#[derive(Debug, Clone)]
19pub struct RetryConfig {
20    pub max_retries: u32,
21    pub min_delay: Duration,
22    pub max_delay: Duration,
23    pub base_multiplier: u32,
24}
25
26impl Default for RetryConfig {
27    fn default() -> Self {
28        Self {
29            max_retries: 3,
30            min_delay: Duration::from_millis(500),
31            max_delay: Duration::from_secs(30),
32            base_multiplier: 2,
33        }
34    }
35}
36
37/// Configuration for HTTP timeouts.
38#[derive(Debug, Clone)]
39pub struct TimeoutConfig {
40    pub connect_timeout: Option<Duration>,
41    pub request_timeout: Option<Duration>,
42}
43
44impl Default for TimeoutConfig {
45    fn default() -> Self {
46        Self {
47            connect_timeout: Some(Duration::from_secs(30)),
48            request_timeout: Some(Duration::from_secs(60)),
49        }
50    }
51}
52
53/// Combined HTTP client configuration.
54#[derive(Debug, Clone, Default)]
55pub struct HttpConfig {
56    pub retry: RetryConfig,
57    pub timeout: TimeoutConfig,
58}
59
60/// HTTP client for making requests to the Replicate API with retry logic.
61#[derive(Debug, Clone)]
62pub struct HttpClient {
63    client: ClientWithMiddleware,
64    base_url: String,
65    api_token: String,
66    http_config: HttpConfig,
67}
68
69impl HttpClient {
70    /// Create a new HTTP client with the given API token and default retry logic.
71    pub fn new(api_token: impl Into<String>) -> Result<Self> {
72        Self::with_retry_config(api_token, RetryConfig::default())
73    }
74
75    /// Create a new HTTP client with the given API token and custom retry configuration.
76    pub fn with_retry_config(
77        api_token: impl Into<String>,
78        retry_config: RetryConfig,
79    ) -> Result<Self> {
80        let http_config = HttpConfig {
81            retry: retry_config,
82            timeout: TimeoutConfig::default(),
83        };
84        Self::with_http_config(api_token, http_config)
85    }
86
87    /// Create a new HTTP client with the given API token and custom HTTP configuration.
88    pub fn with_http_config(api_token: impl Into<String>, http_config: HttpConfig) -> Result<Self> {
89        let api_token = api_token.into();
90        if api_token.is_empty() {
91            return Err(Error::auth_error("API token cannot be empty"));
92        }
93
94        let client = Self::build_client_with_config(&http_config)?;
95
96        Ok(Self {
97            client,
98            base_url: DEFAULT_BASE_URL.to_string(),
99            api_token,
100            http_config,
101        })
102    }
103
104    /// Build a reqwest client with retry middleware and timeout configuration.
105    fn build_client_with_config(http_config: &HttpConfig) -> Result<ClientWithMiddleware> {
106        // Create exponential backoff retry policy
107        let retry_policy = ExponentialBackoff::builder()
108            .retry_bounds(http_config.retry.min_delay, http_config.retry.max_delay)
109            .jitter(Jitter::Bounded)
110            .base(http_config.retry.base_multiplier)
111            .build_with_max_retries(http_config.retry.max_retries);
112
113        // Build reqwest client with timeout configuration
114        let mut client_builder =
115            reqwest::Client::builder().user_agent(format!("replicate-rs/{}", crate::VERSION));
116
117        if let Some(connect_timeout) = http_config.timeout.connect_timeout {
118            client_builder = client_builder.connect_timeout(connect_timeout);
119        }
120
121        if let Some(request_timeout) = http_config.timeout.request_timeout {
122            client_builder = client_builder.timeout(request_timeout);
123        }
124
125        let reqwest_client = client_builder.build()?;
126
127        // Build client with retry middleware
128        let client = ClientBuilder::new(reqwest_client)
129            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
130            .build();
131
132        Ok(client)
133    }
134
135    /// Create a new HTTP client with custom base URL.
136    pub fn with_base_url(
137        api_token: impl Into<String>,
138        base_url: impl Into<String>,
139    ) -> Result<Self> {
140        let mut client = Self::new(api_token)?;
141        client.base_url = base_url.into();
142        Ok(client)
143    }
144
145    /// Create a new HTTP client with custom base URL and retry configuration.
146    pub fn with_base_url_and_retry(
147        api_token: impl Into<String>,
148        base_url: impl Into<String>,
149        retry_config: RetryConfig,
150    ) -> Result<Self> {
151        let mut client = Self::with_retry_config(api_token, retry_config)?;
152        client.base_url = base_url.into();
153        Ok(client)
154    }
155
156    /// Create a new HTTP client with custom base URL and HTTP configuration.
157    pub fn with_base_url_and_http_config(
158        api_token: impl Into<String>,
159        base_url: impl Into<String>,
160        http_config: HttpConfig,
161    ) -> Result<Self> {
162        let mut client = Self::with_http_config(api_token, http_config)?;
163        client.base_url = base_url.into();
164        Ok(client)
165    }
166
167    /// Get a reference to the underlying client with middleware.
168    pub fn inner(&self) -> &ClientWithMiddleware {
169        &self.client
170    }
171
172    /// Build a full URL from a path.
173    fn build_url(&self, path: &str) -> String {
174        let path = path.strip_prefix('/').unwrap_or(path);
175        format!("{}/{}", self.base_url.trim_end_matches('/'), path)
176    }
177
178    /// Execute a request and handle errors.
179    async fn execute_request(&self, method: Method, path: &str) -> Result<Response> {
180        let url = self.build_url(path);
181        let response = self
182            .client
183            .request(method, &url)
184            .header("Authorization", format!("Token {}", self.api_token))
185            .header("Content-Type", "application/json")
186            .send()
187            .await?;
188
189        if response.status().is_success() {
190            Ok(response)
191        } else {
192            let status = response.status();
193            let body = response.text().await.unwrap_or_default();
194            Err(status.to_replicate_error(body))
195        }
196    }
197
198    /// Execute a request with JSON body and handle errors.
199    async fn execute_request_with_json<T: Serialize>(
200        &self,
201        method: Method,
202        path: &str,
203        body: &T,
204    ) -> Result<Response> {
205        let url = self.build_url(path);
206        let json_body = serde_json::to_vec(body)?;
207        let response = self
208            .client
209            .request(method, &url)
210            .header("Authorization", format!("Token {}", self.api_token))
211            .header("Content-Type", "application/json")
212            .body(json_body)
213            .send()
214            .await?;
215
216        if response.status().is_success() {
217            Ok(response)
218        } else {
219            let status = response.status();
220            let body = response.text().await.unwrap_or_default();
221            Err(status.to_replicate_error(body))
222        }
223    }
224
225    /// Make a GET request.
226    pub async fn get(&self, path: &str) -> Result<Response> {
227        self.execute_request(Method::GET, path).await
228    }
229
230    /// Make a POST request with JSON body.
231    pub async fn post<T: Serialize>(&self, path: &str, body: &T) -> Result<Response> {
232        self.execute_request_with_json(Method::POST, path, body)
233            .await
234    }
235
236    /// Make a POST request without a body.
237    pub async fn post_empty(&self, path: &str) -> Result<Response> {
238        self.execute_request(Method::POST, path).await
239    }
240
241    /// Make a PUT request with JSON body.
242    pub async fn put<T: Serialize>(&self, path: &str, body: &T) -> Result<Response> {
243        self.execute_request_with_json(Method::PUT, path, body)
244            .await
245    }
246
247    /// Make a DELETE request.
248    pub async fn delete(&self, path: &str) -> Result<Response> {
249        self.execute_request(Method::DELETE, path).await
250    }
251
252    /// Make a GET request and deserialize the response as JSON.
253    pub async fn get_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
254        let response = self.get(path).await?;
255        let json = response.json().await?;
256        Ok(json)
257    }
258
259    /// Make a POST request and deserialize the response as JSON.
260    pub async fn post_json<B: Serialize, T: for<'de> Deserialize<'de>>(
261        &self,
262        path: &str,
263        body: &B,
264    ) -> Result<T> {
265        let response = self.post(path, body).await?;
266        let json = response.json().await?;
267        Ok(json)
268    }
269
270    /// Make a POST request without body and deserialize the response as JSON.
271    pub async fn post_empty_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
272        let response = self.post_empty(path).await?;
273        let json = response.json().await?;
274        Ok(json)
275    }
276
277    /// Configure retry policy for this client.
278    ///
279    /// This rebuilds the underlying HTTP client with new retry settings.
280    ///
281    /// # Examples
282    ///
283    /// ```no_run
284    /// # use replicate_client::Client;
285    /// # use std::time::Duration;
286    /// # #[tokio::main]
287    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
288    /// let mut client = Client::new("your-api-token")?;
289    ///
290    /// // Configure more aggressive retry settings
291    /// client.http_client_mut().configure_retries(
292    ///     5,                               // max_retries
293    ///     Duration::from_millis(100),      // min_delay
294    ///     Duration::from_secs(60),         // max_delay
295    /// )?;
296    /// # Ok(())
297    /// # }
298    /// ```
299    pub fn configure_retries(
300        &mut self,
301        max_retries: u32,
302        min_delay: Duration,
303        max_delay: Duration,
304    ) -> Result<()> {
305        self.configure_retries_advanced(max_retries, min_delay, max_delay, 2)
306    }
307
308    /// Configure retry policy with advanced settings.
309    ///
310    /// This rebuilds the underlying HTTP client with new retry settings.
311    ///
312    /// # Arguments
313    ///
314    /// * `max_retries` - Maximum number of retry attempts
315    /// * `min_delay` - Minimum delay between retries
316    /// * `max_delay` - Maximum delay between retries
317    /// * `base_multiplier` - Base multiplier for exponential backoff (typically 2)
318    pub fn configure_retries_advanced(
319        &mut self,
320        max_retries: u32,
321        min_delay: Duration,
322        max_delay: Duration,
323        base_multiplier: u32,
324    ) -> Result<()> {
325        let new_retry_config = RetryConfig {
326            max_retries,
327            min_delay,
328            max_delay,
329            base_multiplier,
330        };
331
332        let new_http_config = HttpConfig {
333            retry: new_retry_config,
334            timeout: self.http_config.timeout.clone(),
335        };
336
337        // Rebuild the client with new configuration
338        let new_client = Self::build_client_with_config(&new_http_config)?;
339
340        // Update the client and configuration
341        self.client = new_client;
342        self.http_config = new_http_config;
343
344        Ok(())
345    }
346
347    /// Configure timeout settings for this client.
348    ///
349    /// This rebuilds the underlying HTTP client with new timeout settings.
350    ///
351    /// # Arguments
352    ///
353    /// * `connect_timeout` - Maximum time to wait for connection establishment (None = no timeout)
354    /// * `request_timeout` - Maximum time to wait for complete request (None = no timeout)
355    pub fn configure_timeouts(
356        &mut self,
357        connect_timeout: Option<Duration>,
358        request_timeout: Option<Duration>,
359    ) -> Result<()> {
360        let new_timeout_config = TimeoutConfig {
361            connect_timeout,
362            request_timeout,
363        };
364
365        let new_http_config = HttpConfig {
366            retry: self.http_config.retry.clone(),
367            timeout: new_timeout_config,
368        };
369
370        // Rebuild the client with new configuration
371        let new_client = Self::build_client_with_config(&new_http_config)?;
372
373        // Update the client and configuration
374        self.client = new_client;
375        self.http_config = new_http_config;
376
377        Ok(())
378    }
379
380    /// Get the current retry configuration.
381    pub fn retry_config(&self) -> &RetryConfig {
382        &self.http_config.retry
383    }
384
385    /// Get the current timeout configuration.
386    pub fn timeout_config(&self) -> &TimeoutConfig {
387        &self.http_config.timeout
388    }
389
390    /// Get the current HTTP configuration.
391    pub fn http_config(&self) -> &HttpConfig {
392        &self.http_config
393    }
394
395    /// Execute a multipart form request.
396    async fn execute_multipart_request(
397        &self,
398        method: Method,
399        path: &str,
400        form: reqwest::multipart::Form,
401    ) -> Result<Response> {
402        let url = self.build_url(path);
403
404        let mut headers = HeaderMap::new();
405        headers.insert(
406            AUTHORIZATION,
407            HeaderValue::from_str(&format!("Token {}", self.api_token))
408                .map_err(|_| Error::auth_error("Invalid API token format"))?,
409        );
410        headers.insert(
411            USER_AGENT,
412            HeaderValue::from_str(&format!("replicate-rs/{}", VERSION))
413                .map_err(|_| Error::InvalidInput("Invalid user agent format".to_string()))?,
414        );
415
416        // For multipart requests, we need to use the underlying reqwest client directly
417        // since reqwest-middleware doesn't support multipart forms
418        let inner_client = reqwest::Client::new();
419        let request = inner_client
420            .request(method, &url)
421            .headers(headers)
422            .multipart(form);
423
424        let response = request.send().await?;
425
426        if response.status().is_success() {
427            Ok(response)
428        } else {
429            let status = response.status().as_u16();
430            let text = response.text().await.unwrap_or_default();
431
432            // Try to parse as JSON error
433            if let Ok(api_error) = serde_json::from_str::<serde_json::Value>(&text) {
434                let message = api_error
435                    .get("detail")
436                    .and_then(|v| v.as_str())
437                    .unwrap_or("Unknown API error");
438
439                Err(Error::Api {
440                    status,
441                    message: message.to_string(),
442                    detail: Some(text),
443                })
444            } else {
445                Err(Error::Api {
446                    status,
447                    message: text,
448                    detail: None,
449                })
450            }
451        }
452    }
453
454    /// POST request with multipart form data.
455    pub async fn post_multipart(
456        &self,
457        path: &str,
458        form: reqwest::multipart::Form,
459    ) -> Result<Response> {
460        self.execute_multipart_request(Method::POST, path, form)
461            .await
462    }
463
464    /// POST multipart form data and parse JSON response.
465    pub async fn post_multipart_json<T: for<'de> serde::Deserialize<'de>>(
466        &self,
467        path: &str,
468        form: reqwest::multipart::Form,
469    ) -> Result<T> {
470        let response = self.post_multipart(path, form).await?;
471        let text = response.text().await?;
472        serde_json::from_str(&text).map_err(Into::into)
473    }
474
475    /// Create a multipart form from file and optional metadata.
476    pub async fn create_file_form(
477        file_content: &[u8],
478        filename: Option<&str>,
479        content_type: Option<&str>,
480        metadata: Option<&std::collections::HashMap<String, serde_json::Value>>,
481    ) -> Result<reqwest::multipart::Form> {
482        let filename = filename.unwrap_or("file").to_string();
483        let content_type = content_type
484            .unwrap_or("application/octet-stream")
485            .to_string();
486
487        let file_part = reqwest::multipart::Part::bytes(file_content.to_vec())
488            .file_name(filename)
489            .mime_str(&content_type)
490            .map_err(|e| Error::InvalidInput(format!("Invalid content type: {}", e)))?;
491
492        let mut form = reqwest::multipart::Form::new().part("content", file_part);
493
494        // Add metadata if provided
495        if let Some(metadata) = metadata {
496            let metadata_json = serde_json::to_string(metadata)?;
497            form = form.text("metadata", metadata_json);
498        }
499
500        Ok(form)
501    }
502
503    /// Create a multipart form from a file path.
504    pub async fn create_file_form_from_path(
505        file_path: &Path,
506        metadata: Option<&std::collections::HashMap<String, serde_json::Value>>,
507    ) -> Result<reqwest::multipart::Form> {
508        // Read file content
509        let file_content = tokio::fs::read(file_path).await?;
510
511        // Determine filename and content type
512        let filename = file_path
513            .file_name()
514            .and_then(|n| n.to_str())
515            .unwrap_or("file");
516
517        let content_type = mime_guess::from_path(file_path)
518            .first_or_octet_stream()
519            .to_string();
520
521        Self::create_file_form(&file_content, Some(filename), Some(&content_type), metadata).await
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn test_build_url() {
531        let client = HttpClient::new("test-token").unwrap();
532
533        assert_eq!(
534            client.build_url("/v1/predictions"),
535            "https://api.replicate.com/v1/predictions"
536        );
537
538        assert_eq!(
539            client.build_url("v1/predictions"),
540            "https://api.replicate.com/v1/predictions"
541        );
542    }
543
544    #[test]
545    fn test_empty_token_error() {
546        let result = HttpClient::new("");
547        assert!(result.is_err());
548        assert!(matches!(result.unwrap_err(), Error::Auth(_)));
549    }
550
551    #[test]
552    fn test_client_creation_with_retry() {
553        let client = HttpClient::new("test-token");
554        assert!(client.is_ok());
555
556        // Verify the client has retry capabilities by checking it's using middleware
557        let client = client.unwrap();
558        let _inner = client.inner(); // Should be ClientWithMiddleware
559
560        // Verify default retry configuration
561        let retry_config = client.retry_config();
562        assert_eq!(retry_config.max_retries, 3);
563        assert_eq!(retry_config.min_delay, Duration::from_millis(500));
564        assert_eq!(retry_config.max_delay, Duration::from_secs(30));
565        assert_eq!(retry_config.base_multiplier, 2);
566    }
567
568    #[test]
569    fn test_retry_configuration() {
570        let mut client = HttpClient::new("test-token").unwrap();
571
572        // Test initial configuration
573        let initial_config = client.retry_config();
574        assert_eq!(initial_config.max_retries, 3);
575
576        // Test configuration update
577        let result =
578            client.configure_retries(5, Duration::from_millis(100), Duration::from_secs(60));
579        assert!(result.is_ok());
580
581        // Verify new configuration
582        let new_config = client.retry_config();
583        assert_eq!(new_config.max_retries, 5);
584        assert_eq!(new_config.min_delay, Duration::from_millis(100));
585        assert_eq!(new_config.max_delay, Duration::from_secs(60));
586        assert_eq!(new_config.base_multiplier, 2);
587    }
588
589    #[test]
590    fn test_custom_retry_config() {
591        let custom_config = RetryConfig {
592            max_retries: 2,
593            min_delay: Duration::from_millis(200),
594            max_delay: Duration::from_secs(10),
595            base_multiplier: 3,
596        };
597
598        let client = HttpClient::with_retry_config("test-token", custom_config.clone());
599        assert!(client.is_ok());
600
601        let client = client.unwrap();
602        let actual_config = client.retry_config();
603        assert_eq!(actual_config.max_retries, custom_config.max_retries);
604        assert_eq!(actual_config.min_delay, custom_config.min_delay);
605        assert_eq!(actual_config.max_delay, custom_config.max_delay);
606        assert_eq!(actual_config.base_multiplier, custom_config.base_multiplier);
607    }
608
609    #[test]
610    fn test_timeout_configuration() {
611        let timeout_config = TimeoutConfig {
612            connect_timeout: Some(Duration::from_secs(15)),
613            request_timeout: Some(Duration::from_secs(90)),
614        };
615
616        let http_config = HttpConfig {
617            retry: RetryConfig::default(),
618            timeout: timeout_config,
619        };
620
621        let client = HttpClient::with_http_config("test-token", http_config);
622        assert!(client.is_ok());
623
624        let client = client.unwrap();
625        let returned_timeout_config = client.timeout_config();
626        assert_eq!(
627            returned_timeout_config.connect_timeout,
628            Some(Duration::from_secs(15))
629        );
630        assert_eq!(
631            returned_timeout_config.request_timeout,
632            Some(Duration::from_secs(90))
633        );
634    }
635
636    #[test]
637    fn test_timeout_reconfiguration() {
638        let mut client = HttpClient::new("test-token").unwrap();
639
640        // Initial state should be default
641        let initial_config = client.timeout_config();
642        assert_eq!(
643            initial_config.connect_timeout,
644            Some(Duration::from_secs(30))
645        );
646        assert_eq!(
647            initial_config.request_timeout,
648            Some(Duration::from_secs(60))
649        );
650
651        // Configure new timeouts
652        let result =
653            client.configure_timeouts(Some(Duration::from_secs(5)), Some(Duration::from_secs(120)));
654        assert!(result.is_ok());
655
656        let updated_config = client.timeout_config();
657        assert_eq!(updated_config.connect_timeout, Some(Duration::from_secs(5)));
658        assert_eq!(
659            updated_config.request_timeout,
660            Some(Duration::from_secs(120))
661        );
662    }
663
664    #[test]
665    fn test_timeout_disable() {
666        let mut client = HttpClient::new("test-token").unwrap();
667
668        // Disable all timeouts
669        let result = client.configure_timeouts(None, None);
670        assert!(result.is_ok());
671
672        let config = client.timeout_config();
673        assert_eq!(config.connect_timeout, None);
674        assert_eq!(config.request_timeout, None);
675    }
676
677    #[test]
678    fn test_http_config_accessors() {
679        let http_config = HttpConfig {
680            retry: RetryConfig {
681                max_retries: 2,
682                min_delay: Duration::from_millis(100),
683                max_delay: Duration::from_secs(20),
684                base_multiplier: 4,
685            },
686            timeout: TimeoutConfig {
687                connect_timeout: Some(Duration::from_secs(10)),
688                request_timeout: Some(Duration::from_secs(45)),
689            },
690        };
691
692        let client = HttpClient::with_http_config("test-token", http_config);
693        assert!(client.is_ok());
694
695        let client = client.unwrap();
696        let returned_config = client.http_config();
697        assert_eq!(returned_config.retry.max_retries, 2);
698        assert_eq!(returned_config.retry.min_delay, Duration::from_millis(100));
699        assert_eq!(returned_config.retry.max_delay, Duration::from_secs(20));
700        assert_eq!(returned_config.retry.base_multiplier, 4);
701        assert_eq!(
702            returned_config.timeout.connect_timeout,
703            Some(Duration::from_secs(10))
704        );
705        assert_eq!(
706            returned_config.timeout.request_timeout,
707            Some(Duration::from_secs(45))
708        );
709    }
710}