Skip to main content

codetether_agent/
worker_server.rs

1//! Worker HTTP Server
2//!
3//! Minimal HTTP server for the A2A worker that provides:
4//! - /health - liveness probe
5//! - /ready - readiness probe
6//! - /task - POST endpoint for CloudEvents (Knative integration)
7//!
8//! This enables Kubernetes probes and ingress routing to work.
9
10use crate::a2a::worker::{HeartbeatState, WorkerStatus};
11use crate::bus::{AgentBus, BusEnvelope};
12use crate::cli::WorkerServerArgs;
13use crate::cloudevents::parse_cloud_event;
14use anyhow::Result;
15use axum::{
16    Json, Router,
17    body::Body,
18    extract::State,
19    http::{HeaderMap, Request, StatusCode},
20    response::sse::{Event, KeepAlive, Sse},
21    response::{IntoResponse, Response},
22    routing::{get, post},
23};
24use futures::stream;
25use serde::{Deserialize, Serialize};
26use std::convert::Infallible;
27use std::sync::Arc;
28use tokio::sync::{Mutex, broadcast, mpsc};
29
30/// Worker server state shared across handlers
31#[derive(Clone)]
32pub struct WorkerServerState {
33    /// Heartbeat state from the worker (used to determine readiness)
34    pub heartbeat_state: Option<Arc<HeartbeatState>>,
35    /// Whether the SSE stream is currently connected
36    pub connected: Arc<Mutex<bool>>,
37    /// Worker ID for identification
38    pub worker_id: Arc<Mutex<Option<String>>>,
39    /// The actual heartbeat state (internal, for sharing with HTTP server)
40    internal_heartbeat: Arc<Mutex<Option<Arc<HeartbeatState>>>>,
41    /// Channel to notify worker of new tasks (from CloudEvents)
42    task_notification_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
43    /// Agent bus for inter-agent messaging (exposed via /v1/bus/*)
44    bus: Arc<Mutex<Option<Arc<AgentBus>>>>,
45}
46
47impl Default for WorkerServerState {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl WorkerServerState {
54    pub fn new() -> Self {
55        Self {
56            heartbeat_state: None,
57            connected: Arc::new(Mutex::new(false)),
58            worker_id: Arc::new(Mutex::new(None)),
59            internal_heartbeat: Arc::new(Mutex::new(None)),
60            task_notification_tx: Arc::new(Mutex::new(None)),
61            bus: Arc::new(Mutex::new(None)),
62        }
63    }
64
65    /// Attach the agent bus (can be called after construction from the worker task)
66    pub async fn set_bus(&self, bus: Arc<AgentBus>) {
67        *self.bus.lock().await = Some(bus);
68    }
69
70    /// Set the task notification channel (called from worker)
71    pub async fn set_task_notification_channel(&self, tx: mpsc::Sender<String>) {
72        *self.task_notification_tx.lock().await = Some(tx);
73    }
74
75    /// Notify worker of a new task (called from /task endpoint)
76    pub async fn notify_new_task(&self, task_id: &str) {
77        if let Some(ref tx) = *self.task_notification_tx.lock().await {
78            let _ = tx.send(task_id.to_string()).await;
79            tracing::debug!("Notified worker of new task: {}", task_id);
80        }
81    }
82
83    #[allow(dead_code)]
84    pub fn with_heartbeat(mut self, state: Option<Arc<HeartbeatState>>) -> Self {
85        self.heartbeat_state = state.clone();
86        self
87    }
88
89    /// Set the heartbeat state (called from worker)
90    pub async fn set_heartbeat_state(&self, state: Arc<HeartbeatState>) {
91        *self.internal_heartbeat.lock().await = Some(state);
92    }
93
94    /// Get the heartbeat state (for HTTP server)
95    #[allow(dead_code)]
96    pub async fn heartbeat_state(&self) -> Arc<HeartbeatState> {
97        let guard: Option<Arc<HeartbeatState>> = self.internal_heartbeat.lock().await.clone();
98        guard.unwrap_or_else(|| {
99            // Create a default heartbeat state if not set
100            let state = HeartbeatState::new("unknown".to_string(), "unknown".to_string());
101            Arc::new(state)
102        })
103    }
104
105    pub async fn set_connected(&self, connected: bool) {
106        *self.connected.lock().await = connected;
107    }
108
109    pub async fn set_worker_id(&self, worker_id: String) {
110        *self.worker_id.lock().await = Some(worker_id);
111    }
112
113    #[allow(dead_code)]
114    pub async fn worker_id(&self) -> String {
115        self.worker_id.lock().await.clone().unwrap_or_default()
116    }
117
118    #[allow(dead_code)]
119    pub async fn is_connected(&self) -> bool {
120        *self.connected.lock().await
121    }
122
123    pub async fn is_ready(&self) -> bool {
124        // Ready if we have a connection to the A2A server
125        // Optional: could also check heartbeat_state for active task count
126        *self.connected.lock().await
127    }
128}
129
130/// Start the worker HTTP server with custom state
131pub async fn start_worker_server_with_state(
132    args: WorkerServerArgs,
133    state: WorkerServerState,
134) -> Result<()> {
135    let addr = format!("{}:{}", args.hostname, args.port);
136
137    tracing::info!("Starting worker HTTP server on http://{}", addr);
138
139    let app = Router::new()
140        .route("/health", get(health))
141        .route("/ready", get(ready))
142        .route("/task", post(receive_task))
143        .route("/worker/status", get(worker_status))
144        .route("/v1/bus/stream", get(stream_bus_events))
145        .route("/v1/bus/publish", post(publish_bus_event))
146        .with_state(state);
147
148    let listener = tokio::net::TcpListener::bind(&addr).await?;
149    tracing::info!("Worker HTTP server listening on http://{}", addr);
150
151    axum::serve(listener, app).await?;
152
153    Ok(())
154}
155
156/// Health check - always returns OK if the server is running
157async fn health() -> &'static str {
158    "ok"
159}
160
161/// Readiness check - returns OK only when connected to A2A server
162async fn ready(State(state): State<WorkerServerState>) -> (StatusCode, String) {
163    if state.is_ready().await {
164        (StatusCode::OK, "ready".to_string())
165    } else {
166        (StatusCode::SERVICE_UNAVAILABLE, "not connected".to_string())
167    }
168}
169
170/// Worker status endpoint - returns detailed worker state
171async fn worker_status(State(state): State<WorkerServerState>) -> Json<WorkerStatusResponse> {
172    let connected = *state.connected.lock().await;
173    let worker_id = state.worker_id.lock().await.clone();
174    let heartbeat_state = state
175        .internal_heartbeat
176        .lock()
177        .await
178        .clone()
179        .or_else(|| state.heartbeat_state.clone());
180
181    let heartbeat_info = if let Some(ref hb_state) = heartbeat_state {
182        let status: WorkerStatus = *hb_state.status.lock().await;
183        let task_count = hb_state.active_task_count.lock().await;
184        Some(HeartbeatInfo {
185            status: status.as_str().to_string(),
186            active_tasks: *task_count,
187            agent_name: hb_state.agent_name.clone(),
188        })
189    } else {
190        None
191    };
192
193    Json(WorkerStatusResponse {
194        connected,
195        worker_id,
196        heartbeat: heartbeat_info,
197    })
198}
199
200/// Receive CloudEvents POST (for Knative integration)
201/// This endpoint receives tasks pushed via Knative Eventing
202async fn receive_task(
203    State(state): State<WorkerServerState>,
204    headers: HeaderMap,
205    Json(payload): Json<serde_json::Value>,
206) -> StatusCode {
207    let event = match parse_cloud_event(&headers, payload) {
208        Ok(event) => event,
209        Err(error) => {
210            tracing::warn!("Rejected task event: {}", error);
211            return StatusCode::BAD_REQUEST;
212        }
213    };
214
215    let task_id = event
216        .data
217        .get("task_id")
218        .or_else(|| event.data.get("id"))
219        .and_then(|v| v.as_str())
220        .unwrap_or(event.id.as_str());
221
222    tracing::info!(
223        "Received task via CloudEvent: {} ({})",
224        task_id,
225        event.event_type
226    );
227
228    // Notify the worker loop to pick up this task
229    state.notify_new_task(task_id).await;
230
231    // CloudEvent subscribers should return an empty 2xx response body.
232    // Non-empty non-CloudEvent payloads trigger retries in Knative Broker Filter.
233    StatusCode::ACCEPTED
234}
235
236/// SSE stream of agent bus events — subscribe to live bus messages.
237/// Streams all topics; use the `topic` query param to filter (e.g. `?topic=task.*`).
238async fn stream_bus_events(State(state): State<WorkerServerState>, req: Request<Body>) -> Response {
239    let bus = state.bus.lock().await.clone();
240    let Some(bus) = bus else {
241        // Bus not attached — return an empty keep-alive stream
242        let empty = stream::empty::<Result<Event, Infallible>>();
243        return Sse::new(empty)
244            .keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)))
245            .into_response();
246    };
247
248    let topic_filter: Option<String> = req.uri().query().and_then(|q| {
249        q.split('&')
250            .filter_map(|pair| pair.split_once('='))
251            .find(|(k, _)| *k == "topic")
252            .map(|(_, v)| v.to_owned())
253    });
254
255    let bus_handle = bus.handle("worker_server_bus_stream");
256    let rx: broadcast::Receiver<BusEnvelope> = bus_handle.into_receiver();
257
258    let event_stream = stream::unfold(rx, move |mut rx| {
259        let filter = topic_filter.clone();
260        async move {
261            match rx.recv().await {
262                Ok(envelope) => {
263                    let allowed = filter
264                        .as_deref()
265                        .map(|pat| bus_topic_matches(&envelope.topic, pat))
266                        .unwrap_or(true);
267
268                    if allowed {
269                        let payload =
270                            serde_json::to_string(&envelope).unwrap_or_else(|_| "{}".to_string());
271                        Some((
272                            Ok::<Event, Infallible>(Event::default().event("bus").data(payload)),
273                            rx,
274                        ))
275                    } else {
276                        Some((
277                            Ok::<Event, Infallible>(Event::default().event("keepalive").data("")),
278                            rx,
279                        ))
280                    }
281                }
282                Err(broadcast::error::RecvError::Lagged(n)) => Some((
283                    Ok(Event::default().event("lag").data(format!("skipped {}", n))),
284                    rx,
285                )),
286                Err(broadcast::error::RecvError::Closed) => None,
287            }
288        }
289    });
290
291    Sse::new(event_stream)
292        .keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)))
293        .into_response()
294}
295
296/// Wildcard topic matching (supports `*` and `prefix.*` patterns).
297fn bus_topic_matches(topic: &str, pattern: &str) -> bool {
298    if pattern == "*" {
299        return true;
300    }
301    if let Some(prefix) = pattern.strip_suffix(".*") {
302        return topic.starts_with(prefix);
303    }
304    topic == pattern
305}
306
307/// Publish a message to the agent bus.
308#[derive(Deserialize)]
309struct BusPublishRequest {
310    topic: String,
311    payload: serde_json::Value,
312}
313
314async fn publish_bus_event(
315    State(state): State<WorkerServerState>,
316    Json(req): Json<BusPublishRequest>,
317) -> StatusCode {
318    let bus = state.bus.lock().await.clone();
319    let Some(bus) = bus else {
320        return StatusCode::SERVICE_UNAVAILABLE;
321    };
322    let handle = bus.handle("worker_server_publish");
323    handle.send(
324        &req.topic,
325        crate::bus::BusMessage::SharedResult {
326            key: req.topic.clone(),
327            value: req.payload,
328            tags: vec![],
329        },
330    );
331    StatusCode::ACCEPTED
332}
333
334/// Response types
335#[derive(Serialize)]
336struct WorkerStatusResponse {
337    connected: bool,
338    worker_id: Option<String>,
339    heartbeat: Option<HeartbeatInfo>,
340}
341
342#[derive(Serialize)]
343struct HeartbeatInfo {
344    status: String,
345    active_tasks: usize,
346    agent_name: String,
347}