1use crate::a2a::worker::{HeartbeatState, WorkerStatus};
11use crate::bus::{AgentBus, BusEnvelope};
12use crate::cli::WorkerServerArgs;
13use crate::cloudevents::parse_cloud_event;
14use anyhow::Result;
15use axum::{
16 Json, Router,
17 body::Body,
18 extract::State,
19 http::{HeaderMap, Request, StatusCode},
20 response::sse::{Event, KeepAlive, Sse},
21 response::{IntoResponse, Response},
22 routing::{get, post},
23};
24use futures::stream;
25use serde::{Deserialize, Serialize};
26use std::convert::Infallible;
27use std::sync::Arc;
28use tokio::sync::{Mutex, broadcast, mpsc};
29
30#[derive(Clone)]
32pub struct WorkerServerState {
33 pub heartbeat_state: Option<Arc<HeartbeatState>>,
35 pub connected: Arc<Mutex<bool>>,
37 pub worker_id: Arc<Mutex<Option<String>>>,
39 internal_heartbeat: Arc<Mutex<Option<Arc<HeartbeatState>>>>,
41 task_notification_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
43 bus: Arc<Mutex<Option<Arc<AgentBus>>>>,
45}
46
47impl Default for WorkerServerState {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl WorkerServerState {
54 pub fn new() -> Self {
55 Self {
56 heartbeat_state: None,
57 connected: Arc::new(Mutex::new(false)),
58 worker_id: Arc::new(Mutex::new(None)),
59 internal_heartbeat: Arc::new(Mutex::new(None)),
60 task_notification_tx: Arc::new(Mutex::new(None)),
61 bus: Arc::new(Mutex::new(None)),
62 }
63 }
64
65 pub async fn set_bus(&self, bus: Arc<AgentBus>) {
67 *self.bus.lock().await = Some(bus);
68 }
69
70 pub async fn set_task_notification_channel(&self, tx: mpsc::Sender<String>) {
72 *self.task_notification_tx.lock().await = Some(tx);
73 }
74
75 pub async fn notify_new_task(&self, task_id: &str) {
77 if let Some(ref tx) = *self.task_notification_tx.lock().await {
78 let _ = tx.send(task_id.to_string()).await;
79 tracing::debug!("Notified worker of new task: {}", task_id);
80 }
81 }
82
83 #[allow(dead_code)]
84 pub fn with_heartbeat(mut self, state: Option<Arc<HeartbeatState>>) -> Self {
85 self.heartbeat_state = state.clone();
86 self
87 }
88
89 pub async fn set_heartbeat_state(&self, state: Arc<HeartbeatState>) {
91 *self.internal_heartbeat.lock().await = Some(state);
92 }
93
94 #[allow(dead_code)]
96 pub async fn heartbeat_state(&self) -> Arc<HeartbeatState> {
97 let guard: Option<Arc<HeartbeatState>> = self.internal_heartbeat.lock().await.clone();
98 guard.unwrap_or_else(|| {
99 let state = HeartbeatState::new("unknown".to_string(), "unknown".to_string());
101 Arc::new(state)
102 })
103 }
104
105 pub async fn set_connected(&self, connected: bool) {
106 *self.connected.lock().await = connected;
107 }
108
109 pub async fn set_worker_id(&self, worker_id: String) {
110 *self.worker_id.lock().await = Some(worker_id);
111 }
112
113 #[allow(dead_code)]
114 pub async fn worker_id(&self) -> String {
115 self.worker_id.lock().await.clone().unwrap_or_default()
116 }
117
118 #[allow(dead_code)]
119 pub async fn is_connected(&self) -> bool {
120 *self.connected.lock().await
121 }
122
123 pub async fn is_ready(&self) -> bool {
124 *self.connected.lock().await
127 }
128}
129
130pub async fn start_worker_server_with_state(
132 args: WorkerServerArgs,
133 state: WorkerServerState,
134) -> Result<()> {
135 let addr = format!("{}:{}", args.hostname, args.port);
136
137 tracing::info!("Starting worker HTTP server on http://{}", addr);
138
139 let app = Router::new()
140 .route("/health", get(health))
141 .route("/ready", get(ready))
142 .route("/task", post(receive_task))
143 .route("/worker/status", get(worker_status))
144 .route("/v1/bus/stream", get(stream_bus_events))
145 .route("/v1/bus/publish", post(publish_bus_event))
146 .with_state(state);
147
148 let listener = tokio::net::TcpListener::bind(&addr).await?;
149 tracing::info!("Worker HTTP server listening on http://{}", addr);
150
151 axum::serve(listener, app).await?;
152
153 Ok(())
154}
155
156async fn health() -> &'static str {
158 "ok"
159}
160
161async fn ready(State(state): State<WorkerServerState>) -> (StatusCode, String) {
163 if state.is_ready().await {
164 (StatusCode::OK, "ready".to_string())
165 } else {
166 (StatusCode::SERVICE_UNAVAILABLE, "not connected".to_string())
167 }
168}
169
170async fn worker_status(State(state): State<WorkerServerState>) -> Json<WorkerStatusResponse> {
172 let connected = *state.connected.lock().await;
173 let worker_id = state.worker_id.lock().await.clone();
174 let heartbeat_state = state
175 .internal_heartbeat
176 .lock()
177 .await
178 .clone()
179 .or_else(|| state.heartbeat_state.clone());
180
181 let heartbeat_info = if let Some(ref hb_state) = heartbeat_state {
182 let status: WorkerStatus = *hb_state.status.lock().await;
183 let task_count = hb_state.active_task_count.lock().await;
184 Some(HeartbeatInfo {
185 status: status.as_str().to_string(),
186 active_tasks: *task_count,
187 agent_name: hb_state.agent_name.clone(),
188 })
189 } else {
190 None
191 };
192
193 Json(WorkerStatusResponse {
194 connected,
195 worker_id,
196 heartbeat: heartbeat_info,
197 })
198}
199
200async fn receive_task(
203 State(state): State<WorkerServerState>,
204 headers: HeaderMap,
205 Json(payload): Json<serde_json::Value>,
206) -> StatusCode {
207 let event = match parse_cloud_event(&headers, payload) {
208 Ok(event) => event,
209 Err(error) => {
210 tracing::warn!("Rejected task event: {}", error);
211 return StatusCode::BAD_REQUEST;
212 }
213 };
214
215 let task_id = event
216 .data
217 .get("task_id")
218 .or_else(|| event.data.get("id"))
219 .and_then(|v| v.as_str())
220 .unwrap_or(event.id.as_str());
221
222 tracing::info!(
223 "Received task via CloudEvent: {} ({})",
224 task_id,
225 event.event_type
226 );
227
228 state.notify_new_task(task_id).await;
230
231 StatusCode::ACCEPTED
234}
235
236async fn stream_bus_events(State(state): State<WorkerServerState>, req: Request<Body>) -> Response {
239 let bus = state.bus.lock().await.clone();
240 let Some(bus) = bus else {
241 let empty = stream::empty::<Result<Event, Infallible>>();
243 return Sse::new(empty)
244 .keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)))
245 .into_response();
246 };
247
248 let topic_filter: Option<String> = req.uri().query().and_then(|q| {
249 q.split('&')
250 .filter_map(|pair| pair.split_once('='))
251 .find(|(k, _)| *k == "topic")
252 .map(|(_, v)| v.to_owned())
253 });
254
255 let bus_handle = bus.handle("worker_server_bus_stream");
256 let rx: broadcast::Receiver<BusEnvelope> = bus_handle.into_receiver();
257
258 let event_stream = stream::unfold(rx, move |mut rx| {
259 let filter = topic_filter.clone();
260 async move {
261 match rx.recv().await {
262 Ok(envelope) => {
263 let allowed = filter
264 .as_deref()
265 .map(|pat| bus_topic_matches(&envelope.topic, pat))
266 .unwrap_or(true);
267
268 if allowed {
269 let payload =
270 serde_json::to_string(&envelope).unwrap_or_else(|_| "{}".to_string());
271 Some((
272 Ok::<Event, Infallible>(Event::default().event("bus").data(payload)),
273 rx,
274 ))
275 } else {
276 Some((
277 Ok::<Event, Infallible>(Event::default().event("keepalive").data("")),
278 rx,
279 ))
280 }
281 }
282 Err(broadcast::error::RecvError::Lagged(n)) => Some((
283 Ok(Event::default().event("lag").data(format!("skipped {}", n))),
284 rx,
285 )),
286 Err(broadcast::error::RecvError::Closed) => None,
287 }
288 }
289 });
290
291 Sse::new(event_stream)
292 .keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)))
293 .into_response()
294}
295
296fn bus_topic_matches(topic: &str, pattern: &str) -> bool {
298 if pattern == "*" {
299 return true;
300 }
301 if let Some(prefix) = pattern.strip_suffix(".*") {
302 return topic.starts_with(prefix);
303 }
304 topic == pattern
305}
306
307#[derive(Deserialize)]
309struct BusPublishRequest {
310 topic: String,
311 payload: serde_json::Value,
312}
313
314async fn publish_bus_event(
315 State(state): State<WorkerServerState>,
316 Json(req): Json<BusPublishRequest>,
317) -> StatusCode {
318 let bus = state.bus.lock().await.clone();
319 let Some(bus) = bus else {
320 return StatusCode::SERVICE_UNAVAILABLE;
321 };
322 let handle = bus.handle("worker_server_publish");
323 handle.send(
324 &req.topic,
325 crate::bus::BusMessage::SharedResult {
326 key: req.topic.clone(),
327 value: req.payload,
328 tags: vec![],
329 },
330 );
331 StatusCode::ACCEPTED
332}
333
334#[derive(Serialize)]
336struct WorkerStatusResponse {
337 connected: bool,
338 worker_id: Option<String>,
339 heartbeat: Option<HeartbeatInfo>,
340}
341
342#[derive(Serialize)]
343struct HeartbeatInfo {
344 status: String,
345 active_tasks: usize,
346 agent_name: String,
347}