replicate_rust/
lib.rs

1//! An Unofficial Rust client for [Replicate](https://replicate.com). Provides a type-safe interface by deserializing API responses into Rust structs.
2//!
3//! ## Getting Started
4//!
5//! Add `replicate_rust` to `Cargo.toml`:
6//!
7//! ```toml
8//! [dependencies]
9//! replicate-rust = "0.0.5"
10//! ```
11//!
12//! Grab your token from [replicate.com/account](https://replicate.com/account) and set it as an environment variable:
13//!
14//! ```sh
15//! export REPLICATE_API_TOKEN=<your token>
16//! ```
17//!
18//! Here's an example using `replicate_rust` to run a model:
19//!
20//! ```rust
21//! use replicate_rust::{config::Config, Replicate, errors::ReplicateError};
22//!
23//! fn main() -> Result<(), ReplicateError> {
24//!    let config = Config::default();
25//!    // Instead of using the default config ( which reads API token from env variable), you can also set the token directly:
26//!    // let config = Config {
27//!    //     auth: String::from("REPLICATE_API_TOKEN"),
28//!    //     ..Default::default()
29//!    // };
30//!
31//!    let replicate = Replicate::new(config);
32//!
33//!    // Construct the inputs.
34//!    let mut inputs = std::collections::HashMap::new();
35//!    inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
36//!
37//!    let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
38//!
39//!    // Run the model.
40//!    let result = replicate.run(version, inputs)?;
41//!
42//!    // Print the result.
43//!    println!("{:?}", result.output);
44//!    // Some(Array [String("https://pbxt.replicate.delivery/QLDGe2rXuIQ9ByMViQEXrYCkKfDi9I3YWAzPwWsDZWMXeN7iA/out-0.png")])```
45//!
46//!    Ok(())
47//! }
48//! ```
49//!
50//! ## Usage
51//!
52//! See the [reference docs](https://docs.rs/replicate-rust/) for detailed API documentation.
53//!
54//! ## Examples
55//!
56//! - Run a model in the background:
57//!     ```rust
58//!     // Construct the inputs.
59//!     let mut inputs = std::collections::HashMap::new();
60//!     inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
61//!
62//!     let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
63//!
64//!     // Run the model.
65//!     let mut prediction = replicate.predictions.create(version, inputs)?;
66//!
67//!     println!("{:?}", prediction.status);
68//!     // 'starting'
69//!
70//!     prediction.reload()?;
71//!     println!("{:?}", prediction.status);
72//!     // 'processing'
73//!
74//!     println!("{:?}", prediction.logs);
75//!     // Some("Using seed: 3599
76//!     // 0%|          | 0/50 [00:00<?, ?it/s]
77//!     // 4%|▍         | 2/50 [00:00<00:04, 10.00it/s]
78//!     // 8%|▊         | 4/50 [00:00<00:03, 11.56it/s]
79//!    
80//!
81//!     let prediction = prediction.wait()?;
82//!
83//!     println!("{:?}", prediction.status);
84//!     // 'succeeded'
85//!
86//!     println!("{:?}", prediction.output);
87// !    // Some(Array [String("https://pbxt.replicate.delivery/QLDGe2rXuIQ9ByMViQEXrYCkKfDi9I3YWAzPwWsDZWMXeN7iA/out-0.png")])
88//!     ```
89//!
90//! - Cancel a prediction:
91//!   ```rust
92//!   // Construct the inputs.
93//!   let mut inputs = std::collections::HashMap::new();
94//!   inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
95//!
96//!   let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
97//!
98//!   // Run the model.
99//!   let mut prediction = replicate.predictions.create(version, inputs)?;
100//!
101//!   println!("{:?}", prediction.status);
102//!   // 'starting'
103//!
104//!   prediction.cancel()?;
105//!
106//!   prediction.reload()?;
107//!
108//!   println!("{:?}", prediction.status);
109//!   // 'cancelled'
110//!   ```
111//!
112//! - List predictions:
113//!   ```rust
114//!   let predictions = replicate.predictions.list()?;
115//!   println!("{:?}", predictions);
116//!   // ListPredictions { ... }
117//!   ```
118//!
119//! - Get model Information:
120//!   ```rust
121//!   let model = replicate.models.get("replicate", "hello-world")?;
122//!   println!("{:?}", model);
123//!   // GetModel { ... }
124//!   ```
125//!
126//! - Get Versions List:
127//!   ```rust
128//!   let versions = replicate.models.versions.list("replicate", "hello-world")?;
129//!   println!("{:?}", versions);
130//!   // ListModelVersions { ... }
131//!   ```
132//!
133//! - Get Model Version Information:
134//!   ```rust
135//!   let model = replicate.models.versions.get("kvfrans",
136//!   "clipdraw",
137//!   "5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b",)?;
138//!   println!("{:?}", model);
139//!   // GetModelVersion { ... }
140//!   ```
141//!
142//! - Get Collection Information:
143//!   ```rust
144//!   let collection = replicate.collections.get("audio-generation")?;
145//!   println!("{:?}", collection);
146//!   // GetCollectionModels { ... }//!   ```
147//!    ```
148//!
149//! - Get Collection Lists:
150//!   ```rust
151//!   let collections = replicate.collections.list()?;
152//!   println!("{:?}", collections);
153//!   // ListCollectionModels { ... }
154//!   ```
155//!
156#![warn(missing_docs)]
157#![warn(missing_doc_code_examples)]
158
159use std::collections::HashMap;
160
161use api_definitions::GetPrediction;
162use collection::Collection;
163use config::Config;
164use errors::ReplicateError;
165use model::Model;
166use prediction::Prediction;
167use training::Training;
168
169pub mod collection;
170pub mod config;
171pub mod model;
172pub mod prediction;
173pub mod training;
174pub mod version;
175
176pub mod api_definitions;
177pub mod errors;
178pub mod prediction_client;
179pub mod retry;
180
181/// Rust Client for interacting with the [Replicate API](https://replicate.com/docs/api/). Currently supports the following endpoints:
182/// * [Predictions](https://replicate.com/docs/reference/http#predictions.create)
183/// * [Models](https://replicate.com/docs/reference/http#models.get)
184/// * [Trainings](https://replicate.com/docs/reference/http#trainings.create)
185/// * [Collections](https://replicate.com/docs/reference/http#collections.get)
186#[derive(Clone, Debug)]
187pub struct Replicate {
188    /// Holds a reference to a Config struct.
189    config: Config,
190
191    /// Holds a reference to a Prediction struct. Use to run inference given model inputs and version.
192    pub predictions: Prediction,
193
194    /// Holds a reference to a Model struct. Use to get information about a model.
195    pub models: Model,
196
197    /// Holds a reference to a Training struct. Use to create a new training run.
198    pub trainings: Training,
199
200    /// Holds a reference to a Collection struct. Use to get and list model collections present in Replicate.
201    pub collections: Collection,
202}
203
204/// Rust Client for interacting with the [Replicate API](https://replicate.com/docs/api/).
205impl Replicate {
206    /// Create a new Replicate client.
207    ///
208    /// # Example
209    /// ```
210    /// use replicate_rust::{Replicate, config::Config};
211    ///
212    /// let config = Config::default();
213    /// let replicate = Replicate::new(config);
214    /// ```
215    pub fn new(config: Config) -> Self {
216        // Check if auth is set.
217        config.check_auth();
218
219        // TODO : Maybe reference instead of clone
220        let predictions = Prediction::new(config.clone());
221        let models = Model::new(config.clone());
222        let trainings = Training::new(config.clone());
223        let collections = Collection::new(config.clone());
224
225        Self {
226            config,
227            predictions,
228            models,
229            trainings,
230            collections,
231        }
232    }
233
234    /// Run a model with the given inputs in a blocking manner.
235    /// # Arguments
236    /// * `version` - The version of the model to run.
237    /// * `inputs` - The inputs to the model in the form of a HashMap.
238    /// # Example
239    /// ```
240    /// use replicate_rust::{Replicate, config::Config};
241    ///
242    /// let config = Config::default();
243    /// let replicate = Replicate::new(config);
244    ///
245    /// // Construct the inputs.
246    /// let mut inputs = std::collections::HashMap::new();
247    /// inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
248    ///
249    /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
250    ///
251    /// // Run the model.
252    /// let result = replicate.run(version, inputs)?;
253    ///
254    /// println!("Output : {:?}", result.output);
255    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
256    /// ```
257    pub fn run<K: serde::Serialize, V: serde::Serialize>(
258        &self,
259        version: &str,
260        inputs: HashMap<K, V>,
261    ) -> Result<GetPrediction, ReplicateError> {
262        let prediction = Prediction::new(self.config.clone()).create(version, inputs)?;
263
264        prediction.wait()
265    }
266}
267
268#[cfg(test)]
269mod tests {
270
271    use super::*;
272    use httpmock::{
273        Method::{GET, POST},
274        MockServer,
275    };
276    use serde_json::json;
277
278    #[test]
279    fn test_run() -> Result<(), ReplicateError> {
280        let server = MockServer::start();
281
282        // Mock the POST response
283        let post_mock = server.mock(|when, then| {
284            when.method(POST)
285                .path("/predictions")
286                .json_body_obj(&json!({
287                    "version": "v1",
288                    "input": {"text": "world"}
289                }));
290            then.status(200).json_body_obj(&json!({
291                "id": "p1",
292                "version": "v1",
293                "urls": {
294                    "get": format!("{}/predictions/p1", server.base_url()),
295                    "cancel": format!("{}/predictions/p1", server.base_url()),
296                },
297                "created_at": "2022-04-26T20:00:40.658234Z",
298                "completed_at": "2022-04-26T20:02:27.648305Z",
299                "source": "api",
300                "status": "processing",
301                "input": {"text": "world"},
302                "output": None::<String>,
303                "error": None::<String>,
304                "logs": None::<String>,
305            }));
306        });
307
308        // Mock the GET response
309        let get_mock = server.mock(|when, then| {
310            when.method(GET).path("/predictions/p1");
311            then.status(200).json_body_obj(&json!({
312                "id": "p1",
313                "version": "v1",
314                "urls": {
315                    "get": format!("{}/predictions/p1", server.base_url()),
316                    "cancel": format!("{}/predictions/p1", server.base_url()),
317                },
318                "created_at": "2022-04-26T20:00:40.658234Z",
319                "completed_at": "2022-04-26T20:02:27.648305Z",
320                "source": "api",
321                "status": "succeeded",
322                "input": {"text": "world"},
323                "output": "hello world",
324                "error": None::<String>,
325                "logs": "",
326            }));
327        });
328
329        let config = Config {
330            auth: String::from("test"),
331            base_url: server.base_url(),
332            ..Config::default()
333        };
334        let replicate = Replicate::new(config);
335
336        let mut inputs = std::collections::HashMap::new();
337        inputs.insert("text", "world");
338
339        let result = replicate.run("test/model:v1", inputs)?;
340
341        // Assert that the returned value is correct
342        assert_eq!(result.output, Some(serde_json::to_value("hello world")?));
343
344        // Ensure the mocks were called as expected
345        post_mock.assert();
346        get_mock.assert();
347
348        Ok(())
349    }
350}