replicate_rust/
api_definitions.rs

1//! This module contains the definition of the API responses by the Replicate API.
2//! The responses the documented in the [HTTP API reference](https://replicate.com/docs/reference/http).
3//!
4//! The API responses are defined as structs that implement the `serde::Deserialize` trait.
5//!
6
7// Allow rustdoc::bare_urls for the whole module
8#![allow(rustdoc::bare_urls)]
9#![allow(missing_docs)]
10
11use serde::{Deserialize, Deserializer, Serialize};
12use std::collections::HashMap;
13
14/// If the object is empty, return None
15pub fn object_empty_as_none<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
16where
17    D: Deserializer<'de>,
18    for<'a> T: Deserialize<'a>,
19{
20    #[derive(Deserialize, Debug)]
21    #[serde(deny_unknown_fields)]
22    struct Empty {}
23
24    #[derive(Deserialize, Debug)]
25    #[serde(untagged)]
26    enum Aux<T> {
27        T(T),
28        Empty(Empty),
29        Null,
30    }
31
32    match Deserialize::deserialize(deserializer)? {
33        Aux::T(t) => Ok(Some(t)),
34        Aux::Empty(_) | Aux::Null => Ok(None),
35    }
36}
37
38
39/// GET https://api.replicate.com/v1/models/{model_owner}/{model_name}
40#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
41pub struct GetModel {
42    pub url: String,
43
44    pub owner: String,
45    pub name: String,
46    pub description: String,
47    pub visibility: String,
48
49    pub github_url: Option<String>,
50    pub paper_url: Option<String>,
51    pub license_url: Option<String>,
52
53    pub run_count: Option<u32>,
54
55    pub cover_image_url: Option<String>,
56
57    #[serde(deserialize_with = "object_empty_as_none")]
58    pub default_example: Option<GetPrediction>,
59
60    #[serde(deserialize_with = "object_empty_as_none")]
61    pub latest_version: Option<GetModelVersion>,
62}
63
64/// GET https://api.replicate.com/v1/collections/{collection_slug}
65#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
66pub struct GetCollectionModels {
67    pub name: String,
68    pub slug: String,
69
70    pub description: String,
71
72    pub models: Vec<GetModel>,
73}
74
75/// Prediction urls to iether cancel or get the prediction
76#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
77pub struct PredictionsUrls {
78    pub cancel: String,
79    pub get: String,
80}
81
82/// POST https://api.replicate.com/v1/predictions
83#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
84pub struct GetPrediction {
85    // Unique identifier of the prediction
86    pub id: String,
87
88    // Version of the model used for the prediction
89    pub version: String,
90
91    // Urls to cancel or get the prediction
92    pub urls: PredictionsUrls,
93
94    pub created_at: String,
95    pub started_at: Option<String>,
96    pub completed_at: Option<String>,
97
98    pub source: Option<PredictionSource>,
99
100    // Status of the prediction
101    pub status: PredictionStatus,
102
103    // Input and Outputs of the prediction
104    pub input: HashMap<String, serde_json::Value>,
105
106    // Either a vector of string or a simple string
107    pub output: Option<serde_json::Value>,
108
109    pub error: Option<String>,
110    pub logs: Option<String>,
111
112    pub metrics: Option<HashMap<String, serde_json::Value>>,
113}
114
115/// GET https://api.replicate.com/v1/trainings/{training_id}
116#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
117pub struct GetTraining {
118    pub id: String,
119    pub version: String,
120
121    pub status: PredictionStatus,
122
123    pub input: Option<HashMap<String, String>>,
124    pub output: Option<HashMap<String, String>>,
125
126    pub error: Option<String>,
127    pub logs: Option<String>,
128    pub webhook_completed: Option<String>,
129
130    pub started_at: Option<String>,
131    pub created_at: String,
132    pub completed_at: Option<String>,
133}
134
135/// POST https://api.replicate.com/v1/models/{model_owner}/{model_name}/versions/{version_id}/trainings
136#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
137pub struct CreateTraining {
138    pub id: String,
139    pub version: String,
140
141    pub status: PredictionStatus,
142
143    pub input: Option<HashMap<String, String>>,
144    pub output: Option<HashMap<String, String>>,
145
146    pub logs: Option<String>,
147
148    pub started_at: Option<String>,
149    pub created_at: String,
150    pub completed_at: Option<String>,
151}
152
153/// POST https://api.replicate.com/v1/predictions
154#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
155pub struct CreatePrediction {
156    pub id: String,
157    pub version: String,
158
159    pub urls: PredictionsUrls,
160
161    pub created_at: String,
162
163    pub status: PredictionStatus,
164
165    pub input: HashMap<String, serde_json::Value>,
166
167    pub error: Option<String>,
168
169    pub logs: Option<String>,
170}
171
172/// GET https://api.replicate.com/v1/models/{model_owner}/{model_name}/versions/{version_id}
173#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
174pub struct GetModelVersion {
175    pub id: String,
176    pub created_at: String,
177
178    pub cog_version: String,
179
180    pub openapi_schema: HashMap<String, serde_json::Value>,
181}
182
183/// Each item of the list of collections
184#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
185pub struct ListCollectionModelsItem {
186    pub name: String,
187    pub slug: String,
188    pub description: String,
189}
190
191/// GET https://api.replicate.com/v1/collections
192#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
193pub struct ListCollectionModels {
194    pub previous: Option<String>,
195    pub next: Option<String>,
196
197    pub results: Vec<ListCollectionModelsItem>,
198}
199
200/// Represents a prediction in the list of predictions
201#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
202pub struct PredictionsListItem {
203    pub id: String,
204    pub version: String,
205
206    pub urls: PredictionsUrls,
207
208    pub created_at: String,
209    pub started_at: String,
210    pub completed_at: Option<String>,
211
212    pub source: Option<PredictionSource>,
213
214    pub status: PredictionStatus,
215}
216
217/// GET https://api.replicate.com/v1/predictions
218#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
219pub struct ListPredictions {
220    pub previous: Option<String>,
221    pub next: Option<String>,
222
223    pub results: Vec<PredictionsListItem>,
224}
225
226/// GET https://api.replicate.com/v1/models/{model_owner}/{model_name}/versions
227#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
228pub struct ListModelVersions {
229    pub previous: Option<String>,
230
231    pub next: Option<String>,
232
233    pub results: Vec<GetModelVersion>,
234}
235
236/// Each item of the list of trainings
237#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
238pub struct ListTrainingItem {
239    pub id: String,
240
241    pub version: String,
242
243    pub urls: PredictionsUrls,
244
245    pub created_at: String,
246    pub started_at: String,
247    pub completed_at: String,
248
249    pub source: PredictionSource,
250    pub status: PredictionStatus,
251}
252
253/// GET https://api.replicate.com/v1/trainings
254#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
255pub struct ListTraining {
256    pub previous: Option<String>,
257    pub next: Option<String>,
258
259    pub results: Vec<ListTrainingItem>,
260}
261
262///////////////////////////////////////////////////////////
263///
264/// Implement Display for all the structs
265///
266///////////////////////////////////////////////////////////
267macro_rules! impl_display {
268    ($($t:ty),*) => ($(
269        impl std::fmt::Display for $t {
270            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271
272                match serde_json::to_string_pretty(&self) {
273                    Ok(formatted) => write!(f, "{:?}", formatted),
274                    Err(_) => write!(f, "{:?}", self),
275                }
276            }
277        }
278    )*)
279}
280
281impl_display! {
282    GetModel,
283    GetCollectionModels,
284    PredictionsUrls,
285    GetPrediction,
286    GetTraining,
287    CreateTraining,
288    CreatePrediction,
289    GetModelVersion,
290    ListCollectionModelsItem,
291    ListCollectionModels,
292    PredictionsListItem,
293    ListPredictions,
294    ListModelVersions,
295    ListTrainingItem,
296    ListTraining
297}
298
299///////////////////////////////////////////////////////////
300
301/// Source of the prediction, either from the API or from the web
302#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
303#[allow(non_camel_case_types)]
304pub enum PredictionSource {
305    api,
306    web,
307}
308
309/// Status of the prediction, either starting, processing, succeeded, failed or canceled
310#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
311#[allow(non_camel_case_types)]
312pub enum PredictionStatus {
313    starting,
314    processing,
315    succeeded,
316    failed,
317    canceled,
318}
319
320/// Events of the webhook, either start, output, logs or completed
321#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
322#[allow(non_camel_case_types)]
323pub enum WebhookEvents {
324    start,
325    output,
326    logs,
327    completed,
328}
329
330///////////////////////////////////////////////////////////
331///
332/// Implement Display for the enums
333///
334/// ///////////////////////////////////////////////////////
335
336macro_rules! impl_display {
337    ($($t:ty),*) => ($(
338        impl std::fmt::Display for $t {
339            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340
341                match serde_json::to_string_pretty(&self) {
342                    Ok(formatted) => write!(f, "{:?}", formatted),
343                    Err(_) => write!(f, "{:?}", self),
344                }
345            }
346        }
347    )*)
348}
349
350impl_display! {
351    PredictionSource,
352    PredictionStatus,
353    WebhookEvents
354}