replicate_client/
client.rs

1//! Main client implementation for the Replicate API.
2
3use crate::api::{FilesApi, PredictionsApi, predictions::PredictionBuilder};
4use crate::error::{Error, Result};
5use crate::http::{HttpClient, HttpConfig, TimeoutConfig};
6use std::{env, time::Duration};
7
8/// Main client for interacting with the Replicate API.
9#[derive(Debug, Clone)]
10pub struct Client {
11    http: HttpClient,
12    predictions_api: PredictionsApi,
13    files_api: FilesApi,
14}
15
16impl Client {
17    /// Create a new client with the given API token.
18    pub fn new(api_token: impl Into<String>) -> Result<Self> {
19        let http = HttpClient::new(api_token)?;
20        let predictions_api = PredictionsApi::new(http.clone());
21        let files_api = FilesApi::new(http.clone());
22
23        Ok(Self {
24            http,
25            predictions_api,
26            files_api,
27        })
28    }
29
30    /// Create a new client using the API token from the environment.
31    ///
32    /// Looks for the token in the `REPLICATE_API_TOKEN` environment variable.
33    pub fn from_env() -> Result<Self> {
34        let api_token = env::var("REPLICATE_API_TOKEN")
35            .map_err(|_| Error::auth_error("REPLICATE_API_TOKEN environment variable not found"))?;
36        Self::new(api_token)
37    }
38
39    /// Create a new client with custom base URL.
40    pub fn with_base_url(
41        api_token: impl Into<String>,
42        base_url: impl Into<String>,
43    ) -> Result<Self> {
44        let http = HttpClient::with_base_url(api_token, base_url)?;
45        let predictions_api = PredictionsApi::new(http.clone());
46        let files_api = FilesApi::new(http.clone());
47
48        Ok(Self {
49            http,
50            predictions_api,
51            files_api,
52        })
53    }
54
55    /// Get access to the predictions API.
56    pub fn predictions(&self) -> &PredictionsApi {
57        &self.predictions_api
58    }
59
60    /// Get access to the files API.
61    pub fn files(&self) -> &FilesApi {
62        &self.files_api
63    }
64
65    /// Create a new prediction with a fluent builder API.
66    ///
67    /// # Examples
68    ///
69    /// ```no_run
70    /// # use replicate_client::Client;
71    /// # #[tokio::main]
72    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
73    /// let client = Client::new("your-api-token")?;
74    ///
75    /// let prediction = client
76    ///     .create_prediction("stability-ai/sdxl:version-id")
77    ///     .input("prompt", "A futuristic city skyline")
78    ///     .input("width", 1024)
79    ///     .input("height", 1024)
80    ///     .send()
81    ///     .await?;
82    ///
83    /// println!("Prediction ID: {}", prediction.id);
84    /// # Ok(())
85    /// # }
86    /// ```
87    pub fn create_prediction(&self, version: impl Into<String>) -> PredictionBuilder {
88        PredictionBuilder::new(self.predictions_api.clone(), version)
89    }
90
91    /// Run a model and wait for completion (convenience method).
92    ///
93    /// This is equivalent to creating a prediction and waiting for it to complete.
94    ///
95    /// # Examples
96    ///
97    /// ```no_run
98    /// # use replicate_client::Client;
99    /// # #[tokio::main]
100    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
101    /// let client = Client::new("your-api-token")?;
102    ///
103    /// let result = client
104    ///     .run("stability-ai/sdxl:version-id")
105    ///     .input("prompt", "A futuristic city skyline")
106    ///     .send_and_wait()
107    ///     .await?;
108    ///
109    /// println!("Result: {:?}", result.output);
110    /// # Ok(())
111    /// # }
112    /// ```
113    pub fn run(&self, version: impl Into<String>) -> PredictionBuilder {
114        self.create_prediction(version)
115    }
116
117    /// Get the underlying HTTP client.
118    pub fn http_client(&self) -> &HttpClient {
119        &self.http
120    }
121
122    /// Get mutable access to the underlying HTTP client.
123    ///
124    /// This allows configuring retry settings after client creation.
125    pub fn http_client_mut(&mut self) -> &mut HttpClient {
126        &mut self.http
127    }
128
129    /// Configure retry settings for this client.
130    ///
131    /// This is a convenience method that delegates to the HTTP client.
132    ///
133    /// # Examples
134    ///
135    /// ```no_run
136    /// # use replicate_client::Client;
137    /// # use std::time::Duration;
138    /// # #[tokio::main]
139    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
140    /// let mut client = Client::new("your-api-token")?;
141    ///
142    /// // Configure more aggressive retry settings
143    /// client.configure_retries(
144    ///     5,                               // max_retries
145    ///     Duration::from_millis(100),      // min_delay
146    ///     Duration::from_secs(60),         // max_delay
147    /// )?;
148    /// # Ok(())
149    /// # }
150    /// ```
151    pub fn configure_retries(
152        &mut self,
153        max_retries: u32,
154        min_delay: Duration,
155        max_delay: Duration,
156    ) -> Result<()> {
157        self.http
158            .configure_retries(max_retries, min_delay, max_delay)
159    }
160
161    /// Configure timeout settings for this client.
162    ///
163    /// This is a convenience method that delegates to the HTTP client.
164    ///
165    /// # Examples
166    ///
167    /// ```no_run
168    /// # use replicate_client::Client;
169    /// # use std::time::Duration;
170    /// # #[tokio::main]
171    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
172    /// let mut client = Client::new("your-api-token")?;
173    ///
174    /// // Configure custom timeouts
175    /// client.configure_timeouts(
176    ///     Some(Duration::from_secs(10)),   // connect_timeout
177    ///     Some(Duration::from_secs(120)),  // request_timeout
178    /// )?;
179    /// # Ok(())
180    /// # }
181    /// ```
182    pub fn configure_timeouts(
183        &mut self,
184        connect_timeout: Option<Duration>,
185        request_timeout: Option<Duration>,
186    ) -> Result<()> {
187        self.http
188            .configure_timeouts(connect_timeout, request_timeout)
189    }
190
191    /// Create a new client with custom HTTP configuration.
192    pub fn with_http_config(api_token: impl Into<String>, http_config: HttpConfig) -> Result<Self> {
193        let http = HttpClient::with_http_config(api_token, http_config)?;
194        let predictions_api = PredictionsApi::new(http.clone());
195        let files_api = FilesApi::new(http.clone());
196
197        Ok(Self {
198            http,
199            predictions_api,
200            files_api,
201        })
202    }
203
204    /// Get the current timeout configuration.
205    pub fn timeout_config(&self) -> &TimeoutConfig {
206        self.http.timeout_config()
207    }
208
209    /// Get the current HTTP configuration.
210    pub fn http_config(&self) -> &HttpConfig {
211        self.http.http_config()
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_client_creation() {
221        let client = Client::new("test-token");
222        assert!(client.is_ok());
223    }
224
225    #[test]
226    fn test_client_empty_token() {
227        let client = Client::new("");
228        assert!(client.is_err());
229        assert!(matches!(client.unwrap_err(), Error::Auth(_)));
230    }
231
232    #[test]
233    fn test_client_from_env_missing() {
234        // Save current value and remove it for test
235        let original = env::var("REPLICATE_API_TOKEN").ok();
236        unsafe {
237            env::remove_var("REPLICATE_API_TOKEN");
238        }
239
240        let client = Client::from_env();
241        assert!(client.is_err());
242        assert!(matches!(client.unwrap_err(), Error::Auth(_)));
243
244        // Restore original value if it existed
245        if let Some(value) = original {
246            unsafe {
247                env::set_var("REPLICATE_API_TOKEN", value);
248            }
249        }
250    }
251
252    #[test]
253    fn test_client_from_env_present() {
254        // Save current value
255        let original = env::var("REPLICATE_API_TOKEN").ok();
256
257        unsafe {
258            env::set_var("REPLICATE_API_TOKEN", "test-token");
259        }
260        let client = Client::from_env();
261        assert!(client.is_ok());
262
263        // Restore original value or remove if it didn't exist
264        unsafe {
265            match original {
266                Some(value) => env::set_var("REPLICATE_API_TOKEN", value),
267                None => env::remove_var("REPLICATE_API_TOKEN"),
268            }
269        }
270    }
271}