replicate_rs/
predictions.rs

1//! Utilities for interacting with all prediction endpoints.
2//!
3//! This includes the following:
4//! - [Create Prediction](https://replicate.com/docs/reference/http#predictions.create)
5//! - [Get Prediction](https://replicate.com/docs/reference/http#predictions.get)
6//! - [List Predictions](https://replicate.com/docs/reference/http#predictions.list)
7//! - [Cancel Prediction](https://replicate.com/docs/reference/http#predictions.cancel)
8//!
9
10use crate::config::ReplicateConfig;
11use crate::errors::{get_error, ReplicateError, ReplicateResult};
12
13use anyhow::anyhow;
14use bytes::Bytes;
15use eventsource_stream::{EventStream, Eventsource};
16use futures_lite::StreamExt;
17use serde_json::Value;
18
19use crate::models::ModelClient;
20use crate::{api_key, base_url};
21
22/// Status of a retrieved or created prediction
23#[derive(serde::Serialize, serde::Deserialize, Debug, Eq, PartialEq, Clone)]
24#[serde(rename_all = "lowercase")]
25pub enum PredictionStatus {
26    /// The prediction is starting up. If this status lasts longer than a few seconds, then it's
27    /// typically because a new worker is being started to run the prediction.
28    Starting,
29    /// The `predict()` method of the model is currently running.
30    Processing,
31    /// The prediction completed successfully.
32    Succeeded,
33    /// The prediction was canceled by its creator.
34    Failed,
35    /// The prediction was canceled by its creator.
36    Canceled,
37}
38
39/// Provided urls to either cancel or retrieve updated details for the specific prediction.
40#[derive(serde::Deserialize, Debug)]
41pub struct PredictionUrls {
42    /// Url endpoint to cancel the specific prediction
43    pub cancel: String,
44    /// Url endpoint to retrieve the specific prediction
45    pub get: String,
46    /// Url endpoint to receive streamed output
47    pub stream: Option<String>,
48}
49
50/// Details for a specific prediction
51#[derive(serde::Deserialize, Debug)]
52pub struct Prediction {
53    /// Id of the prediction
54    pub id: String,
55    /// Model used during the prediction
56    pub model: String,
57    /// Specific version used during prediction
58    pub version: String,
59    /// The inputs provided for the specific prediction
60    pub input: Value,
61    /// The current status of the prediction
62    pub status: PredictionStatus,
63    /// The created time for the prediction
64    pub created_at: String,
65    /// Urls to either retrieve or cancel details for this prediction
66    pub urls: PredictionUrls,
67    /// The output of the prediction if completed
68    pub output: Option<Value>,
69}
70
71/// Paginated list of available predictions
72#[derive(serde::Deserialize, Debug)]
73pub struct Predictions {
74    /// Identify for status in pagination
75    pub next: Option<String>,
76    /// Identify for status of pagination
77    pub previous: Option<String>,
78    /// List of predictions
79    pub results: Vec<Prediction>,
80}
81
82impl Prediction {
83    /// Leverage the get url provided, to refresh struct attributes
84    pub async fn reload(&mut self) -> anyhow::Result<()> {
85        let api_key = api_key()?;
86        let endpoint = self.urls.get.clone();
87        let client = reqwest::Client::new();
88        let response = client
89            .get(endpoint)
90            .header("Authorization", format!("Token {api_key}"))
91            .send()
92            .await?;
93
94        let data = response.text().await?;
95        let prediction: Prediction = serde_json::from_str(data.as_str())?;
96        *self = prediction;
97        anyhow::Ok(())
98    }
99
100    /// Get the status for the current prediction
101    pub async fn get_status(&mut self) -> PredictionStatus {
102        self.status.clone()
103    }
104
105    /// Get the stream from a prediction
106    pub async fn get_stream(
107        &mut self,
108    ) -> anyhow::Result<EventStream<impl futures_lite::stream::Stream<Item = reqwest::Result<Bytes>>>>
109    {
110        if let Some(stream_url) = self.urls.stream.clone() {
111            let api_key = api_key()?;
112            let client = reqwest::Client::new();
113            let stream = client
114                .get(stream_url)
115                .header("Authorization", format!("Token {api_key}"))
116                .header("Accept", "text/event-stream")
117                .send()
118                .await?
119                .bytes_stream()
120                .eventsource();
121
122            return anyhow::Ok(stream);
123        } else {
124            return Err(anyhow!("prediction has no stream url available"));
125        }
126    }
127}
128
129/// A client for interacting with 'predictions' endpoint
130#[derive(Debug)]
131pub struct PredictionClient {
132    config: ReplicateConfig,
133}
134
135#[derive(serde::Serialize)]
136struct PredictionInput {
137    version: String,
138    input: serde_json::Value,
139    stream: bool,
140}
141
142impl PredictionClient {
143    /// Create a new `PredictionClient` based upon a `ReplicateConfig` object
144    pub fn from(config: ReplicateConfig) -> Self {
145        PredictionClient { config }
146    }
147    /// Create a new prediction
148    pub async fn create(
149        &self,
150        owner: &str,
151        name: &str,
152        input: serde_json::Value,
153        stream: bool,
154    ) -> ReplicateResult<Prediction> {
155        let api_key = self.config.get_api_key()?;
156        let base_url = self.config.get_base_url();
157
158        let model_client = ModelClient::from(self.config.clone());
159        let version = model_client.get_latest_version(owner, name).await?.id;
160
161        let endpoint = format!("{base_url}/predictions");
162        let input = PredictionInput {
163            version,
164            input,
165            stream,
166        };
167        let body = serde_json::to_string(&input)
168            .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
169        let client = reqwest::Client::new();
170        let response = client
171            .post(endpoint)
172            .header("Authorization", format!("Token {api_key}"))
173            .body(body)
174            .send()
175            .await
176            .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
177
178        return match response.status() {
179            reqwest::StatusCode::OK | reqwest::StatusCode::CREATED => {
180                let data = response
181                    .text()
182                    .await
183                    .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
184                let prediction: Prediction = serde_json::from_str(&data)
185                    .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
186
187                Ok(prediction)
188            }
189            _ => Err(get_error(
190                response.status(),
191                response
192                    .text()
193                    .await
194                    .map_err(|err| ReplicateError::ClientError(err.to_string()))?
195                    .as_str(),
196            )),
197        };
198    }
199
200    /// Get details for an existing prediction
201    pub async fn get(&self, id: String) -> anyhow::Result<Prediction> {
202        let api_key = self.config.get_api_key()?;
203        let base_url = self.config.get_base_url();
204
205        let endpoint = format!("{base_url}/predictions/{id}");
206        let client = reqwest::Client::new();
207        let response = client
208            .get(endpoint)
209            .header("Authorization", format!("Token {api_key}"))
210            .send()
211            .await?;
212
213        let data = response.text().await?;
214        let prediction: Prediction = serde_json::from_str(&data)?;
215
216        anyhow::Ok(prediction)
217    }
218
219    /// List all existing predictions for the current user
220    pub async fn list(&self) -> anyhow::Result<Predictions> {
221        let api_key = self.config.get_api_key()?;
222        let base_url = self.config.get_base_url();
223
224        let endpoint = format!("{base_url}/predictions");
225        let client = reqwest::Client::new();
226        let response = client
227            .get(endpoint)
228            .header("Authorization", format!("Token {api_key}"))
229            .send()
230            .await?;
231
232        let data = response.text().await?;
233        let predictions: Predictions = serde_json::from_str(&data)?;
234
235        anyhow::Ok(predictions)
236    }
237
238    /// Cancel an existing prediction
239    pub async fn cancel(&self, id: String) -> anyhow::Result<Prediction> {
240        let api_key = self.config.get_api_key()?;
241        let base_url = self.config.get_base_url();
242        let endpoint = format!("{base_url}/predictions/{id}/cancel");
243        let client = reqwest::Client::new();
244        let response = client
245            .post(endpoint)
246            .header("Authorization", format!("Token {api_key}"))
247            .send()
248            .await?;
249
250        let data = response.text().await?;
251        let prediction: Prediction = serde_json::from_str(&data)?;
252
253        anyhow::Ok(prediction)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use httpmock::prelude::*;
260    use serde_json::json;
261
262    use super::*;
263
264    #[tokio::test]
265    async fn test_get() {
266        let server = MockServer::start();
267
268        let prediction_mock = server.mock(|when, then| {
269            when.method(GET).path("/predictions/1234");
270            then.status(200).json_body_obj(&json!(
271                {
272                    "id": "1234",
273                    "model": "replicate/hello-world",
274                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
275                    "input": {
276                        "text": "Alice"
277                    },
278                    "logs": "",
279                    "error": null,
280                    "status": "starting",
281                    "created_at": "2023-09-08T16:19:34.765994657Z",
282                    "urls": {
283                        "cancel": "https://api.replicate.com/v1/predictions/1234/cancel",
284                        "get": "https://api.replicate.com/v1/predictions/1234"
285                    }
286                }
287            ));
288        });
289
290        let client = ReplicateConfig::test(server.base_url()).unwrap();
291
292        let prediction_client = PredictionClient::from(client);
293        prediction_client.get("1234".to_string()).await.unwrap();
294
295        prediction_mock.assert();
296    }
297
298    #[tokio::test]
299    async fn test_create() {
300        let server = MockServer::start();
301
302        server.mock(|when, then| {
303            when.method(POST).path("/predictions");
304            then.status(200).json_body_obj(&json!(
305                {
306                    "id": "gm3qorzdhgbfurvjtvhg6dckhu",
307                    "model": "replicate/hello-world",
308                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
309                    "input": {
310                        "text": "Alice"
311                    },
312                    "logs": "",
313                    "error": null,
314                    "status": "starting",
315                    "created_at": "2023-09-08T16:19:34.765994657Z",
316                    "urls": {
317                        "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
318                        "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
319                    }
320                }
321            ));
322        });
323
324        server.mock(|when, then| {
325            when.method(GET)
326                .path("/models/replicate/hello-world/versions");
327
328            then.status(200).json_body_obj(&json!({
329                "next": null,
330                "previous": null,
331                "results": [{
332                    "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
333                    "created_at": "2022-04-26T19:29:04.418669Z",
334                    "cog_version": "0.3.0",
335                    "openapi_schema": null
336                }]
337            }));
338        });
339
340        let client = ReplicateConfig::test(server.base_url()).unwrap();
341
342        let prediction_client = PredictionClient::from(client);
343        prediction_client
344            .create(
345                "replicate",
346                "hello-world",
347                json!({"text": "This is test input"}),
348                false,
349            )
350            .await
351            .unwrap();
352    }
353
354    #[tokio::test]
355    async fn test_list_predictions() {
356        let server = MockServer::start();
357
358        server.mock(|when, then| {
359            when.method(GET).path("/predictions");
360            then.status(200).json_body_obj(&json!(
361                { "next": null,
362                  "previous": null,
363                  "results": [
364                    {
365                        "id": "gm3qorzdhgbfurvjtvhg6dckhu",
366                        "model": "replicate/hello-world",
367                        "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
368                        "input": {
369                            "text": "Alice"
370                        },
371                        "logs": "",
372                        "error": null,
373                        "status": "starting",
374                        "created_at": "2023-09-08T16:19:34.765994657Z",
375                        "urls": {
376                            "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
377                            "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
378                        }
379                    },
380                    {
381                        "id": "gm3qorzdhgbfurvjtvhg6dckhu",
382                        "model": "replicate/hello-world",
383                        "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
384                        "input": {
385                            "text": "Alice"
386                        },
387                        "logs": "",
388                        "error": null,
389                        "status": "starting",
390                        "created_at": "2023-09-08T16:19:34.765994657Z",
391                        "urls": {
392                            "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
393                            "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
394                        }
395                    }
396                ]}
397            ));
398        });
399
400        let client = ReplicateConfig::test(server.base_url()).unwrap();
401
402        let prediction_client = PredictionClient::from(client);
403        prediction_client.list().await.unwrap();
404    }
405
406    #[tokio::test]
407    async fn test_create_and_reload() {
408        let server = MockServer::start();
409
410        server.mock(|when, then| {
411            when.method(POST).path("/predictions");
412            then.status(200).json_body_obj(&json!(
413                {
414                    "id": "gm3qorzdhgbfurvjtvhg6dckhu",
415                    "model": "replicate/hello-world",
416                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
417                    "input": {
418                        "text": "Alice"
419                    },
420                    "logs": "",
421                    "error": null,
422                    "status": "starting",
423                    "created_at": "2023-09-08T16:19:34.765994657Z",
424                    "urls": {
425                        "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
426                        "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
427                    }
428                }
429            ));
430        });
431
432        server.mock(|when, then| {
433            when.method(GET)
434                .path("/models/replicate/hello-world/versions");
435
436            then.status(200).json_body_obj(&json!({
437                "next": null,
438                "previous": null,
439                "results": [{
440                    "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
441                    "created_at": "2022-04-26T19:29:04.418669Z",
442                    "cog_version": "0.3.0",
443                    "openapi_schema": null
444                }]
445            }));
446        });
447
448        let client = ReplicateConfig::test(server.base_url()).unwrap();
449
450        let prediction_client = PredictionClient::from(client);
451        let mut prediction = prediction_client
452            .create(
453                "replicate",
454                "hello-world",
455                json!({"text": "This is test input"}),
456                false,
457            )
458            .await
459            .unwrap();
460    }
461
462    #[tokio::test]
463    async fn test_cancel() {
464        let server = MockServer::start();
465
466        let prediction_mock = server.mock(|when, then| {
467            when.method(POST).path("/predictions/1234/cancel");
468            then.status(200).json_body_obj(&json!(
469                {
470                    "id": "1234",
471                    "model": "replicate/hello-world",
472                    "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
473                    "input": {
474                        "text": "Alice"
475                    },
476                    "logs": "",
477                    "error": null,
478                    "status": "starting",
479                    "created_at": "2023-09-08T16:19:34.765994657Z",
480                    "urls": {
481                        "cancel": "https://api.replicate.com/v1/predictions/1234/cancel",
482                        "get": "https://api.replicate.com/v1/predictions/1234"
483                    }
484                }
485            ));
486        });
487
488        let config = ReplicateConfig::test(server.base_url()).unwrap();
489        let prediction_client = PredictionClient::from(config);
490
491        prediction_client.cancel("1234".to_string()).await.unwrap();
492
493        prediction_mock.assert();
494    }
495}