Skip to main content

codetether_agent/tui/
worker_bridge.rs

1//! TUI Worker Bridge - connects the TUI to the A2A server substrate.
2//!
3//! This module enables the TUI to:
4//! - Register itself as a worker with the A2A server
5//! - Send heartbeats to maintain registration
6//! - Receive incoming tasks via SSE stream
7//! - Register/deregister sub-agents (relay, autochat, spawned agents)
8//! - Forward bus events to the server for observability
9
10use crate::a2a::worker::{
11    CognitionHeartbeatConfig, HeartbeatState, WorkerStatus, register_worker, start_heartbeat,
12};
13use crate::bus::{AgentBus, BusEnvelope, BusMessage};
14use crate::cli::auth::load_saved_credentials;
15use crate::config::Config;
16use anyhow::Result;
17use futures::StreamExt;
18use reqwest::Client;
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::{Mutex, mpsc};
23
24/// Command sent to the worker bridge to register/deregister sub-agents
25#[derive(Debug, Clone)]
26pub enum WorkerBridgeCmd {
27    /// Register a sub-agent with the A2A server
28    RegisterAgent { name: String, instructions: String },
29    /// Deregister a sub-agent
30    DeregisterAgent { name: String },
31    /// Update the processing status (for heartbeat)
32    SetProcessing(bool),
33}
34
35/// Incoming task from the A2A server via SSE
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct IncomingTask {
38    pub task_id: String,
39    pub message: String,
40    pub from_agent: Option<String>,
41}
42
43/// Result of worker bridge initialization
44pub struct TuiWorkerBridge {
45    /// Worker ID assigned by the server
46    pub worker_id: String,
47    /// Worker name
48    pub worker_name: String,
49    /// Sender for commands (register/deregister agents)
50    pub cmd_tx: mpsc::Sender<WorkerBridgeCmd>,
51    /// Receiver for incoming tasks
52    pub task_rx: mpsc::Receiver<IncomingTask>,
53    /// Handle for the bridge task (for shutdown)
54    #[allow(dead_code)]
55    pub handle: tokio::task::JoinHandle<()>,
56}
57
58impl TuiWorkerBridge {
59    /// Spawn the worker bridge if server credentials are available.
60    /// Returns None if no server is configured.
61    pub async fn spawn(
62        server_url: Option<String>,
63        token: Option<String>,
64        worker_name: Option<String>,
65        bus: Arc<AgentBus>,
66    ) -> Result<Option<Self>> {
67        // Try to get server URL from config if not provided
68        let server = match server_url {
69            Some(url) => url,
70            None => {
71                let config = Config::load().await?;
72                match config.a2a.server_url {
73                    Some(url) => url,
74                    None => {
75                        tracing::debug!("No A2A server configured, worker bridge disabled");
76                        return Ok(None);
77                    }
78                }
79            }
80        };
81
82        // Try to get token from saved credentials if not provided
83        let token = match token {
84            Some(t) => Some(t),
85            None => load_saved_credentials().map(|c| c.access_token),
86        };
87
88        // If no token, we can't register - but we could run in "read-only" mode
89        // For now, let's require a token for registration
90        let token = match token {
91            Some(t) => t,
92            None => {
93                tracing::debug!("No A2A token available, worker bridge disabled");
94                return Ok(None);
95            }
96        };
97
98        // Generate worker ID and name
99        let worker_id = crate::a2a::worker::generate_worker_id();
100        let worker_name = worker_name.unwrap_or_else(|| format!("tui-{}", std::process::id()));
101
102        tracing::info!(
103            worker_id = %worker_id,
104            worker_name = %worker_name,
105            server = %server,
106            "Starting TUI worker bridge"
107        );
108
109        // Create channels
110        let (cmd_tx, cmd_rx) = mpsc::channel::<WorkerBridgeCmd>(256);
111        let (task_tx, task_rx) = mpsc::channel::<IncomingTask>(64);
112
113        // Create shared state
114        let client = Client::new();
115        let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
116        let heartbeat_state = HeartbeatState::new(worker_id.clone(), worker_name.clone());
117        let cognition_config = CognitionHeartbeatConfig::from_env();
118        let codebases = vec![
119            std::env::current_dir()
120                .map(|p| p.display().to_string())
121                .unwrap_or_else(|_| ".".to_string()),
122        ];
123
124        // Spawn the main bridge task
125        let handle = tokio::spawn({
126            let server = server.clone();
127            let token = token.clone();
128            let worker_id = worker_id.clone();
129            let worker_name = worker_name.clone();
130            let client = client.clone();
131            let processing = processing.clone();
132            let heartbeat_state = heartbeat_state.clone();
133            async move {
134                let token_opt = Some(token.clone());
135                if let Err(e) = register_worker(
136                    &client,
137                    &server,
138                    &token_opt,
139                    &worker_id,
140                    &worker_name,
141                    &codebases,
142                    None,
143                )
144                .await
145                {
146                    tracing::warn!("Failed to register worker with A2A server: {}", e);
147                    // Continue anyway - we can still try to process tasks
148                }
149
150                let heartbeat_handle = start_heartbeat(
151                    client.clone(),
152                    server.clone(),
153                    token_opt.clone(),
154                    heartbeat_state.clone(),
155                    processing.clone(),
156                    cognition_config,
157                    Arc::new(Mutex::new(
158                        crate::a2a::worker::task_timeline::TaskProgressState::new(),
159                    )),
160                );
161
162                // Background task: connect to SSE stream for incoming tasks
163                let sse_handle = tokio::spawn({
164                    let server = server.clone();
165                    let token = token.clone();
166                    let worker_id = worker_id.clone();
167                    let worker_name = worker_name.clone();
168                    async move {
169                        loop {
170                            let url = format!(
171                                "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
172                                server,
173                                urlencoding::encode(&worker_name),
174                                urlencoding::encode(&worker_id)
175                            );
176
177                            let req = Client::new()
178                                .get(&url)
179                                .header("Accept", "text/event-stream")
180                                .header("X-Worker-ID", &worker_id)
181                                .header("X-Agent-Name", &worker_name)
182                                .bearer_auth(&token);
183
184                            match req.send().await {
185                                Ok(res) if res.status().is_success() => {
186                                    tracing::info!("Connected to A2A task stream");
187                                    let mut stream = res.bytes_stream();
188                                    let mut buffer = String::new();
189
190                                    while let Some(chunk) = stream.next().await {
191                                        match chunk {
192                                            Ok(bytes) => {
193                                                buffer.push_str(&String::from_utf8_lossy(&bytes));
194
195                                                // Process SSE events
196                                                while let Some(pos) = buffer.find("\n\n") {
197                                                    let event_str = buffer[..pos].to_string();
198                                                    buffer = buffer[pos + 2..].to_string();
199
200                                                    if let Some(data_line) = event_str
201                                                        .lines()
202                                                        .find(|l| l.starts_with("data:"))
203                                                    {
204                                                        let data = data_line
205                                                            .trim_start_matches("data:")
206                                                            .trim();
207                                                        if data.is_empty() || data == "[DONE]" {
208                                                            continue;
209                                                        }
210
211                                                        // Try to parse as task
212                                                        if let Ok(task) = serde_json::from_str::<
213                                                            serde_json::Value,
214                                                        >(
215                                                            data
216                                                        ) {
217                                                            let task_id = task
218                                                                .get("task_id")
219                                                                .or_else(|| task.get("id"))
220                                                                .and_then(|v| v.as_str())
221                                                                .unwrap_or("unknown")
222                                                                .to_string();
223
224                                                            let message = task
225                                                                .get("message")
226                                                                .or_else(|| task.get("text"))
227                                                                .and_then(|v| v.as_str())
228                                                                .unwrap_or("")
229                                                                .to_string();
230
231                                                            let from_agent = task
232                                                                .get("from_agent")
233                                                                .or_else(|| task.get("agent"))
234                                                                .and_then(|v| v.as_str())
235                                                                .map(String::from);
236
237                                                            let incoming = IncomingTask {
238                                                                task_id,
239                                                                message,
240                                                                from_agent,
241                                                            };
242
243                                                            if task_tx.send(incoming).await.is_err()
244                                                            {
245                                                                tracing::warn!(
246                                                                    "Task receiver dropped, stopping SSE stream"
247                                                                );
248                                                                return;
249                                                            }
250                                                        }
251                                                    }
252                                                }
253                                            }
254                                            Err(e) => {
255                                                tracing::warn!("SSE stream error: {}", e);
256                                                break;
257                                            }
258                                        }
259                                    }
260                                }
261                                Ok(res) => {
262                                    tracing::warn!(
263                                        "Failed to connect to task stream: {}",
264                                        res.status()
265                                    );
266                                }
267                                Err(e) => {
268                                    tracing::warn!("Failed to connect to task stream: {}", e);
269                                }
270                            }
271
272                            // Reconnect after delay
273                            tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
274                        }
275                    }
276                });
277
278                // Handle commands and bus events
279                let mut cmd_rx = cmd_rx;
280                let mut bus_handle = bus.handle(&worker_id);
281
282                loop {
283                    tokio::select! {
284                        // Handle command to register/deregister agents
285                        Some(cmd) = cmd_rx.recv() => {
286                            match cmd {
287                                WorkerBridgeCmd::RegisterAgent { name, instructions } => {
288                                    tracing::info!(agent = %name, "Registering sub-agent with A2A server");
289                                    tracing::debug!(
290                                        agent = %name,
291                                        instructions_len = instructions.len(),
292                                        "Tracking sub-agent in heartbeat state"
293                                    );
294                                    heartbeat_state.register_sub_agent(name).await;
295                                }
296                                WorkerBridgeCmd::DeregisterAgent { name } => {
297                                    tracing::info!(agent = %name, "Deregistering sub-agent from A2A server");
298                                    heartbeat_state.deregister_sub_agent(&name).await;
299                                }
300                                WorkerBridgeCmd::SetProcessing(processing) => {
301                                    let status = if processing {
302                                        WorkerStatus::Processing
303                                    } else {
304                                        WorkerStatus::Idle
305                                    };
306                                    heartbeat_state.set_status(status).await;
307                                }
308                            }
309                        }
310                        // Handle bus events - forward to server
311                        Some(envelope) = bus_handle.recv() => {
312                            // Forward interesting events to server for observability
313                            if let Err(e) = forward_bus_event(&client, &server, &token, &worker_id, &envelope).await {
314                                tracing::debug!("Failed to forward bus event: {}", e);
315                            }
316                        }
317                        // Handle shutdown
318                        _ = tokio::signal::ctrl_c() => {
319                            tracing::info!("Worker bridge received shutdown signal");
320                            break;
321                        }
322                    }
323                }
324
325                // Cleanup
326                heartbeat_handle.abort();
327                sse_handle.abort();
328                tracing::info!("Worker bridge stopped");
329            }
330        });
331
332        Ok(Some(TuiWorkerBridge {
333            worker_id,
334            worker_name,
335            cmd_tx,
336            task_rx,
337            handle,
338        }))
339    }
340}
341
342/// Forward bus events to the A2A server for observability
343async fn forward_bus_event(
344    client: &Client,
345    server: &str,
346    token: &str,
347    worker_id: &str,
348    envelope: &BusEnvelope,
349) -> Result<()> {
350    // Only forward certain event types
351    let payload = match &envelope.message {
352        BusMessage::AgentReady {
353            agent_id,
354            capabilities,
355        } => {
356            serde_json::json!({
357                "type": "agent_ready",
358                "worker_id": worker_id,
359                "agent_id": agent_id,
360                "capabilities": capabilities,
361            })
362        }
363        BusMessage::TaskUpdate {
364            task_id,
365            state,
366            message,
367        } => {
368            serde_json::json!({
369                "type": "task_update",
370                "worker_id": worker_id,
371                "task_id": task_id,
372                "state": format!("{:?}", state),
373                "message": message,
374            })
375        }
376        BusMessage::AgentMessage { from, to, parts } => {
377            let text = parts
378                .iter()
379                .filter_map(|p| {
380                    if let crate::a2a::types::Part::Text { text } = p {
381                        Some(text.clone())
382                    } else {
383                        None
384                    }
385                })
386                .collect::<Vec<_>>()
387                .join("\n");
388
389            serde_json::json!({
390                "type": "agent_message",
391                "worker_id": worker_id,
392                "from": from,
393                "to": to,
394                "text": text,
395            })
396        }
397        // Skip other event types for now
398        _ => return Ok(()),
399    };
400
401    let url = format!("{}/v1/agent/workers/{}/events", server, worker_id);
402    let res = client
403        .post(&url)
404        .bearer_auth(token)
405        .json(&payload)
406        .send()
407        .await?;
408
409    if !res.status().is_success() {
410        tracing::debug!("Failed to forward bus event: {}", res.status());
411    }
412
413    Ok(())
414}