use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonSchema {
pub title: String,
pub description: Option<String>,
#[serde(rename = "type")]
pub schema_type: String,
pub properties: HashMap<String, JsonSchemaProperty>,
pub required: Vec<String>,
pub additional_properties: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonSchemaProperty {
#[serde(rename = "type")]
pub property_type: Option<String>,
pub description: Option<String>,
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
pub format: Option<String>,
pub minimum: Option<f64>,
pub maximum: Option<f64>,
pub default: Option<serde_json::Value>,
pub items: Option<Box<JsonSchemaProperty>>,
#[serde(rename = "$ref")]
pub reference: Option<String>,
pub properties: Option<HashMap<String, JsonSchemaProperty>>,
pub required: Option<Vec<String>>,
pub one_of: Option<Vec<JsonSchemaProperty>>,
pub all_of: Option<Vec<JsonSchemaProperty>>,
}
pub struct SchemaRegistry;
impl SchemaRegistry {
pub fn resnet_schema() -> JsonSchema {
let mut properties = HashMap::new();
properties.insert(
"num_layers".to_string(),
JsonSchemaProperty {
property_type: Some("integer".to_string()),
description: Some("Number of layers in the ResNet model".to_string()),
enum_values: Some(vec![
"18".to_string(),
"34".to_string(),
"50".to_string(),
"101".to_string(),
"152".to_string(),
]),
format: None,
minimum: Some(18.0),
maximum: Some(152.0),
default: Some(serde_json::json!(50)),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
properties.insert(
"in_channels".to_string(),
JsonSchemaProperty {
property_type: Some("integer".to_string()),
description: Some("Number of input channels".to_string()),
enum_values: None,
format: None,
minimum: Some(1.0),
maximum: None,
default: Some(serde_json::json!(3)),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
properties.insert(
"num_classes".to_string(),
JsonSchemaProperty {
property_type: Some("integer".to_string()),
description: Some("Number of output classes".to_string()),
enum_values: None,
format: None,
minimum: Some(1.0),
maximum: None,
default: Some(serde_json::json!(1000)),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
properties.insert(
"zero_init_residual".to_string(),
JsonSchemaProperty {
property_type: Some("boolean".to_string()),
description: Some(
"Whether to initialize residual connections with zero".to_string(),
),
enum_values: None,
format: None,
minimum: None,
maximum: None,
default: Some(serde_json::json!(false)),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
JsonSchema {
title: "ResNet Configuration".to_string(),
description: Some("Configuration for ResNet models".to_string()),
schema_type: "object".to_string(),
properties,
required: vec![
"num_layers".to_string(),
"in_channels".to_string(),
"num_classes".to_string(),
],
additional_properties: Some(false),
}
}
pub fn vit_schema() -> JsonSchema {
let mut properties = HashMap::new();
let make_prop = |desc: &str, default: serde_json::Value| JsonSchemaProperty {
property_type: Some("integer".to_string()),
description: Some(desc.to_string()),
enum_values: None,
format: None,
minimum: Some(1.0),
maximum: None,
default: Some(default),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
};
properties.insert(
"image_size".to_string(),
make_prop("Size of the input image (square)", serde_json::json!(224)),
);
properties.insert(
"patch_size".to_string(),
make_prop(
"Size of the patches to divide the image into",
serde_json::json!(16),
),
);
properties.insert(
"hidden_size".to_string(),
make_prop(
"Dimension of transformer hidden layers",
serde_json::json!(768),
),
);
properties.insert(
"num_layers".to_string(),
make_prop("Number of transformer layers", serde_json::json!(12)),
);
properties.insert(
"num_heads".to_string(),
make_prop("Number of attention heads", serde_json::json!(12)),
);
properties.insert(
"mlp_dim".to_string(),
make_prop("Dimension of the MLP layers", serde_json::json!(3072)),
);
properties.insert(
"dropout_rate".to_string(),
JsonSchemaProperty {
property_type: Some("number".to_string()),
description: Some("Dropout rate".to_string()),
enum_values: None,
format: None,
minimum: Some(0.0),
maximum: Some(1.0),
default: Some(serde_json::json!(0.1)),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
properties.insert(
"attention_dropout_rate".to_string(),
JsonSchemaProperty {
property_type: Some("number".to_string()),
description: Some("Attention dropout rate".to_string()),
enum_values: None,
format: None,
minimum: Some(0.0),
maximum: Some(1.0),
default: Some(serde_json::json!(0.0)),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
properties.insert(
"classifier".to_string(),
JsonSchemaProperty {
property_type: Some("string".to_string()),
description: Some("Type of classifier ('token' or 'gap')".to_string()),
enum_values: Some(vec!["token".to_string(), "gap".to_string()]),
format: None,
minimum: None,
maximum: None,
default: Some(serde_json::json!("token")),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
properties.insert(
"include_top".to_string(),
JsonSchemaProperty {
property_type: Some("boolean".to_string()),
description: Some("Whether to include the classification head".to_string()),
enum_values: None,
format: None,
minimum: None,
maximum: None,
default: Some(serde_json::json!(true)),
items: None,
reference: None,
properties: None,
required: None,
one_of: None,
all_of: None,
},
);
JsonSchema {
title: "Vision Transformer Configuration".to_string(),
description: Some("Configuration for Vision Transformer models".to_string()),
schema_type: "object".to_string(),
properties,
required: vec![
"image_size".to_string(),
"patch_size".to_string(),
"hidden_size".to_string(),
"num_heads".to_string(),
],
additional_properties: Some(false),
}
}
pub fn get_all_schemas() -> HashMap<String, JsonSchema> {
let mut schemas = HashMap::new();
schemas.insert("resnet".to_string(), Self::resnet_schema());
schemas.insert("vit".to_string(), Self::vit_schema());
schemas
}
}