reddb_server/storage/ml/
runtime.rs1use std::sync::Arc;
13
14use super::persist::{InMemoryMlPersistence, MlPersistence};
15use super::queue::{MlJobQueue, MlWorkFn};
16use super::registry::ModelRegistry;
17
18#[derive(Clone)]
23pub struct MlRuntime {
24 registry: ModelRegistry,
25 queue: MlJobQueue,
26 backend: Arc<dyn MlPersistence>,
27}
28
29impl std::fmt::Debug for MlRuntime {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("MlRuntime")
32 .field("registry", &self.registry)
33 .field("queue", &self.queue)
34 .finish()
35 }
36}
37
38#[derive(Debug, Clone)]
41pub struct MlRuntimeConfig {
42 pub worker_count: usize,
43}
44
45impl Default for MlRuntimeConfig {
46 fn default() -> Self {
47 Self {
48 worker_count: default_worker_count(),
49 }
50 }
51}
52
53fn default_worker_count() -> usize {
54 let logical = std::thread::available_parallelism()
55 .map(|n| n.get())
56 .unwrap_or(2);
57 logical.saturating_sub(1).max(1)
59}
60
61impl MlRuntime {
62 pub fn in_memory(worker_fn: MlWorkFn) -> Self {
65 Self::with_backend(
66 Arc::new(InMemoryMlPersistence::new()),
67 worker_fn,
68 MlRuntimeConfig::default(),
69 )
70 }
71
72 pub fn with_backend(
76 backend: Arc<dyn MlPersistence>,
77 worker_fn: MlWorkFn,
78 config: MlRuntimeConfig,
79 ) -> Self {
80 let registry = ModelRegistry::with_backend(Arc::clone(&backend));
81 let queue =
82 MlJobQueue::start_with_backend(config.worker_count, worker_fn, Arc::clone(&backend));
83 Self {
84 registry,
85 queue,
86 backend,
87 }
88 }
89
90 pub fn registry(&self) -> &ModelRegistry {
91 &self.registry
92 }
93
94 pub fn queue(&self) -> &MlJobQueue {
95 &self.queue
96 }
97
98 pub fn backend(&self) -> &Arc<dyn MlPersistence> {
101 &self.backend
102 }
103
104 pub fn shutdown(&self) {
107 self.queue.shutdown();
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::super::jobs::{MlJobKind, MlJobStatus};
114 use super::*;
115 use std::time::{Duration, Instant};
116
117 fn wait_until<F: Fn() -> bool>(predicate: F, timeout: Duration) -> bool {
118 let deadline = Instant::now() + timeout;
119 while Instant::now() < deadline {
120 if predicate() {
121 return true;
122 }
123 std::thread::sleep(Duration::from_millis(5));
124 }
125 predicate()
126 }
127
128 #[test]
129 fn in_memory_runtime_runs_a_training_job() {
130 let rt = MlRuntime::in_memory(Arc::new(|_| Ok("{\"ok\":true}".to_string())));
131 let id = rt.queue().submit(MlJobKind::Train, "spam", "{}");
132 assert!(wait_until(
133 || rt
134 .queue()
135 .get(id)
136 .map(|j| j.status == MlJobStatus::Completed)
137 .unwrap_or(false),
138 Duration::from_secs(2),
139 ));
140 rt.shutdown();
141 }
142
143 #[test]
144 fn runtime_exposes_registry() {
145 let rt = MlRuntime::in_memory(Arc::new(|_| Ok("{}".to_string())));
146 assert_eq!(rt.registry().summaries().unwrap().len(), 0);
147 rt.shutdown();
148 }
149}