replicate_rust/prediction_client.rs
1//! Helper struct for the prediction struct
2//!
3//! Used to create a prediction, reload for latest info, cancel it and wait for prediction to complete.
4//!
5//! # Example
6//! ```
7//! use replicate_rust::{Replicate, config::Config};
8//!
9//! let config = Config::default();
10//! let replicate = Replicate::new(config);
11//!
12//! // Creating the inputs
13//! let mut inputs = std::collections::HashMap::new();
14//! inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
15//!
16//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
17//!
18//! // Create a new prediction
19//! let mut prediction = replicate.predictions.create(version, inputs)?;
20//!
21//! // Reload the prediction to get the latest info and logs
22//! prediction.reload()?;
23//!
24//! // Cancel the prediction
25//! // prediction.cancel()?;
26//!
27//! // Wait for the prediction to complete
28//! let result = prediction.wait()?;
29//!
30//! println!("Result : {:?}", result);
31//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
32//! ```
33
34use std::collections::HashMap;
35
36use crate::{
37 api_definitions::{CreatePrediction, GetPrediction, PredictionStatus, PredictionsUrls},
38 errors::ReplicateError,
39 prediction::PredictionPayload,
40};
41
42use super::retry::{RetryPolicy, RetryStrategy};
43
44/// Parse a model version string into its model and version parts.
45pub fn parse_version(s: &str) -> Option<(&str, &str)> {
46 // Split the string at the colon.
47 let mut parts = s.splitn(2, ':');
48
49 // Extract the model and version parts.
50 let model = parts.next()?;
51 let version = parts.next()?;
52
53 // Check if the model part contains a slash.
54 if !model.contains('/') {
55 return None;
56 }
57
58 Some((model, version))
59}
60
61/// Helper struct for the Prediction struct. Used to create a prediction, reload for latest info, cancel it and wait for prediction to complete.
62#[allow(missing_docs)]
63#[derive(Clone, Debug)]
64pub struct PredictionClient {
65 /// Holds a reference to a Configuration struct, which contains the base url, auth token among other settings.
66 pub parent: crate::config::Config,
67
68 /// Unique identifier of the prediction
69 pub id: String,
70 pub version: String,
71
72 pub urls: PredictionsUrls,
73
74 pub created_at: String,
75
76 pub status: PredictionStatus,
77
78 pub input: HashMap<String, serde_json::Value>,
79
80 pub error: Option<String>,
81
82 pub logs: Option<String>,
83}
84
85impl PredictionClient {
86 /// Run the prediction of the model version with the given input
87 /// # Example
88 /// ```
89 /// use replicate_rust::{Replicate, config::Config};
90 ///
91 /// let config = Config::default();
92 /// let replicate = Replicate::new(config);
93 ///
94 /// // Creating the inputs
95 /// let mut inputs = std::collections::HashMap::new();
96 /// inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
97 ///
98 /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
99 ///
100 /// // Create a new prediction
101 /// let mut prediction = replicate.predictions.create(version, inputs)?;
102 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
103 /// ```
104 pub fn create<K: serde::Serialize, V: serde::ser::Serialize>(
105 rep: crate::config::Config,
106 version: &str,
107 inputs: HashMap<K, V>,
108 ) -> Result<PredictionClient, ReplicateError> {
109 // Parse the model version string.
110 let (_model, version) = match parse_version(version) {
111 Some((model, version)) => (model, version),
112 None => return Err(ReplicateError::InvalidVersionString(version.to_string())),
113 };
114
115 // Construct the request payload
116 let payload = PredictionPayload {
117 version: version.to_string(),
118 input: inputs,
119 };
120
121 // println!("Payload : {:?}", &payload);
122 let client = reqwest::blocking::Client::new();
123 let response = client
124 .post(format!("{}/predictions", rep.base_url))
125 .header("Authorization", format!("Token {}", rep.auth))
126 .header("User-Agent", &rep.user_agent)
127 .json(&payload)
128 .send()?;
129
130 if !response.status().is_success() {
131 return Err(ReplicateError::ResponseError(response.text()?));
132 }
133
134 if !response.status().is_success() {
135 return Err(ReplicateError::ResponseError(response.text()?));
136 }
137
138 // println!("Response : {:?}", response.text()?);
139
140 let result: CreatePrediction = response.json()?;
141
142 Ok(Self {
143 parent: rep,
144 id: result.id,
145 version: result.version,
146 urls: result.urls,
147 created_at: result.created_at,
148 status: result.status,
149 input: result.input,
150 error: result.error,
151 logs: result.logs,
152 })
153 }
154
155 /// Returns the latest info of the prediction
156 // # Example
157 /// ```
158 /// use replicate_rust::{Replicate, config::Config};
159 ///
160 /// let config = Config::default();
161 /// let replicate = Replicate::new(config);
162 ///
163 /// // Creating the inputs
164 /// let mut inputs = std::collections::HashMap::new();
165 /// inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
166 ///
167 /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
168 ///
169 /// // Create a new prediction
170 /// let mut prediction = replicate.predictions.create(version, inputs)?;
171 ///
172 /// // Reload the prediction to get the latest info and logs
173 /// prediction.reload()?;
174 ///
175 /// println!("Prediction : {:?}", prediction.status);
176 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
177 /// ```
178 pub fn reload(&mut self) -> Result<(), ReplicateError> {
179 let client = reqwest::blocking::Client::new();
180
181 let response = client
182 .get(format!("{}/predictions/{}", self.parent.base_url, self.id))
183 .header("Authorization", format!("Token {}", self.parent.auth))
184 .header("User-Agent", &self.parent.user_agent)
185 .send()?;
186
187 if !response.status().is_success() {
188 return Err(ReplicateError::ResponseError(response.text()?));
189 }
190
191 let response_string = response.text()?;
192 let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
193
194 self.id = response_struct.id;
195 self.version = response_struct.version;
196 self.urls = response_struct.urls;
197 self.created_at = response_struct.created_at;
198 self.status = response_struct.status;
199 self.input = response_struct.input;
200 self.error = response_struct.error;
201 self.logs = response_struct.logs;
202
203 Ok(())
204 }
205
206 /// Cancel the prediction
207 /// # Example
208 /// ```
209 /// use replicate_rust::{Replicate, config::Config};
210 ///
211 /// let config = Config::default();
212 /// let replicate = Replicate::new(config);
213 ///
214 /// // Creating the inputs
215 /// let mut inputs = std::collections::HashMap::new();
216 /// inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
217 ///
218 /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
219 ///
220 /// // Create a new prediction
221 /// let mut prediction = replicate.predictions.create(version, inputs)?;
222 ///
223 /// // Cancel the prediction
224 /// prediction.cancel()?;
225 ///
226 /// // Wait for the prediction to complete (or fail).
227 /// let result = prediction.wait()?;
228 ///
229 /// println!("Result : {:?}", result);
230 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
231 /// ```
232 pub fn cancel(&mut self) -> Result<(), ReplicateError> {
233 let client = reqwest::blocking::Client::new();
234 let response = client
235 .post(format!(
236 "{}/predictions/{}/cancel",
237 self.parent.base_url, self.id
238 ))
239 .header("Authorization", format!("Token {}", &self.parent.auth))
240 .header("User-Agent", &self.parent.user_agent)
241 .send()?;
242
243 if !response.status().is_success() {
244 return Err(ReplicateError::ResponseError(response.text()?));
245 }
246
247 self.reload()?;
248
249 Ok(())
250 }
251
252 /// Blocks until the predictions are ready and returns the predictions
253 /// # Example
254 /// ```
255 /// use replicate_rust::{Replicate, config::Config};
256 ///
257 /// let config = Config::default();
258 /// let replicate = Replicate::new(config);
259 ///
260 /// // Creating the inputs
261 /// let mut inputs = std::collections::HashMap::new();
262 /// inputs.insert("prompt", "a 19th century portrait of a wombat gentleman");
263 ///
264 /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
265 ///
266 /// // Create a new prediction
267 /// let mut prediction = replicate.predictions.create(version, inputs)?;
268 ///
269 /// // Wait for the prediction to complete (or fail).
270 /// let result = prediction.wait()?;
271 ///
272 /// println!("Result : {:?}", result);
273 ///
274 /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
275 /// ```
276 pub fn wait(&self) -> Result<GetPrediction, ReplicateError> {
277 // TODO : Implement a retry policy
278 let retry_policy = RetryPolicy::new(5, RetryStrategy::FixedDelay(1000));
279 let client = reqwest::blocking::Client::new();
280
281 loop {
282 let response = client
283 .get(format!("{}/predictions/{}", self.parent.base_url, self.id))
284 .header("Authorization", format!("Token {}", self.parent.auth))
285 .header("User-Agent", &self.parent.user_agent)
286 .send()?;
287
288 if !response.status().is_success() {
289 return Err(ReplicateError::ResponseError(response.text()?));
290 }
291
292 let response_string = response.text()?;
293 let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
294
295 match response_struct.status {
296 PredictionStatus::succeeded
297 | PredictionStatus::failed
298 | PredictionStatus::canceled => {
299 return Ok(response_struct);
300 }
301 PredictionStatus::processing | PredictionStatus::starting => {
302 // Retry
303 // TODO : Fix the retry implementation
304 retry_policy.step();
305 }
306 }
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use crate::{config::Config, Replicate};
314
315 use super::*;
316 use httpmock::{Method::POST, MockServer};
317 use serde_json::json;
318
319 #[test]
320 fn test_create() -> Result<(), ReplicateError> {
321 let server = MockServer::start();
322
323 let post_mock = server.mock(|when, then| {
324 when.method(POST).path("/predictions");
325 then.status(200).json_body_obj(&json!( {
326 "id": "ufawqhfynnddngldkgtslldrkq",
327 "version":
328 "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
329 "urls": {
330 "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
331 "cancel":
332 "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
333 },
334 "created_at": "2022-04-26T22:13:06.224088Z",
335 "started_at": None::<String>,
336 "completed_at": None::<String>,
337 "status": "starting",
338 "input": {
339 "text": "Alice",
340 },
341 "output": None::<String>,
342 "error": None::<String>,
343 "logs": None::<String>,
344 "metrics": {},
345 }
346 ));
347 });
348
349 let config = Config {
350 auth: String::from("test"),
351 base_url: server.base_url(),
352 ..Config::default()
353 };
354 let replicate = Replicate::new(config);
355
356 let mut input = HashMap::new();
357 input.insert("text", "Alice");
358
359 let result = replicate.predictions.create(
360 "owner/model:632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
361 input,
362 )?;
363 assert_eq!(result.id, "ufawqhfynnddngldkgtslldrkq");
364
365 // Ensure the mocks were called as expected
366 post_mock.assert();
367
368 Ok(())
369 }
370}