use super::backends::InMemoryStorage;
use super::{FileSystemStorage, StorageBackend, sha256_hash};
use crate::error::{Result, TuneError};
use crate::registry::model::{ModelStatus, RegisteredModel};
use arc_swap::ArcSwap;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use uuid::Uuid;
pub struct ModelRegistry {
storage: parking_lot::Mutex<Box<dyn StorageBackend>>,
models: ArcSwap<HashMap<Uuid, RegisteredModel>>,
name_index: ArcSwap<HashMap<String, Uuid>>,
write_lock: parking_lot::Mutex<()>,
}
impl ModelRegistry {
pub fn in_memory() -> Self {
Self {
storage: parking_lot::Mutex::new(Box::new(InMemoryStorage::new())),
models: ArcSwap::new(Arc::new(HashMap::new())),
name_index: ArcSwap::new(Arc::new(HashMap::new())),
write_lock: parking_lot::Mutex::new(()),
}
}
pub fn with_path(path: impl Into<PathBuf>) -> Result<Self> {
let storage = FileSystemStorage::new(path)?;
Ok(Self {
storage: parking_lot::Mutex::new(Box::new(storage)),
models: ArcSwap::new(Arc::new(HashMap::new())),
name_index: ArcSwap::new(Arc::new(HashMap::new())),
write_lock: parking_lot::Mutex::new(()),
})
}
pub fn with_storage(storage: Box<dyn StorageBackend>) -> Self {
Self {
storage: parking_lot::Mutex::new(storage),
models: ArcSwap::new(Arc::new(HashMap::new())),
name_index: ArcSwap::new(Arc::new(HashMap::new())),
write_lock: parking_lot::Mutex::new(()),
}
}
pub fn register(&self, mut model: RegisteredModel, weights: &[u8]) -> Result<Uuid> {
model.validate().map_err(TuneError::Validation)?;
let key = model.full_name();
let _wg = self.write_lock.lock();
if self.name_index.load().contains_key(&key) {
return Err(TuneError::DuplicateModel {
name: model.name.clone(),
version: model.version.clone(),
});
}
let weights_path = self.storage.lock().save(&model, weights)?;
let weights_size = weights.len();
let weights_hash = sha256_hash(weights);
model = model.with_weights(weights_path, weights_size, weights_hash);
let id = model.id;
let current_models = self.models.load();
let mut new_models = (**current_models).clone();
new_models.insert(id, model);
self.models.store(Arc::new(new_models));
let current_index = self.name_index.load();
let mut new_index = (**current_index).clone();
new_index.insert(key, id);
self.name_index.store(Arc::new(new_index));
Ok(id)
}
pub fn register_metadata(&self, model: RegisteredModel) -> Result<Uuid> {
model.validate().map_err(TuneError::Validation)?;
let key = model.full_name();
let _wg = self.write_lock.lock();
if self.name_index.load().contains_key(&key) {
return Err(TuneError::DuplicateModel {
name: model.name.clone(),
version: model.version.clone(),
});
}
let id = model.id;
let current_models = self.models.load();
let mut new_models = (**current_models).clone();
new_models.insert(id, model);
self.models.store(Arc::new(new_models));
let current_index = self.name_index.load();
let mut new_index = (**current_index).clone();
new_index.insert(key, id);
self.name_index.store(Arc::new(new_index));
Ok(id)
}
pub fn update_status(&self, id: &Uuid, status: ModelStatus) -> Result<()> {
let _wg = self.write_lock.lock();
let current = self.models.load();
let mut new_models = (**current).clone();
let model = new_models
.get_mut(id)
.ok_or_else(|| TuneError::ModelNotFound {
name: id.to_string(),
version: "".to_string(),
})?;
model.status = status;
model.updated_at = chrono::Utc::now();
self.models.store(Arc::new(new_models));
Ok(())
}
pub fn promote_to_production(&self, id: &Uuid) -> Result<()> {
let _wg = self.write_lock.lock();
let current = self.models.load();
let mut new_models = (**current).clone();
let name = new_models
.get(id)
.ok_or_else(|| TuneError::ModelNotFound {
name: id.to_string(),
version: "".to_string(),
})?
.name
.clone();
let now = chrono::Utc::now();
let current_production: Vec<Uuid> = new_models
.values()
.filter(|m| m.name == name && m.status == ModelStatus::Production)
.map(|m| m.id)
.collect();
for prod_id in current_production {
if let Some(m) = new_models.get_mut(&prod_id) {
m.status = ModelStatus::Staged;
m.updated_at = now;
}
}
if let Some(m) = new_models.get_mut(id) {
m.status = ModelStatus::Production;
m.updated_at = now;
}
self.models.store(Arc::new(new_models));
Ok(())
}
pub fn delete(&self, id: &Uuid) -> Result<()> {
let _wg = self.write_lock.lock();
let current_models = self.models.load();
let mut new_models = (**current_models).clone();
let model = new_models
.remove(id)
.ok_or_else(|| TuneError::ModelNotFound {
name: id.to_string(),
version: "".to_string(),
})?;
if let Some(path) = &model.weights_path {
self.storage.lock().delete(path)?;
}
let key = model.full_name();
let current_index = self.name_index.load();
let mut new_index = (**current_index).clone();
new_index.remove(&key);
self.models.store(Arc::new(new_models));
self.name_index.store(Arc::new(new_index));
Ok(())
}
pub fn get_by_id(&self, id: &Uuid) -> Option<RegisteredModel> {
self.models.load().get(id).cloned()
}
pub fn get(&self, name: &str, version: &str) -> Option<RegisteredModel> {
let key = format!("{name}:{version}");
let index = self.name_index.load();
let models = self.models.load();
index.get(&key).and_then(|id| models.get(id).cloned())
}
pub fn get_latest(&self, name: &str) -> Option<RegisteredModel> {
self.list_versions(name).into_iter().max_by(|a, b| {
let a_ver = a.version_tuple().unwrap_or((0, 0, 0));
let b_ver = b.version_tuple().unwrap_or((0, 0, 0));
a_ver.cmp(&b_ver)
})
}
pub fn get_production(&self, name: &str) -> Option<RegisteredModel> {
self.list_versions(name)
.into_iter()
.find(|m| m.status == ModelStatus::Production)
}
pub fn list_versions(&self, name: &str) -> Vec<RegisteredModel> {
self.models
.load()
.values()
.filter(|m| m.name == name)
.cloned()
.collect()
}
pub fn list_all(&self) -> Vec<RegisteredModel> {
self.models.load().values().cloned().collect()
}
pub fn list_by_status(&self, status: ModelStatus) -> Vec<RegisteredModel> {
self.models
.load()
.values()
.filter(|m| m.status == status)
.cloned()
.collect()
}
pub fn list_names(&self) -> Vec<String> {
let snap = self.models.load();
let mut names: Vec<String> = snap.values().map(|m| m.name.clone()).collect();
names.sort();
names.dedup();
names
}
pub fn load_weights(&self, model: &RegisteredModel) -> Result<Vec<u8>> {
let path = model
.weights_path
.as_ref()
.ok_or_else(|| TuneError::Storage("No weights path".to_string()))?;
self.storage.lock().load(path)
}
pub fn load_weights_verified(&self, model: &RegisteredModel) -> Result<Vec<u8>> {
let weights = self.load_weights(model)?;
if let Some(ref expected_hash) = model.weights_hash {
let actual_hash = sha256_hash(&weights);
if &actual_hash != expected_hash {
return Err(TuneError::WeightIntegrityError {
expected: expected_hash.clone(),
actual: actual_hash,
});
}
}
Ok(weights)
}
pub fn len(&self) -> usize {
self.models.load().len()
}
pub fn is_empty(&self) -> bool {
self.models.load().is_empty()
}
}