reddb_server/storage/ml/
persist.rs1use std::collections::HashMap;
21use std::sync::{Arc, Mutex};
22
23#[derive(Debug, Clone)]
26pub enum MlPersistenceError {
27 Backend(String),
29 Corruption(String),
31}
32
33impl std::fmt::Display for MlPersistenceError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 MlPersistenceError::Backend(msg) => write!(f, "ml persistence backend error: {msg}"),
37 MlPersistenceError::Corruption(msg) => {
38 write!(f, "ml persistence value corrupted: {msg}")
39 }
40 }
41 }
42}
43
44impl std::error::Error for MlPersistenceError {}
45
46pub type MlPersistenceResult<T> = Result<T, MlPersistenceError>;
47
48pub trait MlPersistence: Send + Sync + std::fmt::Debug {
55 fn put(&self, namespace: &str, key: &str, value: &str) -> MlPersistenceResult<()>;
58
59 fn get(&self, namespace: &str, key: &str) -> MlPersistenceResult<Option<String>>;
61
62 fn delete(&self, namespace: &str, key: &str) -> MlPersistenceResult<()>;
65
66 fn list(&self, namespace: &str) -> MlPersistenceResult<Vec<(String, String)>>;
70}
71
72#[derive(Debug, Default, Clone)]
75pub struct InMemoryMlPersistence {
76 inner: Arc<Mutex<HashMap<(String, String), String>>>,
77}
78
79impl InMemoryMlPersistence {
80 pub fn new() -> Self {
81 Self::default()
82 }
83
84 fn lock(
85 &self,
86 ) -> MlPersistenceResult<std::sync::MutexGuard<'_, HashMap<(String, String), String>>> {
87 self.inner
88 .lock()
89 .map_err(|_| MlPersistenceError::Backend("mutex poisoned".to_string()))
90 }
91}
92
93impl MlPersistence for InMemoryMlPersistence {
94 fn put(&self, namespace: &str, key: &str, value: &str) -> MlPersistenceResult<()> {
95 let mut guard = self.lock()?;
96 guard.insert((namespace.to_string(), key.to_string()), value.to_string());
97 Ok(())
98 }
99
100 fn get(&self, namespace: &str, key: &str) -> MlPersistenceResult<Option<String>> {
101 let guard = self.lock()?;
102 Ok(guard
103 .get(&(namespace.to_string(), key.to_string()))
104 .cloned())
105 }
106
107 fn delete(&self, namespace: &str, key: &str) -> MlPersistenceResult<()> {
108 let mut guard = self.lock()?;
109 guard.remove(&(namespace.to_string(), key.to_string()));
110 Ok(())
111 }
112
113 fn list(&self, namespace: &str) -> MlPersistenceResult<Vec<(String, String)>> {
114 let guard = self.lock()?;
115 Ok(guard
116 .iter()
117 .filter(|((ns, _), _)| ns == namespace)
118 .map(|((_, k), v)| (k.clone(), v.clone()))
119 .collect())
120 }
121}
122
123pub mod ns {
127 pub const MODELS: &str = "models";
128 pub const MODEL_VERSIONS: &str = "model_versions";
129 pub const JOBS: &str = "jobs";
130}
131
132pub mod key {
135 pub fn model(name: &str) -> String {
136 name.to_string()
137 }
138
139 pub fn model_version(model: &str, version: u32) -> String {
140 format!("{model}@v{version}")
141 }
142
143 pub fn job(id: u128) -> String {
144 format!("{id:032x}")
147 }
148
149 pub fn parse_job(raw: &str) -> Option<u128> {
153 if raw.len() != 32 {
154 return None;
155 }
156 u128::from_str_radix(raw, 16).ok()
157 }
158
159 pub fn parse_model_version(raw: &str) -> Option<(String, u32)> {
161 let (model, rest) = raw.rsplit_once("@v")?;
162 let version = rest.parse::<u32>().ok()?;
163 Some((model.to_string(), version))
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn in_memory_put_then_get() {
173 let p = InMemoryMlPersistence::new();
174 p.put("jobs", "abc", "{\"status\":\"queued\"}").unwrap();
175 assert_eq!(
176 p.get("jobs", "abc").unwrap().as_deref(),
177 Some("{\"status\":\"queued\"}")
178 );
179 }
180
181 #[test]
182 fn in_memory_get_missing_returns_none() {
183 let p = InMemoryMlPersistence::new();
184 assert!(p.get("jobs", "nope").unwrap().is_none());
185 }
186
187 #[test]
188 fn in_memory_delete_is_idempotent() {
189 let p = InMemoryMlPersistence::new();
190 p.delete("jobs", "missing").unwrap();
191 p.put("jobs", "k", "v").unwrap();
192 p.delete("jobs", "k").unwrap();
193 assert!(p.get("jobs", "k").unwrap().is_none());
194 }
195
196 #[test]
197 fn in_memory_list_scopes_to_namespace() {
198 let p = InMemoryMlPersistence::new();
199 p.put("jobs", "j1", "a").unwrap();
200 p.put("jobs", "j2", "b").unwrap();
201 p.put("models", "spam", "{}").unwrap();
202 let mut jobs = p.list("jobs").unwrap();
203 jobs.sort();
204 assert_eq!(
205 jobs,
206 vec![
207 ("j1".to_string(), "a".to_string()),
208 ("j2".to_string(), "b".to_string())
209 ]
210 );
211 assert_eq!(p.list("models").unwrap().len(), 1);
212 }
213
214 #[test]
215 fn job_key_round_trips() {
216 let id = 0x0123_4567_89ab_cdef_0123_4567_89ab_cdef_u128;
217 let raw = key::job(id);
218 assert_eq!(raw.len(), 32);
219 assert_eq!(key::parse_job(&raw), Some(id));
220 }
221
222 #[test]
223 fn job_key_rejects_wrong_length() {
224 assert!(key::parse_job("abc").is_none());
225 assert!(key::parse_job(&"0".repeat(31)).is_none());
226 assert!(key::parse_job(&"0".repeat(33)).is_none());
227 }
228
229 #[test]
230 fn model_version_key_round_trips() {
231 let raw = key::model_version("spam_classifier", 42);
232 assert_eq!(raw, "spam_classifier@v42");
233 assert_eq!(
234 key::parse_model_version(&raw),
235 Some(("spam_classifier".to_string(), 42))
236 );
237 }
238
239 #[test]
240 fn model_version_key_survives_at_in_name() {
241 let raw = "weird@name@v7";
244 assert_eq!(
245 key::parse_model_version(raw),
246 Some(("weird@name".to_string(), 7))
247 );
248 }
249
250 #[test]
251 fn model_version_key_rejects_non_numeric_version() {
252 assert!(key::parse_model_version("spam@vfoo").is_none());
253 }
254}