1use crate::config::ReplicateConfig;
11use crate::errors::{get_error, ReplicateError, ReplicateResult};
12
13use anyhow::anyhow;
14use bytes::Bytes;
15use eventsource_stream::{EventStream, Eventsource};
16use futures_lite::StreamExt;
17use serde_json::Value;
18
19use crate::models::ModelClient;
20use crate::{api_key, base_url};
21
22#[derive(serde::Serialize, serde::Deserialize, Debug, Eq, PartialEq, Clone)]
24#[serde(rename_all = "lowercase")]
25pub enum PredictionStatus {
26 Starting,
29 Processing,
31 Succeeded,
33 Failed,
35 Canceled,
37}
38
39#[derive(serde::Deserialize, Debug)]
41pub struct PredictionUrls {
42 pub cancel: String,
44 pub get: String,
46 pub stream: Option<String>,
48}
49
50#[derive(serde::Deserialize, Debug)]
52pub struct Prediction {
53 pub id: String,
55 pub model: String,
57 pub version: String,
59 pub input: Value,
61 pub status: PredictionStatus,
63 pub created_at: String,
65 pub urls: PredictionUrls,
67 pub output: Option<Value>,
69}
70
71#[derive(serde::Deserialize, Debug)]
73pub struct Predictions {
74 pub next: Option<String>,
76 pub previous: Option<String>,
78 pub results: Vec<Prediction>,
80}
81
82impl Prediction {
83 pub async fn reload(&mut self) -> anyhow::Result<()> {
85 let api_key = api_key()?;
86 let endpoint = self.urls.get.clone();
87 let client = reqwest::Client::new();
88 let response = client
89 .get(endpoint)
90 .header("Authorization", format!("Token {api_key}"))
91 .send()
92 .await?;
93
94 let data = response.text().await?;
95 let prediction: Prediction = serde_json::from_str(data.as_str())?;
96 *self = prediction;
97 anyhow::Ok(())
98 }
99
100 pub async fn get_status(&mut self) -> PredictionStatus {
102 self.status.clone()
103 }
104
105 pub async fn get_stream(
107 &mut self,
108 ) -> anyhow::Result<EventStream<impl futures_lite::stream::Stream<Item = reqwest::Result<Bytes>>>>
109 {
110 if let Some(stream_url) = self.urls.stream.clone() {
111 let api_key = api_key()?;
112 let client = reqwest::Client::new();
113 let stream = client
114 .get(stream_url)
115 .header("Authorization", format!("Token {api_key}"))
116 .header("Accept", "text/event-stream")
117 .send()
118 .await?
119 .bytes_stream()
120 .eventsource();
121
122 return anyhow::Ok(stream);
123 } else {
124 return Err(anyhow!("prediction has no stream url available"));
125 }
126 }
127}
128
129#[derive(Debug)]
131pub struct PredictionClient {
132 config: ReplicateConfig,
133}
134
135#[derive(serde::Serialize)]
136struct PredictionInput {
137 version: String,
138 input: serde_json::Value,
139 stream: bool,
140}
141
142impl PredictionClient {
143 pub fn from(config: ReplicateConfig) -> Self {
145 PredictionClient { config }
146 }
147 pub async fn create(
149 &self,
150 owner: &str,
151 name: &str,
152 input: serde_json::Value,
153 stream: bool,
154 ) -> ReplicateResult<Prediction> {
155 let api_key = self.config.get_api_key()?;
156 let base_url = self.config.get_base_url();
157
158 let model_client = ModelClient::from(self.config.clone());
159 let version = model_client.get_latest_version(owner, name).await?.id;
160
161 let endpoint = format!("{base_url}/predictions");
162 let input = PredictionInput {
163 version,
164 input,
165 stream,
166 };
167 let body = serde_json::to_string(&input)
168 .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
169 let client = reqwest::Client::new();
170 let response = client
171 .post(endpoint)
172 .header("Authorization", format!("Token {api_key}"))
173 .body(body)
174 .send()
175 .await
176 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
177
178 return match response.status() {
179 reqwest::StatusCode::OK | reqwest::StatusCode::CREATED => {
180 let data = response
181 .text()
182 .await
183 .map_err(|err| ReplicateError::ClientError(err.to_string()))?;
184 let prediction: Prediction = serde_json::from_str(&data)
185 .map_err(|err| ReplicateError::SerializationError(err.to_string()))?;
186
187 Ok(prediction)
188 }
189 _ => Err(get_error(
190 response.status(),
191 response
192 .text()
193 .await
194 .map_err(|err| ReplicateError::ClientError(err.to_string()))?
195 .as_str(),
196 )),
197 };
198 }
199
200 pub async fn get(&self, id: String) -> anyhow::Result<Prediction> {
202 let api_key = self.config.get_api_key()?;
203 let base_url = self.config.get_base_url();
204
205 let endpoint = format!("{base_url}/predictions/{id}");
206 let client = reqwest::Client::new();
207 let response = client
208 .get(endpoint)
209 .header("Authorization", format!("Token {api_key}"))
210 .send()
211 .await?;
212
213 let data = response.text().await?;
214 let prediction: Prediction = serde_json::from_str(&data)?;
215
216 anyhow::Ok(prediction)
217 }
218
219 pub async fn list(&self) -> anyhow::Result<Predictions> {
221 let api_key = self.config.get_api_key()?;
222 let base_url = self.config.get_base_url();
223
224 let endpoint = format!("{base_url}/predictions");
225 let client = reqwest::Client::new();
226 let response = client
227 .get(endpoint)
228 .header("Authorization", format!("Token {api_key}"))
229 .send()
230 .await?;
231
232 let data = response.text().await?;
233 let predictions: Predictions = serde_json::from_str(&data)?;
234
235 anyhow::Ok(predictions)
236 }
237
238 pub async fn cancel(&self, id: String) -> anyhow::Result<Prediction> {
240 let api_key = self.config.get_api_key()?;
241 let base_url = self.config.get_base_url();
242 let endpoint = format!("{base_url}/predictions/{id}/cancel");
243 let client = reqwest::Client::new();
244 let response = client
245 .post(endpoint)
246 .header("Authorization", format!("Token {api_key}"))
247 .send()
248 .await?;
249
250 let data = response.text().await?;
251 let prediction: Prediction = serde_json::from_str(&data)?;
252
253 anyhow::Ok(prediction)
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use httpmock::prelude::*;
260 use serde_json::json;
261
262 use super::*;
263
264 #[tokio::test]
265 async fn test_get() {
266 let server = MockServer::start();
267
268 let prediction_mock = server.mock(|when, then| {
269 when.method(GET).path("/predictions/1234");
270 then.status(200).json_body_obj(&json!(
271 {
272 "id": "1234",
273 "model": "replicate/hello-world",
274 "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
275 "input": {
276 "text": "Alice"
277 },
278 "logs": "",
279 "error": null,
280 "status": "starting",
281 "created_at": "2023-09-08T16:19:34.765994657Z",
282 "urls": {
283 "cancel": "https://api.replicate.com/v1/predictions/1234/cancel",
284 "get": "https://api.replicate.com/v1/predictions/1234"
285 }
286 }
287 ));
288 });
289
290 let client = ReplicateConfig::test(server.base_url()).unwrap();
291
292 let prediction_client = PredictionClient::from(client);
293 prediction_client.get("1234".to_string()).await.unwrap();
294
295 prediction_mock.assert();
296 }
297
298 #[tokio::test]
299 async fn test_create() {
300 let server = MockServer::start();
301
302 server.mock(|when, then| {
303 when.method(POST).path("/predictions");
304 then.status(200).json_body_obj(&json!(
305 {
306 "id": "gm3qorzdhgbfurvjtvhg6dckhu",
307 "model": "replicate/hello-world",
308 "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
309 "input": {
310 "text": "Alice"
311 },
312 "logs": "",
313 "error": null,
314 "status": "starting",
315 "created_at": "2023-09-08T16:19:34.765994657Z",
316 "urls": {
317 "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
318 "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
319 }
320 }
321 ));
322 });
323
324 server.mock(|when, then| {
325 when.method(GET)
326 .path("/models/replicate/hello-world/versions");
327
328 then.status(200).json_body_obj(&json!({
329 "next": null,
330 "previous": null,
331 "results": [{
332 "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
333 "created_at": "2022-04-26T19:29:04.418669Z",
334 "cog_version": "0.3.0",
335 "openapi_schema": null
336 }]
337 }));
338 });
339
340 let client = ReplicateConfig::test(server.base_url()).unwrap();
341
342 let prediction_client = PredictionClient::from(client);
343 prediction_client
344 .create(
345 "replicate",
346 "hello-world",
347 json!({"text": "This is test input"}),
348 false,
349 )
350 .await
351 .unwrap();
352 }
353
354 #[tokio::test]
355 async fn test_list_predictions() {
356 let server = MockServer::start();
357
358 server.mock(|when, then| {
359 when.method(GET).path("/predictions");
360 then.status(200).json_body_obj(&json!(
361 { "next": null,
362 "previous": null,
363 "results": [
364 {
365 "id": "gm3qorzdhgbfurvjtvhg6dckhu",
366 "model": "replicate/hello-world",
367 "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
368 "input": {
369 "text": "Alice"
370 },
371 "logs": "",
372 "error": null,
373 "status": "starting",
374 "created_at": "2023-09-08T16:19:34.765994657Z",
375 "urls": {
376 "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
377 "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
378 }
379 },
380 {
381 "id": "gm3qorzdhgbfurvjtvhg6dckhu",
382 "model": "replicate/hello-world",
383 "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
384 "input": {
385 "text": "Alice"
386 },
387 "logs": "",
388 "error": null,
389 "status": "starting",
390 "created_at": "2023-09-08T16:19:34.765994657Z",
391 "urls": {
392 "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
393 "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
394 }
395 }
396 ]}
397 ));
398 });
399
400 let client = ReplicateConfig::test(server.base_url()).unwrap();
401
402 let prediction_client = PredictionClient::from(client);
403 prediction_client.list().await.unwrap();
404 }
405
406 #[tokio::test]
407 async fn test_create_and_reload() {
408 let server = MockServer::start();
409
410 server.mock(|when, then| {
411 when.method(POST).path("/predictions");
412 then.status(200).json_body_obj(&json!(
413 {
414 "id": "gm3qorzdhgbfurvjtvhg6dckhu",
415 "model": "replicate/hello-world",
416 "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
417 "input": {
418 "text": "Alice"
419 },
420 "logs": "",
421 "error": null,
422 "status": "starting",
423 "created_at": "2023-09-08T16:19:34.765994657Z",
424 "urls": {
425 "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
426 "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
427 }
428 }
429 ));
430 });
431
432 server.mock(|when, then| {
433 when.method(GET)
434 .path("/models/replicate/hello-world/versions");
435
436 then.status(200).json_body_obj(&json!({
437 "next": null,
438 "previous": null,
439 "results": [{
440 "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
441 "created_at": "2022-04-26T19:29:04.418669Z",
442 "cog_version": "0.3.0",
443 "openapi_schema": null
444 }]
445 }));
446 });
447
448 let client = ReplicateConfig::test(server.base_url()).unwrap();
449
450 let prediction_client = PredictionClient::from(client);
451 let mut prediction = prediction_client
452 .create(
453 "replicate",
454 "hello-world",
455 json!({"text": "This is test input"}),
456 false,
457 )
458 .await
459 .unwrap();
460 }
461
462 #[tokio::test]
463 async fn test_cancel() {
464 let server = MockServer::start();
465
466 let prediction_mock = server.mock(|when, then| {
467 when.method(POST).path("/predictions/1234/cancel");
468 then.status(200).json_body_obj(&json!(
469 {
470 "id": "1234",
471 "model": "replicate/hello-world",
472 "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
473 "input": {
474 "text": "Alice"
475 },
476 "logs": "",
477 "error": null,
478 "status": "starting",
479 "created_at": "2023-09-08T16:19:34.765994657Z",
480 "urls": {
481 "cancel": "https://api.replicate.com/v1/predictions/1234/cancel",
482 "get": "https://api.replicate.com/v1/predictions/1234"
483 }
484 }
485 ));
486 });
487
488 let config = ReplicateConfig::test(server.base_url()).unwrap();
489 let prediction_client = PredictionClient::from(config);
490
491 prediction_client.cancel("1234".to_string()).await.unwrap();
492
493 prediction_mock.assert();
494 }
495}