use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub enum MlPersistenceError {
Backend(String),
Corruption(String),
}
impl std::fmt::Display for MlPersistenceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MlPersistenceError::Backend(msg) => write!(f, "ml persistence backend error: {msg}"),
MlPersistenceError::Corruption(msg) => {
write!(f, "ml persistence value corrupted: {msg}")
}
}
}
}
impl std::error::Error for MlPersistenceError {}
pub type MlPersistenceResult<T> = Result<T, MlPersistenceError>;
pub trait MlPersistence: Send + Sync + std::fmt::Debug {
fn put(&self, namespace: &str, key: &str, value: &str) -> MlPersistenceResult<()>;
fn get(&self, namespace: &str, key: &str) -> MlPersistenceResult<Option<String>>;
fn delete(&self, namespace: &str, key: &str) -> MlPersistenceResult<()>;
fn list(&self, namespace: &str) -> MlPersistenceResult<Vec<(String, String)>>;
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryMlPersistence {
inner: Arc<Mutex<HashMap<(String, String), String>>>,
}
impl InMemoryMlPersistence {
pub fn new() -> Self {
Self::default()
}
fn lock(
&self,
) -> MlPersistenceResult<std::sync::MutexGuard<'_, HashMap<(String, String), String>>> {
self.inner
.lock()
.map_err(|_| MlPersistenceError::Backend("mutex poisoned".to_string()))
}
}
impl MlPersistence for InMemoryMlPersistence {
fn put(&self, namespace: &str, key: &str, value: &str) -> MlPersistenceResult<()> {
let mut guard = self.lock()?;
guard.insert((namespace.to_string(), key.to_string()), value.to_string());
Ok(())
}
fn get(&self, namespace: &str, key: &str) -> MlPersistenceResult<Option<String>> {
let guard = self.lock()?;
Ok(guard
.get(&(namespace.to_string(), key.to_string()))
.cloned())
}
fn delete(&self, namespace: &str, key: &str) -> MlPersistenceResult<()> {
let mut guard = self.lock()?;
guard.remove(&(namespace.to_string(), key.to_string()));
Ok(())
}
fn list(&self, namespace: &str) -> MlPersistenceResult<Vec<(String, String)>> {
let guard = self.lock()?;
Ok(guard
.iter()
.filter(|((ns, _), _)| ns == namespace)
.map(|((_, k), v)| (k.clone(), v.clone()))
.collect())
}
}
pub mod ns {
pub const MODELS: &str = "models";
pub const MODEL_VERSIONS: &str = "model_versions";
pub const JOBS: &str = "jobs";
}
pub mod key {
pub fn model(name: &str) -> String {
name.to_string()
}
pub fn model_version(model: &str, version: u32) -> String {
format!("{model}@v{version}")
}
pub fn job(id: u128) -> String {
format!("{id:032x}")
}
pub fn parse_job(raw: &str) -> Option<u128> {
if raw.len() != 32 {
return None;
}
u128::from_str_radix(raw, 16).ok()
}
pub fn parse_model_version(raw: &str) -> Option<(String, u32)> {
let (model, rest) = raw.rsplit_once("@v")?;
let version = rest.parse::<u32>().ok()?;
Some((model.to_string(), version))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_memory_put_then_get() {
let p = InMemoryMlPersistence::new();
p.put("jobs", "abc", "{\"status\":\"queued\"}").unwrap();
assert_eq!(
p.get("jobs", "abc").unwrap().as_deref(),
Some("{\"status\":\"queued\"}")
);
}
#[test]
fn in_memory_get_missing_returns_none() {
let p = InMemoryMlPersistence::new();
assert!(p.get("jobs", "nope").unwrap().is_none());
}
#[test]
fn in_memory_delete_is_idempotent() {
let p = InMemoryMlPersistence::new();
p.delete("jobs", "missing").unwrap();
p.put("jobs", "k", "v").unwrap();
p.delete("jobs", "k").unwrap();
assert!(p.get("jobs", "k").unwrap().is_none());
}
#[test]
fn in_memory_list_scopes_to_namespace() {
let p = InMemoryMlPersistence::new();
p.put("jobs", "j1", "a").unwrap();
p.put("jobs", "j2", "b").unwrap();
p.put("models", "spam", "{}").unwrap();
let mut jobs = p.list("jobs").unwrap();
jobs.sort();
assert_eq!(
jobs,
vec![
("j1".to_string(), "a".to_string()),
("j2".to_string(), "b".to_string())
]
);
assert_eq!(p.list("models").unwrap().len(), 1);
}
#[test]
fn job_key_round_trips() {
let id = 0x0123_4567_89ab_cdef_0123_4567_89ab_cdef_u128;
let raw = key::job(id);
assert_eq!(raw.len(), 32);
assert_eq!(key::parse_job(&raw), Some(id));
}
#[test]
fn job_key_rejects_wrong_length() {
assert!(key::parse_job("abc").is_none());
assert!(key::parse_job(&"0".repeat(31)).is_none());
assert!(key::parse_job(&"0".repeat(33)).is_none());
}
#[test]
fn model_version_key_round_trips() {
let raw = key::model_version("spam_classifier", 42);
assert_eq!(raw, "spam_classifier@v42");
assert_eq!(
key::parse_model_version(&raw),
Some(("spam_classifier".to_string(), 42))
);
}
#[test]
fn model_version_key_survives_at_in_name() {
let raw = "weird@name@v7";
assert_eq!(
key::parse_model_version(raw),
Some(("weird@name".to_string(), 7))
);
}
#[test]
fn model_version_key_rejects_non_numeric_version() {
assert!(key::parse_model_version("spam@vfoo").is_none());
}
}