use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
#[cfg(not(test))]
use std::time::Duration;
use serde_json::Value;
use tracing::info;
use crate::framing::Encoding;
use crate::job_registry::{EventSender, JobRegistry};
use crate::protocol::{ErrorPayload, WireRequest, WireResponse};
const DEFAULT_SHUTDOWN_DELAY_MS: u64 = 100;
const MAX_SHUTDOWN_DELAY_MS: u64 = 5 * 60 * 1000;
pub trait WorkerHandler: Send + Sync + 'static {
fn handle_method(
&self,
req_id: &str,
method: &str,
params: Option<Value>,
event_tx: EventSender,
registry: Arc<JobRegistry>,
) -> WireResponse;
fn worker_version(&self) -> &str;
fn features(&self) -> Vec<String>;
fn max_concurrent_jobs(&self) -> u32 { 1 }
}
pub struct BaseDispatcher<H: WorkerHandler> {
pub handler: H,
pub registry: Arc<JobRegistry>,
negotiated_encoding: Encoding,
start_time: Instant,
accepting: AtomicBool,
}
impl<H: WorkerHandler> BaseDispatcher<H> {
pub fn new(handler: H, negotiated_encoding: Encoding) -> Self {
Self {
handler,
registry: Arc::new(JobRegistry::new()),
negotiated_encoding,
start_time: Instant::now(),
accepting: AtomicBool::new(true),
}
}
pub async fn dispatch(&self, req: WireRequest, event_tx: EventSender) -> WireResponse {
let id = req.id.clone();
let method = req.method.as_str();
let params = req.params;
match method {
"ping" => ok_response(&id, serde_json::json!({"pong": true})),
"health" => {
let active = self.registry.active_count();
let status = if active > 0 { "busy" } else { "ok" };
ok_response(&id, serde_json::json!({
"status": status,
"active_jobs": active,
"uptime_secs": self.start_time.elapsed().as_secs(),
"pid": std::process::id(),
"version": self.handler.worker_version(),
}))
}
"capabilities" => ok_response(&id, serde_json::json!({
"version": self.handler.worker_version(),
"protocol_version": 1,
"features": self.handler.features(),
"max_concurrent_jobs": self.handler.max_concurrent_jobs(),
"encoding": self.negotiated_encoding.wire_name(),
})),
"shutdown" => {
self.accepting.store(false, Ordering::SeqCst);
let delay_ms = shutdown_delay_ms(¶ms);
info!(delay_ms, "shutdown requested");
#[cfg(not(test))]
{
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
std::process::exit(0);
});
}
ok_response(&id, serde_json::json!({"bye": true}))
}
"cancel_job" => {
let job_id = str_param(¶ms, "job_id").unwrap_or_default();
if self.registry.cancel(job_id) {
ok_response(&id, serde_json::json!({"cancelled": true, "job_id": job_id}))
} else {
err_response(&id, "JOB_NOT_FOUND", &format!("job {job_id} not found"))
}
}
"job_status" => {
let job_id = str_param(¶ms, "job_id").unwrap_or_default();
match self.registry.status(job_id) {
Some(s) => ok_response(&id, serde_json::to_value(s).unwrap_or_default()),
None => err_response(&id, "JOB_NOT_FOUND", &format!("job {job_id} not found")),
}
}
_ => self.handler.handle_method(&id, method, params, event_tx, self.registry.clone()),
}
}
}
pub fn ok_response(req_id: &str, payload: Value) -> WireResponse {
WireResponse {
id: req_id.to_owned(),
ok: true,
error: None,
payload: Some(payload),
}
}
pub fn err_response(req_id: &str, code: &str, message: &str) -> WireResponse {
WireResponse {
id: req_id.to_owned(),
ok: false,
error: Some(ErrorPayload { code: code.to_owned(), message: message.to_owned(), detail: String::new() }),
payload: None,
}
}
pub fn unknown_method(req_id: &str, method: &str) -> WireResponse {
err_response(req_id, "UNKNOWN_METHOD", &format!("unknown method: {method}"))
}
fn str_param<'a>(params: &'a Option<Value>, key: &str) -> Option<&'a str> {
params.as_ref()?.get(key)?.as_str()
}
fn shutdown_delay_ms(params: &Option<Value>) -> u64 {
let Some(v) = params.as_ref().and_then(|p| p.get("delay_ms")) else {
return DEFAULT_SHUTDOWN_DELAY_MS;
};
if let Some(u) = v.as_u64() {
return u.min(MAX_SHUTDOWN_DELAY_MS);
}
if let Some(f) = v.as_f64() {
if f.is_finite() && f >= 0.0 {
return f.min(MAX_SHUTDOWN_DELAY_MS as f64) as u64;
}
}
DEFAULT_SHUTDOWN_DELAY_MS
}