1use std::{
11 any::Any,
12 collections::{BinaryHeap, VecDeque},
13 sync::{
14 Arc, Condvar, Mutex,
15 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
16 mpsc::{self, Receiver, Sender},
17 },
18 thread::{self, JoinHandle},
19 time::Duration,
20};
21
22use reifydb_core::{
23 Result,
24 interface::version::{ComponentType, HasVersion, SystemVersion},
25};
26use reifydb_engine::StandardEngine;
27pub use reifydb_sub_api::Priority;
28use reifydb_sub_api::{BoxedOnceTask, BoxedTask, HealthStatus, Scheduler, SchedulerService, Subsystem, TaskHandle};
29use tracing::{debug, instrument, warn};
30
31use crate::{
32 client::{SchedulerClient, SchedulerRequest, SchedulerResponse},
33 scheduler::{OnceTaskAdapter, SchedulableTaskAdapter, TaskScheduler},
34 task::{PoolTask, PrioritizedTask},
35 tracker::TaskTracker,
36};
37
38#[derive(Debug, Clone)]
40pub struct WorkerConfig {
41 pub num_workers: usize,
43 pub max_queue_size: usize,
45 pub scheduler_interval: Duration,
47 pub task_timeout_warning: Duration,
49}
50
51impl Default for WorkerConfig {
52 fn default() -> Self {
53 Self {
54 num_workers: 1,
55 max_queue_size: 10000,
56 scheduler_interval: Duration::from_millis(10),
57 task_timeout_warning: Duration::from_secs(30),
58 }
59 }
60}
61
62#[derive(Debug, Default)]
64pub struct PoolStats {
65 pub tasks_completed: AtomicUsize,
66 pub tasks_failed: AtomicUsize,
67 pub tasks_queued: AtomicUsize,
68 pub active_workers: AtomicUsize,
69}
70
71pub struct WorkerSubsystem {
73 config: WorkerConfig,
74 running: Arc<AtomicBool>,
75 stats: Arc<PoolStats>,
76
77 thread_pool: Option<Arc<rayon::ThreadPool>>,
79 dispatcher_handle: Option<JoinHandle<()>>,
80
81 task_tracker: Arc<TaskTracker>,
83
84 task_queue: Arc<Mutex<BinaryHeap<PrioritizedTask>>>,
86 task_condvar: Arc<Condvar>,
87
88 scheduler: Arc<Mutex<TaskScheduler>>,
90 scheduler_condvar: Arc<Condvar>, scheduler_handle: Option<JoinHandle<()>>,
92
93 scheduler_receiver: Arc<Mutex<Option<Receiver<(SchedulerRequest, Sender<SchedulerResponse>)>>>>,
96
97 pending_requests: Arc<Mutex<VecDeque<(SchedulerRequest, Sender<SchedulerResponse>)>>>,
99
100 next_handle: Arc<AtomicU64>,
102
103 scheduler_client: Arc<dyn Scheduler>,
105
106 engine: StandardEngine,
108}
109
110impl WorkerSubsystem {
111 #[instrument(level = "debug", skip(engine))]
112 pub fn new(config: WorkerConfig, engine: StandardEngine) -> Self {
113 let pending_requests = Arc::new(Mutex::new(VecDeque::new()));
114 let next_handle = Arc::new(AtomicU64::new(1));
115 let running = Arc::new(AtomicBool::new(false));
116
117 let (sender, receiver) = mpsc::channel();
118
119 let scheduler_client = Arc::new(SchedulerClient::new(
120 sender,
121 Arc::clone(&pending_requests),
122 Arc::clone(&next_handle),
123 Arc::clone(&running),
124 ));
125
126 let max_queue_size = config.max_queue_size;
127 Self {
128 config,
129 running,
130 stats: Arc::new(PoolStats::default()),
131 thread_pool: None,
132 dispatcher_handle: None,
133 task_tracker: Arc::new(TaskTracker::new()),
134 task_queue: Arc::new(Mutex::new(BinaryHeap::with_capacity(max_queue_size))),
135 task_condvar: Arc::new(Condvar::new()),
136 scheduler: Arc::new(Mutex::new(TaskScheduler::new())),
137 scheduler_condvar: Arc::new(Condvar::new()),
138 scheduler_handle: None,
139 scheduler_receiver: Arc::new(Mutex::new(Some(receiver))),
140 pending_requests,
141 next_handle,
142 scheduler_client,
143 engine,
144 }
145 }
146
147 #[instrument(level = "trace", skip(self))]
149 pub fn get_scheduler(&self) -> SchedulerService {
150 SchedulerService(self.scheduler_client.clone())
151 }
152
153 #[instrument(level = "debug", skip(self, task))]
155 pub fn submit(&self, task: Box<dyn PoolTask>) -> Result<()> {
156 if !self.running.load(Ordering::Relaxed) {
157 panic!("Worker pool is not running");
158 }
159
160 {
161 let mut queue = self.task_queue.lock().unwrap();
162
163 if queue.len() >= self.config.max_queue_size {
165 panic!(
166 "Task queue is full. Consider increasing max_queue_size or reducing task submission rate"
167 );
168 }
169
170 queue.push(PrioritizedTask::new(task));
171 self.stats.tasks_queued.fetch_add(1, Ordering::Relaxed);
172 }
173
174 self.task_condvar.notify_one();
176 Ok(())
177 }
178
179 fn schedule_every_internal(
181 &self,
182 task: Box<dyn PoolTask>,
183 interval: Duration,
184 priority: Priority,
185 ) -> Result<TaskHandle> {
186 let mut scheduler = self.scheduler.lock().unwrap();
187 let handle = scheduler.schedule_every_internal(task, interval, priority);
188 drop(scheduler);
189
190 self.scheduler_condvar.notify_one();
192
193 Ok(handle)
194 }
195
196 #[instrument(level = "debug", skip(self))]
198 pub fn cancel_task(&self, handle: TaskHandle) -> Result<()> {
199 let mut scheduler = self.scheduler.lock().unwrap();
200 scheduler.cancel(handle);
201 Ok(())
202 }
203
204 #[instrument(level = "trace", skip(self))]
206 pub fn stats(&self) -> &PoolStats {
207 &self.stats
208 }
209
210 #[instrument(level = "trace", skip(self))]
212 pub fn active_workers(&self) -> usize {
213 self.stats.active_workers.load(Ordering::Relaxed)
214 }
215
216 #[instrument(level = "trace", skip(self))]
218 pub fn queued_tasks(&self) -> usize {
219 self.task_queue.lock().unwrap().len()
220 }
221
222 fn run_dispatcher(
224 pool: Arc<rayon::ThreadPool>,
225 queue: Arc<Mutex<BinaryHeap<PrioritizedTask>>>,
226 condvar: Arc<Condvar>,
227 tracker: Arc<TaskTracker>,
228 stats: Arc<PoolStats>,
229 running: Arc<AtomicBool>,
230 engine: StandardEngine,
231 ) {
232 debug!("Dispatcher thread started");
233
234 while running.load(Ordering::Relaxed) {
235 let task = {
237 let mut queue_guard = queue.lock().unwrap();
238
239 while queue_guard.is_empty() && running.load(Ordering::Relaxed) {
241 let (guard, timeout_result) =
242 condvar.wait_timeout(queue_guard, Duration::from_millis(100)).unwrap();
243 queue_guard = guard;
244
245 if timeout_result.timed_out() {
247 continue;
248 }
249 }
250
251 queue_guard.pop()
252 };
253
254 if let Some(prioritized_task) = task {
255 stats.tasks_queued.fetch_sub(1, Ordering::Relaxed);
257
258 let (task_id, cancel_token) = tracker.register(None);
260
261 let tracker_clone = Arc::clone(&tracker);
263 let stats_clone = Arc::clone(&stats);
264 let engine_clone = engine.clone();
265
266 pool.spawn(move || {
267 if cancel_token.is_cancelled() {
269 tracker_clone.complete(task_id);
270 return;
271 }
272
273 stats_clone.active_workers.fetch_add(1, Ordering::Relaxed);
275
276 let ctx = crate::task::InternalTaskContext {
278 cancel_token: Some(cancel_token.clone()),
279 engine: engine_clone,
280 };
281
282 let start = std::time::Instant::now();
284 let result = prioritized_task.task.execute(&ctx);
285 let duration = start.elapsed();
286
287 if duration > Duration::from_secs(5) {
289 warn!(
290 "Task '{}' took {:?} to execute",
291 prioritized_task.task.name(),
292 duration
293 );
294 }
295
296 match result {
298 Ok(_) => {
299 stats_clone.tasks_completed.fetch_add(1, Ordering::Relaxed);
300 }
301 Err(e) => {
302 warn!("Task '{}' failed: {}", prioritized_task.task.name(), e);
303 stats_clone.tasks_failed.fetch_add(1, Ordering::Relaxed);
304 }
305 }
306
307 stats_clone.active_workers.fetch_sub(1, Ordering::Relaxed);
308 tracker_clone.complete(task_id);
309 });
310 }
311 }
312
313 Self::drain_queue(pool, queue, tracker, stats, engine);
315
316 debug!("Dispatcher thread stopped");
317 }
318
319 fn drain_queue(
321 pool: Arc<rayon::ThreadPool>,
322 queue: Arc<Mutex<BinaryHeap<PrioritizedTask>>>,
323 tracker: Arc<TaskTracker>,
324 stats: Arc<PoolStats>,
325 engine: StandardEngine,
326 ) {
327 debug!("Draining task queue during shutdown");
328
329 loop {
330 let task = {
331 let mut queue_guard = queue.lock().unwrap();
332 queue_guard.pop()
333 };
334
335 match task {
336 Some(prioritized_task) => {
337 stats.tasks_queued.fetch_sub(1, Ordering::Relaxed);
338
339 let (task_id, _) = tracker.register(None);
341 let tracker_clone = Arc::clone(&tracker);
342 let stats_clone = Arc::clone(&stats);
343 let engine_clone = engine.clone();
344
345 pool.spawn(move || {
346 let ctx = crate::task::InternalTaskContext {
347 cancel_token: None,
348 engine: engine_clone,
349 };
350 let _ = prioritized_task.task.execute(&ctx);
351 stats_clone.tasks_completed.fetch_add(1, Ordering::Relaxed);
352 tracker_clone.complete(task_id);
353 });
354 }
355 None => break,
356 }
357 }
358 }
359
360 fn start_scheduler(&mut self) {
362 let scheduler = Arc::clone(&self.scheduler);
363 let scheduler_condvar = Arc::clone(&self.scheduler_condvar);
364 let task_queue = Arc::clone(&self.task_queue);
365 let task_condvar = Arc::clone(&self.task_condvar);
366 let running = Arc::clone(&self.running);
367 let stats = Arc::clone(&self.stats);
368 let max_queue_size = self.config.max_queue_size;
369 let scheduler_receiver = Arc::clone(&self.scheduler_receiver);
370 let pending_requests = Arc::clone(&self.pending_requests);
371 let next_handle = Arc::clone(&self.next_handle);
372 let engine = self.engine.clone();
373
374 let handle = thread::Builder::new()
375 .name("worker-scheduler".to_string())
376 .spawn(move || {
377 {
379 let mut pending = pending_requests.lock().unwrap();
380 let mut sched = scheduler.lock().unwrap();
381
382 sched.set_next_handle(next_handle.load(Ordering::Relaxed));
384
385 while let Some((request, response_tx)) = pending.pop_front() {
386 let response = match request {
387 SchedulerRequest::ScheduleEvery {
388 task,
389 interval,
390 } => {
391 let adapter = Box::new(SchedulableTaskAdapter::new(
392 task,
393 engine.clone(),
394 ));
395 let priority = adapter.priority();
396 let handle = sched.schedule_every_internal(
397 adapter, interval, priority,
398 );
399 SchedulerResponse::TaskScheduled(handle)
400 }
401 SchedulerRequest::Submit {
402 task,
403 priority: _,
404 } => {
405 let adapter = Box::new(OnceTaskAdapter::new(
406 task,
407 engine.clone(),
408 ));
409 drop(sched);
411 {
412 let mut queue = task_queue.lock().unwrap();
413 if queue.len() < max_queue_size {
414 queue.push(PrioritizedTask::new(
415 adapter,
416 ));
417 stats.tasks_queued.fetch_add(
418 1,
419 Ordering::Relaxed,
420 );
421 task_condvar.notify_one();
422 }
423 }
424 sched = scheduler.lock().unwrap();
425 SchedulerResponse::TaskSubmitted
426 }
427 SchedulerRequest::Cancel {
428 handle,
429 } => {
430 sched.cancel(handle);
431 SchedulerResponse::TaskCancelled
432 }
433 };
434 let _ = response_tx.send(response);
436 }
437 drop(sched);
438 drop(pending);
439 }
440
441 while running.load(Ordering::Relaxed) {
442 {
444 let receiver_guard = scheduler_receiver.lock().unwrap();
445 if let Some(ref receiver) = *receiver_guard {
446 while let Ok((request, response_tx)) = receiver.try_recv() {
447 let mut sched = scheduler.lock().unwrap();
448 let response = match request {
449 SchedulerRequest::ScheduleEvery {
450 task,
451 interval,
452 } => {
453 let adapter = Box::new(
456 SchedulableTaskAdapter::new(
457 task,
458 engine.clone(),
459 ),
460 );
461 let priority = adapter.priority();
462 let handle = sched
463 .schedule_every_internal(
464 adapter, interval,
465 priority,
466 );
467 SchedulerResponse::TaskScheduled(handle)
468 }
469 SchedulerRequest::Submit {
470 task,
471 priority: _,
472 } => {
473 let adapter =
474 Box::new(OnceTaskAdapter::new(
475 task,
476 engine.clone(),
477 ));
478 drop(sched);
479
480 {
481 let mut queue = task_queue
482 .lock()
483 .unwrap();
484 if queue.len() < max_queue_size
485 {
486 queue.push(PrioritizedTask::new(adapter));
487 stats.tasks_queued.fetch_add(1, Ordering::Relaxed);
488 task_condvar
489 .notify_one();
490 }
491 }
492 sched = scheduler.lock().unwrap();
493 SchedulerResponse::TaskSubmitted
494 }
495 SchedulerRequest::Cancel {
496 handle,
497 } => {
498 sched.cancel(handle);
499 SchedulerResponse::TaskCancelled
500 }
501 };
502 drop(sched);
503 let _ = response_tx.send(response);
505 }
506 }
507 }
508
509 let mut sched = scheduler.lock().unwrap();
510
511 if sched.task_count() == 0 {
513 let result = scheduler_condvar
515 .wait_timeout(sched, Duration::from_millis(1))
516 .unwrap();
517
518 sched = result.0;
519
520 if !running.load(Ordering::Relaxed) {
522 break;
523 }
524
525 drop(sched);
527 continue;
528 }
529
530 let ready_tasks = sched.get_ready_tasks();
532
533 let wait_duration = if let Some(next_time) = sched.next_run_time() {
535 let now = std::time::Instant::now();
536 if next_time > now {
537 next_time - now
538 } else {
539 Duration::from_millis(0)
540 }
541 } else {
542 Duration::from_secs(1)
544 };
545
546 drop(sched);
547
548 if !ready_tasks.is_empty() {
550 let mut queue = task_queue.lock().unwrap();
551
552 for task in ready_tasks {
553 if queue.len() >= max_queue_size {
554 warn!("Scheduler: Queue full, dropping scheduled task");
555 break;
556 }
557
558 queue.push(PrioritizedTask::new(task));
559 stats.tasks_queued.fetch_add(1, Ordering::Relaxed);
560 }
561
562 drop(queue);
563 task_condvar.notify_all();
564 }
565
566 if wait_duration > Duration::from_millis(0) {
568 let sched = scheduler.lock().unwrap();
569 let _ = scheduler_condvar.wait_timeout(sched, wait_duration);
570 }
571 }
572 })
573 .expect("Failed to create scheduler thread");
574
575 self.scheduler_handle = Some(handle);
576 }
577}
578
579impl Subsystem for WorkerSubsystem {
580 fn name(&self) -> &'static str {
581 "sub-worker"
582 }
583
584 #[instrument(level = "info", skip(self))]
585 fn start(&mut self) -> Result<()> {
586 if self.running.load(Ordering::Relaxed) {
587 return Ok(()); }
589
590 debug!("Starting worker subsystem with {} workers", self.config.num_workers);
591
592 let pool = rayon::ThreadPoolBuilder::new()
594 .num_threads(self.config.num_workers)
595 .thread_name(|i| format!("rayon-worker-{}", i))
596 .panic_handler(|panic_info| {
597 warn!("Worker thread panicked: {:?}", panic_info);
598 })
599 .build()
600 .map_err(|e| {
601 reifydb_core::error!(reifydb_core::diagnostic::internal(format!(
602 "Failed to create thread pool: {}",
603 e
604 )))
605 })?;
606
607 self.thread_pool = Some(Arc::new(pool));
608 self.running.store(true, Ordering::Relaxed);
609
610 {
612 let pool = Arc::clone(self.thread_pool.as_ref().unwrap());
613 let queue = Arc::clone(&self.task_queue);
614 let condvar = Arc::clone(&self.task_condvar);
615 let tracker = Arc::clone(&self.task_tracker);
616 let stats = Arc::clone(&self.stats);
617 let running = Arc::clone(&self.running);
618 let engine = self.engine.clone();
619
620 let handle = thread::Builder::new()
621 .name("worker-dispatcher".to_string())
622 .spawn(move || {
623 Self::run_dispatcher(pool, queue, condvar, tracker, stats, running, engine)
624 })
625 .map_err(|e| {
626 reifydb_core::error!(reifydb_core::diagnostic::internal(format!(
627 "Failed to spawn dispatcher thread: {}",
628 e
629 )))
630 })?;
631
632 self.dispatcher_handle = Some(handle);
633 }
634
635 self.start_scheduler();
637
638 debug!("Started with {} workers", self.config.num_workers);
639
640 Ok(())
641 }
642
643 #[instrument(level = "info", skip(self))]
644 fn shutdown(&mut self) -> Result<()> {
645 if !self.running.load(Ordering::Relaxed) {
646 return Ok(()); }
648
649 debug!("Worker pool shutting down...");
650
651 self.running.store(false, Ordering::Relaxed);
652 debug!("Signaled threads to stop");
653
654 self.task_condvar.notify_all();
655 self.scheduler_condvar.notify_all();
656
657 let cancelled_count = {
658 let mut queue = self.task_queue.lock().unwrap();
659 let count = queue.len();
660 queue.clear();
661 count
662 };
663
664 if cancelled_count > 0 {
665 debug!("Cancelled {} queued tasks", cancelled_count);
666 }
667
668 if let Some(handle) = self.scheduler_handle.take() {
670 let _ = handle.join();
671 }
672
673 if let Some(handle) = self.dispatcher_handle.take() {
675 let _ = handle.join();
676 }
677
678 debug!("Scheduler and dispatcher threads stopped");
679
680 let active_count = self.task_tracker.active_count();
681 if active_count > 0 {
682 debug!("Waiting for {} active tasks to complete...", active_count);
683 }
684
685 let timeout = Duration::from_secs(10);
686 if !self.task_tracker.wait_for_completion(timeout) {
687 let remaining = self.task_tracker.active_count();
688 warn!(
689 "Timeout waiting for tasks to complete. {} tasks still running. Forcing shutdown.",
690 remaining
691 );
692
693 self.thread_pool = None;
694 warn!("Worker pool shutdown forced with {} tasks still active", remaining);
695 } else {
696 self.thread_pool = None;
697 debug!("All tasks completed, worker pool shutdown complete");
698 }
699
700 Ok(())
701 }
702
703 #[instrument(level = "trace", skip(self))]
704 fn is_running(&self) -> bool {
705 self.running.load(Ordering::Relaxed)
706 }
707
708 #[instrument(level = "debug", skip(self))]
709 fn health_status(&self) -> HealthStatus {
710 if !self.is_running() {
711 return HealthStatus::Unknown;
712 }
713
714 let active = self.active_workers();
715 let queued = self.queued_tasks();
716
717 if active == 0 && queued > 0 {
718 HealthStatus::Failed {
720 description: "No active workers but tasks are queued".into(),
721 }
722 } else if queued > self.config.max_queue_size / 2 {
723 HealthStatus::Degraded {
725 description: format!(
726 "Task queue is {}% full",
727 (queued * 100) / self.config.max_queue_size
728 ),
729 }
730 } else {
731 HealthStatus::Healthy
732 }
733 }
734
735 fn as_any(&self) -> &dyn Any {
736 self
737 }
738
739 fn as_any_mut(&mut self) -> &mut dyn Any {
740 self
741 }
742}
743
744impl HasVersion for WorkerSubsystem {
745 fn version(&self) -> SystemVersion {
746 SystemVersion {
747 name: "sub-worker".to_string(),
748 version: env!("CARGO_PKG_VERSION").to_string(),
749 description: "Priority-based task worker pool subsystem".to_string(),
750 r#type: ComponentType::Subsystem,
751 }
752 }
753}
754
755impl Drop for WorkerSubsystem {
756 fn drop(&mut self) {
757 let _ = self.shutdown();
758 }
759}
760
761impl Scheduler for WorkerSubsystem {
762 fn every(&self, interval: Duration, task: BoxedTask) -> Result<TaskHandle> {
763 let adapter = Box::new(SchedulableTaskAdapter::new(task, self.engine.clone()));
764 let priority = adapter.priority();
765 self.schedule_every_internal(adapter, interval, priority)
766 }
767
768 fn cancel(&self, handle: TaskHandle) -> Result<()> {
769 self.cancel_task(handle)
770 }
771
772 fn once(&self, task: BoxedOnceTask) -> Result<()> {
773 let adapter = Box::new(OnceTaskAdapter::new(task, self.engine.clone()));
774 WorkerSubsystem::submit(self, adapter)
775 }
776}