Skip to main content

codetether_agent/tui/
worker_bridge.rs

1//! TUI Worker Bridge - connects the TUI to the A2A server substrate.
2//!
3//! This module enables the TUI to:
4//! - Register itself as a worker with the A2A server
5//! - Send heartbeats to maintain registration
6//! - Receive incoming tasks via SSE stream
7//! - Register/deregister sub-agents (relay, autochat, spawned agents)
8//! - Forward bus events to the server for observability
9
10use crate::a2a::worker::{
11    CognitionHeartbeatConfig, HeartbeatState, WorkerStatus, register_worker, start_heartbeat,
12};
13use crate::bus::{AgentBus, BusEnvelope, BusMessage};
14use crate::cli::auth::load_saved_credentials;
15use crate::config::Config;
16use anyhow::Result;
17use futures::StreamExt;
18use reqwest::Client;
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::{Mutex, mpsc};
23
24/// Command sent to the worker bridge to register/deregister sub-agents
25#[derive(Debug, Clone)]
26pub enum WorkerBridgeCmd {
27    /// Register a sub-agent with the A2A server
28    RegisterAgent { name: String, instructions: String },
29    /// Deregister a sub-agent
30    DeregisterAgent { name: String },
31    /// Update the processing status (for heartbeat)
32    SetProcessing(bool),
33}
34
35/// Incoming task from the A2A server via SSE
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct IncomingTask {
38    pub task_id: String,
39    pub message: String,
40    pub from_agent: Option<String>,
41}
42
43/// Result of worker bridge initialization
44pub struct TuiWorkerBridge {
45    /// Worker ID assigned by the server
46    pub worker_id: String,
47    /// Worker name
48    pub worker_name: String,
49    /// Sender for commands (register/deregister agents)
50    pub cmd_tx: mpsc::Sender<WorkerBridgeCmd>,
51    /// Receiver for incoming tasks
52    pub task_rx: mpsc::Receiver<IncomingTask>,
53    /// Handle for the bridge task (for shutdown)
54    pub handle: tokio::task::JoinHandle<()>,
55}
56
57impl TuiWorkerBridge {
58    /// Spawn the worker bridge if server credentials are available.
59    /// Returns None if no server is configured.
60    pub async fn spawn(
61        server_url: Option<String>,
62        token: Option<String>,
63        worker_name: Option<String>,
64        bus: Arc<AgentBus>,
65    ) -> Result<Option<Self>> {
66        // Try to get server URL from config if not provided
67        let server = match server_url {
68            Some(url) => url,
69            None => {
70                let config = Config::load().await?;
71                match config.a2a.server_url {
72                    Some(url) => url,
73                    None => {
74                        tracing::debug!("No A2A server configured, worker bridge disabled");
75                        return Ok(None);
76                    }
77                }
78            }
79        };
80
81        // Try to get token from saved credentials if not provided
82        let token = match token {
83            Some(t) => Some(t),
84            None => load_saved_credentials().map(|c| c.access_token),
85        };
86
87        // If no token, we can't register - but we could run in "read-only" mode
88        // For now, let's require a token for registration
89        let token = match token {
90            Some(t) => t,
91            None => {
92                tracing::debug!("No A2A token available, worker bridge disabled");
93                return Ok(None);
94            }
95        };
96
97        // Generate worker ID and name
98        let worker_id = crate::a2a::worker::generate_worker_id();
99        let worker_name = worker_name.unwrap_or_else(|| format!("tui-{}", std::process::id()));
100
101        tracing::info!(
102            worker_id = %worker_id,
103            worker_name = %worker_name,
104            server = %server,
105            "Starting TUI worker bridge"
106        );
107
108        // Create channels
109        let (cmd_tx, cmd_rx) = mpsc::channel::<WorkerBridgeCmd>(32);
110        let (task_tx, task_rx) = mpsc::channel::<IncomingTask>(64);
111
112        // Create shared state
113        let client = Client::new();
114        let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
115        let heartbeat_state = HeartbeatState::new(worker_id.clone(), worker_name.clone());
116        let cognition_config = CognitionHeartbeatConfig::from_env();
117        let codebases = vec![
118            std::env::current_dir()
119                .map(|p| p.display().to_string())
120                .unwrap_or_else(|_| ".".to_string()),
121        ];
122
123        // Register with the server
124        let token_opt = Some(token.clone());
125        if let Err(e) = register_worker(
126            &client,
127            &server,
128            &token_opt,
129            &worker_id,
130            &worker_name,
131            &codebases,
132        )
133        .await
134        {
135            tracing::warn!("Failed to register worker with A2A server: {}", e);
136            // Continue anyway - we can still try to process tasks
137        }
138
139        // Start heartbeat
140        let heartbeat_handle = start_heartbeat(
141            client.clone(),
142            server.clone(),
143            token_opt.clone(),
144            heartbeat_state.clone(),
145            processing.clone(),
146            cognition_config,
147        );
148
149        // Spawn the main bridge task
150        let handle = tokio::spawn({
151            let server = server.clone();
152            let token = token.clone();
153            let worker_id = worker_id.clone();
154            let worker_name = worker_name.clone();
155            async move {
156                // Background task: connect to SSE stream for incoming tasks
157                let sse_handle = tokio::spawn({
158                    let server = server.clone();
159                    let token = token.clone();
160                    let worker_id = worker_id.clone();
161                    let worker_name = worker_name.clone();
162                    async move {
163                        loop {
164                            let url = format!(
165                                "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
166                                server,
167                                urlencoding::encode(&worker_name),
168                                urlencoding::encode(&worker_id)
169                            );
170
171                            let req = Client::new()
172                                .get(&url)
173                                .header("Accept", "text/event-stream")
174                                .header("X-Worker-ID", &worker_id)
175                                .header("X-Agent-Name", &worker_name)
176                                .bearer_auth(&token);
177
178                            match req.send().await {
179                                Ok(res) if res.status().is_success() => {
180                                    tracing::info!("Connected to A2A task stream");
181                                    let mut stream = res.bytes_stream();
182                                    let mut buffer = String::new();
183
184                                    while let Some(chunk) = stream.next().await {
185                                        match chunk {
186                                            Ok(bytes) => {
187                                                buffer.push_str(&String::from_utf8_lossy(&bytes));
188
189                                                // Process SSE events
190                                                while let Some(pos) = buffer.find("\n\n") {
191                                                    let event_str = buffer[..pos].to_string();
192                                                    buffer = buffer[pos + 2..].to_string();
193
194                                                    if let Some(data_line) = event_str
195                                                        .lines()
196                                                        .find(|l| l.starts_with("data:"))
197                                                    {
198                                                        let data = data_line
199                                                            .trim_start_matches("data:")
200                                                            .trim();
201                                                        if data.is_empty() || data == "[DONE]" {
202                                                            continue;
203                                                        }
204
205                                                        // Try to parse as task
206                                                        if let Ok(task) = serde_json::from_str::<
207                                                            serde_json::Value,
208                                                        >(
209                                                            data
210                                                        ) {
211                                                            let task_id = task
212                                                                .get("task_id")
213                                                                .or_else(|| task.get("id"))
214                                                                .and_then(|v| v.as_str())
215                                                                .unwrap_or("unknown")
216                                                                .to_string();
217
218                                                            let message = task
219                                                                .get("message")
220                                                                .or_else(|| task.get("text"))
221                                                                .and_then(|v| v.as_str())
222                                                                .unwrap_or("")
223                                                                .to_string();
224
225                                                            let from_agent = task
226                                                                .get("from_agent")
227                                                                .or_else(|| task.get("agent"))
228                                                                .and_then(|v| v.as_str())
229                                                                .map(String::from);
230
231                                                            let incoming = IncomingTask {
232                                                                task_id,
233                                                                message,
234                                                                from_agent,
235                                                            };
236
237                                                            if task_tx.send(incoming).await.is_err()
238                                                            {
239                                                                tracing::warn!(
240                                                                    "Task receiver dropped, stopping SSE stream"
241                                                                );
242                                                                return;
243                                                            }
244                                                        }
245                                                    }
246                                                }
247                                            }
248                                            Err(e) => {
249                                                tracing::warn!("SSE stream error: {}", e);
250                                                break;
251                                            }
252                                        }
253                                    }
254                                }
255                                Ok(res) => {
256                                    tracing::warn!(
257                                        "Failed to connect to task stream: {}",
258                                        res.status()
259                                    );
260                                }
261                                Err(e) => {
262                                    tracing::warn!("Failed to connect to task stream: {}", e);
263                                }
264                            }
265
266                            // Reconnect after delay
267                            tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
268                        }
269                    }
270                });
271
272                // Handle commands and bus events
273                let mut cmd_rx = cmd_rx;
274                let mut bus_handle = bus.handle(&worker_id);
275
276                loop {
277                    tokio::select! {
278                        // Handle command to register/deregister agents
279                        Some(cmd) = cmd_rx.recv() => {
280                            match cmd {
281                                WorkerBridgeCmd::RegisterAgent { name, instructions: _ } => {
282                                    tracing::info!(agent = %name, "Registering sub-agent with A2A server");
283                                    // For now, just log - full implementation would POST to server
284                                    // The server tracks sub-agents via the heartbeat's agent list
285                                }
286                                WorkerBridgeCmd::DeregisterAgent { name } => {
287                                    tracing::info!(agent = %name, "Deregistering sub-agent from A2A server");
288                                }
289                                WorkerBridgeCmd::SetProcessing(processing) => {
290                                    let status = if processing {
291                                        WorkerStatus::Processing
292                                    } else {
293                                        WorkerStatus::Idle
294                                    };
295                                    heartbeat_state.set_status(status).await;
296                                }
297                            }
298                        }
299                        // Handle bus events - forward to server
300                        Some(envelope) = bus_handle.recv() => {
301                            // Forward interesting events to server for observability
302                            if let Err(e) = forward_bus_event(&client, &server, &token, &worker_id, &envelope).await {
303                                tracing::debug!("Failed to forward bus event: {}", e);
304                            }
305                        }
306                        // Handle shutdown
307                        _ = tokio::signal::ctrl_c() => {
308                            tracing::info!("Worker bridge received shutdown signal");
309                            break;
310                        }
311                    }
312                }
313
314                // Cleanup
315                heartbeat_handle.abort();
316                sse_handle.abort();
317                tracing::info!("Worker bridge stopped");
318            }
319        });
320
321        Ok(Some(TuiWorkerBridge {
322            worker_id,
323            worker_name,
324            cmd_tx,
325            task_rx,
326            handle,
327        }))
328    }
329}
330
331/// Forward bus events to the A2A server for observability
332async fn forward_bus_event(
333    client: &Client,
334    server: &str,
335    token: &str,
336    worker_id: &str,
337    envelope: &BusEnvelope,
338) -> Result<()> {
339    // Only forward certain event types
340    let payload = match &envelope.message {
341        BusMessage::AgentReady {
342            agent_id,
343            capabilities,
344        } => {
345            serde_json::json!({
346                "type": "agent_ready",
347                "worker_id": worker_id,
348                "agent_id": agent_id,
349                "capabilities": capabilities,
350            })
351        }
352        BusMessage::TaskUpdate {
353            task_id,
354            state,
355            message,
356        } => {
357            serde_json::json!({
358                "type": "task_update",
359                "worker_id": worker_id,
360                "task_id": task_id,
361                "state": format!("{:?}", state),
362                "message": message,
363            })
364        }
365        BusMessage::AgentMessage { from, to, parts } => {
366            let text = parts
367                .iter()
368                .filter_map(|p| {
369                    if let crate::a2a::types::Part::Text { text } = p {
370                        Some(text.clone())
371                    } else {
372                        None
373                    }
374                })
375                .collect::<Vec<_>>()
376                .join("\n");
377
378            serde_json::json!({
379                "type": "agent_message",
380                "worker_id": worker_id,
381                "from": from,
382                "to": to,
383                "text": text,
384            })
385        }
386        // Skip other event types for now
387        _ => return Ok(()),
388    };
389
390    let url = format!("{}/v1/agent/workers/{}/events", server, worker_id);
391    let res = client
392        .post(&url)
393        .bearer_auth(token)
394        .json(&payload)
395        .send()
396        .await?;
397
398    if !res.status().is_success() {
399        tracing::debug!("Failed to forward bus event: {}", res.status());
400    }
401
402    Ok(())
403}