replicate_client/api/
predictions.rs1use serde_json::Value;
4use std::collections::HashMap;
5use std::time::Duration;
6use tokio::time::{interval, timeout};
7
8use crate::api::files::{FilesApi, process_file_input};
9use crate::error::{Error, Result};
10use crate::http::HttpClient;
11use crate::models::{
12 common::PaginatedResponse,
13 file::{FileEncodingStrategy, FileInput},
14 prediction::{CreatePredictionRequest, Prediction},
15};
16
17#[derive(Debug, Clone)]
19pub struct PredictionsApi {
20 http: HttpClient,
21 files_api: Option<FilesApi>,
22}
23
24impl PredictionsApi {
25 pub fn new(http: HttpClient) -> Self {
27 Self {
28 http: http.clone(),
29 files_api: Some(FilesApi::new(http)),
30 }
31 }
32
33 pub async fn create(&self, mut request: CreatePredictionRequest) -> Result<Prediction> {
35 if !request.file_inputs.is_empty() {
37 for (key, file_input) in request.file_inputs.iter() {
38 let processed_value = process_file_input(
39 file_input,
40 &request.file_encoding_strategy,
41 self.files_api.as_ref(),
42 )
43 .await?;
44
45 request
46 .input
47 .insert(key.clone(), serde_json::Value::String(processed_value));
48 }
49 }
50
51 let prediction: Prediction = self.http.post_json("/v1/predictions", &request).await?;
52 Ok(prediction)
53 }
54
55 pub async fn get(&self, id: &str) -> Result<Prediction> {
57 let path = format!("/v1/predictions/{}", id);
58 let prediction: Prediction = self.http.get_json(&path).await?;
59 Ok(prediction)
60 }
61
62 pub async fn list(&self, cursor: Option<&str>) -> Result<PaginatedResponse<Prediction>> {
64 let path = match cursor {
65 Some(cursor) => cursor.to_string(),
66 None => "/v1/predictions".to_string(),
67 };
68
69 let response: PaginatedResponse<Prediction> = self.http.get_json(&path).await?;
70 Ok(response)
71 }
72
73 pub async fn cancel(&self, id: &str) -> Result<Prediction> {
75 let path = format!("/v1/predictions/{}/cancel", id);
76 let prediction: Prediction = self.http.post_empty_json(&path).await?;
77 Ok(prediction)
78 }
79
80 pub async fn wait_for_completion(
82 &self,
83 id: &str,
84 max_duration: Option<Duration>,
85 poll_interval: Option<Duration>,
86 ) -> Result<Prediction> {
87 let poll_interval = poll_interval.unwrap_or(Duration::from_millis(500));
88 let mut interval = interval(poll_interval);
89
90 let wait_future = async {
91 loop {
92 interval.tick().await;
93 let prediction = self.get(id).await?;
94
95 if prediction.status.is_terminal() {
96 if prediction.is_failed() {
97 return Err(Error::model_execution(
98 id,
99 prediction.error.clone(),
100 prediction.logs.clone(),
101 ));
102 }
103 return Ok(prediction);
104 }
105 }
106 };
107
108 match max_duration {
109 Some(duration) => timeout(duration, wait_future).await.map_err(|_| {
110 Error::Timeout(format!(
111 "Prediction {} did not complete within {:?}",
112 id, duration
113 ))
114 })?,
115 None => wait_future.await,
116 }
117 }
118}
119
120#[derive(Debug)]
122pub struct PredictionBuilder {
123 api: PredictionsApi,
124 request: CreatePredictionRequest,
125}
126
127impl PredictionBuilder {
128 pub fn new(api: PredictionsApi, version: impl Into<String>) -> Self {
130 Self {
131 api,
132 request: CreatePredictionRequest::new(version),
133 }
134 }
135
136 pub fn input<K, V>(mut self, key: K, value: V) -> Self
138 where
139 K: Into<String>,
140 V: Into<Value>,
141 {
142 self.request = self.request.with_input(key, value);
143 self
144 }
145
146 pub fn inputs(mut self, inputs: HashMap<String, Value>) -> Self {
148 for (key, value) in inputs {
149 self.request = self.request.with_input(key, value);
150 }
151 self
152 }
153
154 pub fn file_input<K>(mut self, key: K, file: FileInput) -> Self
156 where
157 K: Into<String>,
158 {
159 self.request.file_inputs.insert(key.into(), file);
161 self
162 }
163
164 pub fn file_input_with_strategy<K>(
166 mut self,
167 key: K,
168 file: FileInput,
169 strategy: FileEncodingStrategy,
170 ) -> Self
171 where
172 K: Into<String>,
173 {
174 self.request.file_inputs.insert(key.into(), file);
176 self.request.file_encoding_strategy = strategy;
177 self
178 }
179
180 pub fn webhook(mut self, webhook: impl Into<String>) -> Self {
182 self.request = self.request.with_webhook(webhook);
183 self
184 }
185
186 pub fn stream(mut self) -> Self {
188 self.request = self.request.with_streaming();
189 self
190 }
191
192 pub async fn send(self) -> Result<Prediction> {
194 self.api.create(self.request).await
195 }
196
197 pub async fn send_and_wait(self) -> Result<Prediction> {
199 let prediction = self.api.create(self.request).await?;
200 self.api
201 .wait_for_completion(&prediction.id, None, None)
202 .await
203 }
204
205 pub async fn send_and_wait_with_timeout(self, max_duration: Duration) -> Result<Prediction> {
207 let prediction = self.api.create(self.request).await?;
208 self.api
209 .wait_for_completion(&prediction.id, Some(max_duration), None)
210 .await
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::http::HttpClient;
218
219 fn create_test_api() -> PredictionsApi {
220 let http = HttpClient::new("test-token").unwrap();
221 PredictionsApi::new(http)
222 }
223
224 #[test]
225 fn test_prediction_builder() {
226 let api = create_test_api();
227 let builder = PredictionBuilder::new(api, "test-version")
228 .input("prompt", "test prompt")
229 .webhook("https://example.com/webhook")
230 .stream();
231
232 assert_eq!(builder.request.version, "test-version");
233 assert_eq!(
234 builder.request.input.get("prompt"),
235 Some(&Value::String("test prompt".to_string()))
236 );
237 assert_eq!(
238 builder.request.webhook,
239 Some("https://example.com/webhook".to_string())
240 );
241 assert_eq!(builder.request.stream, Some(true));
242 }
243}