Skip to main content

reddb_server/storage/ml/
persist.rs

1//! Persistence abstraction for the ML subsystem.
2//!
3//! The registry and job queue call into a small [`MlPersistence`]
4//! trait rather than touching the storage engine directly. That
5//! keeps this module independent of `RedDBRuntime` / `StorageService`
6//! so it stays unit-testable and so a runtime binding can be plugged
7//! in later without reshaping callers.
8//!
9//! The default in-crate backend is [`InMemoryMlPersistence`] — a
10//! thread-safe hashmap. A future sprint adds a `RedConfigMlPersistence`
11//! that writes to the `red.ml.*` KV tree so state survives restart,
12//! backup, and replica sync.
13//!
14//! The surface is intentionally small: three namespaces
15//! (`"models"`, `"model_versions"`, `"jobs"`), CRUD by string key,
16//! and a list operation per namespace. All values are encoded as
17//! JSON strings; the registry / queue own the schema inside each
18//! value.
19
20use std::collections::HashMap;
21use std::sync::{Arc, Mutex};
22
23/// Errors surfaced by a persistence backend. Intentionally small —
24/// callers convert into their own error types.
25#[derive(Debug, Clone)]
26pub enum MlPersistenceError {
27    /// Underlying store returned an error. Message is backend-specific.
28    Backend(String),
29    /// Value was expected to parse as JSON but did not.
30    Corruption(String),
31}
32
33impl std::fmt::Display for MlPersistenceError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            MlPersistenceError::Backend(msg) => write!(f, "ml persistence backend error: {msg}"),
37            MlPersistenceError::Corruption(msg) => {
38                write!(f, "ml persistence value corrupted: {msg}")
39            }
40        }
41    }
42}
43
44impl std::error::Error for MlPersistenceError {}
45
46pub type MlPersistenceResult<T> = Result<T, MlPersistenceError>;
47
48/// Backend-agnostic storage surface the registry and job queue use.
49///
50/// Implementations must be `Send + Sync` because multiple worker
51/// threads touch the same instance concurrently. Writes must be
52/// at-least-once durable — the caller will re-issue a write on a
53/// re-transition rather than relying on partial success semantics.
54pub trait MlPersistence: Send + Sync + std::fmt::Debug {
55    /// Store `value` under `(namespace, key)`. Overwrites any
56    /// existing value.
57    fn put(&self, namespace: &str, key: &str, value: &str) -> MlPersistenceResult<()>;
58
59    /// Fetch the value at `(namespace, key)`, if any.
60    fn get(&self, namespace: &str, key: &str) -> MlPersistenceResult<Option<String>>;
61
62    /// Drop the entry at `(namespace, key)`. Returns `Ok(())` whether
63    /// the key existed or not — callers do not distinguish.
64    fn delete(&self, namespace: &str, key: &str) -> MlPersistenceResult<()>;
65
66    /// Enumerate every `(key, value)` in `namespace`. Ordering is
67    /// implementation-defined. Callers that need deterministic order
68    /// sort the result themselves.
69    fn list(&self, namespace: &str) -> MlPersistenceResult<Vec<(String, String)>>;
70}
71
72/// Test / default backend. Pure in-memory hashmap keyed by
73/// `(namespace, key)`. Cloning the handle shares state.
74#[derive(Debug, Default, Clone)]
75pub struct InMemoryMlPersistence {
76    inner: Arc<Mutex<HashMap<(String, String), String>>>,
77}
78
79impl InMemoryMlPersistence {
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    fn lock(
85        &self,
86    ) -> MlPersistenceResult<std::sync::MutexGuard<'_, HashMap<(String, String), String>>> {
87        self.inner
88            .lock()
89            .map_err(|_| MlPersistenceError::Backend("mutex poisoned".to_string()))
90    }
91}
92
93impl MlPersistence for InMemoryMlPersistence {
94    fn put(&self, namespace: &str, key: &str, value: &str) -> MlPersistenceResult<()> {
95        let mut guard = self.lock()?;
96        guard.insert((namespace.to_string(), key.to_string()), value.to_string());
97        Ok(())
98    }
99
100    fn get(&self, namespace: &str, key: &str) -> MlPersistenceResult<Option<String>> {
101        let guard = self.lock()?;
102        Ok(guard
103            .get(&(namespace.to_string(), key.to_string()))
104            .cloned())
105    }
106
107    fn delete(&self, namespace: &str, key: &str) -> MlPersistenceResult<()> {
108        let mut guard = self.lock()?;
109        guard.remove(&(namespace.to_string(), key.to_string()));
110        Ok(())
111    }
112
113    fn list(&self, namespace: &str) -> MlPersistenceResult<Vec<(String, String)>> {
114        let guard = self.lock()?;
115        Ok(guard
116            .iter()
117            .filter(|((ns, _), _)| ns == namespace)
118            .map(|((_, k), v)| (k.clone(), v.clone()))
119            .collect())
120    }
121}
122
123/// Namespace names — kept as `pub const` so the registry and queue
124/// modules can share them and a future backend can map them
125/// onto the `red.ml.*` KV tree.
126pub mod ns {
127    pub const MODELS: &str = "models";
128    pub const MODEL_VERSIONS: &str = "model_versions";
129    pub const JOBS: &str = "jobs";
130}
131
132/// Composite key helpers. Callers build keys via these helpers so a
133/// future schema migration only needs to update one place.
134pub mod key {
135    pub fn model(name: &str) -> String {
136        name.to_string()
137    }
138
139    pub fn model_version(model: &str, version: u32) -> String {
140        format!("{model}@v{version}")
141    }
142
143    pub fn job(id: u128) -> String {
144        // Zero-padded hex — sort order is deterministic without extra
145        // allocations, and u128 fits in 32 hex characters exactly.
146        format!("{id:032x}")
147    }
148
149    /// Parse a `job(id)` key back into an id. Returns `None` on any
150    /// malformed key — callers skip rather than error out so a single
151    /// poisoned record cannot poison a startup sweep.
152    pub fn parse_job(raw: &str) -> Option<u128> {
153        if raw.len() != 32 {
154            return None;
155        }
156        u128::from_str_radix(raw, 16).ok()
157    }
158
159    /// Parse a `model_version` key back into `(model, version)`.
160    pub fn parse_model_version(raw: &str) -> Option<(String, u32)> {
161        let (model, rest) = raw.rsplit_once("@v")?;
162        let version = rest.parse::<u32>().ok()?;
163        Some((model.to_string(), version))
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn in_memory_put_then_get() {
173        let p = InMemoryMlPersistence::new();
174        p.put("jobs", "abc", "{\"status\":\"queued\"}").unwrap();
175        assert_eq!(
176            p.get("jobs", "abc").unwrap().as_deref(),
177            Some("{\"status\":\"queued\"}")
178        );
179    }
180
181    #[test]
182    fn in_memory_get_missing_returns_none() {
183        let p = InMemoryMlPersistence::new();
184        assert!(p.get("jobs", "nope").unwrap().is_none());
185    }
186
187    #[test]
188    fn in_memory_delete_is_idempotent() {
189        let p = InMemoryMlPersistence::new();
190        p.delete("jobs", "missing").unwrap();
191        p.put("jobs", "k", "v").unwrap();
192        p.delete("jobs", "k").unwrap();
193        assert!(p.get("jobs", "k").unwrap().is_none());
194    }
195
196    #[test]
197    fn in_memory_list_scopes_to_namespace() {
198        let p = InMemoryMlPersistence::new();
199        p.put("jobs", "j1", "a").unwrap();
200        p.put("jobs", "j2", "b").unwrap();
201        p.put("models", "spam", "{}").unwrap();
202        let mut jobs = p.list("jobs").unwrap();
203        jobs.sort();
204        assert_eq!(
205            jobs,
206            vec![
207                ("j1".to_string(), "a".to_string()),
208                ("j2".to_string(), "b".to_string())
209            ]
210        );
211        assert_eq!(p.list("models").unwrap().len(), 1);
212    }
213
214    #[test]
215    fn job_key_round_trips() {
216        let id = 0x0123_4567_89ab_cdef_0123_4567_89ab_cdef_u128;
217        let raw = key::job(id);
218        assert_eq!(raw.len(), 32);
219        assert_eq!(key::parse_job(&raw), Some(id));
220    }
221
222    #[test]
223    fn job_key_rejects_wrong_length() {
224        assert!(key::parse_job("abc").is_none());
225        assert!(key::parse_job(&"0".repeat(31)).is_none());
226        assert!(key::parse_job(&"0".repeat(33)).is_none());
227    }
228
229    #[test]
230    fn model_version_key_round_trips() {
231        let raw = key::model_version("spam_classifier", 42);
232        assert_eq!(raw, "spam_classifier@v42");
233        assert_eq!(
234            key::parse_model_version(&raw),
235            Some(("spam_classifier".to_string(), 42))
236        );
237    }
238
239    #[test]
240    fn model_version_key_survives_at_in_name() {
241        // Model names could in theory contain '@' — rsplit_once picks
242        // the *last* occurrence, which is the `@v` prefix.
243        let raw = "weird@name@v7";
244        assert_eq!(
245            key::parse_model_version(raw),
246            Some(("weird@name".to_string(), 7))
247        );
248    }
249
250    #[test]
251    fn model_version_key_rejects_non_numeric_version() {
252        assert!(key::parse_model_version("spam@vfoo").is_none());
253    }
254}