Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::cli::A2aArgs;
4use crate::session::Session;
5use anyhow::Result;
6use futures::StreamExt;
7use reqwest::Client;
8use std::collections::HashSet;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12/// Run the A2A worker
13pub async fn run(args: A2aArgs) -> Result<()> {
14    let server = args.server.trim_end_matches('/');
15    let name = args.name.unwrap_or_else(|| format!("codetether-{}", std::process::id()));
16    let worker_id = generate_worker_id();
17    
18    let codebases: Vec<String> = args
19        .codebases
20        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
21        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
22
23    tracing::info!(
24        "Starting A2A worker: {} ({})",
25        name,
26        worker_id
27    );
28    tracing::info!("Server: {}", server);
29    tracing::info!("Codebases: {:?}", codebases);
30
31    let client = Client::new();
32    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
33    
34    let auto_approve = match args.auto_approve.as_str() {
35        "all" => AutoApprove::All,
36        "safe" => AutoApprove::Safe,
37        _ => AutoApprove::None,
38    };
39
40    // Register worker
41    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
42
43    // Fetch pending tasks
44    fetch_pending_tasks(&client, server, &args.token, &worker_id, &processing, &auto_approve).await?;
45
46    // Connect to SSE stream
47    loop {
48        match connect_stream(&client, server, &args.token, &worker_id, &name, &codebases, &processing, &auto_approve).await {
49            Ok(()) => {
50                tracing::warn!("Stream ended, reconnecting...");
51            }
52            Err(e) => {
53                tracing::error!("Stream error: {}, reconnecting...", e);
54            }
55        }
56        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
57    }
58}
59
60fn generate_worker_id() -> String {
61    format!(
62        "wrk_{}_{:x}",
63        chrono::Utc::now().timestamp(),
64        rand::random::<u64>()
65    )
66}
67
68#[derive(Debug, Clone, Copy)]
69enum AutoApprove {
70    All,
71    Safe,
72    None,
73}
74
75async fn register_worker(
76    client: &Client,
77    server: &str,
78    token: &Option<String>,
79    worker_id: &str,
80    name: &str,
81    codebases: &[String],
82) -> Result<()> {
83    let mut req = client.put(format!("{}/v1/worker/codebases", server));
84    
85    if let Some(t) = token {
86        req = req.bearer_auth(t);
87    }
88
89    let res = req
90        .json(&serde_json::json!({
91            "codebases": codebases,
92            "worker_id": worker_id,
93            "agent_name": name,
94        }))
95        .send()
96        .await?;
97
98    if res.status().is_success() {
99        tracing::info!("Worker registered successfully");
100    } else {
101        tracing::warn!("Failed to register worker: {}", res.status());
102    }
103
104    Ok(())
105}
106
107async fn fetch_pending_tasks(
108    client: &Client,
109    server: &str,
110    token: &Option<String>,
111    worker_id: &str,
112    processing: &Arc<Mutex<HashSet<String>>>,
113    auto_approve: &AutoApprove,
114) -> Result<()> {
115    tracing::info!("Checking for pending tasks...");
116    
117    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
118    if let Some(t) = token {
119        req = req.bearer_auth(t);
120    }
121
122    let res = req.send().await?;
123    if !res.status().is_success() {
124        return Ok(());
125    }
126
127    let data: serde_json::Value = res.json().await?;
128    let tasks = data["tasks"].as_array().cloned().unwrap_or_default();
129    
130    tracing::info!("Found {} pending task(s)", tasks.len());
131
132    for task in tasks {
133        if let Some(id) = task["id"].as_str() {
134            let mut proc = processing.lock().await;
135            if !proc.contains(id) {
136                proc.insert(id.to_string());
137                drop(proc);
138                
139                let task_id = id.to_string();
140                let client = client.clone();
141                let server = server.to_string();
142                let token = token.clone();
143                let worker_id = worker_id.to_string();
144                let auto_approve = *auto_approve;
145                let processing = processing.clone();
146                
147                tokio::spawn(async move {
148                    if let Err(e) = handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await {
149                        tracing::error!("Task {} failed: {}", task_id, e);
150                    }
151                    processing.lock().await.remove(&task_id);
152                });
153            }
154        }
155    }
156
157    Ok(())
158}
159
160#[allow(clippy::too_many_arguments)]
161async fn connect_stream(
162    client: &Client,
163    server: &str,
164    token: &Option<String>,
165    worker_id: &str,
166    name: &str,
167    codebases: &[String],
168    processing: &Arc<Mutex<HashSet<String>>>,
169    auto_approve: &AutoApprove,
170) -> Result<()> {
171    let url = format!(
172        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
173        server,
174        urlencoding::encode(name),
175        urlencoding::encode(worker_id)
176    );
177
178    let mut req = client.get(&url)
179        .header("Accept", "text/event-stream")
180        .header("X-Worker-ID", worker_id)
181        .header("X-Agent-Name", name)
182        .header("X-Codebases", codebases.join(","));
183
184    if let Some(t) = token {
185        req = req.bearer_auth(t);
186    }
187
188    let res = req.send().await?;
189    if !res.status().is_success() {
190        anyhow::bail!("Failed to connect: {}", res.status());
191    }
192
193    tracing::info!("Connected to A2A server");
194
195    let mut stream = res.bytes_stream();
196    let mut buffer = String::new();
197
198    while let Some(chunk) = stream.next().await {
199        let chunk = chunk?;
200        buffer.push_str(&String::from_utf8_lossy(&chunk));
201
202        // Process SSE events
203        while let Some(pos) = buffer.find("\n\n") {
204            let event_str = buffer[..pos].to_string();
205            buffer = buffer[pos + 2..].to_string();
206
207            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
208                let data = data_line.trim_start_matches("data:").trim();
209                if data == "[DONE]" || data.is_empty() {
210                    continue;
211                }
212
213                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
214                    if let Some(id) = task.get("task").and_then(|t| t["id"].as_str())
215                        .or_else(|| task["id"].as_str())
216                    {
217                        let mut proc = processing.lock().await;
218                        if !proc.contains(id) {
219                            proc.insert(id.to_string());
220                            drop(proc);
221
222                            let task_id = id.to_string();
223                            let client = client.clone();
224                            let server = server.to_string();
225                            let token = token.clone();
226                            let worker_id = worker_id.to_string();
227                            let auto_approve = *auto_approve;
228                            let processing_clone = processing.clone();
229
230                            tokio::spawn(async move {
231                                if let Err(e) = handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await {
232                                    tracing::error!("Task {} failed: {}", task_id, e);
233                                }
234                                processing_clone.lock().await.remove(&task_id);
235                            });
236                        }
237                    }
238                }
239            }
240        }
241    }
242
243    Ok(())
244}
245
246async fn handle_task(
247    client: &Client,
248    server: &str,
249    token: &Option<String>,
250    worker_id: &str,
251    task: &serde_json::Value,
252    _auto_approve: AutoApprove,
253) -> Result<()> {
254    let task_id = task.get("task").and_then(|t| t["id"].as_str())
255        .or_else(|| task["id"].as_str())
256        .ok_or_else(|| anyhow::anyhow!("No task ID"))?;
257    let title = task.get("task").and_then(|t| t["title"].as_str())
258        .or_else(|| task["title"].as_str())
259        .unwrap_or("Untitled");
260
261    tracing::info!("Handling task: {} ({})", title, task_id);
262
263    // Claim the task
264    let mut req = client.post(format!("{}/v1/worker/tasks/claim", server))
265        .header("X-Worker-ID", worker_id);
266    if let Some(t) = token {
267        req = req.bearer_auth(t);
268    }
269
270    let res = req
271        .json(&serde_json::json!({ "task_id": task_id }))
272        .send()
273        .await?;
274
275    if !res.status().is_success() {
276        let text = res.text().await?;
277        tracing::warn!("Failed to claim task: {}", text);
278        return Ok(());
279    }
280
281    tracing::info!("Claimed task: {}", task_id);
282
283    // Create a session and process the task
284    let session = Session::new().await?;
285    let prompt = task.get("task").and_then(|t| t["prompt"].as_str())
286        .or_else(|| task["prompt"].as_str())
287        .or_else(|| task.get("task").and_then(|t| t["description"].as_str()))
288        .or_else(|| task["description"].as_str())
289        .unwrap_or(title);
290
291    // TODO: Actually execute the agent here
292    tracing::info!("Would execute prompt: {}", prompt);
293    let result = format!("Task {} processed (session: {})", task_id, session.id);
294
295    // Release the task
296    let mut req = client.post(format!("{}/v1/worker/tasks/release", server))
297        .header("X-Worker-ID", worker_id);
298    if let Some(t) = token {
299        req = req.bearer_auth(t);
300    }
301
302    req.json(&serde_json::json!({
303        "task_id": task_id,
304        "status": "completed",
305        "result": result,
306    }))
307    .send()
308    .await?;
309
310    tracing::info!("Task completed: {}", task_id);
311    Ok(())
312}
313
314// Add rand for worker ID generation
315use rand;