1use std::collections::VecDeque;
15use std::sync::{Arc, Condvar, Mutex};
16use std::thread::{self, JoinHandle};
17
18use super::jobs::{now_ms, MlJob, MlJobId, MlJobKind, MlJobStatus};
19use super::persist::{key, ns, MlPersistence};
20
21pub type MlWorkFn = Arc<dyn Fn(JobHandle) -> Result<String, String> + Send + Sync>;
28
29#[derive(Clone)]
33pub struct JobHandle {
34 id: MlJobId,
35 shared: Arc<Shared>,
36}
37
38impl JobHandle {
39 pub fn id(&self) -> MlJobId {
40 self.id
41 }
42
43 pub fn set_progress(&self, progress: u8) {
47 let snapshot = {
48 let mut guard = match self.shared.state.lock() {
49 Ok(g) => g,
50 Err(p) => p.into_inner(),
51 };
52 if let Some(job) = find_job_mut(&mut guard.jobs, self.id) {
53 if !job.is_terminal() {
54 job.progress = progress.min(100);
55 Some(job.clone())
56 } else {
57 None
58 }
59 } else {
60 None
61 }
62 };
63 if let Some(job) = snapshot {
64 persist_job(&self.shared, &job);
65 }
66 }
67
68 pub fn is_cancelled(&self) -> bool {
72 let guard = match self.shared.state.lock() {
73 Ok(g) => g,
74 Err(p) => p.into_inner(),
75 };
76 guard
77 .jobs
78 .iter()
79 .find(|j| j.id == self.id)
80 .map(|j| j.status == MlJobStatus::Cancelled)
81 .unwrap_or(false)
82 }
83}
84
85struct Shared {
86 state: Mutex<QueueState>,
87 signal: Condvar,
88 backend: Option<Arc<dyn MlPersistence>>,
89}
90
91impl std::fmt::Debug for Shared {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("Shared")
94 .field("has_backend", &self.backend.is_some())
95 .finish()
96 }
97}
98
99fn persist_job(shared: &Arc<Shared>, job: &MlJob) {
100 let Some(backend) = shared.backend.as_ref() else {
101 return;
102 };
103 let raw = job.to_json();
104 let _ = backend.put(ns::JOBS, &key::job(job.id), &raw);
105}
106
107#[derive(Debug)]
108struct QueueState {
109 pending: VecDeque<MlJobId>,
111 jobs: Vec<MlJob>,
114 shutting_down: bool,
116 next_id: u128,
119}
120
121#[derive(Clone)]
124pub struct MlJobQueue {
125 shared: Arc<Shared>,
126 worker_fn: MlWorkFn,
127 workers: Arc<Mutex<Vec<JoinHandle<()>>>>,
128}
129
130impl std::fmt::Debug for MlJobQueue {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("MlJobQueue")
133 .field(
134 "worker_count",
135 &self.workers.lock().map(|w| w.len()).unwrap_or(0),
136 )
137 .finish()
138 }
139}
140
141impl MlJobQueue {
142 pub fn start(worker_count: usize, worker_fn: MlWorkFn) -> Self {
146 Self::start_with(worker_count, worker_fn, None)
147 }
148
149 pub fn start_with_backend(
154 worker_count: usize,
155 worker_fn: MlWorkFn,
156 backend: Arc<dyn MlPersistence>,
157 ) -> Self {
158 Self::start_with(worker_count, worker_fn, Some(backend))
159 }
160
161 fn start_with(
162 worker_count: usize,
163 worker_fn: MlWorkFn,
164 backend: Option<Arc<dyn MlPersistence>>,
165 ) -> Self {
166 let mut initial_jobs: Vec<MlJob> = Vec::new();
169 let mut initial_pending: VecDeque<MlJobId> = VecDeque::new();
170 let mut resume_next_id: u128 = 1;
171 if let Some(be) = backend.as_ref() {
172 if let Ok(rows) = be.list(ns::JOBS) {
173 for (_, raw) in rows {
174 let Some(mut job) = MlJob::from_json(&raw) else {
175 continue;
176 };
177 if job.status == MlJobStatus::Running {
182 job.status = MlJobStatus::Queued;
183 job.progress = 0;
184 job.started_at_ms = 0;
185 }
186 if job.status == MlJobStatus::Queued {
187 initial_pending.push_back(job.id);
188 }
189 resume_next_id = resume_next_id.max(job.id.saturating_add(1));
190 initial_jobs.push(job);
191 }
192 }
193 }
194
195 let shared = Arc::new(Shared {
196 state: Mutex::new(QueueState {
197 pending: initial_pending,
198 jobs: initial_jobs.clone(),
199 shutting_down: false,
200 next_id: resume_next_id,
201 }),
202 signal: Condvar::new(),
203 backend,
204 });
205
206 for job in &initial_jobs {
209 if job.status == MlJobStatus::Queued {
210 persist_job(&shared, job);
211 }
212 }
213
214 let workers = Arc::new(Mutex::new(Vec::with_capacity(worker_count.max(1))));
215 let queue = MlJobQueue {
216 shared: Arc::clone(&shared),
217 worker_fn: Arc::clone(&worker_fn),
218 workers: Arc::clone(&workers),
219 };
220 for _ in 0..worker_count.max(1) {
221 let shared_w = Arc::clone(&shared);
222 let worker_fn_w = Arc::clone(&worker_fn);
223 let handle = thread::spawn(move || worker_loop(shared_w, worker_fn_w));
224 if let Ok(mut guard) = workers.lock() {
225 guard.push(handle);
226 }
227 }
228 queue
229 }
230
231 pub fn submit(
234 &self,
235 kind: MlJobKind,
236 target_name: impl Into<String>,
237 spec_json: impl Into<String>,
238 ) -> MlJobId {
239 let snapshot = {
240 let mut guard = match self.shared.state.lock() {
241 Ok(g) => g,
242 Err(p) => p.into_inner(),
243 };
244 let id = guard.next_id;
245 guard.next_id = guard.next_id.saturating_add(1);
246 let job = MlJob::new(id, kind, target_name.into(), spec_json.into());
247 let snapshot = job.clone();
248 guard.jobs.push(job);
249 guard.pending.push_back(id);
250 snapshot
251 };
252 persist_job(&self.shared, &snapshot);
253 self.shared.signal.notify_one();
254 snapshot.id
255 }
256
257 pub fn get(&self, id: MlJobId) -> Option<MlJob> {
259 let guard = match self.shared.state.lock() {
260 Ok(g) => g,
261 Err(p) => p.into_inner(),
262 };
263 guard.jobs.iter().find(|j| j.id == id).cloned()
264 }
265
266 pub fn list(&self) -> Vec<MlJob> {
269 let guard = match self.shared.state.lock() {
270 Ok(g) => g,
271 Err(p) => p.into_inner(),
272 };
273 guard.jobs.clone()
274 }
275
276 pub fn cancel(&self, id: MlJobId) -> bool {
280 let snapshot = {
281 let mut guard = match self.shared.state.lock() {
282 Ok(g) => g,
283 Err(p) => p.into_inner(),
284 };
285 let Some(job) = find_job_mut(&mut guard.jobs, id) else {
286 return false;
287 };
288 if job.is_terminal() {
289 return false;
290 }
291 let was_queued = job.status == MlJobStatus::Queued;
292 job.status = MlJobStatus::Cancelled;
293 job.finished_at_ms = now_ms();
294 let snapshot = job.clone();
295 if was_queued {
296 guard.pending.retain(|pid| *pid != id);
299 }
300 snapshot
301 };
302 persist_job(&self.shared, &snapshot);
303 true
304 }
305
306 pub fn shutdown(&self) {
310 {
311 let mut guard = match self.shared.state.lock() {
312 Ok(g) => g,
313 Err(p) => p.into_inner(),
314 };
315 guard.shutting_down = true;
316 }
317 self.shared.signal.notify_all();
318 let Ok(mut workers) = self.workers.lock() else {
319 return;
320 };
321 for handle in workers.drain(..) {
322 let _ = handle.join();
323 }
324 }
325}
326
327fn find_job_mut(jobs: &mut [MlJob], id: MlJobId) -> Option<&mut MlJob> {
328 jobs.iter_mut().find(|j| j.id == id)
329}
330
331fn worker_loop(shared: Arc<Shared>, worker_fn: MlWorkFn) {
332 loop {
333 let (next_id, running_snapshot) = {
336 let guard = match shared.state.lock() {
337 Ok(g) => g,
338 Err(p) => p.into_inner(),
339 };
340 let mut guard = match shared
341 .signal
342 .wait_while(guard, |s| s.pending.is_empty() && !s.shutting_down)
343 {
344 Ok(g) => g,
345 Err(p) => p.into_inner(),
346 };
347 if guard.shutting_down && guard.pending.is_empty() {
348 return;
349 }
350 let id = match guard.pending.pop_front() {
351 Some(id) => id,
352 None => continue,
353 };
354 let mut snapshot = None;
355 if let Some(job) = find_job_mut(&mut guard.jobs, id) {
356 if job.status == MlJobStatus::Cancelled {
359 continue;
360 }
361 job.status = MlJobStatus::Running;
362 job.started_at_ms = now_ms();
363 snapshot = Some(job.clone());
364 }
365 (id, snapshot)
366 };
367 if let Some(job) = running_snapshot {
368 persist_job(&shared, &job);
369 }
370
371 let handle = JobHandle {
372 id: next_id,
373 shared: Arc::clone(&shared),
374 };
375 let outcome = (worker_fn)(handle);
376
377 let post_snapshot = {
378 let mut guard = match shared.state.lock() {
379 Ok(g) => g,
380 Err(p) => p.into_inner(),
381 };
382 if let Some(job) = find_job_mut(&mut guard.jobs, next_id) {
383 if job.status == MlJobStatus::Cancelled {
386 if job.finished_at_ms == 0 {
387 job.finished_at_ms = now_ms();
388 }
389 Some(job.clone())
390 } else {
391 match outcome {
392 Ok(metrics) => {
393 job.status = MlJobStatus::Completed;
394 job.progress = 100;
395 job.metrics_json = Some(metrics);
396 }
397 Err(err) => {
398 job.status = MlJobStatus::Failed;
399 job.error_message = Some(err);
400 }
401 }
402 job.finished_at_ms = now_ms();
403 Some(job.clone())
404 }
405 } else {
406 None
407 }
408 };
409 if let Some(job) = post_snapshot {
410 persist_job(&shared, &job);
411 }
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use std::sync::atomic::{AtomicUsize, Ordering};
419 use std::time::{Duration, Instant};
420
421 fn wait_until<F: Fn() -> bool>(predicate: F, timeout: Duration) -> bool {
422 let deadline = Instant::now() + timeout;
423 while Instant::now() < deadline {
424 if predicate() {
425 return true;
426 }
427 thread::sleep(Duration::from_millis(5));
428 }
429 predicate()
430 }
431
432 #[test]
433 fn submit_and_run_to_completion() {
434 let counter = Arc::new(AtomicUsize::new(0));
435 let counter_w = Arc::clone(&counter);
436 let q = MlJobQueue::start(
437 1,
438 Arc::new(move |handle| {
439 counter_w.fetch_add(1, Ordering::SeqCst);
440 handle.set_progress(50);
441 handle.set_progress(100);
442 Ok("{\"ok\":true}".to_string())
443 }),
444 );
445 let id = q.submit(MlJobKind::Train, "spam", "{}");
446 assert!(wait_until(
447 || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
448 Duration::from_secs(2),
449 ));
450 let job = q.get(id).unwrap();
451 assert_eq!(job.status, MlJobStatus::Completed);
452 assert_eq!(job.progress, 100);
453 assert_eq!(job.metrics_json.as_deref(), Some("{\"ok\":true}"));
454 assert_eq!(counter.load(Ordering::SeqCst), 1);
455 q.shutdown();
456 }
457
458 #[test]
459 fn failed_work_records_error() {
460 let q = MlJobQueue::start(1, Arc::new(|_| Err("bad hyperparameters".to_string())));
461 let id = q.submit(MlJobKind::Train, "spam", "{}");
462 assert!(wait_until(
463 || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
464 Duration::from_secs(2),
465 ));
466 let job = q.get(id).unwrap();
467 assert_eq!(job.status, MlJobStatus::Failed);
468 assert_eq!(job.error_message.as_deref(), Some("bad hyperparameters"));
469 q.shutdown();
470 }
471
472 #[test]
473 fn cancel_while_queued_prevents_execution() {
474 let ran = Arc::new(AtomicUsize::new(0));
475 let ran_w = Arc::clone(&ran);
476 let q = MlJobQueue::start(
478 1,
479 Arc::new(move |handle| {
480 if handle.id() == 1 {
481 thread::sleep(Duration::from_millis(100));
483 } else {
484 ran_w.fetch_add(1, Ordering::SeqCst);
485 }
486 Ok("{}".to_string())
487 }),
488 );
489 let _first = q.submit(MlJobKind::Train, "a", "{}");
490 let second = q.submit(MlJobKind::Train, "b", "{}");
491 assert!(q.cancel(second));
492 thread::sleep(Duration::from_millis(250));
493 let job = q.get(second).unwrap();
494 assert_eq!(job.status, MlJobStatus::Cancelled);
495 assert_eq!(ran.load(Ordering::SeqCst), 0, "cancelled job must not run");
496 q.shutdown();
497 }
498
499 #[test]
500 fn cancel_after_terminal_is_noop() {
501 let q = MlJobQueue::start(1, Arc::new(|_| Ok("{}".to_string())));
502 let id = q.submit(MlJobKind::Train, "x", "{}");
503 assert!(wait_until(
504 || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
505 Duration::from_secs(2),
506 ));
507 assert!(!q.cancel(id));
508 q.shutdown();
509 }
510
511 #[test]
512 fn cooperative_cancellation_observed_by_worker() {
513 let observed = Arc::new(AtomicUsize::new(0));
519 let iters = Arc::new(AtomicUsize::new(0));
520 let observed_w = Arc::clone(&observed);
521 let iters_w = Arc::clone(&iters);
522 let q = MlJobQueue::start(
523 1,
524 Arc::new(move |handle| {
525 for _ in 0..200 {
526 iters_w.fetch_add(1, Ordering::SeqCst);
527 if handle.is_cancelled() {
528 observed_w.fetch_add(1, Ordering::SeqCst);
529 return Err("cancelled".to_string());
530 }
531 handle.set_progress(10);
532 thread::sleep(Duration::from_millis(5));
533 }
534 Ok("{}".to_string())
535 }),
536 );
537 let id = q.submit(MlJobKind::Train, "slow", "{}");
538 assert!(wait_until(
539 || iters.load(Ordering::SeqCst) > 0,
540 Duration::from_secs(2),
541 ));
542 assert!(q.cancel(id));
543 assert!(wait_until(
544 || observed.load(Ordering::SeqCst) >= 1,
545 Duration::from_secs(2),
546 ));
547 let job = q.get(id).unwrap();
548 assert_eq!(job.status, MlJobStatus::Cancelled);
549 q.shutdown();
550 }
551
552 #[test]
553 fn backend_persists_submit_and_completion() {
554 use super::super::persist::InMemoryMlPersistence;
555 let backend = Arc::new(InMemoryMlPersistence::new());
556 let q = MlJobQueue::start_with_backend(
557 1,
558 Arc::new(|_| Ok("{\"f1\":0.9}".to_string())),
559 backend.clone(),
560 );
561 let id = q.submit(MlJobKind::Train, "spam", "{}");
562 assert!(wait_until(
563 || q.get(id).map(|j| j.is_terminal()).unwrap_or(false),
564 Duration::from_secs(2),
565 ));
566 let raw = backend
568 .get(super::ns::JOBS, &super::key::job(id))
569 .unwrap()
570 .unwrap();
571 let decoded = MlJob::from_json(&raw).unwrap();
572 assert_eq!(decoded.status, MlJobStatus::Completed);
573 assert_eq!(decoded.metrics_json.as_deref(), Some("{\"f1\":0.9}"));
574 q.shutdown();
575 }
576
577 #[test]
578 fn resume_from_backend_requeues_running_jobs() {
579 use super::super::persist::InMemoryMlPersistence;
580 let backend: Arc<dyn super::MlPersistence> = Arc::new(InMemoryMlPersistence::new());
581
582 let pending = MlJob {
585 id: 5,
586 kind: MlJobKind::Train,
587 target_name: "a".into(),
588 status: MlJobStatus::Queued,
589 progress: 0,
590 created_at_ms: 1,
591 started_at_ms: 0,
592 finished_at_ms: 0,
593 error_message: None,
594 spec_json: "{}".into(),
595 metrics_json: None,
596 };
597 let stuck = MlJob {
598 id: 6,
599 kind: MlJobKind::Train,
600 target_name: "b".into(),
601 status: MlJobStatus::Running,
602 progress: 40,
603 created_at_ms: 2,
604 started_at_ms: 3,
605 finished_at_ms: 0,
606 error_message: None,
607 spec_json: "{}".into(),
608 metrics_json: None,
609 };
610 let done = MlJob {
611 id: 7,
612 kind: MlJobKind::Train,
613 target_name: "c".into(),
614 status: MlJobStatus::Completed,
615 progress: 100,
616 created_at_ms: 3,
617 started_at_ms: 4,
618 finished_at_ms: 5,
619 error_message: None,
620 spec_json: "{}".into(),
621 metrics_json: Some("{}".into()),
622 };
623 for j in [&pending, &stuck, &done] {
624 backend
625 .put(super::ns::JOBS, &super::key::job(j.id), &j.to_json())
626 .unwrap();
627 }
628
629 let ran = Arc::new(AtomicUsize::new(0));
630 let ran_w = Arc::clone(&ran);
631 let q = MlJobQueue::start_with_backend(
632 2,
633 Arc::new(move |_| {
634 ran_w.fetch_add(1, Ordering::SeqCst);
635 Ok("{}".to_string())
636 }),
637 backend.clone(),
638 );
639
640 assert!(wait_until(
641 || ran.load(Ordering::SeqCst) >= 2,
642 Duration::from_secs(3),
643 ));
644 assert_eq!(q.get(5).unwrap().status, MlJobStatus::Completed);
647 assert_eq!(q.get(6).unwrap().status, MlJobStatus::Completed);
648 assert_eq!(q.get(7).unwrap().status, MlJobStatus::Completed);
649
650 let fresh_id = q.submit(MlJobKind::Train, "d", "{}");
652 assert!(fresh_id > 7);
653
654 q.shutdown();
655 }
656
657 #[test]
658 fn multiple_workers_drain_backlog() {
659 let q = MlJobQueue::start(
660 3,
661 Arc::new(|handle| {
662 handle.set_progress(50);
663 thread::sleep(Duration::from_millis(20));
664 Ok("{}".to_string())
665 }),
666 );
667 let ids: Vec<_> = (0..20)
668 .map(|i| q.submit(MlJobKind::Train, format!("m{i}"), "{}"))
669 .collect();
670 assert!(wait_until(
671 || ids
672 .iter()
673 .all(|id| q.get(*id).map(|j| j.is_terminal()).unwrap_or(false)),
674 Duration::from_secs(5),
675 ));
676 for id in ids {
677 assert_eq!(q.get(id).unwrap().status, MlJobStatus::Completed);
678 }
679 q.shutdown();
680 }
681}