replicate_rust/lib.rs
1//! An Unofficial Rust client for [Replicate](https://replicate.com). Provides a type-safe interface by deserializing API responses into Rust structs.
2//!
3//! ## Getting Started
4//!
5//! Add `replicate_rust` to `Cargo.toml`:
6//!
7//! ```toml
8//! [dependencies]
9//! replicate-rust = "0.0.5"
10//! ```
11//!
12//! Grab your token from [replicate.com/account](https://replicate.com/account) and set it as an environment variable:
13//!
14//! ```sh
15//! export REPLICATE_API_TOKEN=<your token>
16//! ```
17//!
18//! Here's an example using `replicate_rust` to run a model:
19//!
20//! ```rust
21//! use replicate_rust::{config::Config, Replicate, errors::ReplicateError};
22//!
23//! fn main() -> Result<(), ReplicateError> {
24//! let config = Config::default();
25//! // Instead of using the default config ( which reads API token from env variable), you can also set the token directly:
26//! // let config = Config {
27//! // auth: String::from("REPLICATE_API_TOKEN"),
28//! // ..Default::default()
29//! // };
30//!
31//! let replicate = Replicate::new(config);
32//!
33//! // Construct the inputs.
34//! let mut inputs = std::collections::HashMap::new();
35//! inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
36//!
37//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
38//!
39//! // Run the model.
40//! let result = replicate.run(version, inputs)?;
41//!
42//! // Print the result.
43//! println!("{:?}", result.output);
44//! // Some(Array [String("https://pbxt.replicate.delivery/QLDGe2rXuIQ9ByMViQEXrYCkKfDi9I3YWAzPwWsDZWMXeN7iA/out-0.png")])```
45//!
46//! Ok(())
47//! }
48//! ```
49//!
50//! ## Usage
51//!
52//! See the [reference docs](https://docs.rs/replicate-rust/) for detailed API documentation.
53//!
54//! ## Examples
55//!
56//! - Run a model in the background:
57//! ```rust
58//! // Construct the inputs.
59//! let mut inputs = std::collections::HashMap::new();
60//! inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
61//!
62//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
63//!
64//! // Run the model.
65//! let mut prediction = replicate.predictions.create(version, inputs)?;
66//!
67//! println!("{:?}", prediction.status);
68//! // 'starting'
69//!
70//! prediction.reload()?;
71//! println!("{:?}", prediction.status);
72//! // 'processing'
73//!
74//! println!("{:?}", prediction.logs);
75//! // Some("Using seed: 3599
76//! // 0%| | 0/50 [00:00<?, ?it/s]
77//! // 4%|▍ | 2/50 [00:00<00:04, 10.00it/s]
78//! // 8%|▊ | 4/50 [00:00<00:03, 11.56it/s]
79//!
80//!
81//! let prediction = prediction.wait()?;
82//!
83//! println!("{:?}", prediction.status);
84//! // 'succeeded'
85//!
86//! println!("{:?}", prediction.output);
87// ! // Some(Array [String("https://pbxt.replicate.delivery/QLDGe2rXuIQ9ByMViQEXrYCkKfDi9I3YWAzPwWsDZWMXeN7iA/out-0.png")])
88//! ```
89//!
90//! - Cancel a prediction:
91//! ```rust
92//! // Construct the inputs.
93//! let mut inputs = std::collections::HashMap::new();
94//! inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
95//!
96//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
97//!
98//! // Run the model.
99//! let mut prediction = replicate.predictions.create(version, inputs)?;
100//!
101//! println!("{:?}", prediction.status);
102//! // 'starting'
103//!
104//! prediction.cancel()?;
105//!
106//! prediction.reload()?;
107//!
108//! println!("{:?}", prediction.status);
109//! // 'cancelled'
110//! ```
111//!
112//! - List predictions:
113//! ```rust
114//! let predictions = replicate.predictions.list()?;
115//! println!("{:?}", predictions);
116//! // ListPredictions { ... }
117//! ```
118//!
119//! - Get model Information:
120//! ```rust
121//! let model = replicate.models.get("replicate", "hello-world")?;
122//! println!("{:?}", model);
123//! // GetModel { ... }
124//! ```
125//!
126//! - Get Versions List:
127//! ```rust
128//! let versions = replicate.models.versions.list("replicate", "hello-world")?;
129//! println!("{:?}", versions);
130//! // ListModelVersions { ... }
131//! ```
132//!
133//! - Get Model Version Information:
134//! ```rust
135//! let model = replicate.models.versions.get("kvfrans",
136//! "clipdraw",
137//! "5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b",)?;
138//! println!("{:?}", model);
139//! // GetModelVersion { ... }
140//! ```
141//!
142//! - Get Collection Information:
143//! ```rust
144//! let collection = replicate.collections.get("audio-generation")?;
145//! println!("{:?}", collection);
146//! // GetCollectionModels { ... }//! ```
147//! ```
148//!
149//! - Get Collection Lists:
150//! ```rust
151//! let collections = replicate.collections.list()?;
152//! println!("{:?}", collections);
153//! // ListCollectionModels { ... }
154//! ```
155//!
156#![warn(missing_docs)]
157#![warn(missing_doc_code_examples)]
158
159use std::collections::HashMap;
160
161use api_definitions::GetPrediction;
162use collection::Collection;
163use config::Config;
164use errors::ReplicateError;
165use model::Model;
166use prediction::Prediction;
167use training::Training;
168
169pub mod collection;
170pub mod config;
171pub mod model;
172pub mod prediction;
173pub mod training;
174pub mod version;
175
176pub mod api_definitions;
177pub mod errors;
178pub mod prediction_client;
179pub mod retry;
180
181/// Rust Client for interacting with the [Replicate API](https://replicate.com/docs/api/). Currently supports the following endpoints:
182/// * [Predictions](https://replicate.com/docs/reference/http#predictions.create)
183/// * [Models](https://replicate.com/docs/reference/http#models.get)
184/// * [Trainings](https://replicate.com/docs/reference/http#trainings.create)
185/// * [Collections](https://replicate.com/docs/reference/http#collections.get)
186#[derive(Clone, Debug)]
187pub struct Replicate {
188 /// Holds a reference to a Config struct.
189 config: Config,
190
191 /// Holds a reference to a Prediction struct. Use to run inference given model inputs and version.
192 pub predictions: Prediction,
193
194 /// Holds a reference to a Model struct. Use to get information about a model.
195 pub models: Model,
196
197 /// Holds a reference to a Training struct. Use to create a new training run.
198 pub trainings: Training,
199
200 /// Holds a reference to a Collection struct. Use to get and list model collections present in Replicate.
201 pub collections: Collection,
202}
203
204/// Rust Client for interacting with the [Replicate API](https://replicate.com/docs/api/).
205impl Replicate {
206 /// Create a new Replicate client.
207 ///
208 /// # Example
209 /// ```
210 /// use replicate_rust::{Replicate, config::Config};
211 ///
212 /// let config = Config::default();
213 /// let replicate = Replicate::new(config);
214 /// ```
215 pub fn new(config: Config) -> Self {
216 // Check if auth is set.
217 config.check_auth();
218
219 // TODO : Maybe reference instead of clone
220 let predictions = Prediction::new(config.clone());
221 let models = Model::new(config.clone());
222 let trainings = Training::new(config.clone());
223 let collections = Collection::new(config.clone());
224
225 Self {
226 config,
227 predictions,
228 models,
229 trainings,
230 collections,
231 }
232 }
233
234 /// Run a model with the given inputs in a blocking manner.
235 /// # Arguments
236 /// * `version` - The version of the model to run.
237 /// * `inputs` - The inputs to the model in the form of a HashMap.
238 /// # Example
239 /// ```
240 /// use replicate_rust::{Replicate, config::Config};
241 ///
242 /// let config = Config::default();
243 /// let replicate = Replicate::new(config);
244 ///
245 /// // Construct the inputs.
246 /// let mut inputs = std::collections::HashMap::new();
247 /// inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
248 ///
249 /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
250 ///
251 /// // Run the model.
252 /// let result = replicate.run(version, inputs)?;
253 ///
254 /// println!("Output : {:?}", result.output);
255 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
256 /// ```
257 pub fn run<K: serde::Serialize, V: serde::Serialize>(
258 &self,
259 version: &str,
260 inputs: HashMap<K, V>,
261 ) -> Result<GetPrediction, ReplicateError> {
262 let prediction = Prediction::new(self.config.clone()).create(version, inputs)?;
263
264 prediction.wait()
265 }
266}
267
268#[cfg(test)]
269mod tests {
270
271 use super::*;
272 use httpmock::{
273 Method::{GET, POST},
274 MockServer,
275 };
276 use serde_json::json;
277
278 #[test]
279 fn test_run() -> Result<(), ReplicateError> {
280 let server = MockServer::start();
281
282 // Mock the POST response
283 let post_mock = server.mock(|when, then| {
284 when.method(POST)
285 .path("/predictions")
286 .json_body_obj(&json!({
287 "version": "v1",
288 "input": {"text": "world"}
289 }));
290 then.status(200).json_body_obj(&json!({
291 "id": "p1",
292 "version": "v1",
293 "urls": {
294 "get": format!("{}/predictions/p1", server.base_url()),
295 "cancel": format!("{}/predictions/p1", server.base_url()),
296 },
297 "created_at": "2022-04-26T20:00:40.658234Z",
298 "completed_at": "2022-04-26T20:02:27.648305Z",
299 "source": "api",
300 "status": "processing",
301 "input": {"text": "world"},
302 "output": None::<String>,
303 "error": None::<String>,
304 "logs": None::<String>,
305 }));
306 });
307
308 // Mock the GET response
309 let get_mock = server.mock(|when, then| {
310 when.method(GET).path("/predictions/p1");
311 then.status(200).json_body_obj(&json!({
312 "id": "p1",
313 "version": "v1",
314 "urls": {
315 "get": format!("{}/predictions/p1", server.base_url()),
316 "cancel": format!("{}/predictions/p1", server.base_url()),
317 },
318 "created_at": "2022-04-26T20:00:40.658234Z",
319 "completed_at": "2022-04-26T20:02:27.648305Z",
320 "source": "api",
321 "status": "succeeded",
322 "input": {"text": "world"},
323 "output": "hello world",
324 "error": None::<String>,
325 "logs": "",
326 }));
327 });
328
329 let config = Config {
330 auth: String::from("test"),
331 base_url: server.base_url(),
332 ..Config::default()
333 };
334 let replicate = Replicate::new(config);
335
336 let mut inputs = std::collections::HashMap::new();
337 inputs.insert("text", "world");
338
339 let result = replicate.run("test/model:v1", inputs)?;
340
341 // Assert that the returned value is correct
342 assert_eq!(result.output, Some(serde_json::to_value("hello world")?));
343
344 // Ensure the mocks were called as expected
345 post_mock.assert();
346 get_mock.assert();
347
348 Ok(())
349 }
350}