use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type ModelId = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub name: String,
pub model_type: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub config: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metrics: Option<TrainingMetrics>,
#[serde(skip_serializing_if = "Option::is_none")]
pub architecture: Option<ArchitectureInfo>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
pub created_at: chrono::DateTime<chrono::Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub updated_at: Option<chrono::DateTime<chrono::Utc>>,
}
impl Default for ModelMetadata {
fn default() -> Self {
Self {
name: String::new(),
model_type: String::new(),
version: "1.0.0".to_string(),
description: None,
config: None,
metrics: None,
architecture: None,
tags: Vec::new(),
metadata: HashMap::new(),
created_at: chrono::Utc::now(),
updated_at: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub train_loss: f64,
pub val_loss: f64,
pub training_time: f64,
pub epochs: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_val_loss: Option<f64>,
#[serde(default)]
pub additional: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureInfo {
pub input_size: usize,
pub output_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_parameters: Option<usize>,
#[serde(default)]
pub details: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCheckpoint {
pub checkpoint_id: String,
pub model_id: ModelId,
pub epoch: usize,
pub step: usize,
pub loss: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub val_loss: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub optimizer_state: Option<serde_json::Value>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SearchFilter {
#[serde(skip_serializing_if = "Option::is_none")]
pub model_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_val_loss: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_val_loss: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created_after: Option<chrono::DateTime<chrono::Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created_before: Option<chrono::DateTime<chrono::Utc>>,
#[serde(default)]
pub metadata_filters: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum SimilarityMetric {
#[default]
Cosine,
Euclidean,
Dot,
}
impl std::fmt::Display for SimilarityMetric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cosine => write!(f, "cosine"),
Self::Euclidean => write!(f, "euclidean"),
Self::Dot => write!(f, "dot"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub model_id: ModelId,
pub score: f64,
pub metadata: ModelMetadata,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_metadata_serialization() {
let metadata = ModelMetadata {
name: "test-model".to_string(),
model_type: "NHITS".to_string(),
version: "1.0.0".to_string(),
description: Some("Test model".to_string()),
tags: vec!["test".to_string(), "neural".to_string()],
..Default::default()
};
let json = serde_json::to_string(&metadata).unwrap();
let deserialized: ModelMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(metadata.name, deserialized.name);
assert_eq!(metadata.model_type, deserialized.model_type);
assert_eq!(metadata.tags, deserialized.tags);
}
#[test]
fn test_search_filter_default() {
let filter = SearchFilter::default();
assert!(filter.model_type.is_none());
assert!(filter.tags.is_none());
assert!(filter.min_val_loss.is_none());
}
#[test]
fn test_similarity_metric_display() {
assert_eq!(SimilarityMetric::Cosine.to_string(), "cosine");
assert_eq!(SimilarityMetric::Euclidean.to_string(), "euclidean");
assert_eq!(SimilarityMetric::Dot.to_string(), "dot");
}
}