#![allow(clippy::collapsible_match)]
use rust_pipe::dispatch::{DispatchError, Dispatcher};
use rust_pipe::schema::{Priority, Task, TaskResult, TaskStatus};
use rust_pipe::transport::{
BackpressureSignal, HeartbeatPayload, Message, WorkerLanguage, WorkerRegistration,
};
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use std::time::Duration;
use tokio_tungstenite::connect_async;
#[tokio::test]
async fn test_dispatcher_starts_and_accepts_connections() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19876).build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let (ws_stream, _) = connect_async("ws://127.0.0.1:19876")
.await
.expect("Failed to connect");
let (mut write, _read) = ws_stream.split();
let register_msg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: "test-worker-1".to_string(),
supported_tasks: vec!["test-task".to_string()],
max_concurrency: 5,
language: WorkerLanguage::TypeScript,
tags: None,
},
};
let json_str = serde_json::to_string(®ister_msg).unwrap();
write
.send(tokio_tungstenite::tungstenite::Message::Text(json_str))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
let stats = dispatcher.pool_stats();
assert_eq!(stats.total, 1, "Expected 1 worker registered");
assert_eq!(stats.active, 1, "Expected 1 active worker");
}
#[tokio::test]
async fn test_full_task_dispatch_and_result() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19877).build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let (ws_stream, _) = connect_async("ws://127.0.0.1:19877")
.await
.expect("Failed to connect");
let (mut write, mut read) = ws_stream.split();
let register_msg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: "e2e-worker".to_string(),
supported_tasks: vec!["scan-target".to_string()],
max_concurrency: 5,
language: WorkerLanguage::TypeScript,
tags: None,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(®ister_msg).unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
let task = Task::new("scan-target", json!({"url": "https://example.com"}))
.with_timeout(10_000)
.with_priority(Priority::High);
let task_id = task.id;
let handle = dispatcher.dispatch(task).await.unwrap();
assert_eq!(handle.task_id, task_id);
if let Some(Ok(msg)) = read.next().await {
if let tokio_tungstenite::tungstenite::Message::Text(text) = msg {
let received: Message = serde_json::from_str(&text).unwrap();
if let Message::TaskDispatch {
task: received_task,
} = received
{
assert_eq!(received_task.id, task_id);
assert_eq!(received_task.task_type, "scan-target");
let result = TaskResult {
task_id,
status: TaskStatus::Completed,
payload: Some(json!({"vulnerabilities": 3})),
error: None,
duration_ms: 1500,
worker_id: "e2e-worker".to_string(),
};
let result_msg = Message::TaskResult { result };
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(&result_msg).unwrap(),
))
.await
.unwrap();
}
}
}
let result = handle
.await_with_timeout(Duration::from_secs(5))
.await
.unwrap();
assert_eq!(result.status, TaskStatus::Completed);
assert_eq!(result.worker_id, "e2e-worker");
assert_eq!(result.duration_ms, 1500);
assert_eq!(result.payload.unwrap()["vulnerabilities"], 3);
}
#[tokio::test]
async fn test_no_worker_available_error() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19878).build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let task = Task::new("nonexistent-task", json!({}));
let result = dispatcher.dispatch(task).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, DispatchError::NoWorkerAvailable { .. }),
"Expected NoWorkerAvailable, got: {err:?}"
);
}
#[tokio::test]
async fn test_websocket_worker_disconnect_cleanup() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19910).build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let (ws, _) = connect_async("ws://127.0.0.1:19910").await.unwrap();
let (mut write, _read) = ws.split();
let reg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: "disconnect-test".to_string(),
supported_tasks: vec!["x".to_string()],
max_concurrency: 1,
language: WorkerLanguage::TypeScript,
tags: None,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(®).unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(dispatcher.pool_stats().total, 1);
write.close().await.ok();
tokio::time::sleep(Duration::from_millis(200)).await;
}
#[tokio::test]
async fn test_websocket_invalid_json_ignored() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19911).build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let (ws, _) = connect_async("ws://127.0.0.1:19911").await.unwrap();
let (mut write, _read) = ws.split();
write
.send(tokio_tungstenite::tungstenite::Message::Text(
"not valid json {{{".to_string(),
))
.await
.unwrap();
let reg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: "after-garbage".to_string(),
supported_tasks: vec!["y".to_string()],
max_concurrency: 2,
language: WorkerLanguage::Python,
tags: None,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(®).unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(dispatcher.pool_stats().total, 1);
}
#[tokio::test]
async fn test_websocket_multiple_workers_concurrent() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19912).build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let mut handles = vec![];
for i in 0..3 {
let handle = tokio::spawn(async move {
let (ws, _) = connect_async("ws://127.0.0.1:19912").await.unwrap();
let (mut write, _read) = ws.split();
let reg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: format!("concurrent-{}", i),
supported_tasks: vec!["work".to_string()],
max_concurrency: 5,
language: WorkerLanguage::Go,
tags: None,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(®).unwrap(),
))
.await
.unwrap();
write
});
handles.push(handle);
}
for h in handles {
let _ = h.await.unwrap();
}
tokio::time::sleep(Duration::from_millis(300)).await;
assert_eq!(dispatcher.pool_stats().total, 3);
}
#[tokio::test]
async fn test_dispatch_without_start_returns_error() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19913).build();
let task = Task::new("x", json!({}));
let result = dispatcher.dispatch(task).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_dispatcher_handles_backpressure_signal() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19923).build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let (ws, _) = connect_async("ws://127.0.0.1:19923").await.unwrap();
let (mut write, _) = ws.split();
let reg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: "bp-worker".to_string(),
supported_tasks: vec!["x".to_string()],
max_concurrency: 1,
language: WorkerLanguage::TypeScript,
tags: None,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(®).unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
let bp = Message::Backpressure {
signal: BackpressureSignal {
worker_id: "bp-worker".to_string(),
current_load: 0.95,
should_throttle: true,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(&bp).unwrap(),
))
.await
.unwrap();
let hb = Message::Heartbeat {
payload: HeartbeatPayload {
worker_id: "bp-worker".to_string(),
active_tasks: 1,
capacity: 1,
uptime_seconds: 10,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(&hb).unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
}
#[tokio::test]
async fn test_dispatcher_dead_worker_detection() {
let dispatcher = Dispatcher::builder()
.host("127.0.0.1")
.port(19934)
.heartbeat_timeout(50) .build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let (ws, _) = connect_async("ws://127.0.0.1:19934").await.unwrap();
let (mut write, _read) = ws.split();
let reg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: "mortal-worker".to_string(),
supported_tasks: vec!["x".to_string()],
max_concurrency: 1,
language: WorkerLanguage::TypeScript,
tags: None,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(®).unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let stats = dispatcher.pool_stats();
assert_eq!(stats.total, 1);
assert_eq!(stats.active, 1);
tokio::time::sleep(Duration::from_millis(200)).await;
drop(write);
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[tokio::test]
async fn test_dispatcher_dispatch_before_start() {
let dispatcher = Dispatcher::builder().host("127.0.0.1").port(19935).build();
let task = Task::new("x", json!({}));
let result = dispatcher.dispatch(task).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_dispatch_dead_worker_detection_loop_fires() {
let dispatcher = Dispatcher::builder()
.host("127.0.0.1")
.port(19932)
.heartbeat_timeout(100) .build();
dispatcher.start().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let (ws, _) = connect_async("ws://127.0.0.1:19932").await.unwrap();
let (mut write, _read) = ws.split();
let reg = Message::WorkerRegister {
registration: WorkerRegistration {
worker_id: "dying-worker".to_string(),
supported_tasks: vec!["x".to_string()],
max_concurrency: 1,
language: WorkerLanguage::TypeScript,
tags: None,
},
};
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(®).unwrap(),
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(dispatcher.pool_stats().total, 1);
tokio::time::sleep(Duration::from_millis(6000)).await;
let stats = dispatcher.pool_stats();
assert_eq!(
stats.dead, 1,
"Worker should be marked dead after heartbeat timeout"
);
}