use crate::a2a::worker::{HeartbeatState, WorkerStatus};
use crate::bus::{AgentBus, BusEnvelope};
use crate::cli::WorkerServerArgs;
use crate::cloudevents::parse_cloud_event;
use anyhow::Result;
use axum::{
Json, Router,
body::Body,
extract::State,
http::{HeaderMap, Request, StatusCode},
response::sse::{Event, KeepAlive, Sse},
response::{IntoResponse, Response},
routing::{get, post},
};
use futures::stream;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::{Mutex, broadcast, mpsc};
#[derive(Clone)]
pub struct WorkerServerState {
pub heartbeat_state: Option<Arc<HeartbeatState>>,
pub connected: Arc<Mutex<bool>>,
pub worker_id: Arc<Mutex<Option<String>>>,
internal_heartbeat: Arc<Mutex<Option<Arc<HeartbeatState>>>>,
task_notification_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
bus: Arc<Mutex<Option<Arc<AgentBus>>>>,
}
impl Default for WorkerServerState {
fn default() -> Self {
Self::new()
}
}
impl WorkerServerState {
pub fn new() -> Self {
Self {
heartbeat_state: None,
connected: Arc::new(Mutex::new(false)),
worker_id: Arc::new(Mutex::new(None)),
internal_heartbeat: Arc::new(Mutex::new(None)),
task_notification_tx: Arc::new(Mutex::new(None)),
bus: Arc::new(Mutex::new(None)),
}
}
pub async fn set_bus(&self, bus: Arc<AgentBus>) {
*self.bus.lock().await = Some(bus);
}
pub async fn set_task_notification_channel(&self, tx: mpsc::Sender<String>) {
*self.task_notification_tx.lock().await = Some(tx);
}
pub async fn notify_new_task(&self, task_id: &str) {
if let Some(ref tx) = *self.task_notification_tx.lock().await {
let _ = tx.send(task_id.to_string()).await;
tracing::debug!("Notified worker of new task: {}", task_id);
}
}
#[allow(dead_code)]
pub fn with_heartbeat(mut self, state: Option<Arc<HeartbeatState>>) -> Self {
self.heartbeat_state = state.clone();
self
}
pub async fn set_heartbeat_state(&self, state: Arc<HeartbeatState>) {
*self.internal_heartbeat.lock().await = Some(state);
}
#[allow(dead_code)]
pub async fn heartbeat_state(&self) -> Arc<HeartbeatState> {
let guard: Option<Arc<HeartbeatState>> = self.internal_heartbeat.lock().await.clone();
guard.unwrap_or_else(|| {
let state = HeartbeatState::new("unknown".to_string(), "unknown".to_string());
Arc::new(state)
})
}
pub async fn set_connected(&self, connected: bool) {
*self.connected.lock().await = connected;
}
pub async fn set_worker_id(&self, worker_id: String) {
*self.worker_id.lock().await = Some(worker_id);
}
#[allow(dead_code)]
pub async fn worker_id(&self) -> String {
self.worker_id.lock().await.clone().unwrap_or_default()
}
#[allow(dead_code)]
pub async fn is_connected(&self) -> bool {
*self.connected.lock().await
}
pub async fn is_ready(&self) -> bool {
*self.connected.lock().await
}
}
pub async fn start_worker_server_with_state(
args: WorkerServerArgs,
state: WorkerServerState,
) -> Result<()> {
let addr = format!("{}:{}", args.hostname, args.port);
tracing::info!("Starting worker HTTP server on http://{}", addr);
let app = Router::new()
.route("/health", get(health))
.route("/ready", get(ready))
.route("/task", post(receive_task))
.route("/worker/status", get(worker_status))
.route("/v1/bus/stream", get(stream_bus_events))
.route("/v1/bus/publish", post(publish_bus_event))
.with_state(state);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("Worker HTTP server listening on http://{}", addr);
axum::serve(listener, app).await?;
Ok(())
}
async fn health() -> &'static str {
"ok"
}
async fn ready(State(state): State<WorkerServerState>) -> (StatusCode, String) {
if state.is_ready().await {
(StatusCode::OK, "ready".to_string())
} else {
(StatusCode::SERVICE_UNAVAILABLE, "not connected".to_string())
}
}
async fn worker_status(State(state): State<WorkerServerState>) -> Json<WorkerStatusResponse> {
let connected = *state.connected.lock().await;
let worker_id = state.worker_id.lock().await.clone();
let heartbeat_state = state
.internal_heartbeat
.lock()
.await
.clone()
.or_else(|| state.heartbeat_state.clone());
let heartbeat_info = if let Some(ref hb_state) = heartbeat_state {
let status: WorkerStatus = *hb_state.status.lock().await;
let task_count = hb_state.active_task_count.lock().await;
Some(HeartbeatInfo {
status: status.as_str().to_string(),
active_tasks: *task_count,
agent_name: hb_state.agent_name.clone(),
})
} else {
None
};
Json(WorkerStatusResponse {
connected,
worker_id,
heartbeat: heartbeat_info,
})
}
async fn receive_task(
State(state): State<WorkerServerState>,
headers: HeaderMap,
Json(payload): Json<serde_json::Value>,
) -> StatusCode {
let event = match parse_cloud_event(&headers, payload) {
Ok(event) => event,
Err(error) => {
tracing::warn!("Rejected task event: {}", error);
return StatusCode::BAD_REQUEST;
}
};
let task_id = event
.data
.get("task_id")
.or_else(|| event.data.get("id"))
.and_then(|v| v.as_str())
.unwrap_or(event.id.as_str());
tracing::info!(
"Received task via CloudEvent: {} ({})",
task_id,
event.event_type
);
state.notify_new_task(task_id).await;
StatusCode::ACCEPTED
}
async fn stream_bus_events(State(state): State<WorkerServerState>, req: Request<Body>) -> Response {
let bus = state.bus.lock().await.clone();
let Some(bus) = bus else {
let empty = stream::empty::<Result<Event, Infallible>>();
return Sse::new(empty)
.keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)))
.into_response();
};
let topic_filter: Option<String> = req.uri().query().and_then(|q| {
q.split('&')
.filter_map(|pair| pair.split_once('='))
.find(|(k, _)| *k == "topic")
.map(|(_, v)| v.to_owned())
});
let bus_handle = bus.handle("worker_server_bus_stream");
let rx: broadcast::Receiver<BusEnvelope> = bus_handle.into_receiver();
let event_stream = stream::unfold(rx, move |mut rx| {
let filter = topic_filter.clone();
async move {
match rx.recv().await {
Ok(envelope) => {
let allowed = filter
.as_deref()
.map(|pat| bus_topic_matches(&envelope.topic, pat))
.unwrap_or(true);
if allowed {
let payload =
serde_json::to_string(&envelope).unwrap_or_else(|_| "{}".to_string());
Some((
Ok::<Event, Infallible>(Event::default().event("bus").data(payload)),
rx,
))
} else {
Some((
Ok::<Event, Infallible>(Event::default().event("keepalive").data("")),
rx,
))
}
}
Err(broadcast::error::RecvError::Lagged(n)) => Some((
Ok(Event::default().event("lag").data(format!("skipped {}", n))),
rx,
)),
Err(broadcast::error::RecvError::Closed) => None,
}
}
});
Sse::new(event_stream)
.keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)))
.into_response()
}
fn bus_topic_matches(topic: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix(".*") {
return topic.starts_with(prefix);
}
topic == pattern
}
#[derive(Deserialize)]
struct BusPublishRequest {
topic: String,
payload: serde_json::Value,
}
async fn publish_bus_event(
State(state): State<WorkerServerState>,
Json(req): Json<BusPublishRequest>,
) -> StatusCode {
let bus = state.bus.lock().await.clone();
let Some(bus) = bus else {
return StatusCode::SERVICE_UNAVAILABLE;
};
let handle = bus.handle("worker_server_publish");
handle.send(
&req.topic,
crate::bus::BusMessage::SharedResult {
key: req.topic.clone(),
value: req.payload,
tags: vec![],
},
);
StatusCode::ACCEPTED
}
#[derive(Serialize)]
struct WorkerStatusResponse {
connected: bool,
worker_id: Option<String>,
heartbeat: Option<HeartbeatInfo>,
}
#[derive(Serialize)]
struct HeartbeatInfo {
status: String,
active_tasks: usize,
agent_name: String,
}