codetether_agent/
worker_server.rs1use crate::a2a::worker::{HeartbeatState, WorkerStatus};
11use crate::cli::WorkerServerArgs;
12use anyhow::Result;
13use axum::{
14 Json, Router,
15 extract::State,
16 http::StatusCode,
17 routing::{get, post},
18};
19use serde::Serialize;
20use std::sync::Arc;
21use tokio::sync::{Mutex, mpsc};
22
23#[derive(Clone)]
25pub struct WorkerServerState {
26 pub heartbeat_state: Option<Arc<HeartbeatState>>,
28 pub connected: Arc<Mutex<bool>>,
30 pub worker_id: Arc<Mutex<Option<String>>>,
32 internal_heartbeat: Arc<Mutex<Option<Arc<HeartbeatState>>>>,
34 task_notification_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
36}
37
38impl WorkerServerState {
39 pub fn new() -> Self {
40 Self {
41 heartbeat_state: None,
42 connected: Arc::new(Mutex::new(false)),
43 worker_id: Arc::new(Mutex::new(None)),
44 internal_heartbeat: Arc::new(Mutex::new(None)),
45 task_notification_tx: Arc::new(Mutex::new(None)),
46 }
47 }
48
49 pub async fn set_task_notification_channel(&self, tx: mpsc::Sender<String>) {
51 *self.task_notification_tx.lock().await = Some(tx);
52 }
53
54 pub async fn notify_new_task(&self, task_id: &str) {
56 if let Some(ref tx) = *self.task_notification_tx.lock().await {
57 let _ = tx.send(task_id.to_string()).await;
58 tracing::debug!("Notified worker of new task: {}", task_id);
59 }
60 }
61
62 pub fn with_heartbeat(mut self, state: Option<Arc<HeartbeatState>>) -> Self {
63 self.heartbeat_state = state.clone();
64 self
65 }
66
67 pub async fn set_heartbeat_state(&self, state: Arc<HeartbeatState>) {
69 *self.internal_heartbeat.lock().await = Some(state);
70 }
71
72 pub async fn heartbeat_state(&self) -> Arc<HeartbeatState> {
74 let guard: Option<Arc<HeartbeatState>> = self.internal_heartbeat.lock().await.clone();
75 guard.unwrap_or_else(|| {
76 let state = HeartbeatState::new("unknown".to_string(), "unknown".to_string());
78 Arc::new(state)
79 })
80 }
81
82 pub async fn set_connected(&self, connected: bool) {
83 *self.connected.lock().await = connected;
84 }
85
86 pub async fn set_worker_id(&self, worker_id: String) {
87 *self.worker_id.lock().await = Some(worker_id);
88 }
89
90 pub async fn worker_id(&self) -> String {
91 self.worker_id.lock().await.clone().unwrap_or_default()
92 }
93
94 pub async fn is_connected(&self) -> bool {
95 *self.connected.lock().await
96 }
97
98 pub async fn is_ready(&self) -> bool {
99 let connected = *self.connected.lock().await;
100 connected
103 }
104}
105
106pub async fn start_worker_server(args: WorkerServerArgs) -> Result<()> {
108 let state = WorkerServerState::new();
109 start_worker_server_with_state(args, state).await
110}
111
112pub async fn start_worker_server_with_state(
114 args: WorkerServerArgs,
115 state: WorkerServerState,
116) -> Result<()> {
117 let addr = format!("{}:{}", args.hostname, args.port);
118
119 tracing::info!("Starting worker HTTP server on http://{}", addr);
120
121 let app = Router::new()
122 .route("/health", get(health))
123 .route("/ready", get(ready))
124 .route("/task", post(receive_task))
125 .route("/worker/status", get(worker_status))
126 .with_state(state);
127
128 let listener = tokio::net::TcpListener::bind(&addr).await?;
129 tracing::info!("Worker HTTP server listening on http://{}", addr);
130
131 axum::serve(listener, app).await?;
132
133 Ok(())
134}
135
136async fn health() -> &'static str {
138 "ok"
139}
140
141async fn ready(State(state): State<WorkerServerState>) -> (StatusCode, String) {
143 if state.is_ready().await {
144 (StatusCode::OK, "ready".to_string())
145 } else {
146 (StatusCode::SERVICE_UNAVAILABLE, "not connected".to_string())
147 }
148}
149
150async fn worker_status(State(state): State<WorkerServerState>) -> Json<WorkerStatusResponse> {
152 let connected = *state.connected.lock().await;
153 let worker_id = state.worker_id.lock().await.clone();
154 let heartbeat_state = state
155 .internal_heartbeat
156 .lock()
157 .await
158 .clone()
159 .or_else(|| state.heartbeat_state.clone());
160
161 let heartbeat_info = if let Some(ref hb_state) = heartbeat_state {
162 let status: WorkerStatus = *hb_state.status.lock().await;
163 let task_count = hb_state.active_task_count.lock().await;
164 Some(HeartbeatInfo {
165 status: status.as_str().to_string(),
166 active_tasks: *task_count,
167 agent_name: hb_state.agent_name.clone(),
168 })
169 } else {
170 None
171 };
172
173 Json(WorkerStatusResponse {
174 connected,
175 worker_id,
176 heartbeat: heartbeat_info,
177 })
178}
179
180async fn receive_task(
183 State(state): State<WorkerServerState>,
184 Json(payload): Json<serde_json::Value>,
185) -> (StatusCode, String) {
186 let task_id = payload
188 .get("task_id")
189 .or_else(|| payload.get("id"))
190 .and_then(|v| v.as_str())
191 .unwrap_or("unknown");
192
193 tracing::info!("Received task via CloudEvent: {}", task_id);
194
195 state.notify_new_task(task_id).await;
197
198 (StatusCode::ACCEPTED, format!("task {} received", task_id))
199}
200
201#[derive(Serialize)]
203struct WorkerStatusResponse {
204 connected: bool,
205 worker_id: Option<String>,
206 heartbeat: Option<HeartbeatInfo>,
207}
208
209#[derive(Serialize)]
210struct HeartbeatInfo {
211 status: String,
212 active_tasks: usize,
213 agent_name: String,
214}