replicate_rust/prediction.rs
1//! Used to interact with the [Prediction Endpoints](https://replicate.com/docs/reference/http#predictions.get).
2//!
3//! # Example
4//!
5//! ```
6//! use replicate_rust::{Replicate, config::Config};
7//!
8//! let config = Config::default();
9//! let replicate = Replicate::new(config);
10//!
11//! // Construct the inputs.
12//! let mut inputs = std::collections::HashMap::new();
13//! inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
14//!
15//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
16//!
17//! // Run the model.
18//! let result = replicate.predictions.create(version, inputs)?.wait()?;
19//!
20//! // Print the result.
21//! println!("Result : {:?}", result.output);
22//!
23//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
24//! ```
25//!
26//! ## Another example to showcase other methods
27//!
28//! ```
29//! use replicate_rust::{Replicate, config::Config};
30//!
31//! let config = Config::default();
32//! let replicate = Replicate::new(config);
33//!
34//! // Construct the inputs.
35//! let mut inputs = std::collections::HashMap::new();
36//! inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
37//!
38//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
39//!
40//! // Run the model.
41//! let mut prediction = replicate.predictions.create(version, inputs)?;
42//!
43//! println!("Prediction : {:?}", prediction.status);
44//!
45//! // Refetch the prediction using the reload method.
46//! let _ = prediction.reload();
47//! println!("Prediction : {:?}", prediction.status);
48//!
49//! // Cancel the prediction.
50//! let _ = prediction.cancel();
51//! println!("Predictions : {:?}", prediction.status);;
52//!
53//! // Wait for the prediction to complete (or fail).
54//! println!("Prediction : {:?}", prediction.wait()?);
55//!
56//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
57//! ```
58//!
59//!
60
61use serde::Serialize;
62use std::collections::HashMap;
63
64use crate::{
65 api_definitions::{GetPrediction, ListPredictions},
66 errors::ReplicateError,
67 prediction_client::PredictionClient,
68};
69
70/// Used to interact with the [Prediction Endpoints](https://replicate.com/docs/reference/http#predictions.get).
71#[derive(Serialize)]
72pub struct PredictionPayload<K: serde::Serialize, V: serde::ser::Serialize> {
73 /// Version of the model used for the prediction
74 pub version: String,
75
76 /// Input to the model
77 pub input: HashMap<K, V>,
78}
79
80/// Used to interact with the [Prediction Endpoints](https://replicate.com/docs/reference/http#predictions.get).
81#[derive(Clone, Debug)]
82pub struct Prediction {
83 /// Holds a reference to a Config struct. Use to get the base url, auth token among other settings.
84 pub parent: crate::config::Config,
85}
86
87impl Prediction {
88 /// Create a new Prediction struct.
89 pub fn new(rep: crate::config::Config) -> Self {
90 Self { parent: rep }
91 }
92
93 /// Create a new prediction, by passing in the model version and inputs to PredictionClient.
94 /// PredictionClient contains the necessary methods to interact with the prediction such as reload, cancel and wait.
95 ///
96 /// # Example
97 ///
98 /// ```
99 /// use replicate_rust::{Replicate, config::Config};
100 ///
101 /// let config = Config::default();
102 /// let replicate = Replicate::new(config);
103 ///
104 /// // Construct the inputs.
105 /// let mut inputs = std::collections::HashMap::new();
106 /// inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
107 ///
108 /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
109 ///
110 /// // Run the model.
111 /// let mut prediction = replicate.predictions.create(version, inputs)?;
112 ///
113 /// println!("Prediction : {:?}", prediction.status);
114 ///
115 /// // Refetch the prediction using the reload method.
116 /// prediction.reload();
117 /// println!("Prediction : {:?}", prediction.status);
118 ///
119 /// // Wait for the prediction to complete (or fail).
120 /// println!("Prediction : {:?}", prediction.wait()?);
121 ///
122 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
123 /// ```
124 ///
125 pub fn create<K: serde::Serialize, V: serde::ser::Serialize>(
126 &self,
127 version: &str,
128 inputs: HashMap<K, V>,
129 ) -> Result<PredictionClient, ReplicateError> {
130 Ok(PredictionClient::create(
131 self.parent.clone(),
132 version,
133 inputs,
134 )?)
135 }
136
137 /// List all predictions executed in Replicate by the user.
138 ///
139 /// # Example
140 ///
141 /// ```
142 /// use replicate_rust::{Replicate, config::Config};
143 ///
144 /// let config = Config::default();
145 /// let replicate = Replicate::new(config);
146 ///
147 /// let predictions = replicate.predictions.list()?;
148 /// println!("Predictions : {:?}", predictions);
149 ///
150 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
151 /// ```
152 pub fn list(&self) -> Result<ListPredictions, ReplicateError> {
153 let client = reqwest::blocking::Client::new();
154
155 let response = client
156 .get(format!("{}/predictions", self.parent.base_url))
157 .header("Authorization", format!("Token {}", self.parent.auth))
158 .header("User-Agent", &self.parent.user_agent)
159 .send()?;
160
161 if !response.status().is_success() {
162 return Err(ReplicateError::ResponseError(response.text()?));
163 }
164
165 let response_string = response.text()?;
166 let response_struct: ListPredictions = serde_json::from_str(&response_string)?;
167
168 Ok(response_struct)
169 }
170
171 /// Get a prediction by passing in the prediction id.
172 /// The prediction id can be obtained from the PredictionClient struct.
173 ///
174 /// # Example
175 ///
176 /// ```
177 /// use replicate_rust::{Replicate, config::Config};
178 ///
179 /// let config = Config::default();
180 /// let replicate = Replicate::new(config);
181 ///
182 /// let prediction = replicate.predictions.get("rrr4z55ocneqzikepnug6xezpe")?;
183 /// println!("Prediction : {:?}", prediction);
184 ///
185 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
186 /// ```
187 pub fn get(&self, id: &str) -> Result<GetPrediction, ReplicateError> {
188 let client = reqwest::blocking::Client::new();
189
190 let response = client
191 .get(format!("{}/predictions/{}", self.parent.base_url, id))
192 .header("Authorization", format!("Token {}", self.parent.auth))
193 .header("User-Agent", &self.parent.user_agent)
194 .send()?;
195
196 if !response.status().is_success() {
197 return Err(ReplicateError::ResponseError(response.text()?));
198 }
199
200 let response_string = response.text()?;
201 let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
202
203 Ok(response_struct)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use crate::{config::Config, Replicate};
210
211 use super::*;
212 use httpmock::{Method::GET, MockServer};
213 use serde_json::json;
214
215 #[test]
216 fn test_list() -> Result<(), ReplicateError> {
217 let server = MockServer::start();
218
219 let get_mock = server.mock(|when, then| {
220 when.method(GET).path("/predictions");
221 then.status(200).json_body_obj(&json!( {
222 "next": "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw",
223 "previous": None::<String>,
224 "results": [
225 {
226 "id": "jpzd7hm5gfcapbfyt4mqytarku",
227 "version":
228 "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05",
229 "urls": {
230 "get": "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku",
231 "cancel":
232 "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel",
233 },
234 "created_at": "2022-04-26T20:00:40.658234Z",
235 "started_at": "2022-04-26T20:00:84.583803Z",
236 "completed_at": "2022-04-26T20:02:27.648305Z",
237 "source": "web",
238 "status": "succeeded",
239 },
240 ],
241 }
242 ));
243 });
244
245 let config = Config {
246 auth: String::from("test"),
247 base_url: server.base_url(),
248 ..Config::default()
249 };
250 let replicate = Replicate::new(config);
251
252 let mut input = HashMap::new();
253 input.insert("text", "...");
254
255 let result = replicate.predictions.list()?;
256
257 assert_eq!(result.results.len(), 1);
258 assert_eq!(result.results[0].id, "jpzd7hm5gfcapbfyt4mqytarku");
259
260 // Ensure the mocks were called as expected
261 get_mock.assert();
262
263 Ok(())
264 }
265
266 #[test]
267 fn test_get() -> Result<(), ReplicateError> {
268 let server = MockServer::start();
269
270 let get_mock = server.mock(|when, then| {
271 when.method(GET).path("/predictions/rrr4z55ocneqzikepnug6xezpe");
272 then.status(200).json_body_obj(&json!( {
273 "id": "rrr4z55ocneqzikepnug6xezpe",
274 "version":
275 "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e",
276 "urls": {
277 "get": "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe",
278 "cancel":
279 "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel",
280 },
281 "created_at": "2022-09-13T22:54:18.578761Z",
282 "started_at": "2022-09-13T22:54:19.438525Z",
283 "completed_at": "2022-09-13T22:54:23.236610Z",
284 "source": "api",
285 "status": "succeeded",
286 "input": {
287 "prompt": "oak tree with boletus growing on its branches",
288 },
289 "output": [
290 "https://replicate.com/api/models/stability-ai/stable-diffusion/files/9c3b6fe4-2d37-4571-a17a-83951b1cb120/out-0.png",
291 ],
292 "error": None::<String>,
293 "logs": "Using seed: 36941...",
294 "metrics": {
295 "predict_time": 4.484541,
296 },
297 }
298 ));
299 });
300
301 let config = Config {
302 auth: String::from("test"),
303 base_url: server.base_url(),
304 ..Config::default()
305 };
306 let replicate = Replicate::new(config);
307
308 let result = replicate.predictions.get("rrr4z55ocneqzikepnug6xezpe")?;
309
310 assert_eq!(result.id, "rrr4z55ocneqzikepnug6xezpe");
311
312 // Ensure the mocks were called as expected
313 get_mock.assert();
314
315 Ok(())
316 }
317}