use crate::engine::SynaDB;
use crate::error::Result;
use crate::types::Atom;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ModelStage {
#[default]
Development,
Staging,
Production,
Archived,
}
impl std::fmt::Display for ModelStage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelStage::Development => write!(f, "Development"),
ModelStage::Staging => write!(f, "Staging"),
ModelStage::Production => write!(f, "Production"),
ModelStage::Archived => write!(f, "Archived"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelVersion {
pub name: String,
pub version: u32,
pub created_at: u64,
pub checksum: String,
pub size_bytes: u64,
pub metadata: HashMap<String, String>,
pub stage: ModelStage,
}
impl ModelVersion {
pub fn new(
name: String,
version: u32,
created_at: u64,
checksum: String,
size_bytes: u64,
metadata: HashMap<String, String>,
stage: ModelStage,
) -> Self {
Self {
name,
version,
created_at,
checksum,
size_bytes,
metadata,
stage,
}
}
}
pub struct ModelRegistry {
db: SynaDB,
}
impl ModelRegistry {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let db = SynaDB::new(path)?;
Ok(Self { db })
}
pub fn db(&self) -> &SynaDB {
&self.db
}
pub fn db_mut(&mut self) -> &mut SynaDB {
&mut self.db
}
pub fn save_model(
&mut self,
name: &str,
data: &[u8],
metadata: HashMap<String, String>,
) -> Result<ModelVersion> {
let mut hasher = Sha256::new();
hasher.update(data);
let checksum = format!("{:x}", hasher.finalize());
let version = self.get_next_version(name);
let model_version = ModelVersion {
name: name.to_string(),
version,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
checksum: checksum.clone(),
size_bytes: data.len() as u64,
metadata,
stage: ModelStage::Development,
};
let data_key = format!("model/{}/v{}/data", name, version);
self.db.append(&data_key, Atom::Bytes(data.to_vec()))?;
let meta_key = format!("model/{}/v{}/meta", name, version);
let meta_json = serde_json::to_string(&model_version)
.map_err(|e| crate::error::SynaError::InvalidPath(e.to_string()))?;
self.db.append(&meta_key, Atom::Text(meta_json))?;
Ok(model_version)
}
fn get_next_version(&self, name: &str) -> u32 {
let prefix = format!("model/{}/v", name);
let versions: Vec<u32> = self
.db
.keys()
.iter()
.filter(|k| k.starts_with(&prefix) && k.ends_with("/meta"))
.filter_map(|k| k.strip_prefix(&prefix)?.strip_suffix("/meta")?.parse().ok())
.collect();
versions.into_iter().max().unwrap_or(0) + 1
}
fn get_latest_version(&self, name: &str) -> Result<u32> {
let prefix = format!("model/{}/v", name);
self.db
.keys()
.iter()
.filter(|k| k.starts_with(&prefix) && k.ends_with("/meta"))
.filter_map(|k| k.strip_prefix(&prefix)?.strip_suffix("/meta")?.parse().ok())
.max()
.ok_or_else(|| crate::error::SynaError::ModelNotFound(name.to_string()))
}
pub fn load_model(
&mut self,
name: &str,
version: Option<u32>,
) -> Result<(Vec<u8>, ModelVersion)> {
let v = match version {
Some(v) => v,
None => self.get_latest_version(name)?,
};
let meta_key = format!("model/{}/v{}/meta", name, v);
let meta_json = match self.db.get(&meta_key)? {
Some(Atom::Text(s)) => s,
_ => return Err(crate::error::SynaError::ModelNotFound(name.to_string())),
};
let model_version: ModelVersion = serde_json::from_str(&meta_json)
.map_err(|e| crate::error::SynaError::InvalidPath(e.to_string()))?;
let data_key = format!("model/{}/v{}/data", name, v);
let data = match self.db.get(&data_key)? {
Some(Atom::Bytes(b)) => b,
_ => return Err(crate::error::SynaError::ModelNotFound(name.to_string())),
};
let mut hasher = Sha256::new();
hasher.update(&data);
let computed = format!("{:x}", hasher.finalize());
if computed != model_version.checksum {
return Err(crate::error::SynaError::ChecksumMismatch {
expected: model_version.checksum.clone(),
got: computed,
});
}
Ok((data, model_version))
}
pub fn list_versions(&mut self, name: &str) -> Result<Vec<ModelVersion>> {
let prefix = format!("model/{}/v", name);
let mut versions = Vec::new();
for key in self.db.keys() {
if key.starts_with(&prefix) && key.ends_with("/meta") {
if let Some(Atom::Text(json)) = self.db.get(&key)? {
let v: ModelVersion = serde_json::from_str(&json)
.map_err(|e| crate::error::SynaError::InvalidPath(e.to_string()))?;
versions.push(v);
}
}
}
versions.sort_by_key(|v| v.version);
Ok(versions)
}
pub fn set_stage(&mut self, name: &str, version: u32, stage: ModelStage) -> Result<()> {
let meta_key = format!("model/{}/v{}/meta", name, version);
let meta_json = match self.db.get(&meta_key)? {
Some(Atom::Text(s)) => s,
_ => return Err(crate::error::SynaError::ModelNotFound(name.to_string())),
};
let mut model_version: ModelVersion = serde_json::from_str(&meta_json)
.map_err(|e| crate::error::SynaError::InvalidPath(e.to_string()))?;
model_version.stage = stage;
let updated_json = serde_json::to_string(&model_version)
.map_err(|e| crate::error::SynaError::InvalidPath(e.to_string()))?;
self.db.append(&meta_key, Atom::Text(updated_json))?;
Ok(())
}
pub fn get_production(&mut self, name: &str) -> Result<Option<ModelVersion>> {
let versions = self.list_versions(name)?;
Ok(versions
.into_iter()
.find(|v| v.stage == ModelStage::Production))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_model_stage_default() {
let stage = ModelStage::default();
assert_eq!(stage, ModelStage::Development);
}
#[test]
fn test_model_stage_display() {
assert_eq!(format!("{}", ModelStage::Development), "Development");
assert_eq!(format!("{}", ModelStage::Staging), "Staging");
assert_eq!(format!("{}", ModelStage::Production), "Production");
assert_eq!(format!("{}", ModelStage::Archived), "Archived");
}
#[test]
fn test_model_version_new() {
let mut metadata = HashMap::new();
metadata.insert("accuracy".to_string(), "0.95".to_string());
let version = ModelVersion::new(
"test_model".to_string(),
1,
1234567890,
"abc123".to_string(),
1024,
metadata.clone(),
ModelStage::Development,
);
assert_eq!(version.name, "test_model");
assert_eq!(version.version, 1);
assert_eq!(version.created_at, 1234567890);
assert_eq!(version.checksum, "abc123");
assert_eq!(version.size_bytes, 1024);
assert_eq!(version.metadata.get("accuracy"), Some(&"0.95".to_string()));
assert_eq!(version.stage, ModelStage::Development);
}
#[test]
fn test_model_registry_new() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_registry.db");
let registry = ModelRegistry::new(&db_path);
assert!(registry.is_ok());
}
#[test]
fn test_model_stage_serialization() {
let stage = ModelStage::Production;
let serialized = serde_json::to_string(&stage).unwrap();
let deserialized: ModelStage = serde_json::from_str(&serialized).unwrap();
assert_eq!(stage, deserialized);
}
#[test]
fn test_model_version_serialization() {
let mut metadata = HashMap::new();
metadata.insert("key".to_string(), "value".to_string());
let version = ModelVersion::new(
"model".to_string(),
1,
1000,
"checksum".to_string(),
512,
metadata,
ModelStage::Staging,
);
let serialized = serde_json::to_string(&version).unwrap();
let deserialized: ModelVersion = serde_json::from_str(&serialized).unwrap();
assert_eq!(version.name, deserialized.name);
assert_eq!(version.version, deserialized.version);
assert_eq!(version.stage, deserialized.stage);
}
#[test]
fn test_save_model_basic() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_save.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data = vec![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut metadata = HashMap::new();
metadata.insert("accuracy".to_string(), "0.95".to_string());
let version = registry
.save_model("test_model", &model_data, metadata)
.unwrap();
assert_eq!(version.name, "test_model");
assert_eq!(version.version, 1);
assert_eq!(version.size_bytes, 10);
assert_eq!(version.stage, ModelStage::Development);
assert!(!version.checksum.is_empty());
assert_eq!(version.metadata.get("accuracy"), Some(&"0.95".to_string()));
}
#[test]
fn test_save_model_auto_versioning() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_versioning.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data = vec![1u8, 2, 3];
let metadata = HashMap::new();
let v1 = registry
.save_model("model", &model_data, metadata.clone())
.unwrap();
assert_eq!(v1.version, 1);
let v2 = registry
.save_model("model", &model_data, metadata.clone())
.unwrap();
assert_eq!(v2.version, 2);
let v3 = registry.save_model("model", &model_data, metadata).unwrap();
assert_eq!(v3.version, 3);
}
#[test]
fn test_save_model_checksum_consistency() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_checksum.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data = vec![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let metadata = HashMap::new();
let v1 = registry
.save_model("model", &model_data, metadata.clone())
.unwrap();
let v2 = registry.save_model("model", &model_data, metadata).unwrap();
assert_eq!(v1.checksum, v2.checksum);
}
#[test]
fn test_save_model_different_data_different_checksum() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_diff_checksum.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data1 = vec![0u8, 1, 2, 3];
let model_data2 = vec![4u8, 5, 6, 7];
let metadata = HashMap::new();
let v1 = registry
.save_model("model", &model_data1, metadata.clone())
.unwrap();
let v2 = registry
.save_model("model", &model_data2, metadata)
.unwrap();
assert_ne!(v1.checksum, v2.checksum);
}
#[test]
fn test_save_model_multiple_models() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_multi_model.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
let metadata = HashMap::new();
let v1 = registry
.save_model("model_a", &data, metadata.clone())
.unwrap();
let v2 = registry
.save_model("model_b", &data, metadata.clone())
.unwrap();
let v3 = registry.save_model("model_a", &data, metadata).unwrap();
assert_eq!(v1.version, 1);
assert_eq!(v2.version, 1);
assert_eq!(v3.version, 2);
}
#[test]
fn test_save_model_empty_data() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_empty.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data: Vec<u8> = vec![];
let metadata = HashMap::new();
let version = registry
.save_model("empty_model", &model_data, metadata)
.unwrap();
assert_eq!(version.size_bytes, 0);
assert!(!version.checksum.is_empty()); }
#[test]
fn test_load_model_basic() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_load.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data = vec![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut metadata = HashMap::new();
metadata.insert("accuracy".to_string(), "0.95".to_string());
let saved_version = registry
.save_model("test_model", &model_data, metadata)
.unwrap();
let (loaded_data, loaded_version) = registry.load_model("test_model", None).unwrap();
assert_eq!(loaded_data, model_data);
assert_eq!(loaded_version.name, saved_version.name);
assert_eq!(loaded_version.version, saved_version.version);
assert_eq!(loaded_version.checksum, saved_version.checksum);
assert_eq!(loaded_version.size_bytes, saved_version.size_bytes);
}
#[test]
fn test_load_model_specific_version() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_load_version.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data_v1 = vec![1u8, 2, 3];
let data_v2 = vec![4u8, 5, 6, 7];
let metadata = HashMap::new();
registry
.save_model("model", &data_v1, metadata.clone())
.unwrap();
registry.save_model("model", &data_v2, metadata).unwrap();
let (loaded_data, loaded_version) = registry.load_model("model", Some(1)).unwrap();
assert_eq!(loaded_data, data_v1);
assert_eq!(loaded_version.version, 1);
let (loaded_data, loaded_version) = registry.load_model("model", Some(2)).unwrap();
assert_eq!(loaded_data, data_v2);
assert_eq!(loaded_version.version, 2);
}
#[test]
fn test_load_model_latest_version() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_load_latest.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data_v1 = vec![1u8, 2, 3];
let data_v2 = vec![4u8, 5, 6, 7];
let data_v3 = vec![8u8, 9, 10, 11, 12];
let metadata = HashMap::new();
registry
.save_model("model", &data_v1, metadata.clone())
.unwrap();
registry
.save_model("model", &data_v2, metadata.clone())
.unwrap();
registry.save_model("model", &data_v3, metadata).unwrap();
let (loaded_data, loaded_version) = registry.load_model("model", None).unwrap();
assert_eq!(loaded_data, data_v3);
assert_eq!(loaded_version.version, 3);
}
#[test]
fn test_load_model_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_not_found.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let result = registry.load_model("nonexistent", None);
assert!(result.is_err());
}
#[test]
fn test_load_model_version_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_version_not_found.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data = vec![1u8, 2, 3];
let metadata = HashMap::new();
registry.save_model("model", &model_data, metadata).unwrap();
let result = registry.load_model("model", Some(999));
assert!(result.is_err());
}
#[test]
fn test_load_model_checksum_verification() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_checksum_verify.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data = vec![0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let metadata = HashMap::new();
let saved_version = registry
.save_model("test_model", &model_data, metadata)
.unwrap();
let (loaded_data, loaded_version) = registry.load_model("test_model", None).unwrap();
let mut hasher = Sha256::new();
hasher.update(&loaded_data);
let computed_checksum = format!("{:x}", hasher.finalize());
assert_eq!(computed_checksum, saved_version.checksum);
assert_eq!(computed_checksum, loaded_version.checksum);
}
#[test]
fn test_load_model_empty_data() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_load_empty.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data: Vec<u8> = vec![];
let metadata = HashMap::new();
registry
.save_model("empty_model", &model_data, metadata)
.unwrap();
let (loaded_data, loaded_version) = registry.load_model("empty_model", None).unwrap();
assert!(loaded_data.is_empty());
assert_eq!(loaded_version.size_bytes, 0);
}
#[test]
fn test_load_model_preserves_metadata() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_load_metadata.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let model_data = vec![1u8, 2, 3];
let mut metadata = HashMap::new();
metadata.insert("accuracy".to_string(), "0.95".to_string());
metadata.insert("framework".to_string(), "pytorch".to_string());
metadata.insert("description".to_string(), "Test model".to_string());
registry
.save_model("model", &model_data, metadata.clone())
.unwrap();
let (_, loaded_version) = registry.load_model("model", None).unwrap();
assert_eq!(
loaded_version.metadata.get("accuracy"),
Some(&"0.95".to_string())
);
assert_eq!(
loaded_version.metadata.get("framework"),
Some(&"pytorch".to_string())
);
assert_eq!(
loaded_version.metadata.get("description"),
Some(&"Test model".to_string())
);
}
#[test]
fn test_list_versions_empty() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_empty.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let versions = registry.list_versions("nonexistent").unwrap();
assert!(versions.is_empty());
}
#[test]
fn test_list_versions_single() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_single.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
let versions = registry.list_versions("model").unwrap();
assert_eq!(versions.len(), 1);
assert_eq!(versions[0].version, 1);
assert_eq!(versions[0].name, "model");
}
#[test]
fn test_list_versions_multiple() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_multiple.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
registry.save_model("model", &data, HashMap::new()).unwrap();
registry.save_model("model", &data, HashMap::new()).unwrap();
let versions = registry.list_versions("model").unwrap();
assert_eq!(versions.len(), 3);
assert_eq!(versions[0].version, 1);
assert_eq!(versions[1].version, 2);
assert_eq!(versions[2].version, 3);
}
#[test]
fn test_list_versions_sorted() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_sorted.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
for _ in 0..5 {
registry.save_model("model", &data, HashMap::new()).unwrap();
}
let versions = registry.list_versions("model").unwrap();
for i in 1..versions.len() {
assert!(versions[i - 1].version < versions[i].version);
}
}
#[test]
fn test_list_versions_multiple_models() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_multi_model.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry
.save_model("model_a", &data, HashMap::new())
.unwrap();
registry
.save_model("model_a", &data, HashMap::new())
.unwrap();
registry
.save_model("model_b", &data, HashMap::new())
.unwrap();
let versions_a = registry.list_versions("model_a").unwrap();
assert_eq!(versions_a.len(), 2);
assert!(versions_a.iter().all(|v| v.name == "model_a"));
let versions_b = registry.list_versions("model_b").unwrap();
assert_eq!(versions_b.len(), 1);
assert!(versions_b.iter().all(|v| v.name == "model_b"));
}
#[test]
fn test_set_stage_basic() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_set_stage.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
let version = registry.save_model("model", &data, HashMap::new()).unwrap();
assert_eq!(version.stage, ModelStage::Development);
registry.set_stage("model", 1, ModelStage::Staging).unwrap();
let (_, loaded) = registry.load_model("model", Some(1)).unwrap();
assert_eq!(loaded.stage, ModelStage::Staging);
}
#[test]
fn test_set_stage_to_production() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_set_stage_prod.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
registry
.set_stage("model", 1, ModelStage::Production)
.unwrap();
let (_, loaded) = registry.load_model("model", Some(1)).unwrap();
assert_eq!(loaded.stage, ModelStage::Production);
}
#[test]
fn test_set_stage_to_archived() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_set_stage_archived.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
registry
.set_stage("model", 1, ModelStage::Archived)
.unwrap();
let (_, loaded) = registry.load_model("model", Some(1)).unwrap();
assert_eq!(loaded.stage, ModelStage::Archived);
}
#[test]
fn test_set_stage_model_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_set_stage_not_found.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let result = registry.set_stage("nonexistent", 1, ModelStage::Production);
assert!(result.is_err());
}
#[test]
fn test_set_stage_version_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_set_stage_version_not_found.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
let result = registry.set_stage("model", 999, ModelStage::Production);
assert!(result.is_err());
}
#[test]
fn test_set_stage_preserves_other_metadata() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_set_stage_preserves.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
let mut metadata = HashMap::new();
metadata.insert("accuracy".to_string(), "0.95".to_string());
metadata.insert("framework".to_string(), "pytorch".to_string());
let original = registry.save_model("model", &data, metadata).unwrap();
registry
.set_stage("model", 1, ModelStage::Production)
.unwrap();
let (_, loaded) = registry.load_model("model", Some(1)).unwrap();
assert_eq!(loaded.stage, ModelStage::Production);
assert_eq!(loaded.name, original.name);
assert_eq!(loaded.version, original.version);
assert_eq!(loaded.checksum, original.checksum);
assert_eq!(loaded.size_bytes, original.size_bytes);
assert_eq!(loaded.metadata.get("accuracy"), Some(&"0.95".to_string()));
assert_eq!(
loaded.metadata.get("framework"),
Some(&"pytorch".to_string())
);
}
#[test]
fn test_set_stage_multiple_versions() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_set_stage_multi.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
registry.save_model("model", &data, HashMap::new()).unwrap();
registry.save_model("model", &data, HashMap::new()).unwrap();
registry
.set_stage("model", 1, ModelStage::Archived)
.unwrap();
registry.set_stage("model", 2, ModelStage::Staging).unwrap();
registry
.set_stage("model", 3, ModelStage::Production)
.unwrap();
let (_, v1) = registry.load_model("model", Some(1)).unwrap();
let (_, v2) = registry.load_model("model", Some(2)).unwrap();
let (_, v3) = registry.load_model("model", Some(3)).unwrap();
assert_eq!(v1.stage, ModelStage::Archived);
assert_eq!(v2.stage, ModelStage::Staging);
assert_eq!(v3.stage, ModelStage::Production);
}
#[test]
fn test_get_production_none() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_prod_none.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
let prod = registry.get_production("model").unwrap();
assert!(prod.is_none());
}
#[test]
fn test_get_production_exists() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_prod_exists.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
registry
.set_stage("model", 1, ModelStage::Production)
.unwrap();
let prod = registry.get_production("model").unwrap();
assert!(prod.is_some());
assert_eq!(prod.unwrap().version, 1);
}
#[test]
fn test_get_production_multiple_versions() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_prod_multi.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
registry.save_model("model", &data, HashMap::new()).unwrap();
registry.save_model("model", &data, HashMap::new()).unwrap();
registry
.set_stage("model", 2, ModelStage::Production)
.unwrap();
let prod = registry.get_production("model").unwrap();
assert!(prod.is_some());
assert_eq!(prod.unwrap().version, 2);
}
#[test]
fn test_get_production_nonexistent_model() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_prod_nonexistent.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let prod = registry.get_production("nonexistent").unwrap();
assert!(prod.is_none());
}
#[test]
fn test_get_production_after_stage_change() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_prod_change.db");
let mut registry = ModelRegistry::new(&db_path).unwrap();
let data = vec![1u8, 2, 3];
registry.save_model("model", &data, HashMap::new()).unwrap();
registry.save_model("model", &data, HashMap::new()).unwrap();
registry
.set_stage("model", 1, ModelStage::Production)
.unwrap();
let prod = registry.get_production("model").unwrap();
assert_eq!(prod.unwrap().version, 1);
registry
.set_stage("model", 1, ModelStage::Archived)
.unwrap();
registry
.set_stage("model", 2, ModelStage::Production)
.unwrap();
let prod = registry.get_production("model").unwrap();
assert_eq!(prod.unwrap().version, 2);
}
}