Skip to main content

reddb_server/storage/ml/
runtime.rs

1//! `MlRuntime` — convenience bundle of [`ModelRegistry`] + [`MlJobQueue`]
2//! so feature code (classifier, symbolic, semantic cache, …) only
3//! needs to hold a single handle.
4//!
5//! The runtime is detached from [`crate::runtime::RedDBRuntime`]: it
6//! can be constructed standalone (in-memory) for tests or bound to a
7//! shared [`MlPersistence`] backend for durable deployments. A
8//! future sprint will add a `RedDBRuntime::ml()` accessor that
9//! returns the bound instance — this module provides the pieces it
10//! will wire up.
11
12use std::sync::Arc;
13
14use super::persist::{InMemoryMlPersistence, MlPersistence};
15use super::queue::{MlJobQueue, MlWorkFn};
16use super::registry::ModelRegistry;
17
18/// Shared entrypoint used by every ML feature.
19///
20/// Cloning the runtime is cheap — registry and queue both wrap
21/// `Arc`s internally.
22#[derive(Clone)]
23pub struct MlRuntime {
24    registry: ModelRegistry,
25    queue: MlJobQueue,
26    backend: Arc<dyn MlPersistence>,
27}
28
29impl std::fmt::Debug for MlRuntime {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("MlRuntime")
32            .field("registry", &self.registry)
33            .field("queue", &self.queue)
34            .finish()
35    }
36}
37
38/// Compile-time defaults for a standalone runtime. Production
39/// callers pass an [`MlRuntimeConfig`] with their own worker count.
40#[derive(Debug, Clone)]
41pub struct MlRuntimeConfig {
42    pub worker_count: usize,
43}
44
45impl Default for MlRuntimeConfig {
46    fn default() -> Self {
47        Self {
48            worker_count: default_worker_count(),
49        }
50    }
51}
52
53fn default_worker_count() -> usize {
54    let logical = std::thread::available_parallelism()
55        .map(|n| n.get())
56        .unwrap_or(2);
57    // Leave one core for OLTP; guarantee at least one worker.
58    logical.saturating_sub(1).max(1)
59}
60
61impl MlRuntime {
62    /// Build a fully in-memory runtime. Jobs and versions disappear
63    /// on drop — good for unit tests, bad for production.
64    pub fn in_memory(worker_fn: MlWorkFn) -> Self {
65        Self::with_backend(
66            Arc::new(InMemoryMlPersistence::new()),
67            worker_fn,
68            MlRuntimeConfig::default(),
69        )
70    }
71
72    /// Build a runtime that persists registry + job state into
73    /// `backend`. On construction the registry and queue rehydrate
74    /// automatically so prior state is observable immediately.
75    pub fn with_backend(
76        backend: Arc<dyn MlPersistence>,
77        worker_fn: MlWorkFn,
78        config: MlRuntimeConfig,
79    ) -> Self {
80        let registry = ModelRegistry::with_backend(Arc::clone(&backend));
81        let queue =
82            MlJobQueue::start_with_backend(config.worker_count, worker_fn, Arc::clone(&backend));
83        Self {
84            registry,
85            queue,
86            backend,
87        }
88    }
89
90    pub fn registry(&self) -> &ModelRegistry {
91        &self.registry
92    }
93
94    pub fn queue(&self) -> &MlJobQueue {
95        &self.queue
96    }
97
98    /// Access the raw persistence backend — used by features that
99    /// need their own namespace (e.g. semantic cache stats).
100    pub fn backend(&self) -> &Arc<dyn MlPersistence> {
101        &self.backend
102    }
103
104    /// Stop worker threads. Idempotent — safe to call more than
105    /// once; subsequent calls are no-ops.
106    pub fn shutdown(&self) {
107        self.queue.shutdown();
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::super::jobs::{MlJobKind, MlJobStatus};
114    use super::*;
115    use std::time::{Duration, Instant};
116
117    fn wait_until<F: Fn() -> bool>(predicate: F, timeout: Duration) -> bool {
118        let deadline = Instant::now() + timeout;
119        while Instant::now() < deadline {
120            if predicate() {
121                return true;
122            }
123            std::thread::sleep(Duration::from_millis(5));
124        }
125        predicate()
126    }
127
128    #[test]
129    fn in_memory_runtime_runs_a_training_job() {
130        let rt = MlRuntime::in_memory(Arc::new(|_| Ok("{\"ok\":true}".to_string())));
131        let id = rt.queue().submit(MlJobKind::Train, "spam", "{}");
132        assert!(wait_until(
133            || rt
134                .queue()
135                .get(id)
136                .map(|j| j.status == MlJobStatus::Completed)
137                .unwrap_or(false),
138            Duration::from_secs(2),
139        ));
140        rt.shutdown();
141    }
142
143    #[test]
144    fn runtime_exposes_registry() {
145        let rt = MlRuntime::in_memory(Arc::new(|_| Ok("{}".to_string())));
146        assert_eq!(rt.registry().summaries().unwrap().len(), 0);
147        rt.shutdown();
148    }
149}