1use dashmap::DashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::{oneshot, RwLock};
6use tokio::task::JoinHandle;
7use uuid::Uuid;
8
9use crate::schema::{Task, TaskResult};
10use crate::transport::websocket::WebSocketTransport;
11use crate::transport::{Message, Transport, TransportConfig, TransportError};
12use crate::worker::{PoolError, WorkerInfo, WorkerPool, WorkerStatus};
13
14pub struct DispatcherBuilder {
16 config: TransportConfig,
17 heartbeat_timeout_ms: u64,
18 dead_worker_check_interval_ms: u64,
19 max_pool_size: Option<u32>,
20 min_pool_size: Option<u32>,
21 on_pool_below_min: Option<Arc<dyn Fn(u32) + Send + Sync>>,
22}
23
24impl std::fmt::Debug for DispatcherBuilder {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("DispatcherBuilder")
27 .field("config", &self.config)
28 .field("heartbeat_timeout_ms", &self.heartbeat_timeout_ms)
29 .field(
30 "dead_worker_check_interval_ms",
31 &self.dead_worker_check_interval_ms,
32 )
33 .field("max_pool_size", &self.max_pool_size)
34 .field("min_pool_size", &self.min_pool_size)
35 .field("on_pool_below_min", &self.on_pool_below_min.is_some())
36 .finish()
37 }
38}
39
40impl Default for DispatcherBuilder {
41 fn default() -> Self {
42 Self {
43 config: TransportConfig::default(),
44 heartbeat_timeout_ms: 15_000,
45 dead_worker_check_interval_ms: 5_000,
46 max_pool_size: None,
47 min_pool_size: None,
48 on_pool_below_min: None,
49 }
50 }
51}
52
53impl DispatcherBuilder {
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn host(mut self, host: impl Into<String>) -> Self {
59 self.config.host = host.into();
60 self
61 }
62
63 pub fn port(mut self, port: u16) -> Self {
64 self.config.port = port;
65 self
66 }
67
68 pub fn max_connections(mut self, max: u32) -> Self {
69 self.config.max_connections = max;
70 self
71 }
72
73 pub fn heartbeat_interval(mut self, ms: u64) -> Self {
74 self.config.heartbeat_interval_ms = ms;
75 self
76 }
77
78 pub fn heartbeat_timeout(mut self, ms: u64) -> Self {
79 self.heartbeat_timeout_ms = ms;
80 self
81 }
82
83 pub fn max_pool_size(mut self, max: u32) -> Self {
86 self.max_pool_size = Some(max);
87 self
88 }
89
90 pub fn min_pool_size(mut self, min: u32) -> Self {
93 self.min_pool_size = Some(min);
94 self
95 }
96
97 pub fn on_pool_below_min(mut self, cb: impl Fn(u32) + Send + Sync + 'static) -> Self {
100 self.on_pool_below_min = Some(Arc::new(cb));
101 self
102 }
103
104 pub fn build(self) -> Dispatcher {
105 Dispatcher {
106 pool: Arc::new(WorkerPool::with_limits(
107 self.heartbeat_timeout_ms,
108 self.max_pool_size,
109 self.min_pool_size,
110 self.on_pool_below_min,
111 )),
112 pending: Arc::new(DashMap::new()),
113 transport: Arc::new(RwLock::new(None)),
114 config: self.config,
115 dead_worker_check_interval_ms: self.dead_worker_check_interval_ms,
116 started: AtomicBool::new(false),
117 _dead_worker_task: RwLock::new(None),
118 }
119 }
120}
121
122pub struct Dispatcher {
124 pool: Arc<WorkerPool>,
125 pending: Arc<DashMap<Uuid, PendingTask>>,
126 transport: Arc<RwLock<Option<Arc<WebSocketTransport>>>>,
127 config: TransportConfig,
128 dead_worker_check_interval_ms: u64,
129 started: AtomicBool,
130 _dead_worker_task: RwLock<Option<JoinHandle<()>>>,
131}
132
133struct PendingTask {
134 sender: oneshot::Sender<TaskResult>,
135 worker_id: String,
136}
137
138#[must_use = "dropping a DispatchResult discards the task result"]
140pub struct DispatchResult {
141 pub task_id: Uuid,
142 pub(crate) receiver: oneshot::Receiver<TaskResult>,
143}
144
145impl std::fmt::Debug for DispatchResult {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("DispatchResult")
148 .field("task_id", &self.task_id)
149 .finish()
150 }
151}
152
153impl DispatchResult {
154 pub async fn await_result(self) -> Result<TaskResult, DispatchError> {
155 self.receiver
156 .await
157 .map_err(|_| DispatchError::WorkerDisconnected)
158 }
159
160 pub async fn await_with_timeout(self, timeout: Duration) -> Result<TaskResult, DispatchError> {
161 tokio::time::timeout(timeout, self.receiver)
162 .await
163 .map_err(|_| DispatchError::Timeout)?
164 .map_err(|_| DispatchError::WorkerDisconnected)
165 }
166}
167
168impl Dispatcher {
169 pub fn builder() -> DispatcherBuilder {
170 DispatcherBuilder::new()
171 }
172
173 pub async fn start(&self) -> Result<(), DispatchError> {
174 if self.started.swap(true, Ordering::SeqCst) {
175 return Ok(());
176 }
177
178 let pool = self.pool.clone();
179 let pending = self.pending.clone();
180
181 let on_message = move |worker_id: String, message: Message| {
182 let pool = pool.clone();
183 let pending = pending.clone();
184
185 tokio::spawn(async move {
186 match message {
187 Message::WorkerRegister { registration: reg } => {
188 pool.register(WorkerInfo {
189 id: reg.worker_id,
190 language: reg.language,
191 supported_tasks: reg.supported_tasks,
192 max_concurrency: reg.max_concurrency,
193 status: WorkerStatus::Active,
194 active_tasks: 0,
195 registered_at: chrono::Utc::now(),
196 last_heartbeat: chrono::Utc::now(),
197 tags: reg.tags.unwrap_or_default(),
198 });
199 }
200 Message::TaskResult { result } => {
201 pool.mark_task_completed(&worker_id);
202 if let Some((_, pending_task)) = pending.remove(&result.task_id) {
203 let _ = pending_task.sender.send(result);
204 }
205 }
206 Message::Heartbeat { payload: hb } => {
207 pool.heartbeat(&hb.worker_id, hb.active_tasks);
208 }
209 Message::Backpressure { signal: bp } => {
210 tracing::warn!(
211 worker_id = %bp.worker_id,
212 load = bp.current_load,
213 "Worker signaled backpressure"
214 );
215 }
216 _ => {}
217 }
218 });
219 };
220
221 let transport = Arc::new(WebSocketTransport::new(self.config.clone(), on_message));
222 transport
223 .start()
224 .await
225 .map_err(DispatchError::TransportError)?;
226
227 *self.transport.write().await = Some(transport);
228
229 let pool = self.pool.clone();
232 let pending = self.pending.clone();
233 let interval = self.dead_worker_check_interval_ms;
234 let handle = tokio::spawn(async move {
235 loop {
236 tokio::time::sleep(Duration::from_millis(interval)).await;
237 let dead = pool.detect_dead_workers();
238 if !dead.is_empty() {
239 for worker_id in &dead {
240 tracing::warn!(worker_id = %worker_id, "Dead worker detected");
241 }
242 pending
245 .retain(|_task_id, pending_task| !dead.contains(&pending_task.worker_id));
246 }
247 }
248 });
249
250 *self._dead_worker_task.write().await = Some(handle);
251
252 Ok(())
253 }
254
255 pub async fn stop(&self) {
257 self.started.store(false, Ordering::SeqCst);
258 if let Some(handle) = self._dead_worker_task.write().await.take() {
260 handle.abort();
261 }
262 if let Some(transport) = self.transport.read().await.as_ref() {
264 let _ = transport.stop().await;
265 }
266 self.pending.clear();
268 }
269
270 pub async fn dispatch(&self, task: Task) -> Result<DispatchResult, DispatchError> {
271 let worker_id = self.pool.select_and_reserve(&task.task_type).ok_or(
273 DispatchError::NoWorkerAvailable {
274 task_type: task.task_type.clone(),
275 },
276 )?;
277
278 let (tx, rx) = oneshot::channel();
279 let task_id = task.id;
280
281 self.pending.insert(
282 task_id,
283 PendingTask {
284 sender: tx,
285 worker_id: worker_id.clone(),
286 },
287 );
288
289 let transport_guard = self.transport.read().await;
291 let transport = transport_guard.as_ref().ok_or_else(|| {
292 self.pending.remove(&task_id);
294 self.pool.mark_task_completed(&worker_id);
295 DispatchError::TransportNotStarted
296 })?;
297
298 if let Err(e) = transport
299 .send(&worker_id, Message::TaskDispatch { task })
300 .await
301 {
302 self.pending.remove(&task_id);
304 self.pool.mark_task_completed(&worker_id);
305 return Err(DispatchError::TransportError(e));
306 }
307
308 tracing::debug!(task_id = %task_id, worker_id = %worker_id, "Task dispatched");
309
310 Ok(DispatchResult {
311 task_id,
312 receiver: rx,
313 })
314 }
315
316 pub fn pool_stats(&self) -> crate::worker::PoolStats {
317 self.pool.stats()
318 }
319
320 pub fn workers(&self) -> Vec<WorkerInfo> {
322 self.pool.workers()
323 }
324
325 pub fn drain_worker(&self, worker_id: &str) -> Result<(), PoolError> {
328 self.pool.drain_worker(worker_id)
329 }
330
331 pub fn remove_worker(&self, worker_id: &str) -> Result<(), PoolError> {
333 self.pool.remove_worker(worker_id)?;
334 self.pending
336 .retain(|_task_id, pending_task| pending_task.worker_id != worker_id);
337 Ok(())
338 }
339
340 pub async fn dispatch_to(
342 &self,
343 worker_id: &str,
344 task: Task,
345 ) -> Result<DispatchResult, DispatchError> {
346 self.pool.reserve_specific_worker(worker_id)?;
347
348 let (tx, rx) = oneshot::channel();
349 let task_id = task.id;
350
351 self.pending.insert(
352 task_id,
353 PendingTask {
354 sender: tx,
355 worker_id: worker_id.to_string(),
356 },
357 );
358
359 let transport_guard = self.transport.read().await;
361 let transport = transport_guard.as_ref().ok_or_else(|| {
362 self.pending.remove(&task_id);
363 self.pool.mark_task_completed(worker_id);
364 DispatchError::TransportNotStarted
365 })?;
366
367 if let Err(e) = transport
368 .send(worker_id, Message::TaskDispatch { task })
369 .await
370 {
371 self.pending.remove(&task_id);
372 self.pool.mark_task_completed(worker_id);
373 return Err(DispatchError::TransportError(e));
374 }
375
376 tracing::debug!(task_id = %task_id, worker_id = %worker_id, "Task dispatched to specific worker");
377
378 Ok(DispatchResult {
379 task_id,
380 receiver: rx,
381 })
382 }
383
384 pub async fn dispatch_with_tag(
387 &self,
388 tag: &str,
389 task: Task,
390 ) -> Result<DispatchResult, DispatchError> {
391 let worker_id = self
392 .pool
393 .select_and_reserve_with_tag(tag, &task.task_type)
394 .ok_or(DispatchError::NoWorkerAvailable {
395 task_type: task.task_type.clone(),
396 })?;
397
398 let (tx, rx) = oneshot::channel();
399 let task_id = task.id;
400
401 self.pending.insert(
402 task_id,
403 PendingTask {
404 sender: tx,
405 worker_id: worker_id.clone(),
406 },
407 );
408
409 let transport_guard = self.transport.read().await;
411 let transport = transport_guard.as_ref().ok_or_else(|| {
412 self.pending.remove(&task_id);
413 self.pool.mark_task_completed(&worker_id);
414 DispatchError::TransportNotStarted
415 })?;
416
417 if let Err(e) = transport
418 .send(&worker_id, Message::TaskDispatch { task })
419 .await
420 {
421 self.pending.remove(&task_id);
422 self.pool.mark_task_completed(&worker_id);
423 return Err(DispatchError::TransportError(e));
424 }
425
426 tracing::debug!(task_id = %task_id, worker_id = %worker_id, tag = %tag, "Task dispatched with tag");
427
428 Ok(DispatchResult {
429 task_id,
430 receiver: rx,
431 })
432 }
433}
434
435#[derive(Debug, thiserror::Error)]
436pub enum DispatchError {
437 #[error("No worker available for task type: {task_type}")]
438 NoWorkerAvailable { task_type: String },
439
440 #[error("Worker disconnected before returning result")]
441 WorkerDisconnected,
442
443 #[error("Task timed out")]
444 Timeout,
445
446 #[error("Transport not started — call start() first")]
447 TransportNotStarted,
448
449 #[error("Transport error: {0}")]
450 TransportError(#[from] TransportError),
451
452 #[error("Pool error: {0}")]
453 PoolError(#[from] PoolError),
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::schema::{TaskResult, TaskStatus};
460 use serde_json::json;
461
462 #[test]
463 fn test_builder_default_port() {
464 let builder = DispatcherBuilder::new();
465 assert_eq!(builder.config.port, 9876);
466 }
467
468 #[test]
469 fn test_builder_default_host() {
470 let builder = DispatcherBuilder::new();
471 assert_eq!(builder.config.host, "0.0.0.0");
472 }
473
474 #[test]
475 fn test_builder_default_heartbeat_timeout() {
476 let builder = DispatcherBuilder::new();
477 assert_eq!(builder.heartbeat_timeout_ms, 15_000);
478 }
479
480 #[test]
481 fn test_builder_host_sets_value() {
482 let builder = DispatcherBuilder::new().host("10.0.0.1");
483 assert_eq!(builder.config.host, "10.0.0.1");
484 }
485
486 #[test]
487 fn test_builder_port_sets_value() {
488 let builder = DispatcherBuilder::new().port(8080);
489 assert_eq!(builder.config.port, 8080);
490 }
491
492 #[test]
493 fn test_builder_max_connections_sets_value() {
494 let builder = DispatcherBuilder::new().max_connections(50);
495 assert_eq!(builder.config.max_connections, 50);
496 }
497
498 #[test]
499 fn test_builder_heartbeat_interval_sets_value() {
500 let builder = DispatcherBuilder::new().heartbeat_interval(2000);
501 assert_eq!(builder.config.heartbeat_interval_ms, 2000);
502 }
503
504 #[test]
505 fn test_builder_heartbeat_timeout_sets_value() {
506 let builder = DispatcherBuilder::new().heartbeat_timeout(30000);
507 assert_eq!(builder.heartbeat_timeout_ms, 30000);
508 }
509
510 #[test]
511 fn test_builder_chaining() {
512 let builder = DispatcherBuilder::new()
513 .host("1.2.3.4")
514 .port(9999)
515 .max_connections(200)
516 .heartbeat_interval(1000)
517 .heartbeat_timeout(5000);
518 assert_eq!(builder.config.host, "1.2.3.4");
519 assert_eq!(builder.config.port, 9999);
520 assert_eq!(builder.config.max_connections, 200);
521 assert_eq!(builder.config.heartbeat_interval_ms, 1000);
522 assert_eq!(builder.heartbeat_timeout_ms, 5000);
523 }
524
525 #[test]
526 fn test_builder_build_pool_starts_empty() {
527 let dispatcher = Dispatcher::builder().build();
528 let stats = dispatcher.pool_stats();
529 assert_eq!(stats.total, 0);
530 }
531
532 #[test]
533 fn test_dispatcher_builder_shortcut() {
534 let builder = Dispatcher::builder();
535 assert_eq!(builder.config.port, 9876);
536 }
537
538 #[tokio::test]
539 async fn test_dispatch_result_await_result_receives_value() {
540 let (tx, rx) = oneshot::channel();
541 let result = DispatchResult {
542 task_id: Uuid::new_v4(),
543 receiver: rx,
544 };
545 let task_result = TaskResult {
546 task_id: result.task_id,
547 status: TaskStatus::Completed,
548 payload: Some(json!({"ok": true})),
549 error: None,
550 duration_ms: 50,
551 worker_id: "test".to_string(),
552 };
553 tx.send(task_result.clone()).unwrap();
554 let received = result.await_result().await.unwrap();
555 assert_eq!(received.task_id, task_result.task_id);
556 assert_eq!(received.status, TaskStatus::Completed);
557 }
558
559 #[tokio::test]
560 async fn test_dispatch_result_worker_disconnected() {
561 let (tx, rx) = oneshot::channel::<TaskResult>();
562 let result = DispatchResult {
563 task_id: Uuid::new_v4(),
564 receiver: rx,
565 };
566 drop(tx);
567 let err = result.await_result().await.unwrap_err();
568 assert!(matches!(err, DispatchError::WorkerDisconnected));
569 }
570
571 #[tokio::test]
572 async fn test_dispatch_result_timeout() {
573 let (_tx, rx) = oneshot::channel::<TaskResult>();
574 let result = DispatchResult {
575 task_id: Uuid::new_v4(),
576 receiver: rx,
577 };
578 let err = result
579 .await_with_timeout(Duration::from_millis(10))
580 .await
581 .unwrap_err();
582 assert!(matches!(err, DispatchError::Timeout));
583 }
584
585 #[test]
586 fn test_dispatch_result_debug_format() {
587 let (_tx, rx) = oneshot::channel::<TaskResult>();
588 let id = Uuid::new_v4();
589 let result = DispatchResult {
590 task_id: id,
591 receiver: rx,
592 };
593 let debug = format!("{:?}", result);
594 assert!(debug.contains("DispatchResult"));
595 assert!(debug.contains(&id.to_string()));
596 }
597
598 #[test]
599 fn test_dispatch_error_display_no_worker() {
600 let err = DispatchError::NoWorkerAvailable {
601 task_type: "scan".into(),
602 };
603 assert_eq!(err.to_string(), "No worker available for task type: scan");
604 }
605
606 #[test]
607 fn test_dispatch_error_display_worker_disconnected() {
608 let err = DispatchError::WorkerDisconnected;
609 assert_eq!(
610 err.to_string(),
611 "Worker disconnected before returning result"
612 );
613 }
614
615 #[test]
616 fn test_dispatch_error_display_timeout() {
617 let err = DispatchError::Timeout;
618 assert_eq!(err.to_string(), "Task timed out");
619 }
620
621 #[test]
622 fn test_dispatch_error_display_transport_not_started() {
623 let err = DispatchError::TransportNotStarted;
624 assert!(err.to_string().contains("Transport not started"));
625 }
626
627 #[test]
628 fn test_dispatch_error_from_transport_error() {
629 let transport_err = TransportError::Closed;
630 let dispatch_err: DispatchError = transport_err.into();
631 assert!(matches!(
632 dispatch_err,
633 DispatchError::TransportError(TransportError::Closed)
634 ));
635 }
636
637 #[test]
642 fn test_builder_max_pool_size() {
643 let builder = DispatcherBuilder::new().max_pool_size(10);
644 assert_eq!(builder.max_pool_size, Some(10));
645 }
646
647 #[test]
648 fn test_builder_min_pool_size() {
649 let builder = DispatcherBuilder::new().min_pool_size(2);
650 assert_eq!(builder.min_pool_size, Some(2));
651 }
652
653 #[test]
654 fn test_builder_on_pool_below_min() {
655 let builder = DispatcherBuilder::new().on_pool_below_min(|_| {});
656 assert!(builder.on_pool_below_min.is_some());
657 }
658
659 #[test]
660 fn test_builder_pool_limits_chaining() {
661 let builder = DispatcherBuilder::new()
662 .max_pool_size(50)
663 .min_pool_size(5)
664 .on_pool_below_min(|_| {});
665 assert_eq!(builder.max_pool_size, Some(50));
666 assert_eq!(builder.min_pool_size, Some(5));
667 assert!(builder.on_pool_below_min.is_some());
668 }
669
670 #[test]
671 fn test_dispatcher_workers_empty() {
672 let dispatcher = Dispatcher::builder().build();
673 assert!(dispatcher.workers().is_empty());
674 }
675
676 #[test]
677 fn test_dispatcher_drain_worker_not_found() {
678 let dispatcher = Dispatcher::builder().build();
679 let err = dispatcher.drain_worker("ghost").unwrap_err();
680 assert!(matches!(err, PoolError::WorkerNotFound { .. }));
681 }
682
683 #[test]
684 fn test_dispatcher_remove_worker_not_found() {
685 let dispatcher = Dispatcher::builder().build();
686 let err = dispatcher.remove_worker("ghost").unwrap_err();
687 assert!(matches!(err, PoolError::WorkerNotFound { .. }));
688 }
689
690 #[test]
691 fn test_builder_debug_format() {
692 let builder = DispatcherBuilder::new()
693 .max_pool_size(10)
694 .min_pool_size(2)
695 .on_pool_below_min(|_| {});
696 let debug = format!("{:?}", builder);
697 assert!(debug.contains("DispatcherBuilder"));
698 assert!(debug.contains("max_pool_size"));
699 assert!(debug.contains("min_pool_size"));
700 }
701}