#![allow(rustdoc::bare_urls)]
#![allow(missing_docs)]
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
pub fn object_empty_as_none<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
where
D: Deserializer<'de>,
for<'a> T: Deserialize<'a>,
{
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct Empty {}
#[derive(Deserialize, Debug)]
#[serde(untagged)]
enum Aux<T> {
T(T),
Empty(Empty),
Null,
}
match Deserialize::deserialize(deserializer)? {
Aux::T(t) => Ok(Some(t)),
Aux::Empty(_) | Aux::Null => Ok(None),
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct OptionSerdeJson(pub Option<serde_json::Value>);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct GetModel {
pub url: String,
pub owner: String,
pub name: String,
pub description: String,
pub visibility: String,
pub github_url: Option<String>,
pub paper_url: Option<String>,
pub license_url: Option<String>,
pub run_count: Option<u32>,
pub cover_image_url: Option<String>,
#[serde(deserialize_with = "object_empty_as_none")]
pub default_example: Option<GetPrediction>,
#[serde(deserialize_with = "object_empty_as_none")]
pub latest_version: Option<GetModelVersion>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct GetCollectionModels {
pub name: String,
pub slug: String,
pub description: String,
pub models: Vec<GetModel>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct PredictionsUrls {
pub cancel: String,
pub get: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct GetPrediction {
pub id: String,
pub version: String,
pub urls: PredictionsUrls,
pub created_at: String,
pub started_at: Option<String>,
pub completed_at: Option<String>,
pub source: Option<PredictionSource>,
pub status: PredictionStatus,
pub input: HashMap<String, serde_json::Value>,
pub output: OptionSerdeJson,
pub error: Option<String>,
pub logs: Option<String>,
pub metrics: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct GetTraining {
pub id: String,
pub version: String,
pub status: PredictionStatus,
pub input: Option<HashMap<String, String>>,
pub output: Option<HashMap<String, String>>,
pub error: Option<String>,
pub logs: Option<String>,
pub webhook_completed: Option<String>,
pub started_at: Option<String>,
pub created_at: String,
pub completed_at: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct CreateTraining {
pub id: String,
pub version: String,
pub status: PredictionStatus,
pub input: Option<HashMap<String, String>>,
pub output: Option<HashMap<String, String>>,
pub logs: Option<String>,
pub started_at: Option<String>,
pub created_at: String,
pub completed_at: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct CreatePrediction {
pub id: String,
pub version: String,
pub urls: PredictionsUrls,
pub created_at: String,
pub status: PredictionStatus,
pub input: HashMap<String, serde_json::Value>,
pub error: Option<String>,
pub logs: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct GetModelVersion {
pub id: String,
pub created_at: String,
pub cog_version: String,
pub openapi_schema: HashMap<String, serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ListCollectionModelsItem {
pub name: String,
pub slug: String,
pub description: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ListCollectionModels {
pub previous: Option<String>,
pub next: Option<String>,
pub results: Vec<ListCollectionModelsItem>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct PredictionsListItem {
pub id: String,
pub version: String,
pub urls: PredictionsUrls,
pub created_at: String,
pub started_at: String,
pub completed_at: Option<String>,
pub source: Option<PredictionSource>,
pub status: PredictionStatus,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ListPredictions {
pub previous: Option<String>,
pub next: Option<String>,
pub results: Vec<PredictionsListItem>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ListModelVersions {
pub previous: Option<String>,
pub next: Option<String>,
pub results: Vec<GetModelVersion>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ListTrainingItem {
pub id: String,
pub version: String,
pub urls: PredictionsUrls,
pub created_at: String,
pub started_at: String,
pub completed_at: String,
pub source: PredictionSource,
pub status: PredictionStatus,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ListTraining {
pub previous: Option<String>,
pub next: Option<String>,
pub results: Vec<ListTrainingItem>,
}
macro_rules! impl_display {
($($t:ty),*) => ($(
impl std::fmt::Display for $t {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match serde_json::to_string_pretty(&self) {
Ok(formatted) => write!(f, "{:?}", formatted),
Err(_) => write!(f, "{:?}", self),
}
}
}
)*)
}
impl_display! {
GetModel,
GetCollectionModels,
PredictionsUrls,
GetPrediction,
GetTraining,
CreateTraining,
CreatePrediction,
GetModelVersion,
ListCollectionModelsItem,
ListCollectionModels,
PredictionsListItem,
ListPredictions,
ListModelVersions,
ListTrainingItem,
ListTraining
}
impl std::fmt::Display for OptionSerdeJson {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
Some(value) => match serde_json::to_string_pretty(value) {
Ok(formatted) => write!(f, "{:?}", formatted),
Err(_) => write!(f, "{:?}", value),
},
None => write!(f, "None"),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum PredictionSource {
api,
web,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum PredictionStatus {
starting,
processing,
succeeded,
failed,
canceled,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum WebhookEvents {
start,
output,
logs,
completed,
}
macro_rules! impl_display {
($($t:ty),*) => ($(
impl std::fmt::Display for $t {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match serde_json::to_string_pretty(&self) {
Ok(formatted) => write!(f, "{:?}", formatted),
Err(_) => write!(f, "{:?}", self),
}
}
}
)*)
}
impl_display! {
PredictionSource,
PredictionStatus,
WebhookEvents
}