Skip to main content

codetether_agent/
worker_server.rs

1//! Worker HTTP Server
2//!
3//! Minimal HTTP server for the A2A worker that provides:
4//! - /health - liveness probe
5//! - /ready - readiness probe
6//! - /task - POST endpoint for CloudEvents (Knative integration)
7//!
8//! This enables Kubernetes probes and ingress routing to work.
9
10use crate::a2a::worker::{HeartbeatState, WorkerStatus};
11use crate::cli::WorkerServerArgs;
12use anyhow::Result;
13use axum::{
14    Json, Router,
15    extract::State,
16    http::StatusCode,
17    routing::{get, post},
18};
19use serde::Serialize;
20use std::sync::Arc;
21use tokio::sync::{Mutex, mpsc};
22
23/// Worker server state shared across handlers
24#[derive(Clone)]
25pub struct WorkerServerState {
26    /// Heartbeat state from the worker (used to determine readiness)
27    pub heartbeat_state: Option<Arc<HeartbeatState>>,
28    /// Whether the SSE stream is currently connected
29    pub connected: Arc<Mutex<bool>>,
30    /// Worker ID for identification
31    pub worker_id: Arc<Mutex<Option<String>>>,
32    /// The actual heartbeat state (internal, for sharing with HTTP server)
33    internal_heartbeat: Arc<Mutex<Option<Arc<HeartbeatState>>>>,
34    /// Channel to notify worker of new tasks (from CloudEvents)
35    task_notification_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
36}
37
38impl WorkerServerState {
39    pub fn new() -> Self {
40        Self {
41            heartbeat_state: None,
42            connected: Arc::new(Mutex::new(false)),
43            worker_id: Arc::new(Mutex::new(None)),
44            internal_heartbeat: Arc::new(Mutex::new(None)),
45            task_notification_tx: Arc::new(Mutex::new(None)),
46        }
47    }
48
49    /// Set the task notification channel (called from worker)
50    pub async fn set_task_notification_channel(&self, tx: mpsc::Sender<String>) {
51        *self.task_notification_tx.lock().await = Some(tx);
52    }
53
54    /// Notify worker of a new task (called from /task endpoint)
55    pub async fn notify_new_task(&self, task_id: &str) {
56        if let Some(ref tx) = *self.task_notification_tx.lock().await {
57            let _ = tx.send(task_id.to_string()).await;
58            tracing::debug!("Notified worker of new task: {}", task_id);
59        }
60    }
61
62    pub fn with_heartbeat(mut self, state: Option<Arc<HeartbeatState>>) -> Self {
63        self.heartbeat_state = state.clone();
64        self
65    }
66
67    /// Set the heartbeat state (called from worker)
68    pub async fn set_heartbeat_state(&self, state: Arc<HeartbeatState>) {
69        *self.internal_heartbeat.lock().await = Some(state);
70    }
71
72    /// Get the heartbeat state (for HTTP server)
73    pub async fn heartbeat_state(&self) -> Arc<HeartbeatState> {
74        let guard: Option<Arc<HeartbeatState>> = self.internal_heartbeat.lock().await.clone();
75        guard.unwrap_or_else(|| {
76            // Create a default heartbeat state if not set
77            let state = HeartbeatState::new("unknown".to_string(), "unknown".to_string());
78            Arc::new(state)
79        })
80    }
81
82    pub async fn set_connected(&self, connected: bool) {
83        *self.connected.lock().await = connected;
84    }
85
86    pub async fn set_worker_id(&self, worker_id: String) {
87        *self.worker_id.lock().await = Some(worker_id);
88    }
89
90    pub async fn worker_id(&self) -> String {
91        self.worker_id.lock().await.clone().unwrap_or_default()
92    }
93
94    pub async fn is_connected(&self) -> bool {
95        *self.connected.lock().await
96    }
97
98    pub async fn is_ready(&self) -> bool {
99        let connected = *self.connected.lock().await;
100        // Ready if we have a connection to the A2A server
101        // Optional: could also check heartbeat_state for active task count
102        connected
103    }
104}
105
106/// Start the worker HTTP server with default state
107pub async fn start_worker_server(args: WorkerServerArgs) -> Result<()> {
108    let state = WorkerServerState::new();
109    start_worker_server_with_state(args, state).await
110}
111
112/// Start the worker HTTP server with custom state
113pub async fn start_worker_server_with_state(
114    args: WorkerServerArgs,
115    state: WorkerServerState,
116) -> Result<()> {
117    let addr = format!("{}:{}", args.hostname, args.port);
118
119    tracing::info!("Starting worker HTTP server on http://{}", addr);
120
121    let app = Router::new()
122        .route("/health", get(health))
123        .route("/ready", get(ready))
124        .route("/task", post(receive_task))
125        .route("/worker/status", get(worker_status))
126        .with_state(state);
127
128    let listener = tokio::net::TcpListener::bind(&addr).await?;
129    tracing::info!("Worker HTTP server listening on http://{}", addr);
130
131    axum::serve(listener, app).await?;
132
133    Ok(())
134}
135
136/// Health check - always returns OK if the server is running
137async fn health() -> &'static str {
138    "ok"
139}
140
141/// Readiness check - returns OK only when connected to A2A server
142async fn ready(State(state): State<WorkerServerState>) -> (StatusCode, String) {
143    if state.is_ready().await {
144        (StatusCode::OK, "ready".to_string())
145    } else {
146        (StatusCode::SERVICE_UNAVAILABLE, "not connected".to_string())
147    }
148}
149
150/// Worker status endpoint - returns detailed worker state
151async fn worker_status(State(state): State<WorkerServerState>) -> Json<WorkerStatusResponse> {
152    let connected = *state.connected.lock().await;
153    let worker_id = state.worker_id.lock().await.clone();
154    let heartbeat_state = state
155        .internal_heartbeat
156        .lock()
157        .await
158        .clone()
159        .or_else(|| state.heartbeat_state.clone());
160
161    let heartbeat_info = if let Some(ref hb_state) = heartbeat_state {
162        let status: WorkerStatus = *hb_state.status.lock().await;
163        let task_count = hb_state.active_task_count.lock().await;
164        Some(HeartbeatInfo {
165            status: status.as_str().to_string(),
166            active_tasks: *task_count,
167            agent_name: hb_state.agent_name.clone(),
168        })
169    } else {
170        None
171    };
172
173    Json(WorkerStatusResponse {
174        connected,
175        worker_id,
176        heartbeat: heartbeat_info,
177    })
178}
179
180/// Receive CloudEvents POST (for Knative integration)
181/// This endpoint receives tasks pushed via Knative Eventing
182async fn receive_task(
183    State(state): State<WorkerServerState>,
184    Json(payload): Json<serde_json::Value>,
185) -> (StatusCode, String) {
186    // Extract task_id from CloudEvent payload
187    let task_id = payload
188        .get("task_id")
189        .or_else(|| payload.get("id"))
190        .and_then(|v| v.as_str())
191        .unwrap_or("unknown");
192
193    tracing::info!("Received task via CloudEvent: {}", task_id);
194
195    // Notify the worker loop to pick up this task
196    state.notify_new_task(task_id).await;
197
198    (StatusCode::ACCEPTED, format!("task {} received", task_id))
199}
200
201/// Response types
202#[derive(Serialize)]
203struct WorkerStatusResponse {
204    connected: bool,
205    worker_id: Option<String>,
206    heartbeat: Option<HeartbeatInfo>,
207}
208
209#[derive(Serialize)]
210struct HeartbeatInfo {
211    status: String,
212    active_tasks: usize,
213    agent_name: String,
214}