1#![allow(rustdoc::bare_urls)]
9#![allow(missing_docs)]
10
11use serde::{Deserialize, Deserializer, Serialize};
12use std::collections::HashMap;
13
14pub fn object_empty_as_none<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
16where
17 D: Deserializer<'de>,
18 for<'a> T: Deserialize<'a>,
19{
20 #[derive(Deserialize, Debug)]
21 #[serde(deny_unknown_fields)]
22 struct Empty {}
23
24 #[derive(Deserialize, Debug)]
25 #[serde(untagged)]
26 enum Aux<T> {
27 T(T),
28 Empty(Empty),
29 Null,
30 }
31
32 match Deserialize::deserialize(deserializer)? {
33 Aux::T(t) => Ok(Some(t)),
34 Aux::Empty(_) | Aux::Null => Ok(None),
35 }
36}
37
38
39#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
41pub struct GetModel {
42 pub url: String,
43
44 pub owner: String,
45 pub name: String,
46 pub description: String,
47 pub visibility: String,
48
49 pub github_url: Option<String>,
50 pub paper_url: Option<String>,
51 pub license_url: Option<String>,
52
53 pub run_count: Option<u32>,
54
55 pub cover_image_url: Option<String>,
56
57 #[serde(deserialize_with = "object_empty_as_none")]
58 pub default_example: Option<GetPrediction>,
59
60 #[serde(deserialize_with = "object_empty_as_none")]
61 pub latest_version: Option<GetModelVersion>,
62}
63
64#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
66pub struct GetCollectionModels {
67 pub name: String,
68 pub slug: String,
69
70 pub description: String,
71
72 pub models: Vec<GetModel>,
73}
74
75#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
77pub struct PredictionsUrls {
78 pub cancel: String,
79 pub get: String,
80}
81
82#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
84pub struct GetPrediction {
85 pub id: String,
87
88 pub version: String,
90
91 pub urls: PredictionsUrls,
93
94 pub created_at: String,
95 pub started_at: Option<String>,
96 pub completed_at: Option<String>,
97
98 pub source: Option<PredictionSource>,
99
100 pub status: PredictionStatus,
102
103 pub input: HashMap<String, serde_json::Value>,
105
106 pub output: Option<serde_json::Value>,
108
109 pub error: Option<String>,
110 pub logs: Option<String>,
111
112 pub metrics: Option<HashMap<String, serde_json::Value>>,
113}
114
115#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
117pub struct GetTraining {
118 pub id: String,
119 pub version: String,
120
121 pub status: PredictionStatus,
122
123 pub input: Option<HashMap<String, String>>,
124 pub output: Option<HashMap<String, String>>,
125
126 pub error: Option<String>,
127 pub logs: Option<String>,
128 pub webhook_completed: Option<String>,
129
130 pub started_at: Option<String>,
131 pub created_at: String,
132 pub completed_at: Option<String>,
133}
134
135#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
137pub struct CreateTraining {
138 pub id: String,
139 pub version: String,
140
141 pub status: PredictionStatus,
142
143 pub input: Option<HashMap<String, String>>,
144 pub output: Option<HashMap<String, String>>,
145
146 pub logs: Option<String>,
147
148 pub started_at: Option<String>,
149 pub created_at: String,
150 pub completed_at: Option<String>,
151}
152
153#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
155pub struct CreatePrediction {
156 pub id: String,
157 pub version: String,
158
159 pub urls: PredictionsUrls,
160
161 pub created_at: String,
162
163 pub status: PredictionStatus,
164
165 pub input: HashMap<String, serde_json::Value>,
166
167 pub error: Option<String>,
168
169 pub logs: Option<String>,
170}
171
172#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
174pub struct GetModelVersion {
175 pub id: String,
176 pub created_at: String,
177
178 pub cog_version: String,
179
180 pub openapi_schema: HashMap<String, serde_json::Value>,
181}
182
183#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
185pub struct ListCollectionModelsItem {
186 pub name: String,
187 pub slug: String,
188 pub description: String,
189}
190
191#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
193pub struct ListCollectionModels {
194 pub previous: Option<String>,
195 pub next: Option<String>,
196
197 pub results: Vec<ListCollectionModelsItem>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
202pub struct PredictionsListItem {
203 pub id: String,
204 pub version: String,
205
206 pub urls: PredictionsUrls,
207
208 pub created_at: String,
209 pub started_at: String,
210 pub completed_at: Option<String>,
211
212 pub source: Option<PredictionSource>,
213
214 pub status: PredictionStatus,
215}
216
217#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
219pub struct ListPredictions {
220 pub previous: Option<String>,
221 pub next: Option<String>,
222
223 pub results: Vec<PredictionsListItem>,
224}
225
226#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
228pub struct ListModelVersions {
229 pub previous: Option<String>,
230
231 pub next: Option<String>,
232
233 pub results: Vec<GetModelVersion>,
234}
235
236#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
238pub struct ListTrainingItem {
239 pub id: String,
240
241 pub version: String,
242
243 pub urls: PredictionsUrls,
244
245 pub created_at: String,
246 pub started_at: String,
247 pub completed_at: String,
248
249 pub source: PredictionSource,
250 pub status: PredictionStatus,
251}
252
253#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
255pub struct ListTraining {
256 pub previous: Option<String>,
257 pub next: Option<String>,
258
259 pub results: Vec<ListTrainingItem>,
260}
261
262macro_rules! impl_display {
268 ($($t:ty),*) => ($(
269 impl std::fmt::Display for $t {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271
272 match serde_json::to_string_pretty(&self) {
273 Ok(formatted) => write!(f, "{:?}", formatted),
274 Err(_) => write!(f, "{:?}", self),
275 }
276 }
277 }
278 )*)
279}
280
281impl_display! {
282 GetModel,
283 GetCollectionModels,
284 PredictionsUrls,
285 GetPrediction,
286 GetTraining,
287 CreateTraining,
288 CreatePrediction,
289 GetModelVersion,
290 ListCollectionModelsItem,
291 ListCollectionModels,
292 PredictionsListItem,
293 ListPredictions,
294 ListModelVersions,
295 ListTrainingItem,
296 ListTraining
297}
298
299#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
303#[allow(non_camel_case_types)]
304pub enum PredictionSource {
305 api,
306 web,
307}
308
309#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
311#[allow(non_camel_case_types)]
312pub enum PredictionStatus {
313 starting,
314 processing,
315 succeeded,
316 failed,
317 canceled,
318}
319
320#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
322#[allow(non_camel_case_types)]
323pub enum WebhookEvents {
324 start,
325 output,
326 logs,
327 completed,
328}
329
330macro_rules! impl_display {
337 ($($t:ty),*) => ($(
338 impl std::fmt::Display for $t {
339 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340
341 match serde_json::to_string_pretty(&self) {
342 Ok(formatted) => write!(f, "{:?}", formatted),
343 Err(_) => write!(f, "{:?}", self),
344 }
345 }
346 }
347 )*)
348}
349
350impl_display! {
351 PredictionSource,
352 PredictionStatus,
353 WebhookEvents
354}