#[cfg(test)]
mod tests;
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Stage {
Dev,
Staging,
Production,
}
impl Stage {
fn ordinal(self) -> u8 {
match self {
Stage::Dev => 0,
Stage::Staging => 1,
Stage::Production => 2,
}
}
pub fn as_str(self) -> &'static str {
match self {
Stage::Dev => "Dev",
Stage::Staging => "Staging",
Stage::Production => "Production",
}
}
}
impl std::fmt::Display for Stage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelVersion {
pub name: String,
pub version: String,
pub stage: Stage,
pub metadata: HashMap<String, String>,
pub created_at: DateTime<Utc>,
pub promoted_at: Option<DateTime<Utc>>,
pub path: String,
}
#[derive(Debug, Error)]
pub enum StagingError {
#[error("model not found: {name} v{version}")]
NotFound { name: String, version: String },
#[error("invalid transition from {from} to {to} for {name} v{version}")]
InvalidTransition { name: String, version: String, from: Stage, to: Stage },
#[error("model already exists: {name} v{version}")]
AlreadyExists { name: String, version: String },
}
pub type Result<T> = std::result::Result<T, StagingError>;
#[derive(Debug, Default)]
pub struct StagingRegistry {
models: HashMap<(String, String), ModelVersion>,
}
impl StagingRegistry {
pub fn new() -> Self {
Self { models: HashMap::new() }
}
pub fn register_model(&mut self, name: &str, version: &str, path: &str) -> ModelVersion {
let key = (name.to_string(), version.to_string());
let mv = ModelVersion {
name: name.to_string(),
version: version.to_string(),
stage: Stage::Dev,
metadata: HashMap::new(),
created_at: Utc::now(),
promoted_at: None,
path: path.to_string(),
};
self.models.entry(key).or_insert(mv).clone()
}
pub fn promote(&mut self, name: &str, version: &str, target: Stage) -> Result<ModelVersion> {
let key = (name.to_string(), version.to_string());
let mv = self.models.get_mut(&key).ok_or_else(|| StagingError::NotFound {
name: name.to_string(),
version: version.to_string(),
})?;
let current_ord = mv.stage.ordinal();
let target_ord = target.ordinal();
if target_ord != current_ord + 1 {
return Err(StagingError::InvalidTransition {
name: name.to_string(),
version: version.to_string(),
from: mv.stage,
to: target,
});
}
mv.stage = target;
mv.promoted_at = Some(Utc::now());
Ok(mv.clone())
}
pub fn demote(&mut self, name: &str, version: &str, target: Stage) -> Result<ModelVersion> {
let key = (name.to_string(), version.to_string());
let mv = self.models.get_mut(&key).ok_or_else(|| StagingError::NotFound {
name: name.to_string(),
version: version.to_string(),
})?;
let current_ord = mv.stage.ordinal();
let target_ord = target.ordinal();
if current_ord == 0 || target_ord != current_ord - 1 {
return Err(StagingError::InvalidTransition {
name: name.to_string(),
version: version.to_string(),
from: mv.stage,
to: target,
});
}
mv.stage = target;
mv.promoted_at = Some(Utc::now());
Ok(mv.clone())
}
pub fn get_latest(&self, name: &str, stage: Stage) -> Option<&ModelVersion> {
self.models
.values()
.filter(|mv| mv.name == name && mv.stage == stage)
.max_by_key(|mv| mv.created_at)
}
pub fn list_versions(&self, name: &str) -> Vec<&ModelVersion> {
let mut versions: Vec<&ModelVersion> =
self.models.values().filter(|mv| mv.name == name).collect();
versions.sort_by_key(|mv| mv.created_at);
versions
}
}