1use dashmap::DashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::{oneshot, RwLock};
6use tokio::task::JoinHandle;
7use uuid::Uuid;
8
9use crate::schema::{Task, TaskResult};
10use crate::transport::websocket::WebSocketTransport;
11use crate::transport::{Message, Transport, TransportConfig, TransportError};
12use crate::worker::{WorkerInfo, WorkerPool, WorkerStatus};
13
14#[derive(Debug, Clone)]
16pub struct DispatcherBuilder {
17 config: TransportConfig,
18 heartbeat_timeout_ms: u64,
19 dead_worker_check_interval_ms: u64,
20}
21
22impl Default for DispatcherBuilder {
23 fn default() -> Self {
24 Self {
25 config: TransportConfig::default(),
26 heartbeat_timeout_ms: 15_000,
27 dead_worker_check_interval_ms: 5_000,
28 }
29 }
30}
31
32impl DispatcherBuilder {
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn host(mut self, host: impl Into<String>) -> Self {
38 self.config.host = host.into();
39 self
40 }
41
42 pub fn port(mut self, port: u16) -> Self {
43 self.config.port = port;
44 self
45 }
46
47 pub fn max_connections(mut self, max: u32) -> Self {
48 self.config.max_connections = max;
49 self
50 }
51
52 pub fn heartbeat_interval(mut self, ms: u64) -> Self {
53 self.config.heartbeat_interval_ms = ms;
54 self
55 }
56
57 pub fn heartbeat_timeout(mut self, ms: u64) -> Self {
58 self.heartbeat_timeout_ms = ms;
59 self
60 }
61
62 pub fn build(self) -> Dispatcher {
63 Dispatcher {
64 pool: Arc::new(WorkerPool::new(self.heartbeat_timeout_ms)),
65 pending: Arc::new(DashMap::new()),
66 transport: Arc::new(RwLock::new(None)),
67 config: self.config,
68 dead_worker_check_interval_ms: self.dead_worker_check_interval_ms,
69 started: AtomicBool::new(false),
70 _dead_worker_task: RwLock::new(None),
71 }
72 }
73}
74
75pub struct Dispatcher {
77 pool: Arc<WorkerPool>,
78 pending: Arc<DashMap<Uuid, PendingTask>>,
79 transport: Arc<RwLock<Option<Arc<WebSocketTransport>>>>,
80 config: TransportConfig,
81 dead_worker_check_interval_ms: u64,
82 started: AtomicBool,
83 _dead_worker_task: RwLock<Option<JoinHandle<()>>>,
84}
85
86struct PendingTask {
87 sender: oneshot::Sender<TaskResult>,
88 worker_id: String,
89}
90
91#[must_use = "dropping a DispatchResult discards the task result"]
93pub struct DispatchResult {
94 pub task_id: Uuid,
95 pub(crate) receiver: oneshot::Receiver<TaskResult>,
96}
97
98impl std::fmt::Debug for DispatchResult {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("DispatchResult")
101 .field("task_id", &self.task_id)
102 .finish()
103 }
104}
105
106impl DispatchResult {
107 pub async fn await_result(self) -> Result<TaskResult, DispatchError> {
108 self.receiver
109 .await
110 .map_err(|_| DispatchError::WorkerDisconnected)
111 }
112
113 pub async fn await_with_timeout(self, timeout: Duration) -> Result<TaskResult, DispatchError> {
114 tokio::time::timeout(timeout, self.receiver)
115 .await
116 .map_err(|_| DispatchError::Timeout)?
117 .map_err(|_| DispatchError::WorkerDisconnected)
118 }
119}
120
121impl Dispatcher {
122 pub fn builder() -> DispatcherBuilder {
123 DispatcherBuilder::new()
124 }
125
126 pub async fn start(&self) -> Result<(), DispatchError> {
127 if self.started.swap(true, Ordering::SeqCst) {
128 return Ok(());
129 }
130
131 let pool = self.pool.clone();
132 let pending = self.pending.clone();
133
134 let on_message = move |worker_id: String, message: Message| {
135 let pool = pool.clone();
136 let pending = pending.clone();
137
138 tokio::spawn(async move {
139 match message {
140 Message::WorkerRegister { registration: reg } => {
141 pool.register(WorkerInfo {
142 id: reg.worker_id,
143 language: reg.language,
144 supported_tasks: reg.supported_tasks,
145 max_concurrency: reg.max_concurrency,
146 status: WorkerStatus::Active,
147 active_tasks: 0,
148 registered_at: chrono::Utc::now(),
149 last_heartbeat: chrono::Utc::now(),
150 });
151 }
152 Message::TaskResult { result } => {
153 pool.mark_task_completed(&worker_id);
154 if let Some((_, pending_task)) = pending.remove(&result.task_id) {
155 let _ = pending_task.sender.send(result);
156 }
157 }
158 Message::Heartbeat { payload: hb } => {
159 pool.heartbeat(&hb.worker_id, hb.active_tasks);
160 }
161 Message::Backpressure { signal: bp } => {
162 tracing::warn!(
163 worker_id = %bp.worker_id,
164 load = bp.current_load,
165 "Worker signaled backpressure"
166 );
167 }
168 _ => {}
169 }
170 });
171 };
172
173 let transport = Arc::new(WebSocketTransport::new(self.config.clone(), on_message));
174 transport
175 .start()
176 .await
177 .map_err(DispatchError::TransportError)?;
178
179 *self.transport.write().await = Some(transport);
180
181 let pool = self.pool.clone();
184 let pending = self.pending.clone();
185 let interval = self.dead_worker_check_interval_ms;
186 let handle = tokio::spawn(async move {
187 loop {
188 tokio::time::sleep(Duration::from_millis(interval)).await;
189 let dead = pool.detect_dead_workers();
190 if !dead.is_empty() {
191 for worker_id in &dead {
192 tracing::warn!(worker_id = %worker_id, "Dead worker detected");
193 }
194 pending
197 .retain(|_task_id, pending_task| !dead.contains(&pending_task.worker_id));
198 }
199 }
200 });
201
202 *self._dead_worker_task.write().await = Some(handle);
203
204 Ok(())
205 }
206
207 pub async fn stop(&self) {
209 self.started.store(false, Ordering::SeqCst);
210 if let Some(handle) = self._dead_worker_task.write().await.take() {
212 handle.abort();
213 }
214 if let Some(transport) = self.transport.read().await.as_ref() {
216 let _ = transport.stop().await;
217 }
218 self.pending.clear();
220 }
221
222 pub async fn dispatch(&self, task: Task) -> Result<DispatchResult, DispatchError> {
223 let worker_id = self.pool.select_and_reserve(&task.task_type).ok_or(
225 DispatchError::NoWorkerAvailable {
226 task_type: task.task_type.clone(),
227 },
228 )?;
229
230 let (tx, rx) = oneshot::channel();
231 let task_id = task.id;
232
233 self.pending.insert(
234 task_id,
235 PendingTask {
236 sender: tx,
237 worker_id: worker_id.clone(),
238 },
239 );
240
241 let transport_guard = self.transport.read().await;
243 let transport = transport_guard.as_ref().ok_or_else(|| {
244 self.pending.remove(&task_id);
246 self.pool.mark_task_completed(&worker_id);
247 DispatchError::TransportNotStarted
248 })?;
249
250 if let Err(e) = transport
251 .send(&worker_id, Message::TaskDispatch { task })
252 .await
253 {
254 self.pending.remove(&task_id);
256 self.pool.mark_task_completed(&worker_id);
257 return Err(DispatchError::TransportError(e));
258 }
259
260 tracing::debug!(task_id = %task_id, worker_id = %worker_id, "Task dispatched");
261
262 Ok(DispatchResult {
263 task_id,
264 receiver: rx,
265 })
266 }
267
268 pub fn pool_stats(&self) -> crate::worker::PoolStats {
269 self.pool.stats()
270 }
271}
272
273#[derive(Debug, thiserror::Error)]
274pub enum DispatchError {
275 #[error("No worker available for task type: {task_type}")]
276 NoWorkerAvailable { task_type: String },
277
278 #[error("Worker disconnected before returning result")]
279 WorkerDisconnected,
280
281 #[error("Task timed out")]
282 Timeout,
283
284 #[error("Transport not started — call start() first")]
285 TransportNotStarted,
286
287 #[error("Transport error: {0}")]
288 TransportError(#[from] TransportError),
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::schema::{TaskResult, TaskStatus};
295 use serde_json::json;
296
297 #[test]
298 fn test_builder_default_port() {
299 let builder = DispatcherBuilder::new();
300 assert_eq!(builder.config.port, 9876);
301 }
302
303 #[test]
304 fn test_builder_default_host() {
305 let builder = DispatcherBuilder::new();
306 assert_eq!(builder.config.host, "0.0.0.0");
307 }
308
309 #[test]
310 fn test_builder_default_heartbeat_timeout() {
311 let builder = DispatcherBuilder::new();
312 assert_eq!(builder.heartbeat_timeout_ms, 15_000);
313 }
314
315 #[test]
316 fn test_builder_host_sets_value() {
317 let builder = DispatcherBuilder::new().host("10.0.0.1");
318 assert_eq!(builder.config.host, "10.0.0.1");
319 }
320
321 #[test]
322 fn test_builder_port_sets_value() {
323 let builder = DispatcherBuilder::new().port(8080);
324 assert_eq!(builder.config.port, 8080);
325 }
326
327 #[test]
328 fn test_builder_max_connections_sets_value() {
329 let builder = DispatcherBuilder::new().max_connections(50);
330 assert_eq!(builder.config.max_connections, 50);
331 }
332
333 #[test]
334 fn test_builder_heartbeat_interval_sets_value() {
335 let builder = DispatcherBuilder::new().heartbeat_interval(2000);
336 assert_eq!(builder.config.heartbeat_interval_ms, 2000);
337 }
338
339 #[test]
340 fn test_builder_heartbeat_timeout_sets_value() {
341 let builder = DispatcherBuilder::new().heartbeat_timeout(30000);
342 assert_eq!(builder.heartbeat_timeout_ms, 30000);
343 }
344
345 #[test]
346 fn test_builder_chaining() {
347 let builder = DispatcherBuilder::new()
348 .host("1.2.3.4")
349 .port(9999)
350 .max_connections(200)
351 .heartbeat_interval(1000)
352 .heartbeat_timeout(5000);
353 assert_eq!(builder.config.host, "1.2.3.4");
354 assert_eq!(builder.config.port, 9999);
355 assert_eq!(builder.config.max_connections, 200);
356 assert_eq!(builder.config.heartbeat_interval_ms, 1000);
357 assert_eq!(builder.heartbeat_timeout_ms, 5000);
358 }
359
360 #[test]
361 fn test_builder_build_pool_starts_empty() {
362 let dispatcher = Dispatcher::builder().build();
363 let stats = dispatcher.pool_stats();
364 assert_eq!(stats.total, 0);
365 }
366
367 #[test]
368 fn test_dispatcher_builder_shortcut() {
369 let builder = Dispatcher::builder();
370 assert_eq!(builder.config.port, 9876);
371 }
372
373 #[tokio::test]
374 async fn test_dispatch_result_await_result_receives_value() {
375 let (tx, rx) = oneshot::channel();
376 let result = DispatchResult {
377 task_id: Uuid::new_v4(),
378 receiver: rx,
379 };
380 let task_result = TaskResult {
381 task_id: result.task_id,
382 status: TaskStatus::Completed,
383 payload: Some(json!({"ok": true})),
384 error: None,
385 duration_ms: 50,
386 worker_id: "test".to_string(),
387 };
388 tx.send(task_result.clone()).unwrap();
389 let received = result.await_result().await.unwrap();
390 assert_eq!(received.task_id, task_result.task_id);
391 assert_eq!(received.status, TaskStatus::Completed);
392 }
393
394 #[tokio::test]
395 async fn test_dispatch_result_worker_disconnected() {
396 let (tx, rx) = oneshot::channel::<TaskResult>();
397 let result = DispatchResult {
398 task_id: Uuid::new_v4(),
399 receiver: rx,
400 };
401 drop(tx);
402 let err = result.await_result().await.unwrap_err();
403 assert!(matches!(err, DispatchError::WorkerDisconnected));
404 }
405
406 #[tokio::test]
407 async fn test_dispatch_result_timeout() {
408 let (_tx, rx) = oneshot::channel::<TaskResult>();
409 let result = DispatchResult {
410 task_id: Uuid::new_v4(),
411 receiver: rx,
412 };
413 let err = result
414 .await_with_timeout(Duration::from_millis(10))
415 .await
416 .unwrap_err();
417 assert!(matches!(err, DispatchError::Timeout));
418 }
419
420 #[test]
421 fn test_dispatch_result_debug_format() {
422 let (_tx, rx) = oneshot::channel::<TaskResult>();
423 let id = Uuid::new_v4();
424 let result = DispatchResult {
425 task_id: id,
426 receiver: rx,
427 };
428 let debug = format!("{:?}", result);
429 assert!(debug.contains("DispatchResult"));
430 assert!(debug.contains(&id.to_string()));
431 }
432
433 #[test]
434 fn test_dispatch_error_display_no_worker() {
435 let err = DispatchError::NoWorkerAvailable {
436 task_type: "scan".into(),
437 };
438 assert_eq!(err.to_string(), "No worker available for task type: scan");
439 }
440
441 #[test]
442 fn test_dispatch_error_display_worker_disconnected() {
443 let err = DispatchError::WorkerDisconnected;
444 assert_eq!(
445 err.to_string(),
446 "Worker disconnected before returning result"
447 );
448 }
449
450 #[test]
451 fn test_dispatch_error_display_timeout() {
452 let err = DispatchError::Timeout;
453 assert_eq!(err.to_string(), "Task timed out");
454 }
455
456 #[test]
457 fn test_dispatch_error_display_transport_not_started() {
458 let err = DispatchError::TransportNotStarted;
459 assert!(err.to_string().contains("Transport not started"));
460 }
461
462 #[test]
463 fn test_dispatch_error_from_transport_error() {
464 let transport_err = TransportError::Closed;
465 let dispatch_err: DispatchError = transport_err.into();
466 assert!(matches!(
467 dispatch_err,
468 DispatchError::TransportError(TransportError::Closed)
469 ));
470 }
471}