Skip to main content

reddb_server/storage/ml/
queue.rs

1//! Async ML job queue — FIFO queue + worker pool + job-state registry.
2//!
3//! Callers `submit()` a job and get back an [`MlJobId`] immediately.
4//! Worker threads pick jobs off the queue, invoke the caller-supplied
5//! [`MlWorkFn`] to perform the actual work, and record progress +
6//! terminal status back into the queue's job table.
7//!
8//! Cancellation is cooperative: setting a job to [`MlJobStatus::Cancelled`]
9//! flips a flag that the worker polls between checkpoints. Workers
10//! that never check the flag cannot be forcibly killed — this is a
11//! deliberate trade-off (no unsafe thread termination, clean up is
12//! the algorithm's responsibility).
13
14use std::collections::VecDeque;
15use std::sync::{Arc, Condvar, Mutex};
16use std::thread::{self, JoinHandle};
17
18use super::jobs::{now_ms, MlJob, MlJobId, MlJobKind, MlJobStatus};
19use super::persist::{key, ns, MlPersistence};
20
21/// Callback invoked on a worker thread to perform the actual work.
22///
23/// The closure receives a [`JobHandle`] it uses to update progress
24/// and to check cancellation. It returns `Ok(metrics_json)` on
25/// success (which will be stored on the job record) or `Err(msg)` on
26/// failure (surfaced as `error_message`).
27pub type MlWorkFn = Arc<dyn Fn(JobHandle) -> Result<String, String> + Send + Sync>;
28
29/// Handle passed into an [`MlWorkFn`]. The worker uses it to report
30/// progress and observe cancellation — no other mutations are
31/// possible, which keeps contracts small.
32#[derive(Clone)]
33pub struct JobHandle {
34    id: MlJobId,
35    shared: Arc<Shared>,
36}
37
38impl JobHandle {
39    pub fn id(&self) -> MlJobId {
40        self.id
41    }
42
43    /// Update the `progress` field (0..=100). Values > 100 are
44    /// clamped. Non-monotonic updates are allowed — workers that
45    /// retry a checkpoint can move progress backwards.
46    pub fn set_progress(&self, progress: u8) {
47        let snapshot = {
48            let mut guard = match self.shared.state.lock() {
49                Ok(g) => g,
50                Err(p) => p.into_inner(),
51            };
52            if let Some(job) = find_job_mut(&mut guard.jobs, self.id) {
53                if !job.is_terminal() {
54                    job.progress = progress.min(100);
55                    Some(job.clone())
56                } else {
57                    None
58                }
59            } else {
60                None
61            }
62        };
63        if let Some(job) = snapshot {
64            persist_job(&self.shared, &job);
65        }
66    }
67
68    /// True when the operator has requested cancellation. Workers
69    /// should poll this at safe boundaries (per batch / per
70    /// generation) and return promptly on a positive.
71    pub fn is_cancelled(&self) -> bool {
72        let guard = match self.shared.state.lock() {
73            Ok(g) => g,
74            Err(p) => p.into_inner(),
75        };
76        guard
77            .jobs
78            .iter()
79            .find(|j| j.id == self.id)
80            .map(|j| j.status == MlJobStatus::Cancelled)
81            .unwrap_or(false)
82    }
83}
84
85struct Shared {
86    state: Mutex<QueueState>,
87    signal: Condvar,
88    backend: Option<Arc<dyn MlPersistence>>,
89}
90
91impl std::fmt::Debug for Shared {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("Shared")
94            .field("has_backend", &self.backend.is_some())
95            .finish()
96    }
97}
98
99fn persist_job(shared: &Arc<Shared>, job: &MlJob) {
100    let Some(backend) = shared.backend.as_ref() else {
101        return;
102    };
103    let raw = job.to_json();
104    let _ = backend.put(ns::JOBS, &key::job(job.id), &raw);
105}
106
107#[derive(Debug)]
108struct QueueState {
109    /// Pending job ids ordered FIFO.
110    pending: VecDeque<MlJobId>,
111    /// All jobs known to the queue, terminal or not. Callers list /
112    /// inspect through this vec.
113    jobs: Vec<MlJob>,
114    /// True once `shutdown()` has been called.
115    shutting_down: bool,
116    /// Monotonic id counter. u128 so replicas can eventually mint
117    /// ids without coordination.
118    next_id: u128,
119}
120
121/// Queue + worker pool pair. Safe to clone — every clone shares the
122/// underlying queue via `Arc`.
123#[derive(Clone)]
124pub struct MlJobQueue {
125    shared: Arc<Shared>,
126    worker_fn: MlWorkFn,
127    workers: Arc<Mutex<Vec<JoinHandle<()>>>>,
128}
129
130impl std::fmt::Debug for MlJobQueue {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.debug_struct("MlJobQueue")
133            .field(
134                "worker_count",
135                &self.workers.lock().map(|w| w.len()).unwrap_or(0),
136            )
137            .finish()
138    }
139}
140
141impl MlJobQueue {
142    /// Spin up a queue with `worker_count` threads. `worker_fn` is
143    /// invoked once per job. No durable backend — see
144    /// [`Self::start_with_backend`] for the persisted variant.
145    pub fn start(worker_count: usize, worker_fn: MlWorkFn) -> Self {
146        Self::start_with(worker_count, worker_fn, None)
147    }
148
149    /// Spin up a queue that persists every state transition to
150    /// `backend`. On startup the queue rehydrates any non-terminal
151    /// records and re-enqueues them (they were Running when the
152    /// previous process died — treat as Queued on resume).
153    pub fn start_with_backend(
154        worker_count: usize,
155        worker_fn: MlWorkFn,
156        backend: Arc<dyn MlPersistence>,
157    ) -> Self {
158        Self::start_with(worker_count, worker_fn, Some(backend))
159    }
160
161    fn start_with(
162        worker_count: usize,
163        worker_fn: MlWorkFn,
164        backend: Option<Arc<dyn MlPersistence>>,
165    ) -> Self {
166        // Rehydrate first so `next_id` is strictly greater than every
167        // previously-issued id — collision-free across restarts.
168        let mut initial_jobs: Vec<MlJob> = Vec::new();
169        let mut initial_pending: VecDeque<MlJobId> = VecDeque::new();
170        let mut resume_next_id: u128 = 1;
171        if let Some(be) = backend.as_ref() {
172            if let Ok(rows) = be.list(ns::JOBS) {
173                for (_, raw) in rows {
174                    let Some(mut job) = MlJob::from_json(&raw) else {
175                        continue;
176                    };
177                    // A Running job from a prior process is now stuck —
178                    // requeue it so a worker can pick it up. Progress
179                    // resets to zero so the operator can tell it's a
180                    // resumed job from `SELECT * FROM ML_JOBS`.
181                    if job.status == MlJobStatus::Running {
182                        job.status = MlJobStatus::Queued;
183                        job.progress = 0;
184                        job.started_at_ms = 0;
185                    }
186                    if job.status == MlJobStatus::Queued {
187                        initial_pending.push_back(job.id);
188                    }
189                    resume_next_id = resume_next_id.max(job.id.saturating_add(1));
190                    initial_jobs.push(job);
191                }
192            }
193        }
194
195        let shared = Arc::new(Shared {
196            state: Mutex::new(QueueState {
197                pending: initial_pending,
198                jobs: initial_jobs.clone(),
199                shutting_down: false,
200                next_id: resume_next_id,
201            }),
202            signal: Condvar::new(),
203            backend,
204        });
205
206        // Flush rehydrated pending-state back to the backend so the
207        // status change (Running → Queued) is durable.
208        for job in &initial_jobs {
209            if job.status == MlJobStatus::Queued {
210                persist_job(&shared, job);
211            }
212        }
213
214        let workers = Arc::new(Mutex::new(Vec::with_capacity(worker_count.max(1))));
215        let queue = MlJobQueue {
216            shared: Arc::clone(&shared),
217            worker_fn: Arc::clone(&worker_fn),
218            workers: Arc::clone(&workers),
219        };
220        for _ in 0..worker_count.max(1) {
221            let shared_w = Arc::clone(&shared);
222            let worker_fn_w = Arc::clone(&worker_fn);
223            let handle = thread::spawn(move || worker_loop(shared_w, worker_fn_w));
224            if let Ok(mut guard) = workers.lock() {
225                guard.push(handle);
226            }
227        }
228        queue
229    }
230
231    /// Enqueue a new job. Returns the assigned id so the caller can
232    /// poll status later.
233    pub fn submit(
234        &self,
235        kind: MlJobKind,
236        target_name: impl Into<String>,
237        spec_json: impl Into<String>,
238    ) -> MlJobId {
239        let snapshot = {
240            let mut guard = match self.shared.state.lock() {
241                Ok(g) => g,
242                Err(p) => p.into_inner(),
243            };
244            let id = guard.next_id;
245            guard.next_id = guard.next_id.saturating_add(1);
246            let job = MlJob::new(id, kind, target_name.into(), spec_json.into());
247            let snapshot = job.clone();
248            guard.jobs.push(job);
249            guard.pending.push_back(id);
250            snapshot
251        };
252        persist_job(&self.shared, &snapshot);
253        self.shared.signal.notify_one();
254        snapshot.id
255    }
256
257    /// Fetch a job by id.
258    pub fn get(&self, id: MlJobId) -> Option<MlJob> {
259        let guard = match self.shared.state.lock() {
260            Ok(g) => g,
261            Err(p) => p.into_inner(),
262        };
263        guard.jobs.iter().find(|j| j.id == id).cloned()
264    }
265
266    /// Snapshot every job (terminal + live). Callers use this to
267    /// back `SELECT * FROM ML_JOBS`.
268    pub fn list(&self) -> Vec<MlJob> {
269        let guard = match self.shared.state.lock() {
270            Ok(g) => g,
271            Err(p) => p.into_inner(),
272        };
273        guard.jobs.clone()
274    }
275
276    /// Request cooperative cancellation. Returns `true` if the job
277    /// was still cancellable, `false` if it had already reached a
278    /// terminal state or does not exist.
279    pub fn cancel(&self, id: MlJobId) -> bool {
280        let snapshot = {
281            let mut guard = match self.shared.state.lock() {
282                Ok(g) => g,
283                Err(p) => p.into_inner(),
284            };
285            let Some(job) = find_job_mut(&mut guard.jobs, id) else {
286                return false;
287            };
288            if job.is_terminal() {
289                return false;
290            }
291            let was_queued = job.status == MlJobStatus::Queued;
292            job.status = MlJobStatus::Cancelled;
293            job.finished_at_ms = now_ms();
294            let snapshot = job.clone();
295            if was_queued {
296                // Drop from pending so no worker picks it up; workers
297                // already running observe `is_cancelled()` themselves.
298                guard.pending.retain(|pid| *pid != id);
299            }
300            snapshot
301        };
302        persist_job(&self.shared, &snapshot);
303        true
304    }
305
306    /// Stop every worker thread after they finish their current job.
307    /// Pending jobs are left in the queue — a future process start
308    /// would pick them up once persistence is wired.
309    pub fn shutdown(&self) {
310        {
311            let mut guard = match self.shared.state.lock() {
312                Ok(g) => g,
313                Err(p) => p.into_inner(),
314            };
315            guard.shutting_down = true;
316        }
317        self.shared.signal.notify_all();
318        let Ok(mut workers) = self.workers.lock() else {
319            return;
320        };
321        for handle in workers.drain(..) {
322            let _ = handle.join();
323        }
324    }
325}
326
327fn find_job_mut(jobs: &mut [MlJob], id: MlJobId) -> Option<&mut MlJob> {
328    jobs.iter_mut().find(|j| j.id == id)
329}
330
331fn worker_loop(shared: Arc<Shared>, worker_fn: MlWorkFn) {
332    loop {
333        // Claim the next queued job, marking it running in the same
334        // critical section so two workers can't pick the same one.
335        let (next_id, running_snapshot) = {
336            let guard = match shared.state.lock() {
337                Ok(g) => g,
338                Err(p) => p.into_inner(),
339            };
340            let mut guard = match shared
341                .signal
342                .wait_while(guard, |s| s.pending.is_empty() && !s.shutting_down)
343            {
344                Ok(g) => g,
345                Err(p) => p.into_inner(),
346            };
347            if guard.shutting_down && guard.pending.is_empty() {
348                return;
349            }
350            let id = match guard.pending.pop_front() {
351                Some(id) => id,
352                None => continue,
353            };
354            let mut snapshot = None;
355            if let Some(job) = find_job_mut(&mut guard.jobs, id) {
356                // A cancel-while-queued slipped through between the
357                // wait and the pop; skip the work.
358                if job.status == MlJobStatus::Cancelled {
359                    continue;
360                }
361                job.status = MlJobStatus::Running;
362                job.started_at_ms = now_ms();
363                snapshot = Some(job.clone());
364            }
365            (id, snapshot)
366        };
367        if let Some(job) = running_snapshot {
368            persist_job(&shared, &job);
369        }
370
371        let handle = JobHandle {
372            id: next_id,
373            shared: Arc::clone(&shared),
374        };
375        let outcome = (worker_fn)(handle);
376
377        let post_snapshot = {
378            let mut guard = match shared.state.lock() {
379                Ok(g) => g,
380                Err(p) => p.into_inner(),
381            };
382            if let Some(job) = find_job_mut(&mut guard.jobs, next_id) {
383                // The operator may have cancelled mid-flight — respect
384                // that state rather than overwriting it.
385                if job.status == MlJobStatus::Cancelled {
386                    if job.finished_at_ms == 0 {
387                        job.finished_at_ms = now_ms();
388                    }
389                    Some(job.clone())
390                } else {
391                    match outcome {
392                        Ok(metrics) => {
393                            job.status = MlJobStatus::Completed;
394                            job.progress = 100;
395                            job.metrics_json = Some(metrics);
396                        }
397                        Err(err) => {
398                            job.status = MlJobStatus::Failed;
399                            job.error_message = Some(err);
400                        }
401                    }
402                    job.finished_at_ms = now_ms();
403                    Some(job.clone())
404                }
405            } else {
406                None
407            }
408        };
409        if let Some(job) = post_snapshot {
410            persist_job(&shared, &job);
411        }
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use std::sync::atomic::{AtomicUsize, Ordering};
419    use std::time::{Duration, Instant};
420
421    fn wait_until<F: Fn() -> bool>(predicate: F, timeout: Duration) -> bool {
422        let deadline = Instant::now() + timeout;
423        while Instant::now() < deadline {
424            if predicate() {
425                return true;
426            }
427            thread::sleep(Duration::from_millis(5));
428        }
429        predicate()
430    }
431
432    #[test]
433    fn submit_and_run_to_completion() {
434        let counter = Arc::new(AtomicUsize::new(0));
435        let counter_w = Arc::clone(&counter);
436        let q = MlJobQueue::start(
437            1,
438            Arc::new(move |handle| {
439                counter_w.fetch_add(1, Ordering::SeqCst);
440                handle.set_progress(50);
441                handle.set_progress(100);
442                Ok("{\"ok\":true}".to_string())
443            }),
444        );
445        let id = q.submit(MlJobKind::Train, "spam", "{}");
446        assert!(wait_until(
447            || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
448            Duration::from_secs(2),
449        ));
450        let job = q.get(id).unwrap();
451        assert_eq!(job.status, MlJobStatus::Completed);
452        assert_eq!(job.progress, 100);
453        assert_eq!(job.metrics_json.as_deref(), Some("{\"ok\":true}"));
454        assert_eq!(counter.load(Ordering::SeqCst), 1);
455        q.shutdown();
456    }
457
458    #[test]
459    fn failed_work_records_error() {
460        let q = MlJobQueue::start(1, Arc::new(|_| Err("bad hyperparameters".to_string())));
461        let id = q.submit(MlJobKind::Train, "spam", "{}");
462        assert!(wait_until(
463            || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
464            Duration::from_secs(2),
465        ));
466        let job = q.get(id).unwrap();
467        assert_eq!(job.status, MlJobStatus::Failed);
468        assert_eq!(job.error_message.as_deref(), Some("bad hyperparameters"));
469        q.shutdown();
470    }
471
472    #[test]
473    fn cancel_while_queued_prevents_execution() {
474        let ran = Arc::new(AtomicUsize::new(0));
475        let ran_w = Arc::clone(&ran);
476        // One worker, occupied by a long job to force queueing.
477        let q = MlJobQueue::start(
478            1,
479            Arc::new(move |handle| {
480                if handle.id() == 1 {
481                    // Hold the first job long enough for #2 to sit queued.
482                    thread::sleep(Duration::from_millis(100));
483                } else {
484                    ran_w.fetch_add(1, Ordering::SeqCst);
485                }
486                Ok("{}".to_string())
487            }),
488        );
489        let _first = q.submit(MlJobKind::Train, "a", "{}");
490        let second = q.submit(MlJobKind::Train, "b", "{}");
491        assert!(q.cancel(second));
492        thread::sleep(Duration::from_millis(250));
493        let job = q.get(second).unwrap();
494        assert_eq!(job.status, MlJobStatus::Cancelled);
495        assert_eq!(ran.load(Ordering::SeqCst), 0, "cancelled job must not run");
496        q.shutdown();
497    }
498
499    #[test]
500    fn cancel_after_terminal_is_noop() {
501        let q = MlJobQueue::start(1, Arc::new(|_| Ok("{}".to_string())));
502        let id = q.submit(MlJobKind::Train, "x", "{}");
503        assert!(wait_until(
504            || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
505            Duration::from_secs(2),
506        ));
507        assert!(!q.cancel(id));
508        q.shutdown();
509    }
510
511    #[test]
512    fn cooperative_cancellation_observed_by_worker() {
513        // Uses a barrier-like counter so the main thread can see the
514        // worker actually ran, then cancels and waits for the worker
515        // to observe the flag. `cancel()` flips terminal status
516        // immediately, so we cannot poll `is_terminal()` to prove the
517        // worker co-operated — we poll the observation counter.
518        let observed = Arc::new(AtomicUsize::new(0));
519        let iters = Arc::new(AtomicUsize::new(0));
520        let observed_w = Arc::clone(&observed);
521        let iters_w = Arc::clone(&iters);
522        let q = MlJobQueue::start(
523            1,
524            Arc::new(move |handle| {
525                for _ in 0..200 {
526                    iters_w.fetch_add(1, Ordering::SeqCst);
527                    if handle.is_cancelled() {
528                        observed_w.fetch_add(1, Ordering::SeqCst);
529                        return Err("cancelled".to_string());
530                    }
531                    handle.set_progress(10);
532                    thread::sleep(Duration::from_millis(5));
533                }
534                Ok("{}".to_string())
535            }),
536        );
537        let id = q.submit(MlJobKind::Train, "slow", "{}");
538        assert!(wait_until(
539            || iters.load(Ordering::SeqCst) > 0,
540            Duration::from_secs(2),
541        ));
542        assert!(q.cancel(id));
543        assert!(wait_until(
544            || observed.load(Ordering::SeqCst) >= 1,
545            Duration::from_secs(2),
546        ));
547        let job = q.get(id).unwrap();
548        assert_eq!(job.status, MlJobStatus::Cancelled);
549        q.shutdown();
550    }
551
552    #[test]
553    fn backend_persists_submit_and_completion() {
554        use super::super::persist::InMemoryMlPersistence;
555        let backend = Arc::new(InMemoryMlPersistence::new());
556        let q = MlJobQueue::start_with_backend(
557            1,
558            Arc::new(|_| Ok("{\"f1\":0.9}".to_string())),
559            backend.clone(),
560        );
561        let id = q.submit(MlJobKind::Train, "spam", "{}");
562        assert!(wait_until(
563            || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
564            Duration::from_secs(2),
565        ));
566        // Raw record must exist and must reflect the completed status.
567        let raw = backend
568            .get(super::ns::JOBS, &super::key::job(id))
569            .unwrap()
570            .unwrap();
571        let decoded = MlJob::from_json(&raw).unwrap();
572        assert_eq!(decoded.status, MlJobStatus::Completed);
573        assert_eq!(decoded.metrics_json.as_deref(), Some("{\"f1\":0.9}"));
574        q.shutdown();
575    }
576
577    #[test]
578    fn resume_from_backend_requeues_running_jobs() {
579        use super::super::persist::InMemoryMlPersistence;
580        let backend: Arc<dyn super::MlPersistence> = Arc::new(InMemoryMlPersistence::new());
581
582        // Simulate a prior process: one queued + one running + one
583        // completed job in the store.
584        let pending = MlJob {
585            id: 5,
586            kind: MlJobKind::Train,
587            target_name: "a".into(),
588            status: MlJobStatus::Queued,
589            progress: 0,
590            created_at_ms: 1,
591            started_at_ms: 0,
592            finished_at_ms: 0,
593            error_message: None,
594            spec_json: "{}".into(),
595            metrics_json: None,
596        };
597        let stuck = MlJob {
598            id: 6,
599            kind: MlJobKind::Train,
600            target_name: "b".into(),
601            status: MlJobStatus::Running,
602            progress: 40,
603            created_at_ms: 2,
604            started_at_ms: 3,
605            finished_at_ms: 0,
606            error_message: None,
607            spec_json: "{}".into(),
608            metrics_json: None,
609        };
610        let done = MlJob {
611            id: 7,
612            kind: MlJobKind::Train,
613            target_name: "c".into(),
614            status: MlJobStatus::Completed,
615            progress: 100,
616            created_at_ms: 3,
617            started_at_ms: 4,
618            finished_at_ms: 5,
619            error_message: None,
620            spec_json: "{}".into(),
621            metrics_json: Some("{}".into()),
622        };
623        for j in [&pending, &stuck, &done] {
624            backend
625                .put(super::ns::JOBS, &super::key::job(j.id), &j.to_json())
626                .unwrap();
627        }
628
629        let ran = Arc::new(AtomicUsize::new(0));
630        let ran_w = Arc::clone(&ran);
631        let q = MlJobQueue::start_with_backend(
632            2,
633            Arc::new(move |_| {
634                ran_w.fetch_add(1, Ordering::SeqCst);
635                Ok("{}".to_string())
636            }),
637            backend.clone(),
638        );
639
640        assert!(wait_until(
641            || ran.load(Ordering::SeqCst) >= 2,
642            Duration::from_secs(3),
643        ));
644        // Both previously non-terminal jobs were re-run; the completed
645        // one stayed as-is.
646        assert_eq!(q.get(5).unwrap().status, MlJobStatus::Completed);
647        assert_eq!(q.get(6).unwrap().status, MlJobStatus::Completed);
648        assert_eq!(q.get(7).unwrap().status, MlJobStatus::Completed);
649
650        // next_id must have advanced past the largest resumed id.
651        let fresh_id = q.submit(MlJobKind::Train, "d", "{}");
652        assert!(fresh_id > 7);
653
654        q.shutdown();
655    }
656
657    #[test]
658    fn multiple_workers_drain_backlog() {
659        let q = MlJobQueue::start(
660            3,
661            Arc::new(|handle| {
662                handle.set_progress(50);
663                thread::sleep(Duration::from_millis(20));
664                Ok("{}".to_string())
665            }),
666        );
667        let ids: Vec<_> = (0..20)
668            .map(|i| q.submit(MlJobKind::Train, format!("m{i}"), "{}"))
669            .collect();
670        assert!(wait_until(
671            || ids
672                .iter()
673                .all(|id| q.get(*id).map(|j| j.is_terminal()).unwrap_or(false)),
674            Duration::from_secs(5),
675        ));
676        for id in ids {
677            assert_eq!(q.get(id).unwrap().status, MlJobStatus::Completed);
678        }
679        q.shutdown();
680    }
681}