1use crate::cli::A2aArgs;
4use crate::session::Session;
5use anyhow::Result;
6use futures::StreamExt;
7use reqwest::Client;
8use std::collections::HashSet;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12pub async fn run(args: A2aArgs) -> Result<()> {
14 let server = args.server.trim_end_matches('/');
15 let name = args
16 .name
17 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
18 let worker_id = generate_worker_id();
19
20 let codebases: Vec<String> = args
21 .codebases
22 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
23 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
24
25 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
26 tracing::info!("Server: {}", server);
27 tracing::info!("Codebases: {:?}", codebases);
28
29 let client = Client::new();
30 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
31
32 let auto_approve = match args.auto_approve.as_str() {
33 "all" => AutoApprove::All,
34 "safe" => AutoApprove::Safe,
35 _ => AutoApprove::None,
36 };
37
38 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
40
41 fetch_pending_tasks(
43 &client,
44 server,
45 &args.token,
46 &worker_id,
47 &processing,
48 &auto_approve,
49 )
50 .await?;
51
52 loop {
54 match connect_stream(
55 &client,
56 server,
57 &args.token,
58 &worker_id,
59 &name,
60 &codebases,
61 &processing,
62 &auto_approve,
63 )
64 .await
65 {
66 Ok(()) => {
67 tracing::warn!("Stream ended, reconnecting...");
68 }
69 Err(e) => {
70 tracing::error!("Stream error: {}, reconnecting...", e);
71 }
72 }
73 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
74 }
75}
76
77fn generate_worker_id() -> String {
78 format!(
79 "wrk_{}_{:x}",
80 chrono::Utc::now().timestamp(),
81 rand::random::<u64>()
82 )
83}
84
85#[derive(Debug, Clone, Copy)]
86enum AutoApprove {
87 All,
88 Safe,
89 None,
90}
91
92async fn register_worker(
93 client: &Client,
94 server: &str,
95 token: &Option<String>,
96 worker_id: &str,
97 name: &str,
98 codebases: &[String],
99) -> Result<()> {
100 let mut req = client.put(format!("{}/v1/worker/codebases", server));
101
102 if let Some(t) = token {
103 req = req.bearer_auth(t);
104 }
105
106 let res = req
107 .json(&serde_json::json!({
108 "codebases": codebases,
109 "worker_id": worker_id,
110 "agent_name": name,
111 }))
112 .send()
113 .await?;
114
115 if res.status().is_success() {
116 tracing::info!("Worker registered successfully");
117 } else {
118 tracing::warn!("Failed to register worker: {}", res.status());
119 }
120
121 Ok(())
122}
123
124async fn fetch_pending_tasks(
125 client: &Client,
126 server: &str,
127 token: &Option<String>,
128 worker_id: &str,
129 processing: &Arc<Mutex<HashSet<String>>>,
130 auto_approve: &AutoApprove,
131) -> Result<()> {
132 tracing::info!("Checking for pending tasks...");
133
134 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
135 if let Some(t) = token {
136 req = req.bearer_auth(t);
137 }
138
139 let res = req.send().await?;
140 if !res.status().is_success() {
141 return Ok(());
142 }
143
144 let data: serde_json::Value = res.json().await?;
145 let tasks = data["tasks"].as_array().cloned().unwrap_or_default();
146
147 tracing::info!("Found {} pending task(s)", tasks.len());
148
149 for task in tasks {
150 if let Some(id) = task["id"].as_str() {
151 let mut proc = processing.lock().await;
152 if !proc.contains(id) {
153 proc.insert(id.to_string());
154 drop(proc);
155
156 let task_id = id.to_string();
157 let client = client.clone();
158 let server = server.to_string();
159 let token = token.clone();
160 let worker_id = worker_id.to_string();
161 let auto_approve = *auto_approve;
162 let processing = processing.clone();
163
164 tokio::spawn(async move {
165 if let Err(e) =
166 handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
167 {
168 tracing::error!("Task {} failed: {}", task_id, e);
169 }
170 processing.lock().await.remove(&task_id);
171 });
172 }
173 }
174 }
175
176 Ok(())
177}
178
179#[allow(clippy::too_many_arguments)]
180async fn connect_stream(
181 client: &Client,
182 server: &str,
183 token: &Option<String>,
184 worker_id: &str,
185 name: &str,
186 codebases: &[String],
187 processing: &Arc<Mutex<HashSet<String>>>,
188 auto_approve: &AutoApprove,
189) -> Result<()> {
190 let url = format!(
191 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
192 server,
193 urlencoding::encode(name),
194 urlencoding::encode(worker_id)
195 );
196
197 let mut req = client
198 .get(&url)
199 .header("Accept", "text/event-stream")
200 .header("X-Worker-ID", worker_id)
201 .header("X-Agent-Name", name)
202 .header("X-Codebases", codebases.join(","));
203
204 if let Some(t) = token {
205 req = req.bearer_auth(t);
206 }
207
208 let res = req.send().await?;
209 if !res.status().is_success() {
210 anyhow::bail!("Failed to connect: {}", res.status());
211 }
212
213 tracing::info!("Connected to A2A server");
214
215 let mut stream = res.bytes_stream();
216 let mut buffer = String::new();
217
218 while let Some(chunk) = stream.next().await {
219 let chunk = chunk?;
220 buffer.push_str(&String::from_utf8_lossy(&chunk));
221
222 while let Some(pos) = buffer.find("\n\n") {
224 let event_str = buffer[..pos].to_string();
225 buffer = buffer[pos + 2..].to_string();
226
227 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
228 let data = data_line.trim_start_matches("data:").trim();
229 if data == "[DONE]" || data.is_empty() {
230 continue;
231 }
232
233 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
234 if let Some(id) = task
235 .get("task")
236 .and_then(|t| t["id"].as_str())
237 .or_else(|| task["id"].as_str())
238 {
239 let mut proc = processing.lock().await;
240 if !proc.contains(id) {
241 proc.insert(id.to_string());
242 drop(proc);
243
244 let task_id = id.to_string();
245 let client = client.clone();
246 let server = server.to_string();
247 let token = token.clone();
248 let worker_id = worker_id.to_string();
249 let auto_approve = *auto_approve;
250 let processing_clone = processing.clone();
251
252 tokio::spawn(async move {
253 if let Err(e) = handle_task(
254 &client,
255 &server,
256 &token,
257 &worker_id,
258 &task,
259 auto_approve,
260 )
261 .await
262 {
263 tracing::error!("Task {} failed: {}", task_id, e);
264 }
265 processing_clone.lock().await.remove(&task_id);
266 });
267 }
268 }
269 }
270 }
271 }
272 }
273
274 Ok(())
275}
276
277async fn handle_task(
278 client: &Client,
279 server: &str,
280 token: &Option<String>,
281 worker_id: &str,
282 task: &serde_json::Value,
283 _auto_approve: AutoApprove,
284) -> Result<()> {
285 let task_id = task
286 .get("task")
287 .and_then(|t| t["id"].as_str())
288 .or_else(|| task["id"].as_str())
289 .ok_or_else(|| anyhow::anyhow!("No task ID"))?;
290 let title = task
291 .get("task")
292 .and_then(|t| t["title"].as_str())
293 .or_else(|| task["title"].as_str())
294 .unwrap_or("Untitled");
295
296 tracing::info!("Handling task: {} ({})", title, task_id);
297
298 let mut req = client
300 .post(format!("{}/v1/worker/tasks/claim", server))
301 .header("X-Worker-ID", worker_id);
302 if let Some(t) = token {
303 req = req.bearer_auth(t);
304 }
305
306 let res = req
307 .json(&serde_json::json!({ "task_id": task_id }))
308 .send()
309 .await?;
310
311 if !res.status().is_success() {
312 let text = res.text().await?;
313 tracing::warn!("Failed to claim task: {}", text);
314 return Ok(());
315 }
316
317 tracing::info!("Claimed task: {}", task_id);
318
319 let session = Session::new().await?;
321 let prompt = task
322 .get("task")
323 .and_then(|t| t["prompt"].as_str())
324 .or_else(|| task["prompt"].as_str())
325 .or_else(|| task.get("task").and_then(|t| t["description"].as_str()))
326 .or_else(|| task["description"].as_str())
327 .unwrap_or(title);
328
329 tracing::info!("Would execute prompt: {}", prompt);
331 let result = format!("Task {} processed (session: {})", task_id, session.id);
332
333 let mut req = client
335 .post(format!("{}/v1/worker/tasks/release", server))
336 .header("X-Worker-ID", worker_id);
337 if let Some(t) = token {
338 req = req.bearer_auth(t);
339 }
340
341 req.json(&serde_json::json!({
342 "task_id": task_id,
343 "status": "completed",
344 "result": result,
345 }))
346 .send()
347 .await?;
348
349 tracing::info!("Task completed: {}", task_id);
350 Ok(())
351}
352
353use rand;