codetether_agent/tui/worker_bridge.rs
1//! TUI Worker Bridge - connects the TUI to the A2A server substrate.
2//!
3//! This module enables the TUI to:
4//! - Register itself as a worker with the A2A server
5//! - Send heartbeats to maintain registration
6//! - Receive incoming tasks via SSE stream
7//! - Register/deregister sub-agents (relay, autochat, spawned agents)
8//! - Forward bus events to the server for observability
9
10use crate::a2a::worker::{
11 CognitionHeartbeatConfig, HeartbeatState, WorkerStatus, register_worker, start_heartbeat,
12};
13use crate::bus::{AgentBus, BusEnvelope, BusMessage};
14use crate::cli::auth::load_saved_credentials;
15use crate::config::Config;
16use anyhow::Result;
17use futures::StreamExt;
18use reqwest::Client;
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::{Mutex, mpsc};
23
24/// Command sent to the worker bridge to register/deregister sub-agents
25#[derive(Debug, Clone)]
26pub enum WorkerBridgeCmd {
27 /// Register a sub-agent with the A2A server
28 RegisterAgent { name: String, instructions: String },
29 /// Deregister a sub-agent
30 DeregisterAgent { name: String },
31 /// Update the processing status (for heartbeat)
32 SetProcessing(bool),
33}
34
35/// Incoming task from the A2A server via SSE
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct IncomingTask {
38 pub task_id: String,
39 pub message: String,
40 pub from_agent: Option<String>,
41}
42
43/// Result of worker bridge initialization
44pub struct TuiWorkerBridge {
45 /// Worker ID assigned by the server
46 pub worker_id: String,
47 /// Worker name
48 pub worker_name: String,
49 /// Sender for commands (register/deregister agents)
50 pub cmd_tx: mpsc::Sender<WorkerBridgeCmd>,
51 /// Receiver for incoming tasks
52 pub task_rx: mpsc::Receiver<IncomingTask>,
53 /// Handle for the bridge task (for shutdown)
54 #[allow(dead_code)]
55 pub handle: tokio::task::JoinHandle<()>,
56}
57
58impl TuiWorkerBridge {
59 /// Spawn the worker bridge if server credentials are available.
60 /// Returns None if no server is configured.
61 pub async fn spawn(
62 server_url: Option<String>,
63 token: Option<String>,
64 worker_name: Option<String>,
65 bus: Arc<AgentBus>,
66 ) -> Result<Option<Self>> {
67 // Try to get server URL from config if not provided
68 let server = match server_url {
69 Some(url) => url,
70 None => {
71 let config = Config::load().await?;
72 match config.a2a.server_url {
73 Some(url) => url,
74 None => {
75 tracing::debug!("No A2A server configured, worker bridge disabled");
76 return Ok(None);
77 }
78 }
79 }
80 };
81
82 // Try to get token from saved credentials if not provided
83 let token = match token {
84 Some(t) => Some(t),
85 None => load_saved_credentials().map(|c| c.access_token),
86 };
87
88 // If no token, we can't register - but we could run in "read-only" mode
89 // For now, let's require a token for registration
90 let token = match token {
91 Some(t) => t,
92 None => {
93 tracing::debug!("No A2A token available, worker bridge disabled");
94 return Ok(None);
95 }
96 };
97
98 // Generate worker ID and name
99 let worker_id = crate::a2a::worker::generate_worker_id();
100 let worker_name = worker_name.unwrap_or_else(|| format!("tui-{}", std::process::id()));
101
102 tracing::info!(
103 worker_id = %worker_id,
104 worker_name = %worker_name,
105 server = %server,
106 "Starting TUI worker bridge"
107 );
108
109 // Create channels
110 let (cmd_tx, cmd_rx) = mpsc::channel::<WorkerBridgeCmd>(256);
111 let (task_tx, task_rx) = mpsc::channel::<IncomingTask>(64);
112
113 // Create shared state
114 let client = Client::new();
115 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
116 let heartbeat_state = HeartbeatState::new(worker_id.clone(), worker_name.clone());
117 let cognition_config = CognitionHeartbeatConfig::from_env();
118 let codebases = vec![
119 std::env::current_dir()
120 .map(|p| p.display().to_string())
121 .unwrap_or_else(|_| ".".to_string()),
122 ];
123
124 // Spawn the main bridge task
125 let handle = tokio::spawn({
126 let server = server.clone();
127 let token = token.clone();
128 let worker_id = worker_id.clone();
129 let worker_name = worker_name.clone();
130 let client = client.clone();
131 let processing = processing.clone();
132 let heartbeat_state = heartbeat_state.clone();
133 async move {
134 let token_opt = Some(token.clone());
135 if let Err(e) = register_worker(
136 &client,
137 &server,
138 &token_opt,
139 &worker_id,
140 &worker_name,
141 &codebases,
142 None,
143 )
144 .await
145 {
146 tracing::warn!("Failed to register worker with A2A server: {}", e);
147 // Continue anyway - we can still try to process tasks
148 }
149
150 let heartbeat_handle = start_heartbeat(
151 client.clone(),
152 server.clone(),
153 token_opt.clone(),
154 heartbeat_state.clone(),
155 processing.clone(),
156 cognition_config,
157 );
158
159 // Background task: connect to SSE stream for incoming tasks
160 let sse_handle = tokio::spawn({
161 let server = server.clone();
162 let token = token.clone();
163 let worker_id = worker_id.clone();
164 let worker_name = worker_name.clone();
165 async move {
166 loop {
167 let url = format!(
168 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
169 server,
170 urlencoding::encode(&worker_name),
171 urlencoding::encode(&worker_id)
172 );
173
174 let req = Client::new()
175 .get(&url)
176 .header("Accept", "text/event-stream")
177 .header("X-Worker-ID", &worker_id)
178 .header("X-Agent-Name", &worker_name)
179 .bearer_auth(&token);
180
181 match req.send().await {
182 Ok(res) if res.status().is_success() => {
183 tracing::info!("Connected to A2A task stream");
184 let mut stream = res.bytes_stream();
185 let mut buffer = String::new();
186
187 while let Some(chunk) = stream.next().await {
188 match chunk {
189 Ok(bytes) => {
190 buffer.push_str(&String::from_utf8_lossy(&bytes));
191
192 // Process SSE events
193 while let Some(pos) = buffer.find("\n\n") {
194 let event_str = buffer[..pos].to_string();
195 buffer = buffer[pos + 2..].to_string();
196
197 if let Some(data_line) = event_str
198 .lines()
199 .find(|l| l.starts_with("data:"))
200 {
201 let data = data_line
202 .trim_start_matches("data:")
203 .trim();
204 if data.is_empty() || data == "[DONE]" {
205 continue;
206 }
207
208 // Try to parse as task
209 if let Ok(task) = serde_json::from_str::<
210 serde_json::Value,
211 >(
212 data
213 ) {
214 let task_id = task
215 .get("task_id")
216 .or_else(|| task.get("id"))
217 .and_then(|v| v.as_str())
218 .unwrap_or("unknown")
219 .to_string();
220
221 let message = task
222 .get("message")
223 .or_else(|| task.get("text"))
224 .and_then(|v| v.as_str())
225 .unwrap_or("")
226 .to_string();
227
228 let from_agent = task
229 .get("from_agent")
230 .or_else(|| task.get("agent"))
231 .and_then(|v| v.as_str())
232 .map(String::from);
233
234 let incoming = IncomingTask {
235 task_id,
236 message,
237 from_agent,
238 };
239
240 if task_tx.send(incoming).await.is_err()
241 {
242 tracing::warn!(
243 "Task receiver dropped, stopping SSE stream"
244 );
245 return;
246 }
247 }
248 }
249 }
250 }
251 Err(e) => {
252 tracing::warn!("SSE stream error: {}", e);
253 break;
254 }
255 }
256 }
257 }
258 Ok(res) => {
259 tracing::warn!(
260 "Failed to connect to task stream: {}",
261 res.status()
262 );
263 }
264 Err(e) => {
265 tracing::warn!("Failed to connect to task stream: {}", e);
266 }
267 }
268
269 // Reconnect after delay
270 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
271 }
272 }
273 });
274
275 // Handle commands and bus events
276 let mut cmd_rx = cmd_rx;
277 let mut bus_handle = bus.handle(&worker_id);
278
279 loop {
280 tokio::select! {
281 // Handle command to register/deregister agents
282 Some(cmd) = cmd_rx.recv() => {
283 match cmd {
284 WorkerBridgeCmd::RegisterAgent { name, instructions } => {
285 tracing::info!(agent = %name, "Registering sub-agent with A2A server");
286 tracing::debug!(
287 agent = %name,
288 instructions_len = instructions.len(),
289 "Tracking sub-agent in heartbeat state"
290 );
291 heartbeat_state.register_sub_agent(name).await;
292 }
293 WorkerBridgeCmd::DeregisterAgent { name } => {
294 tracing::info!(agent = %name, "Deregistering sub-agent from A2A server");
295 heartbeat_state.deregister_sub_agent(&name).await;
296 }
297 WorkerBridgeCmd::SetProcessing(processing) => {
298 let status = if processing {
299 WorkerStatus::Processing
300 } else {
301 WorkerStatus::Idle
302 };
303 heartbeat_state.set_status(status).await;
304 }
305 }
306 }
307 // Handle bus events - forward to server
308 Some(envelope) = bus_handle.recv() => {
309 // Forward interesting events to server for observability
310 if let Err(e) = forward_bus_event(&client, &server, &token, &worker_id, &envelope).await {
311 tracing::debug!("Failed to forward bus event: {}", e);
312 }
313 }
314 // Handle shutdown
315 _ = tokio::signal::ctrl_c() => {
316 tracing::info!("Worker bridge received shutdown signal");
317 break;
318 }
319 }
320 }
321
322 // Cleanup
323 heartbeat_handle.abort();
324 sse_handle.abort();
325 tracing::info!("Worker bridge stopped");
326 }
327 });
328
329 Ok(Some(TuiWorkerBridge {
330 worker_id,
331 worker_name,
332 cmd_tx,
333 task_rx,
334 handle,
335 }))
336 }
337}
338
339/// Forward bus events to the A2A server for observability
340async fn forward_bus_event(
341 client: &Client,
342 server: &str,
343 token: &str,
344 worker_id: &str,
345 envelope: &BusEnvelope,
346) -> Result<()> {
347 // Only forward certain event types
348 let payload = match &envelope.message {
349 BusMessage::AgentReady {
350 agent_id,
351 capabilities,
352 } => {
353 serde_json::json!({
354 "type": "agent_ready",
355 "worker_id": worker_id,
356 "agent_id": agent_id,
357 "capabilities": capabilities,
358 })
359 }
360 BusMessage::TaskUpdate {
361 task_id,
362 state,
363 message,
364 } => {
365 serde_json::json!({
366 "type": "task_update",
367 "worker_id": worker_id,
368 "task_id": task_id,
369 "state": format!("{:?}", state),
370 "message": message,
371 })
372 }
373 BusMessage::AgentMessage { from, to, parts } => {
374 let text = parts
375 .iter()
376 .filter_map(|p| {
377 if let crate::a2a::types::Part::Text { text } = p {
378 Some(text.clone())
379 } else {
380 None
381 }
382 })
383 .collect::<Vec<_>>()
384 .join("\n");
385
386 serde_json::json!({
387 "type": "agent_message",
388 "worker_id": worker_id,
389 "from": from,
390 "to": to,
391 "text": text,
392 })
393 }
394 // Skip other event types for now
395 _ => return Ok(()),
396 };
397
398 let url = format!("{}/v1/agent/workers/{}/events", server, worker_id);
399 let res = client
400 .post(&url)
401 .bearer_auth(token)
402 .json(&payload)
403 .send()
404 .await?;
405
406 if !res.status().is_success() {
407 tracing::debug!("Failed to forward bus event: {}", res.status());
408 }
409
410 Ok(())
411}