use crate::core::error::{Error, Result};
use crate::ml::serving::serialization::{
BinaryModelSerializer, JsonModelSerializer, ModelSerializationFactory, SerializableModel,
TomlModelSerializer, YamlModelSerializer,
};
use crate::ml::serving::{ModelMetadata, ModelSerializer, ModelServing, SerializationFormat};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
pub trait ModelRegistry {
fn register_model(&mut self, model: Box<dyn ModelServing>) -> Result<()>;
fn load_model(&self, name: &str, version: &str) -> Result<Box<dyn ModelServing>>;
fn list_models(&self) -> Result<Vec<ModelRegistryEntry>>;
fn list_versions(&self, name: &str) -> Result<Vec<String>>;
fn get_metadata(&self, name: &str, version: &str) -> Result<ModelMetadata>;
fn delete_model(&mut self, name: &str, version: &str) -> Result<()>;
fn update_metadata(&mut self, name: &str, version: &str, metadata: ModelMetadata)
-> Result<()>;
fn exists(&self, name: &str, version: &str) -> bool;
fn get_latest_version(&self, name: &str) -> Result<String>;
fn set_default_version(&mut self, name: &str, version: &str) -> Result<()>;
fn get_default_version(&self, name: &str) -> Result<String>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRegistryEntry {
pub name: String,
pub versions: Vec<String>,
pub default_version: Option<String>,
pub latest_version: Option<String>,
pub description: String,
pub tags: Vec<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
pub struct InMemoryModelRegistry {
models: HashMap<String, HashMap<String, Box<dyn ModelServing>>>,
entries: HashMap<String, ModelRegistryEntry>,
default_versions: HashMap<String, String>,
}
impl InMemoryModelRegistry {
pub fn new() -> Self {
Self {
models: HashMap::new(),
entries: HashMap::new(),
default_versions: HashMap::new(),
}
}
fn get_model_key(name: &str, version: &str) -> String {
format!("{}:{}", name, version)
}
fn update_entry(&mut self, name: &str, version: &str, metadata: &ModelMetadata) {
let entry = self
.entries
.entry(name.to_string())
.or_insert_with(|| ModelRegistryEntry {
name: name.to_string(),
versions: Vec::new(),
default_version: None,
latest_version: None,
description: metadata.description.clone(),
tags: Vec::new(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
});
if !entry.versions.contains(&version.to_string()) {
entry.versions.push(version.to_string());
entry.versions.sort();
}
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.is_none() {
entry.default_version = Some(version.to_string());
self.default_versions
.insert(name.to_string(), version.to_string());
}
entry.updated_at = chrono::Utc::now();
}
}
impl Default for InMemoryModelRegistry {
fn default() -> Self {
Self::new()
}
}
impl ModelRegistry for InMemoryModelRegistry {
fn register_model(&mut self, model: Box<dyn ModelServing>) -> Result<()> {
let metadata = model.get_metadata().clone(); let name = metadata.name.clone();
let version = metadata.version.clone();
if self.exists(&name, &version) {
return Err(Error::InvalidOperation(format!(
"Model '{}' version '{}' already exists",
name, version
)));
}
self.models
.entry(name.clone())
.or_insert_with(HashMap::new)
.insert(version.clone(), model);
self.update_entry(&name, &version, &metadata);
Ok(())
}
fn load_model(&self, name: &str, version: &str) -> Result<Box<dyn ModelServing>> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
self.models
.get(name)
.and_then(|versions| versions.get(&resolved_version))
.ok_or_else(|| {
Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, resolved_version
))
})
.map(|_| {
return Err(Error::NotImplemented(
"Loading models from in-memory registry requires cloning support".to_string(),
));
})?
}
fn list_models(&self) -> Result<Vec<ModelRegistryEntry>> {
Ok(self.entries.values().cloned().collect())
}
fn list_versions(&self, name: &str) -> Result<Vec<String>> {
self.entries
.get(name)
.map(|entry| entry.versions.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn get_metadata(&self, name: &str, version: &str) -> Result<ModelMetadata> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
self.models
.get(name)
.and_then(|versions| versions.get(&resolved_version))
.map(|model| model.get_metadata().clone())
.ok_or_else(|| {
Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, resolved_version
))
})
}
fn delete_model(&mut self, name: &str, version: &str) -> Result<()> {
if let Some(versions) = self.models.get_mut(name) {
if versions.remove(version).is_some() {
if let Some(entry) = self.entries.get_mut(name) {
entry.versions.retain(|v| v != version);
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.as_ref() == Some(&version.to_string()) {
entry.default_version = entry.versions.first().cloned();
if let Some(new_default) = &entry.default_version {
self.default_versions
.insert(name.to_string(), new_default.clone());
} else {
self.default_versions.remove(name);
}
}
if entry.versions.is_empty() {
self.entries.remove(name);
self.models.remove(name);
self.default_versions.remove(name);
}
}
Ok(())
} else {
Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)))
}
} else {
Err(Error::KeyNotFound(format!("Model '{}' not found", name)))
}
}
fn update_metadata(
&mut self,
name: &str,
version: &str,
metadata: ModelMetadata,
) -> Result<()> {
Err(Error::NotImplemented(
"Updating metadata for in-memory models is not supported".to_string(),
))
}
fn exists(&self, name: &str, version: &str) -> bool {
self.models
.get(name)
.map(|versions| versions.contains_key(version))
.unwrap_or(false)
}
fn get_latest_version(&self, name: &str) -> Result<String> {
self.entries
.get(name)
.and_then(|entry| entry.latest_version.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn set_default_version(&mut self, name: &str, version: &str) -> Result<()> {
if !self.exists(name, version) {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
self.default_versions
.insert(name.to_string(), version.to_string());
if let Some(entry) = self.entries.get_mut(name) {
entry.default_version = Some(version.to_string());
entry.updated_at = chrono::Utc::now();
}
Ok(())
}
fn get_default_version(&self, name: &str) -> Result<String> {
self.default_versions
.get(name)
.cloned()
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
}
pub struct FileSystemModelRegistry {
base_path: PathBuf,
registry_file: PathBuf,
entries: HashMap<String, ModelRegistryEntry>,
default_format: SerializationFormat,
}
impl FileSystemModelRegistry {
pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
let registry_file = base_path.join("registry.json");
if !base_path.exists() {
fs::create_dir_all(&base_path)?;
}
let mut registry = Self {
base_path,
registry_file,
entries: HashMap::new(),
default_format: SerializationFormat::Json,
};
registry.load_registry()?;
Ok(registry)
}
pub fn set_default_format(&mut self, format: SerializationFormat) {
self.default_format = format;
}
fn get_model_dir(&self, name: &str) -> PathBuf {
self.base_path.join(name)
}
fn get_model_file(&self, name: &str, version: &str) -> PathBuf {
self.get_model_dir(name)
.join(format!("{}.{}", version, self.default_format.extension()))
}
fn load_registry(&mut self) -> Result<()> {
if self.registry_file.exists() {
let registry_data = fs::read_to_string(&self.registry_file)?;
self.entries = serde_json::from_str(®istry_data)?;
}
Ok(())
}
fn save_registry(&self) -> Result<()> {
let registry_data = serde_json::to_string_pretty(&self.entries)?;
fs::write(&self.registry_file, registry_data)?;
Ok(())
}
fn update_entry(&mut self, name: &str, version: &str, metadata: &ModelMetadata) -> Result<()> {
let entry = self
.entries
.entry(name.to_string())
.or_insert_with(|| ModelRegistryEntry {
name: name.to_string(),
versions: Vec::new(),
default_version: None,
latest_version: None,
description: metadata.description.clone(),
tags: Vec::new(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
});
if !entry.versions.contains(&version.to_string()) {
entry.versions.push(version.to_string());
entry.versions.sort();
}
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.is_none() {
entry.default_version = Some(version.to_string());
}
entry.updated_at = chrono::Utc::now();
self.save_registry()
}
fn model_to_serializable(&self, model: &dyn ModelServing) -> Result<SerializableModel> {
let metadata = model.get_metadata().clone();
let info = model.info();
Ok(SerializableModel {
metadata,
parameters: HashMap::new(), model_data: serde_json::json!({}), preprocessing: None,
config: info.configuration,
})
}
}
impl ModelRegistry for FileSystemModelRegistry {
fn register_model(&mut self, model: Box<dyn ModelServing>) -> Result<()> {
let metadata = model.get_metadata();
let name = &metadata.name;
let version = &metadata.version;
if self.exists(name, version) {
return Err(Error::InvalidOperation(format!(
"Model '{}' version '{}' already exists",
name, version
)));
}
let model_dir = self.get_model_dir(name);
if !model_dir.exists() {
fs::create_dir_all(&model_dir)?;
}
let serializable_model = self.model_to_serializable(model.as_ref())?;
let model_file = self.get_model_file(name, version);
ModelSerializationFactory::save_model(
&serializable_model,
&model_file,
self.default_format,
)?;
self.update_entry(name, version, metadata)?;
Ok(())
}
fn load_model(&self, name: &str, version: &str) -> Result<Box<dyn ModelServing>> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
let model_file = self.get_model_file(name, &resolved_version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model file not found: {:?}",
model_file
)));
}
ModelSerializationFactory::auto_detect_and_load(&model_file)
}
fn list_models(&self) -> Result<Vec<ModelRegistryEntry>> {
Ok(self.entries.values().cloned().collect())
}
fn list_versions(&self, name: &str) -> Result<Vec<String>> {
self.entries
.get(name)
.map(|entry| entry.versions.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn get_metadata(&self, name: &str, version: &str) -> Result<ModelMetadata> {
let resolved_version = if version == "latest" {
self.get_latest_version(name)?
} else if version == "default" {
self.get_default_version(name)?
} else {
version.to_string()
};
let model_file = self.get_model_file(name, &resolved_version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model file not found: {:?}",
model_file
)));
}
let format = SerializationFormat::from_extension(
model_file
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| Error::InvalidInput("File has no extension".to_string()))?,
)
.ok_or_else(|| Error::InvalidInput("Unsupported file extension".to_string()))?;
let serializable_model = match format {
SerializationFormat::Json => {
let serializer = JsonModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Yaml => {
let serializer = YamlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Toml => {
let serializer = TomlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Binary => {
let serializer = BinaryModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
};
Ok(serializable_model.metadata)
}
fn delete_model(&mut self, name: &str, version: &str) -> Result<()> {
let model_file = self.get_model_file(name, version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
fs::remove_file(&model_file)?;
if let Some(entry) = self.entries.get_mut(name) {
entry.versions.retain(|v| v != version);
entry.latest_version = entry.versions.last().cloned();
if entry.default_version.as_ref() == Some(&version.to_string()) {
entry.default_version = entry.versions.first().cloned();
}
if entry.versions.is_empty() {
self.entries.remove(name);
let model_dir = self.get_model_dir(name);
if model_dir.exists() && model_dir.read_dir()?.next().is_none() {
fs::remove_dir(&model_dir)?;
}
}
}
self.save_registry()?;
Ok(())
}
fn update_metadata(
&mut self,
name: &str,
version: &str,
new_metadata: ModelMetadata,
) -> Result<()> {
let model_file = self.get_model_file(name, version);
if !model_file.exists() {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
let format = SerializationFormat::from_extension(
model_file
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| Error::InvalidInput("File has no extension".to_string()))?,
)
.ok_or_else(|| Error::InvalidInput("Unsupported file extension".to_string()))?;
let mut serializable_model = match format {
SerializationFormat::Json => {
let serializer = JsonModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Yaml => {
let serializer = YamlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Toml => {
let serializer = TomlModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
SerializationFormat::Binary => {
let serializer = BinaryModelSerializer;
serializer.deserialize(&fs::read(&model_file)?)?
}
};
serializable_model.metadata = new_metadata.clone();
ModelSerializationFactory::save_model(&serializable_model, &model_file, format)?;
self.update_entry(name, version, &new_metadata)?;
Ok(())
}
fn exists(&self, name: &str, version: &str) -> bool {
self.get_model_file(name, version).exists()
}
fn get_latest_version(&self, name: &str) -> Result<String> {
self.entries
.get(name)
.and_then(|entry| entry.latest_version.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
fn set_default_version(&mut self, name: &str, version: &str) -> Result<()> {
if !self.exists(name, version) {
return Err(Error::KeyNotFound(format!(
"Model '{}' version '{}' not found",
name, version
)));
}
if let Some(entry) = self.entries.get_mut(name) {
entry.default_version = Some(version.to_string());
entry.updated_at = chrono::Utc::now();
}
self.save_registry()?;
Ok(())
}
fn get_default_version(&self, name: &str) -> Result<String> {
self.entries
.get(name)
.and_then(|entry| entry.default_version.clone())
.ok_or_else(|| Error::KeyNotFound(format!("Model '{}' not found", name)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_in_memory_registry() {
let registry = InMemoryModelRegistry::new();
assert!(registry
.list_models()
.expect("operation should succeed")
.is_empty());
assert!(!registry.exists("test_model", "1.0.0"));
}
#[test]
fn test_filesystem_registry_creation() {
let temp_dir = TempDir::new().expect("operation should succeed");
let registry =
FileSystemModelRegistry::new(temp_dir.path()).expect("operation should succeed");
assert!(temp_dir.path().exists());
assert!(registry.registry_file.exists() || registry.entries.is_empty());
}
#[test]
fn test_model_registry_entry() {
let entry = ModelRegistryEntry {
name: "test_model".to_string(),
versions: vec!["1.0.0".to_string(), "1.1.0".to_string()],
default_version: Some("1.0.0".to_string()),
latest_version: Some("1.1.0".to_string()),
description: "Test model".to_string(),
tags: vec!["test".to_string()],
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
};
assert_eq!(entry.name, "test_model");
assert_eq!(entry.versions.len(), 2);
assert_eq!(entry.latest_version, Some("1.1.0".to_string()));
}
}