1use std::{
11 any::Any,
12 collections::{BinaryHeap, VecDeque},
13 sync::{
14 Arc, Condvar, Mutex,
15 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
16 mpsc::{self, Receiver, Sender},
17 },
18 thread::{self, JoinHandle},
19 time::Duration,
20};
21
22use reifydb_core::{
23 Result,
24 interface::version::{ComponentType, HasVersion, SystemVersion},
25 log_debug, log_warn,
26};
27use reifydb_engine::StandardEngine;
28pub use reifydb_sub_api::Priority;
29use reifydb_sub_api::{BoxedOnceTask, BoxedTask, HealthStatus, Scheduler, SchedulerService, Subsystem, TaskHandle};
30
31use crate::{
32 client::{SchedulerClient, SchedulerRequest, SchedulerResponse},
33 scheduler::{OnceTaskAdapter, SchedulableTaskAdapter, TaskScheduler},
34 task::{PoolTask, PrioritizedTask},
35 tracker::TaskTracker,
36};
37
38#[derive(Debug, Clone)]
40pub struct WorkerConfig {
41 pub num_workers: usize,
43 pub max_queue_size: usize,
45 pub scheduler_interval: Duration,
47 pub task_timeout_warning: Duration,
49}
50
51impl Default for WorkerConfig {
52 fn default() -> Self {
53 Self {
54 num_workers: 1,
55 max_queue_size: 10000,
56 scheduler_interval: Duration::from_millis(10),
57 task_timeout_warning: Duration::from_secs(30),
58 }
59 }
60}
61
62#[derive(Debug, Default)]
64pub struct PoolStats {
65 pub tasks_completed: AtomicUsize,
66 pub tasks_failed: AtomicUsize,
67 pub tasks_queued: AtomicUsize,
68 pub active_workers: AtomicUsize,
69}
70
71pub struct WorkerSubsystem {
73 config: WorkerConfig,
74 running: Arc<AtomicBool>,
75 stats: Arc<PoolStats>,
76
77 thread_pool: Option<Arc<rayon::ThreadPool>>,
79 dispatcher_handle: Option<JoinHandle<()>>,
80
81 task_tracker: Arc<TaskTracker>,
83
84 task_queue: Arc<Mutex<BinaryHeap<PrioritizedTask>>>,
86 task_condvar: Arc<Condvar>,
87
88 scheduler: Arc<Mutex<TaskScheduler>>,
90 scheduler_condvar: Arc<Condvar>, scheduler_handle: Option<JoinHandle<()>>,
92
93 scheduler_receiver: Arc<Mutex<Option<Receiver<(SchedulerRequest, Sender<SchedulerResponse>)>>>>,
96
97 pending_requests: Arc<Mutex<VecDeque<(SchedulerRequest, Sender<SchedulerResponse>)>>>,
99
100 next_handle: Arc<AtomicU64>,
102
103 scheduler_client: Arc<dyn Scheduler>,
105
106 engine: StandardEngine,
108}
109
110impl WorkerSubsystem {
111 pub fn new(config: WorkerConfig, engine: StandardEngine) -> Self {
112 let pending_requests = Arc::new(Mutex::new(VecDeque::new()));
113 let next_handle = Arc::new(AtomicU64::new(1));
114 let running = Arc::new(AtomicBool::new(false));
115
116 let (sender, receiver) = mpsc::channel();
117
118 let scheduler_client = Arc::new(SchedulerClient::new(
119 sender,
120 Arc::clone(&pending_requests),
121 Arc::clone(&next_handle),
122 Arc::clone(&running),
123 ));
124
125 let max_queue_size = config.max_queue_size;
126 Self {
127 config,
128 running,
129 stats: Arc::new(PoolStats::default()),
130 thread_pool: None,
131 dispatcher_handle: None,
132 task_tracker: Arc::new(TaskTracker::new()),
133 task_queue: Arc::new(Mutex::new(BinaryHeap::with_capacity(max_queue_size))),
134 task_condvar: Arc::new(Condvar::new()),
135 scheduler: Arc::new(Mutex::new(TaskScheduler::new())),
136 scheduler_condvar: Arc::new(Condvar::new()),
137 scheduler_handle: None,
138 scheduler_receiver: Arc::new(Mutex::new(Some(receiver))),
139 pending_requests,
140 next_handle,
141 scheduler_client,
142 engine,
143 }
144 }
145
146 pub fn get_scheduler(&self) -> SchedulerService {
148 SchedulerService(self.scheduler_client.clone())
149 }
150
151 pub fn submit(&self, task: Box<dyn PoolTask>) -> Result<()> {
153 if !self.running.load(Ordering::Relaxed) {
154 panic!("Worker pool is not running");
155 }
156
157 {
158 let mut queue = self.task_queue.lock().unwrap();
159
160 if queue.len() >= self.config.max_queue_size {
162 panic!(
163 "Task queue is full. Consider increasing max_queue_size or reducing task submission rate"
164 );
165 }
166
167 queue.push(PrioritizedTask::new(task));
168 self.stats.tasks_queued.fetch_add(1, Ordering::Relaxed);
169 }
170
171 self.task_condvar.notify_one();
173 Ok(())
174 }
175
176 fn schedule_every_internal(
178 &self,
179 task: Box<dyn PoolTask>,
180 interval: Duration,
181 priority: Priority,
182 ) -> Result<TaskHandle> {
183 let mut scheduler = self.scheduler.lock().unwrap();
184 let handle = scheduler.schedule_every_internal(task, interval, priority);
185 drop(scheduler);
186
187 self.scheduler_condvar.notify_one();
189
190 Ok(handle)
191 }
192
193 pub fn cancel_task(&self, handle: TaskHandle) -> Result<()> {
195 let mut scheduler = self.scheduler.lock().unwrap();
196 scheduler.cancel(handle);
197 Ok(())
198 }
199
200 pub fn stats(&self) -> &PoolStats {
202 &self.stats
203 }
204
205 pub fn active_workers(&self) -> usize {
207 self.stats.active_workers.load(Ordering::Relaxed)
208 }
209
210 pub fn queued_tasks(&self) -> usize {
212 self.task_queue.lock().unwrap().len()
213 }
214
215 fn run_dispatcher(
217 pool: Arc<rayon::ThreadPool>,
218 queue: Arc<Mutex<BinaryHeap<PrioritizedTask>>>,
219 condvar: Arc<Condvar>,
220 tracker: Arc<TaskTracker>,
221 stats: Arc<PoolStats>,
222 running: Arc<AtomicBool>,
223 engine: StandardEngine,
224 ) {
225 log_debug!("Dispatcher thread started");
226
227 while running.load(Ordering::Relaxed) {
228 let task = {
230 let mut queue_guard = queue.lock().unwrap();
231
232 while queue_guard.is_empty() && running.load(Ordering::Relaxed) {
234 let (guard, timeout_result) =
235 condvar.wait_timeout(queue_guard, Duration::from_millis(100)).unwrap();
236 queue_guard = guard;
237
238 if timeout_result.timed_out() {
240 continue;
241 }
242 }
243
244 queue_guard.pop()
245 };
246
247 if let Some(prioritized_task) = task {
248 stats.tasks_queued.fetch_sub(1, Ordering::Relaxed);
250
251 let (task_id, cancel_token) = tracker.register(None);
253
254 let tracker_clone = Arc::clone(&tracker);
256 let stats_clone = Arc::clone(&stats);
257 let engine_clone = engine.clone();
258
259 pool.spawn(move || {
260 if cancel_token.is_cancelled() {
262 tracker_clone.complete(task_id);
263 return;
264 }
265
266 stats_clone.active_workers.fetch_add(1, Ordering::Relaxed);
268
269 let ctx = crate::task::InternalTaskContext {
271 cancel_token: Some(cancel_token.clone()),
272 engine: engine_clone,
273 };
274
275 let start = std::time::Instant::now();
277 let result = prioritized_task.task.execute(&ctx);
278 let duration = start.elapsed();
279
280 if duration > Duration::from_secs(5) {
282 log_warn!(
283 "Task '{}' took {:?} to execute",
284 prioritized_task.task.name(),
285 duration
286 );
287 }
288
289 match result {
291 Ok(_) => {
292 stats_clone.tasks_completed.fetch_add(1, Ordering::Relaxed);
293 }
294 Err(e) => {
295 log_warn!(
296 "Task '{}' failed: {}",
297 prioritized_task.task.name(),
298 e
299 );
300 stats_clone.tasks_failed.fetch_add(1, Ordering::Relaxed);
301 }
302 }
303
304 stats_clone.active_workers.fetch_sub(1, Ordering::Relaxed);
305 tracker_clone.complete(task_id);
306 });
307 }
308 }
309
310 Self::drain_queue(pool, queue, tracker, stats, engine);
312
313 log_debug!("Dispatcher thread stopped");
314 }
315
316 fn drain_queue(
318 pool: Arc<rayon::ThreadPool>,
319 queue: Arc<Mutex<BinaryHeap<PrioritizedTask>>>,
320 tracker: Arc<TaskTracker>,
321 stats: Arc<PoolStats>,
322 engine: StandardEngine,
323 ) {
324 log_debug!("Draining task queue during shutdown");
325
326 loop {
327 let task = {
328 let mut queue_guard = queue.lock().unwrap();
329 queue_guard.pop()
330 };
331
332 match task {
333 Some(prioritized_task) => {
334 stats.tasks_queued.fetch_sub(1, Ordering::Relaxed);
335
336 let (task_id, _) = tracker.register(None);
338 let tracker_clone = Arc::clone(&tracker);
339 let stats_clone = Arc::clone(&stats);
340 let engine_clone = engine.clone();
341
342 pool.spawn(move || {
343 let ctx = crate::task::InternalTaskContext {
344 cancel_token: None,
345 engine: engine_clone,
346 };
347 let _ = prioritized_task.task.execute(&ctx);
348 stats_clone.tasks_completed.fetch_add(1, Ordering::Relaxed);
349 tracker_clone.complete(task_id);
350 });
351 }
352 None => break,
353 }
354 }
355 }
356
357 fn start_scheduler(&mut self) {
359 let scheduler = Arc::clone(&self.scheduler);
360 let scheduler_condvar = Arc::clone(&self.scheduler_condvar);
361 let task_queue = Arc::clone(&self.task_queue);
362 let task_condvar = Arc::clone(&self.task_condvar);
363 let running = Arc::clone(&self.running);
364 let stats = Arc::clone(&self.stats);
365 let max_queue_size = self.config.max_queue_size;
366 let scheduler_receiver = Arc::clone(&self.scheduler_receiver);
367 let pending_requests = Arc::clone(&self.pending_requests);
368 let next_handle = Arc::clone(&self.next_handle);
369 let engine = self.engine.clone();
370
371 let handle = thread::Builder::new()
372 .name("worker-scheduler".to_string())
373 .spawn(move || {
374 {
376 let mut pending = pending_requests.lock().unwrap();
377 let mut sched = scheduler.lock().unwrap();
378
379 sched.set_next_handle(next_handle.load(Ordering::Relaxed));
381
382 while let Some((request, response_tx)) = pending.pop_front() {
383 let response = match request {
384 SchedulerRequest::ScheduleEvery {
385 task,
386 interval,
387 } => {
388 let adapter = Box::new(SchedulableTaskAdapter::new(
389 task,
390 engine.clone(),
391 ));
392 let priority = adapter.priority();
393 let handle = sched.schedule_every_internal(
394 adapter, interval, priority,
395 );
396 SchedulerResponse::TaskScheduled(handle)
397 }
398 SchedulerRequest::Submit {
399 task,
400 priority: _,
401 } => {
402 let adapter = Box::new(OnceTaskAdapter::new(
403 task,
404 engine.clone(),
405 ));
406 drop(sched);
408 {
409 let mut queue = task_queue.lock().unwrap();
410 if queue.len() < max_queue_size {
411 queue.push(PrioritizedTask::new(
412 adapter,
413 ));
414 stats.tasks_queued.fetch_add(
415 1,
416 Ordering::Relaxed,
417 );
418 task_condvar.notify_one();
419 }
420 }
421 sched = scheduler.lock().unwrap();
422 SchedulerResponse::TaskSubmitted
423 }
424 SchedulerRequest::Cancel {
425 handle,
426 } => {
427 sched.cancel(handle);
428 SchedulerResponse::TaskCancelled
429 }
430 };
431 let _ = response_tx.send(response);
433 }
434 drop(sched);
435 drop(pending);
436 }
437
438 while running.load(Ordering::Relaxed) {
439 {
441 let receiver_guard = scheduler_receiver.lock().unwrap();
442 if let Some(ref receiver) = *receiver_guard {
443 while let Ok((request, response_tx)) = receiver.try_recv() {
444 let mut sched = scheduler.lock().unwrap();
445 let response = match request {
446 SchedulerRequest::ScheduleEvery {
447 task,
448 interval,
449 } => {
450 let adapter = Box::new(
453 SchedulableTaskAdapter::new(
454 task,
455 engine.clone(),
456 ),
457 );
458 let priority = adapter.priority();
459 let handle = sched
460 .schedule_every_internal(
461 adapter, interval,
462 priority,
463 );
464 SchedulerResponse::TaskScheduled(handle)
465 }
466 SchedulerRequest::Submit {
467 task,
468 priority: _,
469 } => {
470 let adapter =
471 Box::new(OnceTaskAdapter::new(
472 task,
473 engine.clone(),
474 ));
475 drop(sched);
476
477 {
478 let mut queue = task_queue
479 .lock()
480 .unwrap();
481 if queue.len() < max_queue_size
482 {
483 queue.push(PrioritizedTask::new(adapter));
484 stats.tasks_queued.fetch_add(1, Ordering::Relaxed);
485 task_condvar
486 .notify_one();
487 }
488 }
489 sched = scheduler.lock().unwrap();
490 SchedulerResponse::TaskSubmitted
491 }
492 SchedulerRequest::Cancel {
493 handle,
494 } => {
495 sched.cancel(handle);
496 SchedulerResponse::TaskCancelled
497 }
498 };
499 drop(sched);
500 let _ = response_tx.send(response);
502 }
503 }
504 }
505
506 let mut sched = scheduler.lock().unwrap();
507
508 if sched.task_count() == 0 {
510 let result = scheduler_condvar
512 .wait_timeout(sched, Duration::from_millis(1))
513 .unwrap();
514
515 sched = result.0;
516
517 if !running.load(Ordering::Relaxed) {
519 break;
520 }
521
522 drop(sched);
524 continue;
525 }
526
527 let ready_tasks = sched.get_ready_tasks();
529
530 let wait_duration = if let Some(next_time) = sched.next_run_time() {
532 let now = std::time::Instant::now();
533 if next_time > now {
534 next_time - now
535 } else {
536 Duration::from_millis(0)
537 }
538 } else {
539 Duration::from_secs(1)
541 };
542
543 drop(sched);
544
545 if !ready_tasks.is_empty() {
547 let mut queue = task_queue.lock().unwrap();
548
549 for task in ready_tasks {
550 if queue.len() >= max_queue_size {
551 log_warn!(
552 "Scheduler: Queue full, dropping scheduled task"
553 );
554 break;
555 }
556
557 queue.push(PrioritizedTask::new(task));
558 stats.tasks_queued.fetch_add(1, Ordering::Relaxed);
559 }
560
561 drop(queue);
562 task_condvar.notify_all();
563 }
564
565 if wait_duration > Duration::from_millis(0) {
567 let sched = scheduler.lock().unwrap();
568 let _ = scheduler_condvar.wait_timeout(sched, wait_duration);
569 }
570 }
571 })
572 .expect("Failed to create scheduler thread");
573
574 self.scheduler_handle = Some(handle);
575 }
576}
577
578impl Subsystem for WorkerSubsystem {
579 fn name(&self) -> &'static str {
580 "sub-worker"
581 }
582
583 fn start(&mut self) -> Result<()> {
584 if self.running.load(Ordering::Relaxed) {
585 return Ok(()); }
587
588 log_debug!("Starting worker subsystem with {} workers", self.config.num_workers);
589
590 let pool = rayon::ThreadPoolBuilder::new()
592 .num_threads(self.config.num_workers)
593 .thread_name(|i| format!("rayon-worker-{}", i))
594 .panic_handler(|panic_info| {
595 log_warn!("Worker thread panicked: {:?}", panic_info);
596 })
597 .build()
598 .map_err(|e| {
599 reifydb_core::error!(reifydb_core::diagnostic::internal(format!(
600 "Failed to create thread pool: {}",
601 e
602 )))
603 })?;
604
605 self.thread_pool = Some(Arc::new(pool));
606 self.running.store(true, Ordering::Relaxed);
607
608 {
610 let pool = Arc::clone(self.thread_pool.as_ref().unwrap());
611 let queue = Arc::clone(&self.task_queue);
612 let condvar = Arc::clone(&self.task_condvar);
613 let tracker = Arc::clone(&self.task_tracker);
614 let stats = Arc::clone(&self.stats);
615 let running = Arc::clone(&self.running);
616 let engine = self.engine.clone();
617
618 let handle = thread::Builder::new()
619 .name("worker-dispatcher".to_string())
620 .spawn(move || {
621 Self::run_dispatcher(pool, queue, condvar, tracker, stats, running, engine)
622 })
623 .map_err(|e| {
624 reifydb_core::error!(reifydb_core::diagnostic::internal(format!(
625 "Failed to spawn dispatcher thread: {}",
626 e
627 )))
628 })?;
629
630 self.dispatcher_handle = Some(handle);
631 }
632
633 self.start_scheduler();
635
636 log_debug!("Started with {} workers", self.config.num_workers);
637
638 Ok(())
639 }
640
641 fn shutdown(&mut self) -> Result<()> {
642 if !self.running.load(Ordering::Relaxed) {
643 return Ok(()); }
645
646 log_debug!("Worker pool shutting down...");
647
648 self.running.store(false, Ordering::Relaxed);
649 log_debug!("Signaled threads to stop");
650
651 self.task_condvar.notify_all();
652 self.scheduler_condvar.notify_all();
653
654 let cancelled_count = {
655 let mut queue = self.task_queue.lock().unwrap();
656 let count = queue.len();
657 queue.clear();
658 count
659 };
660
661 if cancelled_count > 0 {
662 log_debug!("Cancelled {} queued tasks", cancelled_count);
663 }
664
665 if let Some(handle) = self.scheduler_handle.take() {
667 let _ = handle.join();
668 }
669
670 if let Some(handle) = self.dispatcher_handle.take() {
672 let _ = handle.join();
673 }
674
675 log_debug!("Scheduler and dispatcher threads stopped");
676
677 let active_count = self.task_tracker.active_count();
678 if active_count > 0 {
679 log_debug!("Waiting for {} active tasks to complete...", active_count);
680 }
681
682 let timeout = Duration::from_secs(10);
683 if !self.task_tracker.wait_for_completion(timeout) {
684 let remaining = self.task_tracker.active_count();
685 log_warn!(
686 "Timeout waiting for tasks to complete. {} tasks still running. Forcing shutdown.",
687 remaining
688 );
689
690 self.thread_pool = None;
691 log_warn!("Worker pool shutdown forced with {} tasks still active", remaining);
692 } else {
693 self.thread_pool = None;
694 log_debug!("All tasks completed, worker pool shutdown complete");
695 }
696
697 Ok(())
698 }
699
700 fn is_running(&self) -> bool {
701 self.running.load(Ordering::Relaxed)
702 }
703
704 fn health_status(&self) -> HealthStatus {
705 if !self.is_running() {
706 return HealthStatus::Unknown;
707 }
708
709 let active = self.active_workers();
710 let queued = self.queued_tasks();
711
712 if active == 0 && queued > 0 {
713 HealthStatus::Failed {
715 description: "No active workers but tasks are queued".into(),
716 }
717 } else if queued > self.config.max_queue_size / 2 {
718 HealthStatus::Degraded {
720 description: format!(
721 "Task queue is {}% full",
722 (queued * 100) / self.config.max_queue_size
723 ),
724 }
725 } else {
726 HealthStatus::Healthy
727 }
728 }
729
730 fn as_any(&self) -> &dyn Any {
731 self
732 }
733
734 fn as_any_mut(&mut self) -> &mut dyn Any {
735 self
736 }
737}
738
739impl HasVersion for WorkerSubsystem {
740 fn version(&self) -> SystemVersion {
741 SystemVersion {
742 name: "sub-worker".to_string(),
743 version: env!("CARGO_PKG_VERSION").to_string(),
744 description: "Priority-based task worker pool subsystem".to_string(),
745 r#type: ComponentType::Subsystem,
746 }
747 }
748}
749
750impl Drop for WorkerSubsystem {
751 fn drop(&mut self) {
752 let _ = self.shutdown();
753 }
754}
755
756impl Scheduler for WorkerSubsystem {
757 fn every(&self, interval: Duration, task: BoxedTask) -> Result<TaskHandle> {
758 let adapter = Box::new(SchedulableTaskAdapter::new(task, self.engine.clone()));
759 let priority = adapter.priority();
760 self.schedule_every_internal(adapter, interval, priority)
761 }
762
763 fn cancel(&self, handle: TaskHandle) -> Result<()> {
764 self.cancel_task(handle)
765 }
766
767 fn once(&self, task: BoxedOnceTask) -> Result<()> {
768 let adapter = Box::new(OnceTaskAdapter::new(task, self.engine.clone()));
769 WorkerSubsystem::submit(self, adapter)
770 }
771}