Skip to main content

codetether_agent/tui/
worker_bridge.rs

1//! TUI Worker Bridge - connects the TUI to the A2A server substrate.
2//!
3//! This module enables the TUI to:
4//! - Register itself as a worker with the A2A server
5//! - Send heartbeats to maintain registration
6//! - Receive incoming tasks via SSE stream
7//! - Register/deregister sub-agents (relay, autochat, spawned agents)
8//! - Forward bus events to the server for observability
9
10use crate::a2a::worker::{
11    CognitionHeartbeatConfig, HeartbeatState, WorkerStatus, register_worker, start_heartbeat,
12};
13use crate::bus::{AgentBus, BusEnvelope, BusMessage};
14use crate::cli::auth::load_saved_credentials;
15use crate::config::Config;
16use anyhow::Result;
17use futures::StreamExt;
18use reqwest::Client;
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::{Mutex, mpsc};
23
24/// Command sent to the worker bridge to register/deregister sub-agents
25#[derive(Debug, Clone)]
26pub enum WorkerBridgeCmd {
27    /// Register a sub-agent with the A2A server
28    RegisterAgent { name: String, instructions: String },
29    /// Deregister a sub-agent
30    DeregisterAgent { name: String },
31    /// Update the processing status (for heartbeat)
32    SetProcessing(bool),
33}
34
35/// Incoming task from the A2A server via SSE
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct IncomingTask {
38    pub task_id: String,
39    pub message: String,
40    pub from_agent: Option<String>,
41}
42
43/// Result of worker bridge initialization
44pub struct TuiWorkerBridge {
45    /// Worker ID assigned by the server
46    pub worker_id: String,
47    /// Worker name
48    pub worker_name: String,
49    /// Sender for commands (register/deregister agents)
50    pub cmd_tx: mpsc::Sender<WorkerBridgeCmd>,
51    /// Receiver for incoming tasks
52    pub task_rx: mpsc::Receiver<IncomingTask>,
53    /// Handle for the bridge task (for shutdown)
54    #[allow(dead_code)]
55    pub handle: tokio::task::JoinHandle<()>,
56}
57
58impl TuiWorkerBridge {
59    /// Spawn the worker bridge if server credentials are available.
60    /// Returns None if no server is configured.
61    pub async fn spawn(
62        server_url: Option<String>,
63        token: Option<String>,
64        worker_name: Option<String>,
65        bus: Arc<AgentBus>,
66    ) -> Result<Option<Self>> {
67        // Try to get server URL from config if not provided
68        let server = match server_url {
69            Some(url) => url,
70            None => {
71                let config = Config::load().await?;
72                match config.a2a.server_url {
73                    Some(url) => url,
74                    None => {
75                        tracing::debug!("No A2A server configured, worker bridge disabled");
76                        return Ok(None);
77                    }
78                }
79            }
80        };
81
82        // Try to get token from saved credentials if not provided
83        let token = match token {
84            Some(t) => Some(t),
85            None => load_saved_credentials().map(|c| c.access_token),
86        };
87
88        // If no token, we can't register - but we could run in "read-only" mode
89        // For now, let's require a token for registration
90        let token = match token {
91            Some(t) => t,
92            None => {
93                tracing::debug!("No A2A token available, worker bridge disabled");
94                return Ok(None);
95            }
96        };
97
98        // Generate worker ID and name
99        let worker_id = crate::a2a::worker::generate_worker_id();
100        let worker_name = worker_name.unwrap_or_else(|| format!("tui-{}", std::process::id()));
101
102        tracing::info!(
103            worker_id = %worker_id,
104            worker_name = %worker_name,
105            server = %server,
106            "Starting TUI worker bridge"
107        );
108
109        // Create channels
110        let (cmd_tx, cmd_rx) = mpsc::channel::<WorkerBridgeCmd>(256);
111        let (task_tx, task_rx) = mpsc::channel::<IncomingTask>(64);
112
113        // Create shared state
114        let client = Client::new();
115        let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
116        let heartbeat_state = HeartbeatState::new(worker_id.clone(), worker_name.clone());
117        let cognition_config = CognitionHeartbeatConfig::from_env();
118        let codebases = vec![
119            std::env::current_dir()
120                .map(|p| p.display().to_string())
121                .unwrap_or_else(|_| ".".to_string()),
122        ];
123
124        // Spawn the main bridge task
125        let handle = tokio::spawn({
126            let server = server.clone();
127            let token = token.clone();
128            let worker_id = worker_id.clone();
129            let worker_name = worker_name.clone();
130            let client = client.clone();
131            let processing = processing.clone();
132            let heartbeat_state = heartbeat_state.clone();
133            async move {
134                let token_opt = Some(token.clone());
135                if let Err(e) = register_worker(
136                    &client,
137                    &server,
138                    &token_opt,
139                    &worker_id,
140                    &worker_name,
141                    &codebases,
142                    None,
143                )
144                .await
145                {
146                    tracing::warn!("Failed to register worker with A2A server: {}", e);
147                    // Continue anyway - we can still try to process tasks
148                }
149
150                let heartbeat_handle = start_heartbeat(
151                    client.clone(),
152                    server.clone(),
153                    token_opt.clone(),
154                    heartbeat_state.clone(),
155                    processing.clone(),
156                    cognition_config,
157                );
158
159                // Background task: connect to SSE stream for incoming tasks
160                let sse_handle = tokio::spawn({
161                    let server = server.clone();
162                    let token = token.clone();
163                    let worker_id = worker_id.clone();
164                    let worker_name = worker_name.clone();
165                    async move {
166                        loop {
167                            let url = format!(
168                                "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
169                                server,
170                                urlencoding::encode(&worker_name),
171                                urlencoding::encode(&worker_id)
172                            );
173
174                            let req = Client::new()
175                                .get(&url)
176                                .header("Accept", "text/event-stream")
177                                .header("X-Worker-ID", &worker_id)
178                                .header("X-Agent-Name", &worker_name)
179                                .bearer_auth(&token);
180
181                            match req.send().await {
182                                Ok(res) if res.status().is_success() => {
183                                    tracing::info!("Connected to A2A task stream");
184                                    let mut stream = res.bytes_stream();
185                                    let mut buffer = String::new();
186
187                                    while let Some(chunk) = stream.next().await {
188                                        match chunk {
189                                            Ok(bytes) => {
190                                                buffer.push_str(&String::from_utf8_lossy(&bytes));
191
192                                                // Process SSE events
193                                                while let Some(pos) = buffer.find("\n\n") {
194                                                    let event_str = buffer[..pos].to_string();
195                                                    buffer = buffer[pos + 2..].to_string();
196
197                                                    if let Some(data_line) = event_str
198                                                        .lines()
199                                                        .find(|l| l.starts_with("data:"))
200                                                    {
201                                                        let data = data_line
202                                                            .trim_start_matches("data:")
203                                                            .trim();
204                                                        if data.is_empty() || data == "[DONE]" {
205                                                            continue;
206                                                        }
207
208                                                        // Try to parse as task
209                                                        if let Ok(task) = serde_json::from_str::<
210                                                            serde_json::Value,
211                                                        >(
212                                                            data
213                                                        ) {
214                                                            let task_id = task
215                                                                .get("task_id")
216                                                                .or_else(|| task.get("id"))
217                                                                .and_then(|v| v.as_str())
218                                                                .unwrap_or("unknown")
219                                                                .to_string();
220
221                                                            let message = task
222                                                                .get("message")
223                                                                .or_else(|| task.get("text"))
224                                                                .and_then(|v| v.as_str())
225                                                                .unwrap_or("")
226                                                                .to_string();
227
228                                                            let from_agent = task
229                                                                .get("from_agent")
230                                                                .or_else(|| task.get("agent"))
231                                                                .and_then(|v| v.as_str())
232                                                                .map(String::from);
233
234                                                            let incoming = IncomingTask {
235                                                                task_id,
236                                                                message,
237                                                                from_agent,
238                                                            };
239
240                                                            if task_tx.send(incoming).await.is_err()
241                                                            {
242                                                                tracing::warn!(
243                                                                    "Task receiver dropped, stopping SSE stream"
244                                                                );
245                                                                return;
246                                                            }
247                                                        }
248                                                    }
249                                                }
250                                            }
251                                            Err(e) => {
252                                                tracing::warn!("SSE stream error: {}", e);
253                                                break;
254                                            }
255                                        }
256                                    }
257                                }
258                                Ok(res) => {
259                                    tracing::warn!(
260                                        "Failed to connect to task stream: {}",
261                                        res.status()
262                                    );
263                                }
264                                Err(e) => {
265                                    tracing::warn!("Failed to connect to task stream: {}", e);
266                                }
267                            }
268
269                            // Reconnect after delay
270                            tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
271                        }
272                    }
273                });
274
275                // Handle commands and bus events
276                let mut cmd_rx = cmd_rx;
277                let mut bus_handle = bus.handle(&worker_id);
278
279                loop {
280                    tokio::select! {
281                        // Handle command to register/deregister agents
282                        Some(cmd) = cmd_rx.recv() => {
283                            match cmd {
284                                WorkerBridgeCmd::RegisterAgent { name, instructions } => {
285                                    tracing::info!(agent = %name, "Registering sub-agent with A2A server");
286                                    tracing::debug!(
287                                        agent = %name,
288                                        instructions_len = instructions.len(),
289                                        "Tracking sub-agent in heartbeat state"
290                                    );
291                                    heartbeat_state.register_sub_agent(name).await;
292                                }
293                                WorkerBridgeCmd::DeregisterAgent { name } => {
294                                    tracing::info!(agent = %name, "Deregistering sub-agent from A2A server");
295                                    heartbeat_state.deregister_sub_agent(&name).await;
296                                }
297                                WorkerBridgeCmd::SetProcessing(processing) => {
298                                    let status = if processing {
299                                        WorkerStatus::Processing
300                                    } else {
301                                        WorkerStatus::Idle
302                                    };
303                                    heartbeat_state.set_status(status).await;
304                                }
305                            }
306                        }
307                        // Handle bus events - forward to server
308                        Some(envelope) = bus_handle.recv() => {
309                            // Forward interesting events to server for observability
310                            if let Err(e) = forward_bus_event(&client, &server, &token, &worker_id, &envelope).await {
311                                tracing::debug!("Failed to forward bus event: {}", e);
312                            }
313                        }
314                        // Handle shutdown
315                        _ = tokio::signal::ctrl_c() => {
316                            tracing::info!("Worker bridge received shutdown signal");
317                            break;
318                        }
319                    }
320                }
321
322                // Cleanup
323                heartbeat_handle.abort();
324                sse_handle.abort();
325                tracing::info!("Worker bridge stopped");
326            }
327        });
328
329        Ok(Some(TuiWorkerBridge {
330            worker_id,
331            worker_name,
332            cmd_tx,
333            task_rx,
334            handle,
335        }))
336    }
337}
338
339/// Forward bus events to the A2A server for observability
340async fn forward_bus_event(
341    client: &Client,
342    server: &str,
343    token: &str,
344    worker_id: &str,
345    envelope: &BusEnvelope,
346) -> Result<()> {
347    // Only forward certain event types
348    let payload = match &envelope.message {
349        BusMessage::AgentReady {
350            agent_id,
351            capabilities,
352        } => {
353            serde_json::json!({
354                "type": "agent_ready",
355                "worker_id": worker_id,
356                "agent_id": agent_id,
357                "capabilities": capabilities,
358            })
359        }
360        BusMessage::TaskUpdate {
361            task_id,
362            state,
363            message,
364        } => {
365            serde_json::json!({
366                "type": "task_update",
367                "worker_id": worker_id,
368                "task_id": task_id,
369                "state": format!("{:?}", state),
370                "message": message,
371            })
372        }
373        BusMessage::AgentMessage { from, to, parts } => {
374            let text = parts
375                .iter()
376                .filter_map(|p| {
377                    if let crate::a2a::types::Part::Text { text } = p {
378                        Some(text.clone())
379                    } else {
380                        None
381                    }
382                })
383                .collect::<Vec<_>>()
384                .join("\n");
385
386            serde_json::json!({
387                "type": "agent_message",
388                "worker_id": worker_id,
389                "from": from,
390                "to": to,
391                "text": text,
392            })
393        }
394        // Skip other event types for now
395        _ => return Ok(()),
396    };
397
398    let url = format!("{}/v1/agent/workers/{}/events", server, worker_id);
399    let res = client
400        .post(&url)
401        .bearer_auth(token)
402        .json(&payload)
403        .send()
404        .await?;
405
406    if !res.status().is_success() {
407        tracing::debug!("Failed to forward bus event: {}", res.status());
408    }
409
410    Ok(())
411}