use anyhow::{Context, Result};
use bincode::{config, Decode, Encode};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::OwnedWriteHalf;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tracing::{debug, error, info};
#[derive(Debug, Clone)]
pub struct WorkerConfig {
pub worker_id: String,
pub master: String,
pub concurrency: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub enum WorkerMessage {
Register {
worker_id: String,
capabilities: WorkerCapabilities,
},
Heartbeat {
worker_id: String,
status: WorkerStatus,
},
ScanRequest {
domains: Vec<String>,
batch_id: String,
},
ScanResult {
worker_id: String,
batch_id: String,
findings: Vec<ScanFinding>,
},
Shutdown { worker_id: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct WorkerCapabilities {
pub max_concurrency: usize,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, Encode, Decode)]
pub struct WorkerStatus {
pub active_scans: usize,
pub completed_scans: usize,
pub findings: usize,
pub uptime_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct ScanFinding {
pub domain: String,
pub rule_name: String,
pub matched_path: String,
pub detected: bool,
}
pub struct ConnectedWorker {
#[allow(dead_code)]
pub id: String,
pub capabilities: WorkerCapabilities,
pub writer: Arc<Mutex<OwnedWriteHalf>>,
pub status: WorkerStatus,
}
lazy_static! {
static ref WORKERS: Mutex<HashMap<String, Arc<ConnectedWorker>>> = Mutex::new(HashMap::new());
}
pub async fn stop_worker(worker_id: &str) -> Result<()> {
let workers = WORKERS.lock().await;
if let Some(worker) = workers.get(worker_id) {
let shutdown_msg = WorkerMessage::Shutdown {
worker_id: worker_id.to_string(),
};
send_message(&worker.writer, &shutdown_msg)
.await
.context(format!(
"Failed to send shutdown message to worker {}",
worker_id
))?;
info!("⏹️ Sent shutdown request to worker: {}", worker_id);
Ok(())
} else {
anyhow::bail!("Worker not found: {}", worker_id)
}
}
pub async fn worker_status() -> Result<()> {
let workers = WORKERS.lock().await;
if workers.is_empty() {
info!("🔍 No workers connected");
return Ok(());
}
info!("🔍 Connected Workers: {}", workers.len());
for (id, worker) in workers.iter() {
info!(
"👷 Worker {}: Active={}, Completed={}, Findings={}, MaxConcurrency={}",
id,
worker.status.active_scans,
worker.status.completed_scans,
worker.status.findings,
worker.capabilities.max_concurrency
);
}
Ok(())
}
pub async fn start_worker(config: &WorkerConfig) -> Result<()> {
info!("🚀 Starting worker node with ID: {}", config.worker_id);
let stream = TcpStream::connect(&config.master)
.await
.context(format!("Failed to connect to master at {}", config.master))?;
let (mut reader, write_half) = stream.into_split();
let writer = Arc::new(Mutex::new(write_half));
let capabilities = WorkerCapabilities {
max_concurrency: config.concurrency,
version: env!("CARGO_PKG_VERSION").to_string(),
};
let register_msg = WorkerMessage::Register {
worker_id: config.worker_id.clone(),
capabilities: capabilities.clone(),
};
send_message(&writer, ®ister_msg)
.await
.context("Failed to register with master")?;
info!("✅ Registered with master at {}", config.master);
loop {
let mut len_bytes = [0u8; 4];
reader
.read_exact(&mut len_bytes)
.await
.context("Failed to read message length")?;
let len = u32::from_be_bytes(len_bytes) as usize;
let mut buffer = vec![0u8; len];
reader
.read_exact(&mut buffer)
.await
.context("Failed to read message")?;
let message: WorkerMessage =
bincode::decode_from_slice(&buffer, bincode::config::standard())
.context("Failed to deserialize message")?
.0;
debug!("📩 Received message: {:?}", message);
match message {
WorkerMessage::ScanRequest { domains, batch_id } => {
info!(
"🔍 Received scan request for {} domains (batch: {})",
domains.len(),
batch_id
);
let _scan_config = config.clone();
let result_msg = WorkerMessage::ScanResult {
worker_id: config.worker_id.clone(),
batch_id,
findings: vec![],
};
send_message(&writer, &result_msg)
.await
.context("Failed to send scan results")?;
}
WorkerMessage::Shutdown { .. } => {
info!("⏹️ Received shutdown request, stopping worker");
break;
}
_ => {
error!("❓ Received unexpected message type");
}
}
}
Ok(())
}
async fn send_message(writer: &Arc<Mutex<OwnedWriteHalf>>, message: &WorkerMessage) -> Result<()> {
let mut writer_guard = writer.lock().await;
let config = config::standard();
let encoded = bincode::encode_to_vec(message, config)?;
let msg_len = encoded.len() as u32;
writer_guard.write_all(&msg_len.to_be_bytes()).await?;
writer_guard.write_all(&encoded).await?;
writer_guard.flush().await?;
Ok(())
}
#[allow(dead_code)]
async fn read_message(stream: &mut TcpStream) -> Result<WorkerMessage> {
let mut len_bytes = [0u8; 4];
stream.read_exact(&mut len_bytes).await?;
let msg_len = u32::from_be_bytes(len_bytes) as usize;
let mut buffer = vec![0u8; msg_len];
stream.read_exact(&mut buffer).await?;
let config = config::standard();
let (message, _): (WorkerMessage, _) = bincode::decode_from_slice(&buffer, config)?;
Ok(message)
}
#[allow(dead_code)]
pub async fn start_master(
listen_addr: &str,
_scan_config: crate::config::ScanConfig,
) -> Result<()> {
info!("🌐 Starting master node on {}", listen_addr);
let listener = TcpListener::bind(listen_addr)
.await
.context(format!("Failed to bind to {}", listen_addr))?;
info!("✅ Master node started, waiting for workers to connect");
let workers = Arc::new(Mutex::new(Vec::new()));
loop {
let (socket, addr) = listener
.accept()
.await
.context("Failed to accept connection")?;
info!("✅ New connection from: {}", addr);
let workers_clone = workers.clone();
tokio::spawn(async move {
if let Err(e) = handle_worker_connection(socket, workers_clone).await {
error!("❌ Error handling worker connection: {}", e);
}
});
}
}
#[allow(dead_code)]
async fn handle_worker_connection(
mut stream: TcpStream,
_workers: Arc<Mutex<Vec<ConnectedWorker>>>,
) -> Result<()> {
info!("🔌 Worker connected from: {}", stream.peer_addr()?);
let message = read_message(&mut stream).await?;
match message {
WorkerMessage::Register {
worker_id,
capabilities,
} => {
info!(
"👷 Worker registered: {} (concurrency={})",
worker_id, capabilities.max_concurrency
);
let (_read_half, write_half) = stream.into_split();
let worker = Arc::new(ConnectedWorker {
id: worker_id.clone(),
capabilities,
writer: Arc::new(Mutex::new(write_half)),
status: WorkerStatus::default(),
});
{
let mut workers = WORKERS.lock().await;
workers.insert(worker_id.clone(), worker.clone());
}
let heartbeat = WorkerMessage::Heartbeat {
worker_id: worker_id.clone(),
status: WorkerStatus::default(),
};
send_message(&worker.writer, &heartbeat).await?;
Ok(())
}
_ => {
error!("❌ Expected Register message from worker, got something else");
anyhow::bail!("Invalid initial message from worker")
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MasterMessage {
RegisterResponse {
accepted: bool,
message: Option<String>,
},
WorkAssignment {
batch_id: String,
domains: Vec<String>,
rules: Vec<ScanRule>,
},
NoWorkAvailable,
Shutdown { reason: Option<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScanRule {
pub name: String,
pub paths: Vec<String>,
pub severity: String,
}