use crate::core::error::{Error, Result};
use crate::ml::serving::serialization::{
BinaryModelSerializer, JsonModelSerializer, ModelSerializationFactory, SerializableModel,
TomlModelSerializer, YamlModelSerializer,
};
use crate::ml::serving::{ModelMetadata, ModelSerializer, ModelServing, SerializationFormat};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
pub trait ModelRegistry {
fn register_model(&mut self, model: Box<dyn ModelServing>) -> Result<()>;
fn load_model(&self, name: &str, version: &str) -> Result<Arc<dyn ModelServing>>;
fn list_models(&self) -> Result<Vec<ModelRegistryEntry>>;
fn list_versions(&self, name: &str) -> Result<Vec<String>>;
fn get_metadata(&self, name: &str, version: &str) -> Result<ModelMetadata>;
fn delete_model(&mut self, name: &str, version: &str) -> Result<()>;
fn update_metadata(&mut self, name: &str, version: &str, metadata: ModelMetadata)
-> Result<()>;
fn exists(&self, name: &str, version: &str) -> bool;
fn get_latest_version(&self, name: &str) -> Result<String>;
fn set_default_version(&mut self, name: &str, version: &str) -> Result<()>;
fn get_default_version(&self, name: &str) -> Result<String>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRegistryEntry {
pub name: String,
pub versions: Vec<String>,
pub default_version: Option<String>,
pub latest_version: Option<String>,
pub description: String,
pub tags: Vec<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
pub struct InMemoryModelRegistry {
models: HashMap<String, HashMap<String, Arc<dyn ModelServing>>>,
entries: HashMap<String, ModelRegistryEntry>,
default_versions: HashMap<String, String>,
}
impl InMemoryModelRegistry {
pub fn new() -> Self {
Self {
models: HashMap::new(),
entries: HashMap::new(),
default_versions: HashMap::new(),
}
}
fn get_model_key(name: &str, version: &str) -> String {
format!("{}:{}", name, version)
}
fn update_entry(&mut self, name: &str, version: &str, metadata: &ModelMetadata) {
let entry = self
.entries
.entry(name.to_string())
.or_insert_with(|| ModelRegistryEntry {
name: name.to_string(),
versions: Vec::new(),
default_version: None,
latest_version: None,
description: metadata.description.clone(),
tags: Vec::new(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
});
if !entry.versions.contains(&version.to_string()) {
entry.versions.push(version.to_string());
entry.versions.sort();
}
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.is_none() {
entry.default_version = Some(version.to_string());
self.default_versions
.insert(name.to_string(), version.to_string());
}
entry.updated_at = chrono::Utc::now();
}
}
impl Default for InMemoryModelRegistry {
fn default() -> Self {
Self::new()
}
}
impl ModelRegistry for InMemoryModelRegistry {
fn register_model(&mut self, model: Box<dyn ModelServing>) -> Result<()> {
let metadata = model.get_metadata().clone(); let name = metadata.name.clone();
let version = metadata.version.clone();
if self.exists(&name, &version) {
return Err(Error::InvalidOperation(format!(
"Model '{}' version '{}' already exists",
name, version
)));
}
let arc_model: Arc<dyn ModelServing> = Arc::from(model);
self.models
.entry(name.clone())
.or_insert_with(HashMap::new)
.insert(version.clone(), arc_model);
self.update_entry(&name, &version, &metadata);
Ok(())
}
fn load_model(&self, name: &str, version: &str) -> Result<Arc<dyn ModelServing>> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
self.models
.get(name)
.and_then(|versions| versions.get(&resolved_version))
.map(|arc_model| Arc::clone(arc_model))
.ok_or_else(|| {
Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, resolved_version
))
})
}
fn list_models(&self) -> Result<Vec<ModelRegistryEntry>> {
Ok(self.entries.values().cloned().collect())
}
fn list_versions(&self, name: &str) -> Result<Vec<String>> {
self.entries
.get(name)
.map(|entry| entry.versions.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn get_metadata(&self, name: &str, version: &str) -> Result<ModelMetadata> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
self.models
.get(name)
.and_then(|versions| versions.get(&resolved_version))
.map(|model| model.get_metadata().clone())
.ok_or_else(|| {
Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, resolved_version
))
})
}
fn delete_model(&mut self, name: &str, version: &str) -> Result<()> {
if let Some(versions) = self.models.get_mut(name) {
if versions.remove(version).is_some() {
if let Some(entry) = self.entries.get_mut(name) {
entry.versions.retain(|v| v != version);
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.as_ref() == Some(&version.to_string()) {
entry.default_version = entry.versions.first().cloned();
if let Some(new_default) = &entry.default_version {
self.default_versions
.insert(name.to_string(), new_default.clone());
} else {
self.default_versions.remove(name);
}
}
if entry.versions.is_empty() {
self.entries.remove(name);
self.models.remove(name);
self.default_versions.remove(name);
}
}
Ok(())
} else {
Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)))
}
} else {
Err(Error::KeyNotFound(format!("Model '{}' not found", name)))
}
}
fn update_metadata(
&mut self,
name: &str,
version: &str,
new_metadata: ModelMetadata,
) -> Result<()> {
let existing_arc = self
.models
.get(name)
.and_then(|versions| versions.get(version))
.cloned()
.ok_or_else(|| {
Error::KeyNotFound(format!("Model '{}' version '{}' not found", name, version))
})?;
use crate::ml::serving::serialization::GenericServingModel;
let mut serializable = SerializableModel {
metadata: existing_arc.get_metadata().clone(),
parameters: std::collections::HashMap::new(),
model_data: serde_json::json!({}),
preprocessing: None,
config: existing_arc.info().configuration,
};
serializable.metadata = new_metadata.clone();
let rebuilt: Arc<dyn ModelServing> =
Arc::new(GenericServingModel::from_serializable(serializable)?);
if let Some(versions) = self.models.get_mut(name) {
versions.insert(version.to_string(), rebuilt);
}
self.update_entry(name, version, &new_metadata);
Ok(())
}
fn exists(&self, name: &str, version: &str) -> bool {
self.models
.get(name)
.map(|versions| versions.contains_key(version))
.unwrap_or(false)
}
fn get_latest_version(&self, name: &str) -> Result<String> {
self.entries
.get(name)
.and_then(|entry| entry.latest_version.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn set_default_version(&mut self, name: &str, version: &str) -> Result<()> {
if !self.exists(name, version) {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
self.default_versions
.insert(name.to_string(), version.to_string());
if let Some(entry) = self.entries.get_mut(name) {
entry.default_version = Some(version.to_string());
entry.updated_at = chrono::Utc::now();
}
Ok(())
}
fn get_default_version(&self, name: &str) -> Result<String> {
self.default_versions
.get(name)
.cloned()
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
}
pub struct FileSystemModelRegistry {
base_path: PathBuf,
registry_file: PathBuf,
entries: HashMap<String, ModelRegistryEntry>,
default_format: SerializationFormat,
}
impl FileSystemModelRegistry {
pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
let registry_file = base_path.join("registry.json");
if !base_path.exists() {
fs::create_dir_all(&base_path)?;
}
let mut registry = Self {
base_path,
registry_file,
entries: HashMap::new(),
default_format: SerializationFormat::Json,
};
registry.load_registry()?;
Ok(registry)
}
pub fn set_default_format(&mut self, format: SerializationFormat) {
self.default_format = format;
}
fn get_model_dir(&self, name: &str) -> PathBuf {
self.base_path.join(name)
}
fn get_model_file(&self, name: &str, version: &str) -> PathBuf {
self.get_model_dir(name)
.join(format!("{}.{}", version, self.default_format.extension()))
}
fn load_registry(&mut self) -> Result<()> {
if self.registry_file.exists() {
let registry_data = fs::read_to_string(&self.registry_file)?;
self.entries = serde_json::from_str(®istry_data)?;
}
Ok(())
}
fn save_registry(&self) -> Result<()> {
let registry_data = serde_json::to_string_pretty(&self.entries)?;
fs::write(&self.registry_file, registry_data)?;
Ok(())
}
fn update_entry(&mut self, name: &str, version: &str, metadata: &ModelMetadata) -> Result<()> {
let entry = self
.entries
.entry(name.to_string())
.or_insert_with(|| ModelRegistryEntry {
name: name.to_string(),
versions: Vec::new(),
default_version: None,
latest_version: None,
description: metadata.description.clone(),
tags: Vec::new(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
});
if !entry.versions.contains(&version.to_string()) {
entry.versions.push(version.to_string());
entry.versions.sort();
}
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.is_none() {
entry.default_version = Some(version.to_string());
}
entry.updated_at = chrono::Utc::now();
self.save_registry()
}
fn model_to_serializable(&self, model: &dyn ModelServing) -> Result<SerializableModel> {
let metadata = model.get_metadata().clone();
let info = model.info();
Ok(SerializableModel {
metadata,
parameters: HashMap::new(), model_data: serde_json::json!({}), preprocessing: None,
config: info.configuration,
})
}
}
impl ModelRegistry for FileSystemModelRegistry {
fn register_model(&mut self, model: Box<dyn ModelServing>) -> Result<()> {
let metadata = model.get_metadata();
let name = &metadata.name;
let version = &metadata.version;
if self.exists(name, version) {
return Err(Error::InvalidOperation(format!(
"Model '{}' version '{}' already exists",
name, version
)));
}
let model_dir = self.get_model_dir(name);
if !model_dir.exists() {
fs::create_dir_all(&model_dir)?;
}
let serializable_model = self.model_to_serializable(model.as_ref())?;
let model_file = self.get_model_file(name, version);
ModelSerializationFactory::save_model(
&serializable_model,
&model_file,
self.default_format,
)?;
self.update_entry(name, version, metadata)?;
Ok(())
}
fn load_model(&self, name: &str, version: &str) -> Result<Arc<dyn ModelServing>> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
let model_file = self.get_model_file(name, &resolved_version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model file not found: {:?}",
model_file
)));
}
let boxed = ModelSerializationFactory::auto_detect_and_load(&model_file)?;
Ok(Arc::from(boxed))
}
fn list_models(&self) -> Result<Vec<ModelRegistryEntry>> {
Ok(self.entries.values().cloned().collect())
}
fn list_versions(&self, name: &str) -> Result<Vec<String>> {
self.entries
.get(name)
.map(|entry| entry.versions.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn get_metadata(&self, name: &str, version: &str) -> Result<ModelMetadata> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
let model_file = self.get_model_file(name, &resolved_version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model file not found: {:?}",
model_file
)));
}
let format = SerializationFormat::from_extension(
model_file
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| Error::InvalidInput("File has no extension".to_string()))?,
)
.ok_or_else(|| Error::InvalidInput("Unsupported file extension".to_string()))?;
let serializable_model = match format {
SerializationFormat::Json => {
let serializer = JsonModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Yaml => {
let serializer = YamlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Toml => {
let serializer = TomlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Binary => {
let serializer = BinaryModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
};
Ok(serializable_model.metadata)
}
fn delete_model(&mut self, name: &str, version: &str) -> Result<()> {
let model_file = self.get_model_file(name, version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
fs::remove_file(&model_file)?;
if let Some(entry) = self.entries.get_mut(name) {
entry.versions.retain(|v| v != version);
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.as_ref() == Some(&version.to_string()) {
entry.default_version = entry.versions.first().cloned();
}
if entry.versions.is_empty() {
self.entries.remove(name);
let model_dir = self.get_model_dir(name);
if model_dir.exists() && model_dir.read_dir()?.next().is_none() {
fs::remove_dir(&model_dir)?;
}
}
}
self.save_registry()?;
Ok(())
}
fn update_metadata(
&mut self,
name: &str,
version: &str,
new_metadata: ModelMetadata,
) -> Result<()> {
let model_file = self.get_model_file(name, version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
let format = SerializationFormat::from_extension(
model_file
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| Error::InvalidInput("File has no extension".to_string()))?,
)
.ok_or_else(|| Error::InvalidInput("Unsupported file extension".to_string()))?;
let mut serializable_model = match format {
SerializationFormat::Json => {
let serializer = JsonModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Yaml => {
let serializer = YamlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Toml => {
let serializer = TomlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Binary => {
let serializer = BinaryModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
};
serializable_model.metadata = new_metadata.clone();
ModelSerializationFactory::save_model(&serializable_model, &model_file, format)?;
self.update_entry(name, version, &new_metadata)?;
Ok(())
}
fn exists(&self, name: &str, version: &str) -> bool {
self.get_model_file(name, version).exists()
}
fn get_latest_version(&self, name: &str) -> Result<String> {
self.entries
.get(name)
.and_then(|entry| entry.latest_version.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn set_default_version(&mut self, name: &str, version: &str) -> Result<()> {
if !self.exists(name, version) {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
if let Some(entry) = self.entries.get_mut(name) {
entry.default_version = Some(version.to_string());
entry.updated_at = chrono::Utc::now();
}
self.save_registry()?;
Ok(())
}
fn get_default_version(&self, name: &str) -> Result<String> {
self.entries
.get(name)
.and_then(|entry| entry.default_version.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_in_memory_registry() {
let registry = InMemoryModelRegistry::new();
assert!(registry
.list_models()
.expect("operation should succeed")
.is_empty());
assert!(!registry.exists("test_model", "1.0.0"));
}
#[test]
fn test_filesystem_registry_creation() {
let temp_dir = TempDir::new().expect("operation should succeed");
let registry =
FileSystemModelRegistry::new(temp_dir.path()).expect("operation should succeed");
assert!(temp_dir.path().exists());
assert!(registry.registry_file.exists() || registry.entries.is_empty());
}
#[test]
fn test_model_registry_entry() {
let entry = ModelRegistryEntry {
name: "test_model".to_string(),
versions: vec!["1.0.0".to_string(), "1.1.0".to_string()],
default_version: Some("1.0.0".to_string()),
latest_version: Some("1.1.0".to_string()),
description: "Test model".to_string(),
tags: vec!["test".to_string()],
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
};
assert_eq!(entry.name, "test_model");
assert_eq!(entry.versions.len(), 2);
assert_eq!(entry.latest_version, Some("1.1.0".to_string()));
}
#[test]
fn test_in_memory_registry_load_model() {
use crate::ml::serving::serialization::{GenericServingModel, SerializableModel};
use crate::ml::serving::{ModelMetadata, ModelServing};
use std::collections::HashMap;
let metadata = ModelMetadata {
name: "test_model".to_string(),
version: "1.0.0".to_string(),
model_type: "linear_regression".to_string(),
feature_names: vec!["x1".to_string(), "x2".to_string()],
target_name: Some("y".to_string()),
description: "Unit-test model for load_model".to_string(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
metrics: HashMap::new(),
metadata: HashMap::new(),
};
let serializable = SerializableModel {
metadata,
parameters: HashMap::new(),
model_data: serde_json::json!({}),
preprocessing: None,
config: HashMap::new(),
};
let generic_model = GenericServingModel::from_serializable(serializable)
.expect("model creation must succeed");
let boxed: Box<dyn ModelServing> = Box::new(generic_model);
let mut registry = InMemoryModelRegistry::new();
registry
.register_model(boxed)
.expect("register_model must succeed");
let loaded = registry
.load_model("test_model", "1.0.0")
.expect("load_model must succeed for a registered model");
let returned_meta = loaded.get_metadata();
assert_eq!(returned_meta.name, "test_model");
assert_eq!(returned_meta.version, "1.0.0");
assert_eq!(returned_meta.model_type, "linear_regression");
assert_eq!(returned_meta.description, "Unit-test model for load_model");
assert!(registry.exists("test_model", "1.0.0"));
let meta_from_registry = registry
.get_metadata("test_model", "1.0.0")
.expect("get_metadata must succeed");
assert_eq!(meta_from_registry.name, "test_model");
assert_eq!(meta_from_registry.version, "1.0.0");
}
}