replicate_client/models/
prediction.rs

1//! Prediction-related types and structures.
2
3use crate::models::file::{FileEncodingStrategy, FileInput};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7
8/// Status of a prediction.
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "lowercase")]
11pub enum PredictionStatus {
12    /// The prediction is starting up
13    Starting,
14    /// The prediction is currently processing
15    Processing,
16    /// The prediction completed successfully
17    Succeeded,
18    /// The prediction failed
19    Failed,
20    /// The prediction was canceled
21    Canceled,
22}
23
24impl PredictionStatus {
25    /// Check if the prediction is in a terminal state
26    pub fn is_terminal(&self) -> bool {
27        matches!(self, Self::Succeeded | Self::Failed | Self::Canceled)
28    }
29
30    /// Check if the prediction is still running
31    pub fn is_running(&self) -> bool {
32        matches!(self, Self::Starting | Self::Processing)
33    }
34}
35
36/// URLs associated with a prediction.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PredictionUrls {
39    /// URL to fetch the prediction
40    pub get: String,
41    /// URL to cancel the prediction
42    pub cancel: String,
43    /// URL to stream the prediction output (if supported)
44    pub stream: Option<String>,
45}
46
47/// A prediction made by a model hosted on Replicate.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Prediction {
50    /// The unique ID of the prediction
51    pub id: String,
52
53    /// The model used to create the prediction (format: owner/name)
54    pub model: String,
55
56    /// The version ID of the model used
57    pub version: String,
58
59    /// The current status of the prediction
60    pub status: PredictionStatus,
61
62    /// The input parameters for the prediction
63    pub input: Option<HashMap<String, Value>>,
64
65    /// The output of the prediction (if completed)
66    pub output: Option<Value>,
67
68    /// Logs from the prediction execution
69    pub logs: Option<String>,
70
71    /// Error message if the prediction failed
72    pub error: Option<String>,
73
74    /// Metrics about the prediction performance
75    pub metrics: Option<HashMap<String, Value>>,
76
77    /// When the prediction was created
78    pub created_at: Option<String>,
79
80    /// When the prediction started processing
81    pub started_at: Option<String>,
82
83    /// When the prediction completed
84    pub completed_at: Option<String>,
85
86    /// URLs associated with the prediction
87    pub urls: Option<PredictionUrls>,
88}
89
90impl Prediction {
91    /// Check if the prediction is complete
92    pub fn is_complete(&self) -> bool {
93        self.status.is_terminal()
94    }
95
96    /// Check if the prediction succeeded
97    pub fn is_successful(&self) -> bool {
98        self.status == PredictionStatus::Succeeded
99    }
100
101    /// Check if the prediction failed
102    pub fn is_failed(&self) -> bool {
103        self.status == PredictionStatus::Failed
104    }
105
106    /// Check if the prediction was canceled
107    pub fn is_canceled(&self) -> bool {
108        self.status == PredictionStatus::Canceled
109    }
110}
111
112/// Request to create a new prediction.
113#[derive(Debug, Clone, Serialize)]
114pub struct CreatePredictionRequest {
115    /// The version ID of the model to run
116    pub version: String,
117
118    /// Input parameters for the model
119    pub input: HashMap<String, Value>,
120
121    /// Optional webhook URL for notifications
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub webhook: Option<String>,
124
125    /// Optional webhook URL for completion notifications
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub webhook_completed: Option<String>,
128
129    /// Events to filter for webhooks
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub webhook_events_filter: Option<Vec<String>>,
132
133    /// Enable streaming of output
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub stream: Option<bool>,
136
137    /// File inputs that need to be processed
138    #[serde(skip)]
139    pub file_inputs: HashMap<String, FileInput>,
140
141    /// File encoding strategy
142    #[serde(skip)]
143    pub file_encoding_strategy: FileEncodingStrategy,
144}
145
146impl CreatePredictionRequest {
147    /// Create a new prediction request
148    pub fn new(version: impl Into<String>) -> Self {
149        Self {
150            version: version.into(),
151            input: HashMap::new(),
152            webhook: None,
153            webhook_completed: None,
154            webhook_events_filter: None,
155            stream: None,
156            file_inputs: HashMap::new(),
157            file_encoding_strategy: FileEncodingStrategy::default(),
158        }
159    }
160
161    /// Add an input parameter
162    pub fn with_input(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
163        self.input.insert(key.into(), value.into());
164        self
165    }
166
167    /// Set the webhook URL
168    pub fn with_webhook(mut self, webhook: impl Into<String>) -> Self {
169        self.webhook = Some(webhook.into());
170        self
171    }
172
173    /// Enable streaming
174    pub fn with_streaming(mut self) -> Self {
175        self.stream = Some(true);
176        self
177    }
178}