replicate_client/api/
predictions.rs

1//! Predictions API implementation.
2
3use serde_json::Value;
4use std::collections::HashMap;
5use std::time::Duration;
6use tokio::time::{interval, timeout};
7
8use crate::api::files::{FilesApi, process_file_input};
9use crate::error::{Error, Result};
10use crate::http::HttpClient;
11use crate::models::{
12    common::PaginatedResponse,
13    file::{FileEncodingStrategy, FileInput},
14    prediction::{CreatePredictionRequest, Prediction},
15};
16
17/// API for managing predictions.
18#[derive(Debug, Clone)]
19pub struct PredictionsApi {
20    http: HttpClient,
21    files_api: Option<FilesApi>,
22}
23
24impl PredictionsApi {
25    /// Create a new predictions API instance.
26    pub fn new(http: HttpClient) -> Self {
27        Self {
28            http: http.clone(),
29            files_api: Some(FilesApi::new(http)),
30        }
31    }
32
33    /// Create a new prediction.
34    pub async fn create(&self, mut request: CreatePredictionRequest) -> Result<Prediction> {
35        // Process file inputs if any
36        if !request.file_inputs.is_empty() {
37            for (key, file_input) in request.file_inputs.iter() {
38                let processed_value = process_file_input(
39                    file_input,
40                    &request.file_encoding_strategy,
41                    self.files_api.as_ref(),
42                )
43                .await?;
44
45                request
46                    .input
47                    .insert(key.clone(), serde_json::Value::String(processed_value));
48            }
49        }
50
51        let prediction: Prediction = self.http.post_json("/v1/predictions", &request).await?;
52        Ok(prediction)
53    }
54
55    /// Get a prediction by ID.
56    pub async fn get(&self, id: &str) -> Result<Prediction> {
57        let path = format!("/v1/predictions/{}", id);
58        let prediction: Prediction = self.http.get_json(&path).await?;
59        Ok(prediction)
60    }
61
62    /// List predictions with optional pagination.
63    pub async fn list(&self, cursor: Option<&str>) -> Result<PaginatedResponse<Prediction>> {
64        let path = match cursor {
65            Some(cursor) => cursor.to_string(),
66            None => "/v1/predictions".to_string(),
67        };
68
69        let response: PaginatedResponse<Prediction> = self.http.get_json(&path).await?;
70        Ok(response)
71    }
72
73    /// Cancel a prediction.
74    pub async fn cancel(&self, id: &str) -> Result<Prediction> {
75        let path = format!("/v1/predictions/{}/cancel", id);
76        let prediction: Prediction = self.http.post_empty_json(&path).await?;
77        Ok(prediction)
78    }
79
80    /// Wait for a prediction to complete with polling.
81    pub async fn wait_for_completion(
82        &self,
83        id: &str,
84        max_duration: Option<Duration>,
85        poll_interval: Option<Duration>,
86    ) -> Result<Prediction> {
87        let poll_interval = poll_interval.unwrap_or(Duration::from_millis(500));
88        let mut interval = interval(poll_interval);
89
90        let wait_future = async {
91            loop {
92                interval.tick().await;
93                let prediction = self.get(id).await?;
94
95                if prediction.status.is_terminal() {
96                    if prediction.is_failed() {
97                        return Err(Error::model_execution(
98                            id,
99                            prediction.error.clone(),
100                            prediction.logs.clone(),
101                        ));
102                    }
103                    return Ok(prediction);
104                }
105            }
106        };
107
108        match max_duration {
109            Some(duration) => timeout(duration, wait_future).await.map_err(|_| {
110                Error::Timeout(format!(
111                    "Prediction {} did not complete within {:?}",
112                    id, duration
113                ))
114            })?,
115            None => wait_future.await,
116        }
117    }
118}
119
120/// Builder for creating predictions with a fluent API.
121#[derive(Debug)]
122pub struct PredictionBuilder {
123    api: PredictionsApi,
124    request: CreatePredictionRequest,
125}
126
127impl PredictionBuilder {
128    /// Create a new prediction builder.
129    pub fn new(api: PredictionsApi, version: impl Into<String>) -> Self {
130        Self {
131            api,
132            request: CreatePredictionRequest::new(version),
133        }
134    }
135
136    /// Add an input parameter.
137    pub fn input<K, V>(mut self, key: K, value: V) -> Self
138    where
139        K: Into<String>,
140        V: Into<Value>,
141    {
142        self.request = self.request.with_input(key, value);
143        self
144    }
145
146    /// Add multiple input parameters from a HashMap.
147    pub fn inputs(mut self, inputs: HashMap<String, Value>) -> Self {
148        for (key, value) in inputs {
149            self.request = self.request.with_input(key, value);
150        }
151        self
152    }
153
154    /// Add a file input parameter.
155    pub fn file_input<K>(mut self, key: K, file: FileInput) -> Self
156    where
157        K: Into<String>,
158    {
159        // Store the file input for later processing
160        self.request.file_inputs.insert(key.into(), file);
161        self
162    }
163
164    /// Add a file input with specific encoding strategy.
165    pub fn file_input_with_strategy<K>(
166        mut self,
167        key: K,
168        file: FileInput,
169        strategy: FileEncodingStrategy,
170    ) -> Self
171    where
172        K: Into<String>,
173    {
174        // Store the file input and strategy for later processing
175        self.request.file_inputs.insert(key.into(), file);
176        self.request.file_encoding_strategy = strategy;
177        self
178    }
179
180    /// Set a webhook URL.
181    pub fn webhook(mut self, webhook: impl Into<String>) -> Self {
182        self.request = self.request.with_webhook(webhook);
183        self
184    }
185
186    /// Enable streaming output.
187    pub fn stream(mut self) -> Self {
188        self.request = self.request.with_streaming();
189        self
190    }
191
192    /// Send the prediction request.
193    pub async fn send(self) -> Result<Prediction> {
194        self.api.create(self.request).await
195    }
196
197    /// Send the prediction request and wait for completion.
198    pub async fn send_and_wait(self) -> Result<Prediction> {
199        let prediction = self.api.create(self.request).await?;
200        self.api
201            .wait_for_completion(&prediction.id, None, None)
202            .await
203    }
204
205    /// Send the prediction request and wait for completion with custom timeout.
206    pub async fn send_and_wait_with_timeout(self, max_duration: Duration) -> Result<Prediction> {
207        let prediction = self.api.create(self.request).await?;
208        self.api
209            .wait_for_completion(&prediction.id, Some(max_duration), None)
210            .await
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use crate::http::HttpClient;
218
219    fn create_test_api() -> PredictionsApi {
220        let http = HttpClient::new("test-token").unwrap();
221        PredictionsApi::new(http)
222    }
223
224    #[test]
225    fn test_prediction_builder() {
226        let api = create_test_api();
227        let builder = PredictionBuilder::new(api, "test-version")
228            .input("prompt", "test prompt")
229            .webhook("https://example.com/webhook")
230            .stream();
231
232        assert_eq!(builder.request.version, "test-version");
233        assert_eq!(
234            builder.request.input.get("prompt"),
235            Some(&Value::String("test prompt".to_string()))
236        );
237        assert_eq!(
238            builder.request.webhook,
239            Some("https://example.com/webhook".to_string())
240        );
241        assert_eq!(builder.request.stream, Some(true));
242    }
243}