replicate_rust/
prediction_client.rs

1//! Helper struct for the prediction struct
2//!
3//! Used to create a prediction, reload for latest info, cancel it and wait for prediction to complete.
4//!
5//! # Example
6//! ```
7//! use replicate_rust::{Replicate, config::Config};
8//!
9//! let config = Config::default();
10//! let replicate = Replicate::new(config);
11//!
12//! // Creating the inputs
13//! let mut inputs = std::collections::HashMap::new();
14//! inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
15//!
16//! let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
17//!
18//! // Create a new prediction
19//! let mut prediction = replicate.predictions.create(version, inputs)?;
20//!
21//! // Reload the prediction to get the latest info and logs
22//! prediction.reload()?;
23//!
24//! // Cancel the prediction
25//! // prediction.cancel()?;
26//!
27//! // Wait for the prediction to complete
28//! let result = prediction.wait()?;
29//!
30//! println!("Result : {:?}", result);
31//! # Ok::<(), replicate_rust::errors::ReplicateError>(())
32//! ```
33
34use std::collections::HashMap;
35
36use crate::{
37    api_definitions::{CreatePrediction, GetPrediction, PredictionStatus, PredictionsUrls},
38    errors::ReplicateError,
39    prediction::PredictionPayload,
40};
41
42use super::retry::{RetryPolicy, RetryStrategy};
43
44/// Parse a model version string into its model and version parts.
45pub fn parse_version(s: &str) -> Option<(&str, &str)> {
46    // Split the string at the colon.
47    let mut parts = s.splitn(2, ':');
48
49    // Extract the model and version parts.
50    let model = parts.next()?;
51    let version = parts.next()?;
52
53    // Check if the model part contains a slash.
54    if !model.contains('/') {
55        return None;
56    }
57
58    Some((model, version))
59}
60
61/// Helper struct for the Prediction struct. Used to create a prediction, reload for latest info, cancel it and wait for prediction to complete.
62#[allow(missing_docs)]
63#[derive(Clone, Debug)]
64pub struct PredictionClient {
65    /// Holds a reference to a Configuration struct, which contains the base url,  auth token among other settings.
66    pub parent: crate::config::Config,
67
68    /// Unique identifier of the prediction
69    pub id: String,
70    pub version: String,
71
72    pub urls: PredictionsUrls,
73
74    pub created_at: String,
75
76    pub status: PredictionStatus,
77
78    pub input: HashMap<String, serde_json::Value>,
79
80    pub error: Option<String>,
81
82    pub logs: Option<String>,
83}
84
85impl PredictionClient {
86    /// Run the prediction of the model version with the given input
87    /// # Example
88    /// ```
89    /// use replicate_rust::{Replicate, config::Config};
90    ///
91    /// let config = Config::default();
92    /// let replicate = Replicate::new(config);
93    ///
94    /// // Creating the inputs
95    /// let mut inputs = std::collections::HashMap::new();
96    /// inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
97    ///
98    /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
99    ///
100    /// // Create a new prediction
101    /// let mut prediction = replicate.predictions.create(version, inputs)?;
102    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
103    /// ```
104    pub fn create<K: serde::Serialize, V: serde::ser::Serialize>(
105        rep: crate::config::Config,
106        version: &str,
107        inputs: HashMap<K, V>,
108    ) -> Result<PredictionClient, ReplicateError> {
109        // Parse the model version string.
110        let (_model, version) = match parse_version(version) {
111            Some((model, version)) => (model, version),
112            None => return Err(ReplicateError::InvalidVersionString(version.to_string())),
113        };
114
115        // Construct the request payload
116        let payload = PredictionPayload {
117            version: version.to_string(),
118            input: inputs,
119        };
120
121        // println!("Payload : {:?}", &payload);
122        let client = reqwest::blocking::Client::new();
123        let response = client
124            .post(format!("{}/predictions", rep.base_url))
125            .header("Authorization", format!("Token {}", rep.auth))
126            .header("User-Agent", &rep.user_agent)
127            .json(&payload)
128            .send()?;
129
130        if !response.status().is_success() {
131            return Err(ReplicateError::ResponseError(response.text()?));
132        }
133
134        if !response.status().is_success() {
135            return Err(ReplicateError::ResponseError(response.text()?));
136        }
137
138        // println!("Response : {:?}", response.text()?);
139
140        let result: CreatePrediction = response.json()?;
141
142        Ok(Self {
143            parent: rep,
144            id: result.id,
145            version: result.version,
146            urls: result.urls,
147            created_at: result.created_at,
148            status: result.status,
149            input: result.input,
150            error: result.error,
151            logs: result.logs,
152        })
153    }
154
155    /// Returns the latest info of the prediction
156    // # Example
157    /// ```
158    /// use replicate_rust::{Replicate, config::Config};
159    ///
160    /// let config = Config::default();
161    /// let replicate = Replicate::new(config);
162    ///
163    /// // Creating the inputs
164    /// let mut inputs = std::collections::HashMap::new();
165    /// inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
166    ///
167    /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
168    ///
169    /// // Create a new prediction
170    /// let mut prediction = replicate.predictions.create(version, inputs)?;
171    ///
172    /// // Reload the prediction to get the latest info and logs
173    /// prediction.reload()?;
174    ///
175    /// println!("Prediction : {:?}", prediction.status);
176    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
177    /// ```
178    pub fn reload(&mut self) -> Result<(), ReplicateError> {
179        let client = reqwest::blocking::Client::new();
180
181        let response = client
182            .get(format!("{}/predictions/{}", self.parent.base_url, self.id))
183            .header("Authorization", format!("Token {}", self.parent.auth))
184            .header("User-Agent", &self.parent.user_agent)
185            .send()?;
186
187        if !response.status().is_success() {
188            return Err(ReplicateError::ResponseError(response.text()?));
189        }
190
191        let response_string = response.text()?;
192        let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
193
194        self.id = response_struct.id;
195        self.version = response_struct.version;
196        self.urls = response_struct.urls;
197        self.created_at = response_struct.created_at;
198        self.status = response_struct.status;
199        self.input = response_struct.input;
200        self.error = response_struct.error;
201        self.logs = response_struct.logs;
202
203        Ok(())
204    }
205
206    /// Cancel the prediction
207    /// # Example
208    /// ```
209    /// use replicate_rust::{Replicate, config::Config};
210    ///
211    /// let config = Config::default();
212    /// let replicate = Replicate::new(config);
213    ///
214    /// // Creating the inputs
215    /// let mut inputs = std::collections::HashMap::new();
216    /// inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
217    ///
218    /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
219    ///
220    /// // Create a new prediction
221    /// let mut prediction = replicate.predictions.create(version, inputs)?;
222    ///
223    /// // Cancel the prediction
224    /// prediction.cancel()?;
225    ///
226    /// // Wait for the prediction to complete (or fail).
227    /// let result = prediction.wait()?;
228    ///
229    /// println!("Result : {:?}", result);
230    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
231    /// ```
232    pub fn cancel(&mut self) -> Result<(), ReplicateError> {
233        let client = reqwest::blocking::Client::new();
234        let response = client
235            .post(format!(
236                "{}/predictions/{}/cancel",
237                self.parent.base_url, self.id
238            ))
239            .header("Authorization", format!("Token {}", &self.parent.auth))
240            .header("User-Agent", &self.parent.user_agent)
241            .send()?;
242
243        if !response.status().is_success() {
244            return Err(ReplicateError::ResponseError(response.text()?));
245        }
246
247        self.reload()?;
248
249        Ok(())
250    }
251
252    /// Blocks until the predictions are ready and returns the predictions
253    /// # Example
254    /// ```
255    /// use replicate_rust::{Replicate, config::Config};
256    ///
257    /// let config = Config::default();
258    /// let replicate = Replicate::new(config);
259    ///
260    /// // Creating the inputs
261    /// let mut inputs = std::collections::HashMap::new();
262    /// inputs.insert("prompt", "a  19th century portrait of a wombat gentleman");
263    ///
264    /// let version = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478";
265    ///
266    /// // Create a new prediction
267    /// let mut prediction = replicate.predictions.create(version, inputs)?;
268    ///
269    /// // Wait for the prediction to complete (or fail).
270    /// let result = prediction.wait()?;
271    ///
272    /// println!("Result : {:?}", result);
273    ///
274    /// # Ok::<(), replicate_rust::errors::ReplicateError>(())
275    /// ```
276    pub fn wait(&self) -> Result<GetPrediction, ReplicateError> {
277        // TODO : Implement a retry policy
278        let retry_policy = RetryPolicy::new(5, RetryStrategy::FixedDelay(1000));
279        let client = reqwest::blocking::Client::new();
280
281        loop {
282            let response = client
283                .get(format!("{}/predictions/{}", self.parent.base_url, self.id))
284                .header("Authorization", format!("Token {}", self.parent.auth))
285                .header("User-Agent", &self.parent.user_agent)
286                .send()?;
287
288            if !response.status().is_success() {
289                return Err(ReplicateError::ResponseError(response.text()?));
290            }
291
292            let response_string = response.text()?;
293            let response_struct: GetPrediction = serde_json::from_str(&response_string)?;
294
295            match response_struct.status {
296                PredictionStatus::succeeded
297                | PredictionStatus::failed
298                | PredictionStatus::canceled => {
299                    return Ok(response_struct);
300                }
301                PredictionStatus::processing | PredictionStatus::starting => {
302                    // Retry
303                    // TODO : Fix the retry implementation
304                    retry_policy.step();
305                }
306            }
307        }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use crate::{config::Config, Replicate};
314
315    use super::*;
316    use httpmock::{Method::POST, MockServer};
317    use serde_json::json;
318
319    #[test]
320    fn test_create() -> Result<(), ReplicateError> {
321        let server = MockServer::start();
322
323        let post_mock = server.mock(|when, then| {
324            when.method(POST).path("/predictions");
325            then.status(200).json_body_obj(&json!(  {
326                "id": "ufawqhfynnddngldkgtslldrkq",
327                "version":
328                  "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
329                "urls": {
330                  "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
331                  "cancel":
332                    "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
333                },
334                "created_at": "2022-04-26T22:13:06.224088Z",
335                "started_at": None::<String>,
336                "completed_at": None::<String>,
337                "status": "starting",
338                "input": {
339                  "text": "Alice",
340                },
341                "output": None::<String>,
342                "error": None::<String>,
343                "logs": None::<String>,
344                "metrics": {},
345              }
346            ));
347        });
348
349        let config = Config {
350            auth: String::from("test"),
351            base_url: server.base_url(),
352            ..Config::default()
353        };
354        let replicate = Replicate::new(config);
355
356        let mut input = HashMap::new();
357        input.insert("text", "Alice");
358
359        let result = replicate.predictions.create(
360            "owner/model:632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532",
361            input,
362        )?;
363        assert_eq!(result.id, "ufawqhfynnddngldkgtslldrkq");
364
365        // Ensure the mocks were called as expected
366        post_mock.assert();
367
368        Ok(())
369    }
370}