Skip to main content

zer_schema/
registry.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4    sync::Mutex,
5};
6
7use zer_core::error::ZerError;
8
9use crate::{
10    artifact::ModelArtifact,
11    fingerprint::SchemaFingerprint,
12    similarity::{fingerprint_distance, WARM_START_THRESHOLD},
13};
14
15const MAGIC: &[u8] = b"ZSM\x01";
16
17/// Decides how the pipeline should initialize when a new dataset arrives.
18#[derive(Debug)]
19pub enum StartupMode {
20    /// Schema hash matches exactly, skip EM and use the saved params directly.
21    WarmLoad(ModelArtifact),
22    /// Schema is similar (distance ≤ threshold), use saved params as the EM
23    /// warm-start initializer and run 2–3 iterations to fine-tune.
24    WarmStart { artifact: ModelArtifact, distance: f32 },
25    /// Schema is new or too different, initialize from priors and run full EM.
26    ColdStart,
27}
28
29struct RegistryInner {
30    path:      Option<PathBuf>,
31    artifacts: HashMap<[u8; 32], ModelArtifact>,
32}
33
34/// Persistent store for trained [`ModelArtifact`]s.
35///
36/// Backed by a single portable `.zsm` binary file (`b"ZSM\x01"` magic +
37/// bincode-serialized `HashMap`). The file is written atomically on every
38/// mutation, a `.zsm.tmp` file is written first then renamed into place, so a
39/// crash during flush can never leave a partially-written registry.
40///
41/// The registry is small in practice (< 1 000 entries), so nearest-neighbor
42/// lookup performs a full linear scan without an index.
43pub struct SchemaRegistry {
44    inner: Mutex<RegistryInner>,
45}
46
47impl SchemaRegistry {
48    /// Open (or create) a registry at the given `.zsm` file path.
49    ///
50    /// If the file does not exist yet it is created on the first [`Self::save`] call.
51    pub fn open(path: &Path) -> Result<Self, ZerError> {
52        let artifacts = load(path)?;
53        Ok(Self {
54            inner: Mutex::new(RegistryInner {
55                path: Some(path.to_path_buf()),
56                artifacts,
57            }),
58        })
59    }
60
61    /// Create an in-memory registry. No file I/O; data is lost on drop.
62    #[cfg(test)]
63    pub(crate) fn open_temporary() -> Result<Self, ZerError> {
64        Ok(Self {
65            inner: Mutex::new(RegistryInner {
66                path: None,
67                artifacts: HashMap::new(),
68            }),
69        })
70    }
71
72    // ── Write ────────────────────────────────────────────────────────────────
73
74    /// Persist a trained model artifact. Overwrites any existing artifact with
75    /// the same schema hash and atomically flushes to disk.
76    pub fn save(&self, artifact: &ModelArtifact) -> Result<(), ZerError> {
77        let mut inner = self.inner.lock().unwrap();
78        inner.artifacts.insert(artifact.fingerprint.schema_hash, artifact.clone());
79        flush(&inner)?;
80        tracing::debug!(tag = artifact.tag.as_deref(), "saved model artifact");
81        Ok(())
82    }
83
84    // ── Read ─────────────────────────────────────────────────────────────────
85
86    /// Exact lookup by schema hash. Returns `None` if no matching artifact exists.
87    pub fn get_exact(
88        &self,
89        fingerprint: &SchemaFingerprint,
90    ) -> Result<Option<ModelArtifact>, ZerError> {
91        let inner = self.inner.lock().unwrap();
92        Ok(inner.artifacts.get(&fingerprint.schema_hash).cloned())
93    }
94
95    /// Nearest-neighbor lookup: returns the closest artifact and its distance.
96    ///
97    /// Performs a full linear scan, acceptable because the registry is expected
98    /// to hold far fewer than 1 000 entries.
99    ///
100    /// Returns `None` when the registry is empty.
101    pub fn get_nearest(
102        &self,
103        fingerprint: &SchemaFingerprint,
104    ) -> Result<Option<(ModelArtifact, f32)>, ZerError> {
105        let inner = self.inner.lock().unwrap();
106        let best = inner
107            .artifacts
108            .values()
109            .map(|a| {
110                let dist = fingerprint_distance(fingerprint, &a.fingerprint);
111                (a.clone(), dist)
112            })
113            .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Equal));
114        Ok(best)
115    }
116
117    /// Determine the startup mode for an incoming dataset given its fingerprint.
118    ///
119    /// ```text
120    /// exact hash match         → WarmLoad   (skip EM entirely)
121    /// distance ≤ 0.25          → WarmStart  (2–3 EM iterations from saved init)
122    /// distance  > 0.25 / empty → ColdStart  (full EM from priors)
123    /// ```
124    pub fn lookup_startup_mode(
125        &self,
126        fingerprint: &SchemaFingerprint,
127    ) -> Result<StartupMode, ZerError> {
128        if let Some(exact) = self.get_exact(fingerprint)? {
129            tracing::info!("exact schema match, warm load");
130            return Ok(StartupMode::WarmLoad(exact));
131        }
132
133        match self.get_nearest(fingerprint)? {
134            Some((artifact, dist)) if dist <= WARM_START_THRESHOLD => {
135                tracing::info!(dist, "similar schema, warm start");
136                Ok(StartupMode::WarmStart { artifact, distance: dist })
137            }
138            _ => {
139                tracing::info!("no suitable prior, cold start");
140                Ok(StartupMode::ColdStart)
141            }
142        }
143    }
144
145    // ── Enumeration / deletion ────────────────────────────────────────────────
146
147    /// Return all stored artifacts in arbitrary order.
148    pub fn list_all(&self) -> Result<Vec<ModelArtifact>, ZerError> {
149        let inner = self.inner.lock().unwrap();
150        Ok(inner.artifacts.values().cloned().collect())
151    }
152
153    /// Delete the artifact for the given schema hash.
154    ///
155    /// Returns `true` if an artifact was found and removed, `false` otherwise.
156    pub fn delete(&self, schema_hash: &[u8; 32]) -> Result<bool, ZerError> {
157        let mut inner = self.inner.lock().unwrap();
158        let removed = inner.artifacts.remove(schema_hash).is_some();
159        if removed {
160            flush(&inner)?;
161        }
162        Ok(removed)
163    }
164}
165
166// ── File I/O ──────────────────────────────────────────────────────────────────
167
168fn flush(inner: &RegistryInner) -> Result<(), ZerError> {
169    let Some(path) = &inner.path else {
170        return Ok(());
171    };
172    let payload = bincode::serialize(&inner.artifacts)
173        .map_err(|e| ZerError::Serialization(e.to_string()))?;
174    let mut buf = Vec::with_capacity(4 + payload.len());
175    buf.extend_from_slice(MAGIC);
176    buf.extend(payload);
177    let tmp = path.with_extension("zsm.tmp");
178    std::fs::write(&tmp, &buf).map_err(|e| ZerError::Store(e.to_string()))?;
179    std::fs::rename(&tmp, path).map_err(|e| ZerError::Store(e.to_string()))?;
180    Ok(())
181}
182
183fn load(path: &Path) -> Result<HashMap<[u8; 32], ModelArtifact>, ZerError> {
184    if !path.exists() {
185        return Ok(HashMap::new());
186    }
187    let bytes = std::fs::read(path).map_err(|e| ZerError::Store(e.to_string()))?;
188    if bytes.get(..4) != Some(MAGIC) {
189        return Err(ZerError::Store("invalid .zsm magic".into()));
190    }
191    bincode::deserialize(&bytes[4..]).map_err(|e| ZerError::Serialization(e.to_string()))
192}
193
194// ── Unit tests ────────────────────────────────────────────────────────────────
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use zer_core::{
200        schema::{FieldKind, SchemaBuilder},
201        scoring::ModelParams,
202    };
203
204    use crate::{artifact::ModelArtifact, fingerprint::SchemaFingerprint};
205
206    fn dummy_params(n_fields: usize) -> ModelParams {
207        ModelParams {
208            m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
209            u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
210            log_prior_odds: -2.0,
211            upper_threshold: 0.9,
212            lower_threshold: 0.1,
213        }
214    }
215
216    fn make_artifact(schema: &zer_core::schema::Schema, tag: &str) -> ModelArtifact {
217        ModelArtifact {
218            fingerprint: SchemaFingerprint::from_schema(schema),
219            params: dummy_params(schema.len()),
220            tag: Some(tag.into()),
221            trained_on: 0,
222            em_iterations: 25,
223        }
224    }
225
226    fn brp_schema() -> zer_core::schema::Schema {
227        SchemaBuilder::new()
228            .field("voornamen", FieldKind::Name)
229            .field("achternaam", FieldKind::Name)
230            .field("geboortedatum", FieldKind::Date)
231            .field("nationaliteit", FieldKind::Categorical)
232            .field("postcode", FieldKind::Id)
233            .build()
234            .unwrap()
235    }
236
237    fn sim_schema() -> zer_core::schema::Schema {
238        SchemaBuilder::new()
239            .field("sim_id", FieldKind::Id)
240            .field("msisdn", FieldKind::Phone)
241            .field("imsi", FieldKind::Id)
242            .field("voornamen", FieldKind::Name)
243            .field("achternaam", FieldKind::Name)
244            .field("geboortedatum", FieldKind::Date)
245            .field("nationaliteit", FieldKind::Categorical)
246            .build()
247            .unwrap()
248    }
249
250    #[test]
251    fn roundtrip_save_and_get_exact() {
252        let registry = SchemaRegistry::open_temporary().unwrap();
253        let schema = brp_schema();
254        let artifact = make_artifact(&schema, "brp_test");
255
256        registry.save(&artifact).unwrap();
257
258        let fp = SchemaFingerprint::from_schema(&schema);
259        let loaded = registry.get_exact(&fp).unwrap().unwrap();
260
261        assert_eq!(loaded.tag.as_deref(), Some("brp_test"));
262        assert_eq!(
263            loaded.fingerprint.schema_hash,
264            artifact.fingerprint.schema_hash
265        );
266        assert_eq!(loaded.params.upper_threshold, artifact.params.upper_threshold);
267    }
268
269    #[test]
270    fn get_exact_returns_none_for_unknown_schema() {
271        let registry = SchemaRegistry::open_temporary().unwrap();
272        let fp = SchemaFingerprint::from_schema(&brp_schema());
273        let result = registry.get_exact(&fp).unwrap();
274        assert!(result.is_none());
275    }
276
277    #[test]
278    fn list_all_returns_all_artifacts() {
279        let registry = SchemaRegistry::open_temporary().unwrap();
280        let brp = brp_schema();
281        let sim = sim_schema();
282
283        registry.save(&make_artifact(&brp, "brp")).unwrap();
284        registry.save(&make_artifact(&sim, "sim")).unwrap();
285
286        let all = registry.list_all().unwrap();
287        assert_eq!(all.len(), 2);
288    }
289
290    #[test]
291    fn delete_removes_artifact_and_returns_true() {
292        let registry = SchemaRegistry::open_temporary().unwrap();
293        let schema = brp_schema();
294        let artifact = make_artifact(&schema, "brp");
295        registry.save(&artifact).unwrap();
296
297        let removed = registry.delete(&artifact.fingerprint.schema_hash).unwrap();
298        assert!(removed, "delete should return true when the key existed");
299
300        let fp = SchemaFingerprint::from_schema(&schema);
301        assert!(registry.get_exact(&fp).unwrap().is_none());
302    }
303
304    #[test]
305    fn delete_returns_false_for_missing_key() {
306        let registry = SchemaRegistry::open_temporary().unwrap();
307        let hash = [0u8; 32];
308        assert!(!registry.delete(&hash).unwrap());
309    }
310
311    #[test]
312    fn startup_mode_exact_match_is_warm_load() {
313        let registry = SchemaRegistry::open_temporary().unwrap();
314        let schema = brp_schema();
315        registry.save(&make_artifact(&schema, "brp")).unwrap();
316
317        let fp = SchemaFingerprint::from_schema(&schema);
318        let mode = registry.lookup_startup_mode(&fp).unwrap();
319
320        assert!(
321            matches!(mode, StartupMode::WarmLoad(_)),
322            "exact schema match must return WarmLoad"
323        );
324    }
325
326    #[test]
327    fn startup_mode_added_field_is_warm_start() {
328        let registry = SchemaRegistry::open_temporary().unwrap();
329        registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
330
331        let extended = SchemaBuilder::new()
332            .field("voornamen", FieldKind::Name)
333            .field("achternaam", FieldKind::Name)
334            .field("geboortedatum", FieldKind::Date)
335            .field("nationaliteit", FieldKind::Categorical)
336            .field("postcode", FieldKind::Id)
337            .field("verblijfstitel", FieldKind::Categorical)
338            .build()
339            .unwrap();
340
341        let fp = SchemaFingerprint::from_schema(&extended);
342        let mode = registry.lookup_startup_mode(&fp).unwrap();
343
344        assert!(
345            matches!(mode, StartupMode::WarmStart { .. }),
346            "one added field should return WarmStart"
347        );
348    }
349
350    #[test]
351    fn startup_mode_incompatible_schema_is_cold_start() {
352        let registry = SchemaRegistry::open_temporary().unwrap();
353        registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
354
355        let fp = SchemaFingerprint::from_schema(&sim_schema());
356        let mode = registry.lookup_startup_mode(&fp).unwrap();
357
358        assert!(
359            matches!(mode, StartupMode::ColdStart),
360            "BRP artifact vs SIM schema should return ColdStart"
361        );
362    }
363
364    #[test]
365    fn startup_mode_empty_registry_is_cold_start() {
366        let registry = SchemaRegistry::open_temporary().unwrap();
367        let fp = SchemaFingerprint::from_schema(&brp_schema());
368        assert!(matches!(
369            registry.lookup_startup_mode(&fp).unwrap(),
370            StartupMode::ColdStart
371        ));
372    }
373
374    #[test]
375    fn nearest_prefers_closer_artifact() {
376        let registry = SchemaRegistry::open_temporary().unwrap();
377        registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
378        registry.save(&make_artifact(&sim_schema(), "sim")).unwrap();
379
380        let brp_like = SchemaBuilder::new()
381            .field("voornamen", FieldKind::Name)
382            .field("achternaam", FieldKind::Name)
383            .field("geboortedatum", FieldKind::Date)
384            .field("nationaliteit", FieldKind::Categorical)
385            .field("postcode", FieldKind::Id)
386            .field("verblijfstitel", FieldKind::Categorical)
387            .build()
388            .unwrap();
389
390        let (nearest, _dist) = registry
391            .get_nearest(&SchemaFingerprint::from_schema(&brp_like))
392            .unwrap()
393            .expect("registry is not empty");
394
395        assert_eq!(
396            nearest.tag.as_deref(),
397            Some("brp"),
398            "BRP-like schema should match the BRP artifact, not SIM"
399        );
400    }
401}