use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use arc_swap::ArcSwap;
use crate::{
cache::ModelCache,
error::{RealizarError, Result},
layers::Model,
tokenizer::BPETokenizer,
};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub description: String,
pub format: String,
pub loaded: bool,
}
#[derive(Clone)]
struct ModelEntry {
model: Arc<Model>,
tokenizer: Arc<BPETokenizer>,
info: ModelInfo,
}
type ModelsMap = HashMap<String, ModelEntry>;
type ModelTuple = (Arc<Model>, Arc<BPETokenizer>);
pub struct ModelRegistry {
models: ArcSwap<ModelsMap>,
write_lock: Mutex<()>,
#[allow(dead_code)]
cache: Arc<ModelCache>,
}
impl ModelRegistry {
#[must_use]
pub fn new(cache_capacity: usize) -> Self {
Self {
models: ArcSwap::from_pointee(HashMap::new()),
write_lock: Mutex::new(()),
cache: Arc::new(ModelCache::new(cache_capacity)),
}
}
pub fn register(&self, id: &str, model: Model, tokenizer: BPETokenizer) -> Result<()> {
let _guard = self.write_lock.lock().map_err(|_| {
RealizarError::RegistryError("Failed to acquire write lock".to_string())
})?;
let current = self.models.load();
if current.contains_key(id) {
return Err(RealizarError::ModelAlreadyExists(id.to_string()));
}
let entry = ModelEntry {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
info: ModelInfo {
id: id.to_string(),
name: id.to_string(),
description: String::new(),
format: "unknown".to_string(),
loaded: true,
},
};
let mut new_map: ModelsMap = (**current).clone();
new_map.insert(id.to_string(), entry);
self.models.store(Arc::new(new_map));
Ok(())
}
pub fn register_with_info(
&self,
mut info: ModelInfo,
model: Model,
tokenizer: BPETokenizer,
) -> Result<()> {
let _guard = self.write_lock.lock().map_err(|_| {
RealizarError::RegistryError("Failed to acquire write lock".to_string())
})?;
let current = self.models.load();
if current.contains_key(&info.id) {
return Err(RealizarError::ModelAlreadyExists(info.id));
}
info.loaded = true;
let entry = ModelEntry {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
info,
};
let id = entry.info.id.clone();
let mut new_map: ModelsMap = (**current).clone();
new_map.insert(id, entry);
self.models.store(Arc::new(new_map));
Ok(())
}
pub fn get(&self, id: &str) -> Result<ModelTuple> {
let models = self.models.load();
let entry = models
.get(id)
.ok_or_else(|| RealizarError::ModelNotFound(id.to_string()))?;
Ok((Arc::clone(&entry.model), Arc::clone(&entry.tokenizer)))
}
pub fn get_info(&self, id: &str) -> Result<ModelInfo> {
let models = self.models.load();
let entry = models
.get(id)
.ok_or_else(|| RealizarError::ModelNotFound(id.to_string()))?;
Ok(entry.info.clone())
}
#[must_use]
pub fn list(&self) -> Vec<ModelInfo> {
let models = self.models.load();
models.values().map(|entry| entry.info.clone()).collect()
}
pub fn unregister(&self, id: &str) -> Result<()> {
let _guard = self.write_lock.lock().map_err(|_| {
RealizarError::RegistryError("Failed to acquire write lock".to_string())
})?;
let current = self.models.load();
if !current.contains_key(id) {
return Err(RealizarError::ModelNotFound(id.to_string()));
}
let mut new_map: ModelsMap = (**current).clone();
new_map.remove(id);
self.models.store(Arc::new(new_map));
Ok(())
}
pub fn replace(&self, id: &str, model: Model, tokenizer: BPETokenizer) -> Result<()> {
let _guard = self.write_lock.lock().map_err(|_| {
RealizarError::RegistryError("Failed to acquire write lock".to_string())
})?;
let current = self.models.load();
if !current.contains_key(id) {
return Err(RealizarError::ModelNotFound(id.to_string()));
}
let existing_info = current.get(id).map_or_else(
|| ModelInfo {
id: id.to_string(),
name: id.to_string(),
description: String::new(),
format: "unknown".to_string(),
loaded: true,
},
|e| e.info.clone(),
);
let entry = ModelEntry {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
info: existing_info,
};
let mut new_map: ModelsMap = (**current).clone();
new_map.insert(id.to_string(), entry);
self.models.store(Arc::new(new_map));
Ok(())
}
#[must_use]
pub fn contains(&self, id: &str) -> bool {
let models = self.models.load();
models.contains_key(id)
}
#[must_use]
pub fn len(&self) -> usize {
let models = self.models.load();
models.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::ModelConfig;
fn create_test_model() -> (Model, BPETokenizer) {
let config = ModelConfig {
vocab_size: 100,
hidden_dim: 32,
num_heads: 1,
num_layers: 1,
intermediate_dim: 64,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let vocab: Vec<String> = (0..100)
.map(|i| {
if i == 0 {
"<unk>".to_string()
} else {
format!("token{i}")
}
})
.collect();
let tokenizer = BPETokenizer::new(vocab, vec![], "<unk>").unwrap();
(model, tokenizer)
}
#[test]
fn test_registry_creation() {
let registry = ModelRegistry::new(5);
assert_eq!(registry.len(), 0);
assert!(registry.is_empty());
}
#[test]
fn test_register_model() {
let registry = ModelRegistry::new(5);
let (model, tokenizer) = create_test_model();
registry.register("test-model", model, tokenizer).unwrap();
assert_eq!(registry.len(), 1);
assert!(!registry.is_empty());
assert!(registry.contains("test-model"));
}
#[test]
fn test_register_duplicate_error() {
let registry = ModelRegistry::new(5);
let (model1, tokenizer1) = create_test_model();
let (model2, tokenizer2) = create_test_model();
registry.register("test-model", model1, tokenizer1).unwrap();
let result = registry.register("test-model", model2, tokenizer2);
assert!(result.is_err());
assert_eq!(registry.len(), 1);
}
#[test]
fn test_get_model() {
let registry = ModelRegistry::new(5);
let (model, tokenizer) = create_test_model();
registry.register("test-model", model, tokenizer).unwrap();
let (retrieved_model, retrieved_tokenizer) = registry.get("test-model").unwrap();
assert!(Arc::strong_count(&retrieved_model) >= 2); assert!(Arc::strong_count(&retrieved_tokenizer) >= 2);
}
#[test]
fn test_get_nonexistent_model() {
let registry = ModelRegistry::new(5);
let result = registry.get("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_register_with_info() {
let registry = ModelRegistry::new(5);
let (model, tokenizer) = create_test_model();
let info = ModelInfo {
id: "llama-7b".to_string(),
name: "Llama 7B".to_string(),
description: "7B parameter Llama model".to_string(),
format: "GGUF".to_string(),
loaded: false,
};
registry
.register_with_info(info.clone(), model, tokenizer)
.unwrap();
let retrieved_info = registry.get_info("llama-7b").unwrap();
assert_eq!(retrieved_info.id, "llama-7b");
assert_eq!(retrieved_info.name, "Llama 7B");
assert_eq!(retrieved_info.description, "7B parameter Llama model");
assert_eq!(retrieved_info.format, "GGUF");
assert!(retrieved_info.loaded); }
#[test]
fn test_list_models() {
let registry = ModelRegistry::new(5);
let (model1, tokenizer1) = create_test_model();
let (model2, tokenizer2) = create_test_model();
registry.register("model-1", model1, tokenizer1).unwrap();
registry.register("model-2", model2, tokenizer2).unwrap();
let model_list = registry.list();
assert_eq!(model_list.len(), 2);
let ids: Vec<String> = model_list.iter().map(|m| m.id.clone()).collect();
assert!(ids.contains(&"model-1".to_string()));
assert!(ids.contains(&"model-2".to_string()));
}
#[test]
fn test_unregister_model() {
let registry = ModelRegistry::new(5);
let (model, tokenizer) = create_test_model();
registry.register("test-model", model, tokenizer).unwrap();
assert_eq!(registry.len(), 1);
registry.unregister("test-model").unwrap();
assert_eq!(registry.len(), 0);
assert!(!registry.contains("test-model"));
}
#[test]
fn test_unregister_nonexistent() {
let registry = ModelRegistry::new(5);
let result = registry.unregister("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let registry = Arc::new(ModelRegistry::new(10));
let mut handles = vec![];
for i in 0..5 {
let registry_clone = Arc::clone(®istry);
let handle = thread::spawn(move || {
let (model, tokenizer) = create_test_model();
registry_clone
.register(&format!("model-{i}"), model, tokenizer)
.unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(registry.len(), 5);
}
#[test]
fn test_multiple_get_same_model() {
let registry = ModelRegistry::new(5);
let (model, tokenizer) = create_test_model();
registry.register("test-model", model, tokenizer).unwrap();
let (model1, _) = registry.get("test-model").unwrap();
let (model2, _) = registry.get("test-model").unwrap();
assert!(Arc::ptr_eq(&model1, &model2));
}
#[test]
fn test_replace_model() {
let registry = ModelRegistry::new(5);
let (model1, tokenizer1) = create_test_model();
let (model2, tokenizer2) = create_test_model();
registry.register("test-model", model1, tokenizer1).unwrap();
assert_eq!(registry.len(), 1);
registry.replace("test-model", model2, tokenizer2).unwrap();
assert_eq!(registry.len(), 1);
let (retrieved, _) = registry.get("test-model").unwrap();
assert!(Arc::strong_count(&retrieved) >= 2);
}
#[test]
fn test_replace_nonexistent_model() {
let registry = ModelRegistry::new(5);
let (model, tokenizer) = create_test_model();
let result = registry.replace("nonexistent", model, tokenizer);
assert!(result.is_err());
}
#[test]
fn test_contains_method() {
let registry = ModelRegistry::new(5);
let (model, tokenizer) = create_test_model();
assert!(!registry.contains("test-model"));
registry.register("test-model", model, tokenizer).unwrap();
assert!(registry.contains("test-model"));
}
#[test]
fn test_is_empty_method() {
let registry = ModelRegistry::new(5);
assert!(registry.is_empty());
let (model, tokenizer) = create_test_model();
registry.register("test-model", model, tokenizer).unwrap();
assert!(!registry.is_empty());
}
#[test]
fn test_get_info_nonexistent() {
let registry = ModelRegistry::new(5);
let result = registry.get_info("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_len_method() {
let registry = ModelRegistry::new(5);
assert_eq!(registry.len(), 0);
let (model1, tokenizer1) = create_test_model();
registry.register("model-1", model1, tokenizer1).unwrap();
assert_eq!(registry.len(), 1);
let (model2, tokenizer2) = create_test_model();
registry.register("model-2", model2, tokenizer2).unwrap();
assert_eq!(registry.len(), 2);
}
}