Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::cli::A2aArgs;
4use crate::session::Session;
5use anyhow::Result;
6use futures::StreamExt;
7use reqwest::Client;
8use std::collections::HashSet;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12/// Run the A2A worker
13pub async fn run(args: A2aArgs) -> Result<()> {
14    let server = args.server.trim_end_matches('/');
15    let name = args
16        .name
17        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
18    let worker_id = generate_worker_id();
19
20    let codebases: Vec<String> = args
21        .codebases
22        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
23        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
24
25    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
26    tracing::info!("Server: {}", server);
27    tracing::info!("Codebases: {:?}", codebases);
28
29    let client = Client::new();
30    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
31
32    let auto_approve = match args.auto_approve.as_str() {
33        "all" => AutoApprove::All,
34        "safe" => AutoApprove::Safe,
35        _ => AutoApprove::None,
36    };
37
38    // Register worker
39    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
40
41    // Fetch pending tasks
42    fetch_pending_tasks(
43        &client,
44        server,
45        &args.token,
46        &worker_id,
47        &processing,
48        &auto_approve,
49    )
50    .await?;
51
52    // Connect to SSE stream
53    loop {
54        match connect_stream(
55            &client,
56            server,
57            &args.token,
58            &worker_id,
59            &name,
60            &codebases,
61            &processing,
62            &auto_approve,
63        )
64        .await
65        {
66            Ok(()) => {
67                tracing::warn!("Stream ended, reconnecting...");
68            }
69            Err(e) => {
70                tracing::error!("Stream error: {}, reconnecting...", e);
71            }
72        }
73        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
74    }
75}
76
77fn generate_worker_id() -> String {
78    format!(
79        "wrk_{}_{:x}",
80        chrono::Utc::now().timestamp(),
81        rand::random::<u64>()
82    )
83}
84
85#[derive(Debug, Clone, Copy)]
86enum AutoApprove {
87    All,
88    Safe,
89    None,
90}
91
92async fn register_worker(
93    client: &Client,
94    server: &str,
95    token: &Option<String>,
96    worker_id: &str,
97    name: &str,
98    codebases: &[String],
99) -> Result<()> {
100    let mut req = client.put(format!("{}/v1/worker/codebases", server));
101
102    if let Some(t) = token {
103        req = req.bearer_auth(t);
104    }
105
106    let res = req
107        .json(&serde_json::json!({
108            "codebases": codebases,
109            "worker_id": worker_id,
110            "agent_name": name,
111        }))
112        .send()
113        .await?;
114
115    if res.status().is_success() {
116        tracing::info!("Worker registered successfully");
117    } else {
118        tracing::warn!("Failed to register worker: {}", res.status());
119    }
120
121    Ok(())
122}
123
124async fn fetch_pending_tasks(
125    client: &Client,
126    server: &str,
127    token: &Option<String>,
128    worker_id: &str,
129    processing: &Arc<Mutex<HashSet<String>>>,
130    auto_approve: &AutoApprove,
131) -> Result<()> {
132    tracing::info!("Checking for pending tasks...");
133
134    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
135    if let Some(t) = token {
136        req = req.bearer_auth(t);
137    }
138
139    let res = req.send().await?;
140    if !res.status().is_success() {
141        return Ok(());
142    }
143
144    let data: serde_json::Value = res.json().await?;
145    let tasks = data["tasks"].as_array().cloned().unwrap_or_default();
146
147    tracing::info!("Found {} pending task(s)", tasks.len());
148
149    for task in tasks {
150        if let Some(id) = task["id"].as_str() {
151            let mut proc = processing.lock().await;
152            if !proc.contains(id) {
153                proc.insert(id.to_string());
154                drop(proc);
155
156                let task_id = id.to_string();
157                let client = client.clone();
158                let server = server.to_string();
159                let token = token.clone();
160                let worker_id = worker_id.to_string();
161                let auto_approve = *auto_approve;
162                let processing = processing.clone();
163
164                tokio::spawn(async move {
165                    if let Err(e) =
166                        handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
167                    {
168                        tracing::error!("Task {} failed: {}", task_id, e);
169                    }
170                    processing.lock().await.remove(&task_id);
171                });
172            }
173        }
174    }
175
176    Ok(())
177}
178
179#[allow(clippy::too_many_arguments)]
180async fn connect_stream(
181    client: &Client,
182    server: &str,
183    token: &Option<String>,
184    worker_id: &str,
185    name: &str,
186    codebases: &[String],
187    processing: &Arc<Mutex<HashSet<String>>>,
188    auto_approve: &AutoApprove,
189) -> Result<()> {
190    let url = format!(
191        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
192        server,
193        urlencoding::encode(name),
194        urlencoding::encode(worker_id)
195    );
196
197    let mut req = client
198        .get(&url)
199        .header("Accept", "text/event-stream")
200        .header("X-Worker-ID", worker_id)
201        .header("X-Agent-Name", name)
202        .header("X-Codebases", codebases.join(","));
203
204    if let Some(t) = token {
205        req = req.bearer_auth(t);
206    }
207
208    let res = req.send().await?;
209    if !res.status().is_success() {
210        anyhow::bail!("Failed to connect: {}", res.status());
211    }
212
213    tracing::info!("Connected to A2A server");
214
215    let mut stream = res.bytes_stream();
216    let mut buffer = String::new();
217
218    while let Some(chunk) = stream.next().await {
219        let chunk = chunk?;
220        buffer.push_str(&String::from_utf8_lossy(&chunk));
221
222        // Process SSE events
223        while let Some(pos) = buffer.find("\n\n") {
224            let event_str = buffer[..pos].to_string();
225            buffer = buffer[pos + 2..].to_string();
226
227            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
228                let data = data_line.trim_start_matches("data:").trim();
229                if data == "[DONE]" || data.is_empty() {
230                    continue;
231                }
232
233                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
234                    if let Some(id) = task
235                        .get("task")
236                        .and_then(|t| t["id"].as_str())
237                        .or_else(|| task["id"].as_str())
238                    {
239                        let mut proc = processing.lock().await;
240                        if !proc.contains(id) {
241                            proc.insert(id.to_string());
242                            drop(proc);
243
244                            let task_id = id.to_string();
245                            let client = client.clone();
246                            let server = server.to_string();
247                            let token = token.clone();
248                            let worker_id = worker_id.to_string();
249                            let auto_approve = *auto_approve;
250                            let processing_clone = processing.clone();
251
252                            tokio::spawn(async move {
253                                if let Err(e) = handle_task(
254                                    &client,
255                                    &server,
256                                    &token,
257                                    &worker_id,
258                                    &task,
259                                    auto_approve,
260                                )
261                                .await
262                                {
263                                    tracing::error!("Task {} failed: {}", task_id, e);
264                                }
265                                processing_clone.lock().await.remove(&task_id);
266                            });
267                        }
268                    }
269                }
270            }
271        }
272    }
273
274    Ok(())
275}
276
277async fn handle_task(
278    client: &Client,
279    server: &str,
280    token: &Option<String>,
281    worker_id: &str,
282    task: &serde_json::Value,
283    _auto_approve: AutoApprove,
284) -> Result<()> {
285    let task_id = task
286        .get("task")
287        .and_then(|t| t["id"].as_str())
288        .or_else(|| task["id"].as_str())
289        .ok_or_else(|| anyhow::anyhow!("No task ID"))?;
290    let title = task
291        .get("task")
292        .and_then(|t| t["title"].as_str())
293        .or_else(|| task["title"].as_str())
294        .unwrap_or("Untitled");
295
296    tracing::info!("Handling task: {} ({})", title, task_id);
297
298    // Claim the task
299    let mut req = client
300        .post(format!("{}/v1/worker/tasks/claim", server))
301        .header("X-Worker-ID", worker_id);
302    if let Some(t) = token {
303        req = req.bearer_auth(t);
304    }
305
306    let res = req
307        .json(&serde_json::json!({ "task_id": task_id }))
308        .send()
309        .await?;
310
311    if !res.status().is_success() {
312        let text = res.text().await?;
313        tracing::warn!("Failed to claim task: {}", text);
314        return Ok(());
315    }
316
317    tracing::info!("Claimed task: {}", task_id);
318
319    // Create a session and process the task
320    let session = Session::new().await?;
321    let prompt = task
322        .get("task")
323        .and_then(|t| t["prompt"].as_str())
324        .or_else(|| task["prompt"].as_str())
325        .or_else(|| task.get("task").and_then(|t| t["description"].as_str()))
326        .or_else(|| task["description"].as_str())
327        .unwrap_or(title);
328
329    // TODO: Actually execute the agent here
330    tracing::info!("Would execute prompt: {}", prompt);
331    let result = format!("Task {} processed (session: {})", task_id, session.id);
332
333    // Release the task
334    let mut req = client
335        .post(format!("{}/v1/worker/tasks/release", server))
336        .header("X-Worker-ID", worker_id);
337    if let Some(t) = token {
338        req = req.bearer_auth(t);
339    }
340
341    req.json(&serde_json::json!({
342        "task_id": task_id,
343        "status": "completed",
344        "result": result,
345    }))
346    .send()
347    .await?;
348
349    tracing::info!("Task completed: {}", task_id);
350    Ok(())
351}
352
353// Add rand for worker ID generation
354use rand;