use arc_swap::ArcSwap;
use std::sync::Arc;
use uuid::Uuid;
use super::model::RegisteredModel;
pub struct LiveModel {
current: ArcSwap<RegisteredModel>,
}
impl LiveModel {
pub fn new(model: RegisteredModel) -> Self {
Self {
current: ArcSwap::new(Arc::new(model)),
}
}
pub fn load(&self) -> Arc<RegisteredModel> {
self.current.load_full()
}
pub fn swap(&self, new_model: RegisteredModel) -> Arc<RegisteredModel> {
self.current.swap(Arc::new(new_model))
}
pub fn model_id(&self) -> Uuid {
self.current.load().id
}
pub fn version(&self) -> String {
self.current.load().version.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::model::RegisteredModel;
#[test]
fn test_live_model_load() {
let model = RegisteredModel::new("classifier", "1.0.0");
let expected_id = model.id;
let live = LiveModel::new(model);
let snap = live.load();
assert_eq!(snap.name, "classifier");
assert_eq!(snap.version, "1.0.0");
assert_eq!(live.model_id(), expected_id);
assert_eq!(live.version(), "1.0.0");
}
#[test]
fn test_live_model_swap() {
let model_v1 = RegisteredModel::new("classifier", "1.0.0");
let live = LiveModel::new(model_v1);
assert_eq!(live.version(), "1.0.0");
let model_v2 = RegisteredModel::new("classifier", "2.0.0");
let old = live.swap(model_v2);
assert_eq!(old.version, "1.0.0");
assert_eq!(live.version(), "2.0.0");
assert_eq!(live.load().version, "2.0.0");
}
#[test]
fn test_live_model_snapshot_stability() {
let model_v1 = RegisteredModel::new("classifier", "1.0.0");
let live = LiveModel::new(model_v1);
let snap_before = live.load();
assert_eq!(snap_before.version, "1.0.0");
let model_v2 = RegisteredModel::new("classifier", "2.0.0");
live.swap(model_v2);
assert_eq!(snap_before.version, "1.0.0");
assert_eq!(live.load().version, "2.0.0");
}
#[test]
fn test_live_model_concurrent_swap() {
use std::sync::Arc as StdArc;
use std::thread;
let model = RegisteredModel::new("classifier", "0.0.0");
let live = StdArc::new(LiveModel::new(model));
let mut handles = Vec::new();
for i in 1..=10 {
let live_clone = StdArc::clone(&live);
handles.push(thread::spawn(move || {
let new_model = RegisteredModel::new("classifier", &format!("{i}.0.0"));
live_clone.swap(new_model);
}));
}
for _ in 0..10 {
let live_clone = StdArc::clone(&live);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let snap = live_clone.load();
assert_eq!(snap.name, "classifier");
assert!(!snap.version.is_empty());
}
}));
}
for h in handles {
h.join().unwrap();
}
let final_version = live.version();
assert!(!final_version.is_empty());
}
}