use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::Mutex,
};
use zer_core::error::ZerError;
use crate::{
artifact::ModelArtifact,
fingerprint::SchemaFingerprint,
similarity::{fingerprint_distance, WARM_START_THRESHOLD},
};
const MAGIC: &[u8] = b"ZSM\x01";
#[derive(Debug)]
pub enum StartupMode {
WarmLoad(ModelArtifact),
WarmStart { artifact: ModelArtifact, distance: f32 },
ColdStart,
}
struct RegistryInner {
path: Option<PathBuf>,
artifacts: HashMap<[u8; 32], ModelArtifact>,
}
pub struct SchemaRegistry {
inner: Mutex<RegistryInner>,
}
impl SchemaRegistry {
pub fn open(path: &Path) -> Result<Self, ZerError> {
let artifacts = load(path)?;
Ok(Self {
inner: Mutex::new(RegistryInner {
path: Some(path.to_path_buf()),
artifacts,
}),
})
}
#[cfg(test)]
pub(crate) fn open_temporary() -> Result<Self, ZerError> {
Ok(Self {
inner: Mutex::new(RegistryInner {
path: None,
artifacts: HashMap::new(),
}),
})
}
pub fn save(&self, artifact: &ModelArtifact) -> Result<(), ZerError> {
let mut inner = self.inner.lock().unwrap();
inner.artifacts.insert(artifact.fingerprint.schema_hash, artifact.clone());
flush(&inner)?;
tracing::debug!(tag = artifact.tag.as_deref(), "saved model artifact");
Ok(())
}
pub fn get_exact(
&self,
fingerprint: &SchemaFingerprint,
) -> Result<Option<ModelArtifact>, ZerError> {
let inner = self.inner.lock().unwrap();
Ok(inner.artifacts.get(&fingerprint.schema_hash).cloned())
}
pub fn get_nearest(
&self,
fingerprint: &SchemaFingerprint,
) -> Result<Option<(ModelArtifact, f32)>, ZerError> {
let inner = self.inner.lock().unwrap();
let best = inner
.artifacts
.values()
.map(|a| {
let dist = fingerprint_distance(fingerprint, &a.fingerprint);
(a.clone(), dist)
})
.min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Equal));
Ok(best)
}
pub fn lookup_startup_mode(
&self,
fingerprint: &SchemaFingerprint,
) -> Result<StartupMode, ZerError> {
if let Some(exact) = self.get_exact(fingerprint)? {
tracing::info!("exact schema match, warm load");
return Ok(StartupMode::WarmLoad(exact));
}
match self.get_nearest(fingerprint)? {
Some((artifact, dist)) if dist <= WARM_START_THRESHOLD => {
tracing::info!(dist, "similar schema, warm start");
Ok(StartupMode::WarmStart { artifact, distance: dist })
}
_ => {
tracing::info!("no suitable prior, cold start");
Ok(StartupMode::ColdStart)
}
}
}
pub fn list_all(&self) -> Result<Vec<ModelArtifact>, ZerError> {
let inner = self.inner.lock().unwrap();
Ok(inner.artifacts.values().cloned().collect())
}
pub fn delete(&self, schema_hash: &[u8; 32]) -> Result<bool, ZerError> {
let mut inner = self.inner.lock().unwrap();
let removed = inner.artifacts.remove(schema_hash).is_some();
if removed {
flush(&inner)?;
}
Ok(removed)
}
}
fn flush(inner: &RegistryInner) -> Result<(), ZerError> {
let Some(path) = &inner.path else {
return Ok(());
};
let payload = bincode::serialize(&inner.artifacts)
.map_err(|e| ZerError::Serialization(e.to_string()))?;
let mut buf = Vec::with_capacity(4 + payload.len());
buf.extend_from_slice(MAGIC);
buf.extend(payload);
let tmp = path.with_extension("zsm.tmp");
std::fs::write(&tmp, &buf).map_err(|e| ZerError::Store(e.to_string()))?;
std::fs::rename(&tmp, path).map_err(|e| ZerError::Store(e.to_string()))?;
Ok(())
}
fn load(path: &Path) -> Result<HashMap<[u8; 32], ModelArtifact>, ZerError> {
if !path.exists() {
return Ok(HashMap::new());
}
let bytes = std::fs::read(path).map_err(|e| ZerError::Store(e.to_string()))?;
if bytes.get(..4) != Some(MAGIC) {
return Err(ZerError::Store("invalid .zsm magic".into()));
}
bincode::deserialize(&bytes[4..]).map_err(|e| ZerError::Serialization(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use zer_core::{
schema::{FieldKind, SchemaBuilder},
scoring::ModelParams,
};
use crate::{artifact::ModelArtifact, fingerprint::SchemaFingerprint};
fn dummy_params(n_fields: usize) -> ModelParams {
ModelParams {
m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
log_prior_odds: -2.0,
upper_threshold: 0.9,
lower_threshold: 0.1,
}
}
fn make_artifact(schema: &zer_core::schema::Schema, tag: &str) -> ModelArtifact {
ModelArtifact {
fingerprint: SchemaFingerprint::from_schema(schema),
params: dummy_params(schema.len()),
tag: Some(tag.into()),
trained_on: 0,
em_iterations: 25,
}
}
fn brp_schema() -> zer_core::schema::Schema {
SchemaBuilder::new()
.field("voornamen", FieldKind::Name)
.field("achternaam", FieldKind::Name)
.field("geboortedatum", FieldKind::Date)
.field("nationaliteit", FieldKind::Categorical)
.field("postcode", FieldKind::Id)
.build()
.unwrap()
}
fn sim_schema() -> zer_core::schema::Schema {
SchemaBuilder::new()
.field("sim_id", FieldKind::Id)
.field("msisdn", FieldKind::Phone)
.field("imsi", FieldKind::Id)
.field("voornamen", FieldKind::Name)
.field("achternaam", FieldKind::Name)
.field("geboortedatum", FieldKind::Date)
.field("nationaliteit", FieldKind::Categorical)
.build()
.unwrap()
}
#[test]
fn roundtrip_save_and_get_exact() {
let registry = SchemaRegistry::open_temporary().unwrap();
let schema = brp_schema();
let artifact = make_artifact(&schema, "brp_test");
registry.save(&artifact).unwrap();
let fp = SchemaFingerprint::from_schema(&schema);
let loaded = registry.get_exact(&fp).unwrap().unwrap();
assert_eq!(loaded.tag.as_deref(), Some("brp_test"));
assert_eq!(
loaded.fingerprint.schema_hash,
artifact.fingerprint.schema_hash
);
assert_eq!(loaded.params.upper_threshold, artifact.params.upper_threshold);
}
#[test]
fn get_exact_returns_none_for_unknown_schema() {
let registry = SchemaRegistry::open_temporary().unwrap();
let fp = SchemaFingerprint::from_schema(&brp_schema());
let result = registry.get_exact(&fp).unwrap();
assert!(result.is_none());
}
#[test]
fn list_all_returns_all_artifacts() {
let registry = SchemaRegistry::open_temporary().unwrap();
let brp = brp_schema();
let sim = sim_schema();
registry.save(&make_artifact(&brp, "brp")).unwrap();
registry.save(&make_artifact(&sim, "sim")).unwrap();
let all = registry.list_all().unwrap();
assert_eq!(all.len(), 2);
}
#[test]
fn delete_removes_artifact_and_returns_true() {
let registry = SchemaRegistry::open_temporary().unwrap();
let schema = brp_schema();
let artifact = make_artifact(&schema, "brp");
registry.save(&artifact).unwrap();
let removed = registry.delete(&artifact.fingerprint.schema_hash).unwrap();
assert!(removed, "delete should return true when the key existed");
let fp = SchemaFingerprint::from_schema(&schema);
assert!(registry.get_exact(&fp).unwrap().is_none());
}
#[test]
fn delete_returns_false_for_missing_key() {
let registry = SchemaRegistry::open_temporary().unwrap();
let hash = [0u8; 32];
assert!(!registry.delete(&hash).unwrap());
}
#[test]
fn startup_mode_exact_match_is_warm_load() {
let registry = SchemaRegistry::open_temporary().unwrap();
let schema = brp_schema();
registry.save(&make_artifact(&schema, "brp")).unwrap();
let fp = SchemaFingerprint::from_schema(&schema);
let mode = registry.lookup_startup_mode(&fp).unwrap();
assert!(
matches!(mode, StartupMode::WarmLoad(_)),
"exact schema match must return WarmLoad"
);
}
#[test]
fn startup_mode_added_field_is_warm_start() {
let registry = SchemaRegistry::open_temporary().unwrap();
registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
let extended = SchemaBuilder::new()
.field("voornamen", FieldKind::Name)
.field("achternaam", FieldKind::Name)
.field("geboortedatum", FieldKind::Date)
.field("nationaliteit", FieldKind::Categorical)
.field("postcode", FieldKind::Id)
.field("verblijfstitel", FieldKind::Categorical)
.build()
.unwrap();
let fp = SchemaFingerprint::from_schema(&extended);
let mode = registry.lookup_startup_mode(&fp).unwrap();
assert!(
matches!(mode, StartupMode::WarmStart { .. }),
"one added field should return WarmStart"
);
}
#[test]
fn startup_mode_incompatible_schema_is_cold_start() {
let registry = SchemaRegistry::open_temporary().unwrap();
registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
let fp = SchemaFingerprint::from_schema(&sim_schema());
let mode = registry.lookup_startup_mode(&fp).unwrap();
assert!(
matches!(mode, StartupMode::ColdStart),
"BRP artifact vs SIM schema should return ColdStart"
);
}
#[test]
fn startup_mode_empty_registry_is_cold_start() {
let registry = SchemaRegistry::open_temporary().unwrap();
let fp = SchemaFingerprint::from_schema(&brp_schema());
assert!(matches!(
registry.lookup_startup_mode(&fp).unwrap(),
StartupMode::ColdStart
));
}
#[test]
fn nearest_prefers_closer_artifact() {
let registry = SchemaRegistry::open_temporary().unwrap();
registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
registry.save(&make_artifact(&sim_schema(), "sim")).unwrap();
let brp_like = SchemaBuilder::new()
.field("voornamen", FieldKind::Name)
.field("achternaam", FieldKind::Name)
.field("geboortedatum", FieldKind::Date)
.field("nationaliteit", FieldKind::Categorical)
.field("postcode", FieldKind::Id)
.field("verblijfstitel", FieldKind::Categorical)
.build()
.unwrap();
let (nearest, _dist) = registry
.get_nearest(&SchemaFingerprint::from_schema(&brp_like))
.unwrap()
.expect("registry is not empty");
assert_eq!(
nearest.tag.as_deref(),
Some("brp"),
"BRP-like schema should match the BRP artifact, not SIM"
);
}
}