replicate_rust/
prediction.rs

1//! Used to interact with the [Prediction Endpoints](https://replicate.com/docs/reference/http#predictions.get).
2//!
3//! # Example
4//!
5//! ```
6//! use replicate_rust::{Replicate, config::Config};
7//!
8//! let config = Config::default();
9//! let replicate = Replicate::new(config);
10//!
11//! // Construct the inputs.
12//! let mut inputs = std::collections::HashMap::new();
13//! inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
14//!
15//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
16//!
17//! // Run the model.
18//! let result = replicate.predictions.create(version, inputs)?.wait()?;
19//!
20//! // Print the result.
21//! println!("Result : {:?}", result.output);
22//!
23//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
24//! ```
25//!
26//! ## Another example to showcase other methods
27//!
28//! ```
29//! use replicate_rust::{Replicate, config::Config};
30//!
31//! let config = Config::default();
32//! let replicate = Replicate::new(config);
33//!
34//! // Construct the inputs.
35//! let mut inputs = std::collections::HashMap::new();
36//! inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
37//!
38//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
39//!
40//! // Run the model.
41//! let mut prediction = replicate.predictions.create(version, inputs)?;
42//!
43//! println!("Prediction : {:?}", prediction.status);
44//!
45//! // Refetch the prediction using the reload method.
46//! let _ = prediction.reload();
47//! println!("Prediction : {:?}", prediction.status);
48//!
49//! // Cancel the prediction.
50//! let _ = prediction.cancel();
51//! println!("Predictions : {:?}", prediction.status);;
52//!
53//! // Wait for the prediction to complete (or fail).
54//! println!("Prediction : {:?}", prediction.wait()?);
55//!
56//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
57//! ```
58//!
59//!
60
61use serde::Serialize;
62use std::collections::HashMap;
63
64use crate::{
65    api_definitions::{GetPrediction, ListPredictions},
66    errors::ReplicateError,
67    prediction_client::PredictionClient,
68};
69
70/// Used to interact with the [Prediction Endpoints](https://replicate.com/docs/reference/http#predictions.get).
71#[derive(Serialize)]
72pub struct PredictionPayload<K: serde::Serialize, V: serde::ser::Serialize> {
73    /// Version of the model used for the prediction
74    pub version: String,
75
76    /// Input to the model
77    pub input: HashMap<K, V>,
78}
79
80/// Used to interact with the [Prediction Endpoints](https://replicate.com/docs/reference/http#predictions.get).
81#[derive(Clone, Debug)]
82pub struct Prediction {
83    /// Holds a reference to a Config struct. Use to get the base url, auth token among other settings.
84    pub parent: crate::config::Config,
85}
86
87impl Prediction {
88    /// Create a new Prediction struct.
89    pub fn new(rep: crate::config::Config) -> Self {
90        Self { parent: rep }
91    }
92
93    /// Create a new prediction, by passing in the model version and inputs to PredictionClient.
94    /// PredictionClient contains the necessary methods to interact with the prediction such as reload, cancel and wait.
95    ///
96    /// # Example
97    ///
98    /// ```
99    /// use replicate_rust::{Replicate, config::Config};
100    ///
101    /// let config = Config::default();
102    /// let replicate = Replicate::new(config);
103    ///
104    /// // Construct the inputs.
105    /// let mut inputs = std::collections::HashMap::new();
106    /// inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
107    ///
108    /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
109    ///
110    /// // Run the model.
111    /// let mut prediction = replicate.predictions.create(version, inputs)?;
112    ///
113    /// println!("Prediction : {:?}", prediction.status);
114    ///
115    /// // Refetch the prediction using the reload method.
116    /// prediction.reload();
117    /// println!("Prediction : {:?}", prediction.status);
118    ///
119    /// // Wait for the prediction to complete (or fail).
120    /// println!("Prediction : {:?}", prediction.wait()?);
121    ///
122    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
123    /// ```
124    ///
125    pub fn create<K: serde::Serialize, V: serde::ser::Serialize>(
126        &self,
127        version: &str,
128        inputs: HashMap<K, V>,
129    ) -> Result<PredictionClient, ReplicateError> {
130        Ok(PredictionClient::create(
131            self.parent.clone(),
132            version,
133            inputs,
134        )?)
135    }
136
137    /// List all predictions executed in Replicate by the user.
138    ///
139    /// # Example
140    ///
141    /// ```
142    /// use replicate_rust::{Replicate, config::Config};
143    ///
144    /// let config = Config::default();
145    /// let replicate = Replicate::new(config);
146    ///
147    /// let predictions = replicate.predictions.list()?;
148    /// println!("Predictions : {:?}", predictions);
149    ///
150    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
151    /// ```
152    pub fn list(&self) -> Result<ListPredictions, ReplicateError> {
153        let client = reqwest::blocking::Client::new();
154
155        let response = client
156            .get(format!("{}/predictions", self.parent.base_url))
157            .header("Authorization", format!("Token {}", self.parent.auth))
158            .header("User-Agent", &self.parent.user_agent)
159            .send()?;
160
161        if !response.status().is_success() {
162            return Err(ReplicateError::ResponseError(response.text()?));
163        }
164
165        let response_string = response.text()?;
166        let response_struct: ListPredictions = serde_json::from_str(&response_string)?;
167
168        Ok(response_struct)
169    }
170
171    /// Get a prediction by passing in the prediction id.
172    /// The prediction id can be obtained from the PredictionClient struct.
173    ///
174    /// # Example
175    ///
176    /// ```
177    /// use replicate_rust::{Replicate, config::Config};
178    ///
179    /// let config = Config::default();
180    /// let replicate = Replicate::new(config);
181    ///
182    /// let prediction = replicate.predictions.get("rrr4z55ocneqzikepnug6xezpe")?;
183    /// println!("Prediction : {:?}", prediction);
184    ///
185    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
186    /// ```
187    pub fn get(&self, id: &str) -> Result<GetPrediction, ReplicateError> {
188        let client = reqwest::blocking::Client::new();
189
190        let response = client
191            .get(format!("{}/predictions/{}", self.parent.base_url, id))
192            .header("Authorization", format!("Token {}", self.parent.auth))
193            .header("User-Agent", &self.parent.user_agent)
194            .send()?;
195
196        if !response.status().is_success() {
197            return Err(ReplicateError::ResponseError(response.text()?));
198        }
199
200        let response_string = response.text()?;
201        let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
202
203        Ok(response_struct)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use crate::{config::Config, Replicate};
210
211    use super::*;
212    use httpmock::{Method::GET, MockServer};
213    use serde_json::json;
214
215    #[test]
216    fn test_list() -> Result<(), ReplicateError> {
217        let server = MockServer::start();
218
219        let get_mock = server.mock(|when, then| {
220            when.method(GET).path("/predictions");
221            then.status(200).json_body_obj(&json!( {
222                "next": "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw",
223                "previous": None::<String>,
224                "results": [
225                  {
226                    "id": "jpzd7hm5gfcapbfyt4mqytarku",
227                    "version":
228                      "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05",
229                    "urls": {
230                      "get": "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku",
231                      "cancel":
232                        "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel",
233                    },
234                    "created_at": "2022-04-26T20:00:40.658234Z",
235                    "started_at": "2022-04-26T20:00:84.583803Z",
236                    "completed_at": "2022-04-26T20:02:27.648305Z",
237                    "source": "web",
238                    "status": "succeeded",
239                  },
240                ],
241              }
242            ));
243        });
244
245        let config = Config {
246            auth: String::from("test"),
247            base_url: server.base_url(),
248            ..Config::default()
249        };
250        let replicate = Replicate::new(config);
251
252        let mut input = HashMap::new();
253        input.insert("text", "...");
254
255        let result = replicate.predictions.list()?;
256
257        assert_eq!(result.results.len(), 1);
258        assert_eq!(result.results[0].id, "jpzd7hm5gfcapbfyt4mqytarku");
259
260        // Ensure the mocks were called as expected
261        get_mock.assert();
262
263        Ok(())
264    }
265
266    #[test]
267    fn test_get() -> Result<(), ReplicateError> {
268        let server = MockServer::start();
269
270        let get_mock = server.mock(|when, then| {
271            when.method(GET).path("/predictions/rrr4z55ocneqzikepnug6xezpe");
272            then.status(200).json_body_obj(&json!(  {
273                "id": "rrr4z55ocneqzikepnug6xezpe",
274                "version":
275                  "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e",
276                "urls": {
277                  "get": "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe",
278                  "cancel":
279                    "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel",
280                },
281                "created_at": "2022-09-13T22:54:18.578761Z",
282                "started_at": "2022-09-13T22:54:19.438525Z",
283                "completed_at": "2022-09-13T22:54:23.236610Z",
284                "source": "api",
285                "status": "succeeded",
286                "input": {
287                  "prompt": "oak tree with boletus growing on its branches",
288                },
289                "output": [
290                  "https://replicate.com/api/models/stability-ai/stable-diffusion/files/9c3b6fe4-2d37-4571-a17a-83951b1cb120/out-0.png",
291                ],
292                "error": None::<String>,
293                "logs": "Using seed: 36941...",
294                "metrics": {
295                  "predict_time": 4.484541,
296                },
297              }
298            ));
299        });
300
301        let config = Config {
302            auth: String::from("test"),
303            base_url: server.base_url(),
304            ..Config::default()
305        };
306        let replicate = Replicate::new(config);
307
308        let result = replicate.predictions.get("rrr4z55ocneqzikepnug6xezpe")?;
309
310        assert_eq!(result.id, "rrr4z55ocneqzikepnug6xezpe");
311
312        // Ensure the mocks were called as expected
313        get_mock.assert();
314
315        Ok(())
316    }
317}