use std::any::Any;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::util::epoch_millis_for_ordering;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StatsModelId(pub String);
impl StatsModelId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn generate() -> Self {
let ts = epoch_millis_for_ordering();
Self(format!("stats-{:x}", ts))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for StatsModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub trait Model: Send + Sync {
fn model_type(&self) -> ModelType;
fn version(&self) -> &ModelVersion;
fn created_at(&self) -> u64;
fn metadata(&self) -> &ModelMetadata;
fn as_any(&self) -> &dyn Any;
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelType {
ActionScore,
OptimalParams,
Custom(String),
}
impl ModelType {
pub fn dir_name(&self) -> &str {
match self {
Self::ActionScore => "action_scores",
Self::OptimalParams => "optimal_params",
Self::Custom(name) => name,
}
}
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ActionScore => write!(f, "ActionScore"),
Self::OptimalParams => write!(f, "OptimalParams"),
Self::Custom(name) => write!(f, "Custom({})", name),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelVersion {
pub major: u32,
pub minor: u32,
pub source_ids: Vec<String>,
}
impl ModelVersion {
pub fn new(major: u32, minor: u32) -> Self {
Self {
major,
minor,
source_ids: Vec::new(),
}
}
pub fn with_sources(major: u32, minor: u32, source_ids: Vec<String>) -> Self {
Self {
major,
minor,
source_ids,
}
}
}
impl std::fmt::Display for ModelVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.major, self.minor)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelMetadata {
pub name: Option<String>,
pub description: Option<String>,
pub tags: HashMap<String, String>,
}
impl ModelMetadata {
pub fn new() -> Self {
Self::default()
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.tags.insert(key.into(), value.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stats_model_id_generate() {
let id1 = StatsModelId::generate();
let id2 = StatsModelId::generate();
assert!(!id1.0.is_empty());
assert!(!id2.0.is_empty());
assert!(id1.as_str().starts_with("stats-"));
}
#[test]
fn test_model_type_dir_name() {
assert_eq!(ModelType::ActionScore.dir_name(), "action_scores");
assert_eq!(ModelType::OptimalParams.dir_name(), "optimal_params");
assert_eq!(
ModelType::Custom("my_model".to_string()).dir_name(),
"my_model"
);
}
#[test]
fn test_model_version() {
let v = ModelVersion::new(1, 2);
assert_eq!(format!("{}", v), "1.2");
}
#[test]
fn test_model_metadata_builder() {
let meta = ModelMetadata::new()
.with_name("test")
.with_description("desc")
.with_tag("env", "prod");
assert_eq!(meta.name.as_deref(), Some("test"));
assert_eq!(meta.description.as_deref(), Some("desc"));
assert_eq!(meta.tags.get("env").map(|s| s.as_str()), Some("prod"));
}
}