1use crate::cli::A2aArgs;
4use crate::session::Session;
5use anyhow::Result;
6use futures::StreamExt;
7use reqwest::Client;
8use std::collections::HashSet;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12pub async fn run(args: A2aArgs) -> Result<()> {
14 let server = args.server.trim_end_matches('/');
15 let name = args.name.unwrap_or_else(|| format!("codetether-{}", std::process::id()));
16 let worker_id = generate_worker_id();
17
18 let codebases: Vec<String> = args
19 .codebases
20 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
21 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
22
23 tracing::info!(
24 "Starting A2A worker: {} ({})",
25 name,
26 worker_id
27 );
28 tracing::info!("Server: {}", server);
29 tracing::info!("Codebases: {:?}", codebases);
30
31 let client = Client::new();
32 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
33
34 let auto_approve = match args.auto_approve.as_str() {
35 "all" => AutoApprove::All,
36 "safe" => AutoApprove::Safe,
37 _ => AutoApprove::None,
38 };
39
40 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
42
43 fetch_pending_tasks(&client, server, &args.token, &worker_id, &processing, &auto_approve).await?;
45
46 loop {
48 match connect_stream(&client, server, &args.token, &worker_id, &name, &codebases, &processing, &auto_approve).await {
49 Ok(()) => {
50 tracing::warn!("Stream ended, reconnecting...");
51 }
52 Err(e) => {
53 tracing::error!("Stream error: {}, reconnecting...", e);
54 }
55 }
56 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
57 }
58}
59
60fn generate_worker_id() -> String {
61 format!(
62 "wrk_{}_{:x}",
63 chrono::Utc::now().timestamp(),
64 rand::random::<u64>()
65 )
66}
67
68#[derive(Debug, Clone, Copy)]
69enum AutoApprove {
70 All,
71 Safe,
72 None,
73}
74
75async fn register_worker(
76 client: &Client,
77 server: &str,
78 token: &Option<String>,
79 worker_id: &str,
80 name: &str,
81 codebases: &[String],
82) -> Result<()> {
83 let mut req = client.put(format!("{}/v1/worker/codebases", server));
84
85 if let Some(t) = token {
86 req = req.bearer_auth(t);
87 }
88
89 let res = req
90 .json(&serde_json::json!({
91 "codebases": codebases,
92 "worker_id": worker_id,
93 "agent_name": name,
94 }))
95 .send()
96 .await?;
97
98 if res.status().is_success() {
99 tracing::info!("Worker registered successfully");
100 } else {
101 tracing::warn!("Failed to register worker: {}", res.status());
102 }
103
104 Ok(())
105}
106
107async fn fetch_pending_tasks(
108 client: &Client,
109 server: &str,
110 token: &Option<String>,
111 worker_id: &str,
112 processing: &Arc<Mutex<HashSet<String>>>,
113 auto_approve: &AutoApprove,
114) -> Result<()> {
115 tracing::info!("Checking for pending tasks...");
116
117 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
118 if let Some(t) = token {
119 req = req.bearer_auth(t);
120 }
121
122 let res = req.send().await?;
123 if !res.status().is_success() {
124 return Ok(());
125 }
126
127 let data: serde_json::Value = res.json().await?;
128 let tasks = data["tasks"].as_array().cloned().unwrap_or_default();
129
130 tracing::info!("Found {} pending task(s)", tasks.len());
131
132 for task in tasks {
133 if let Some(id) = task["id"].as_str() {
134 let mut proc = processing.lock().await;
135 if !proc.contains(id) {
136 proc.insert(id.to_string());
137 drop(proc);
138
139 let task_id = id.to_string();
140 let client = client.clone();
141 let server = server.to_string();
142 let token = token.clone();
143 let worker_id = worker_id.to_string();
144 let auto_approve = *auto_approve;
145 let processing = processing.clone();
146
147 tokio::spawn(async move {
148 if let Err(e) = handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await {
149 tracing::error!("Task {} failed: {}", task_id, e);
150 }
151 processing.lock().await.remove(&task_id);
152 });
153 }
154 }
155 }
156
157 Ok(())
158}
159
160#[allow(clippy::too_many_arguments)]
161async fn connect_stream(
162 client: &Client,
163 server: &str,
164 token: &Option<String>,
165 worker_id: &str,
166 name: &str,
167 codebases: &[String],
168 processing: &Arc<Mutex<HashSet<String>>>,
169 auto_approve: &AutoApprove,
170) -> Result<()> {
171 let url = format!(
172 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
173 server,
174 urlencoding::encode(name),
175 urlencoding::encode(worker_id)
176 );
177
178 let mut req = client.get(&url)
179 .header("Accept", "text/event-stream")
180 .header("X-Worker-ID", worker_id)
181 .header("X-Agent-Name", name)
182 .header("X-Codebases", codebases.join(","));
183
184 if let Some(t) = token {
185 req = req.bearer_auth(t);
186 }
187
188 let res = req.send().await?;
189 if !res.status().is_success() {
190 anyhow::bail!("Failed to connect: {}", res.status());
191 }
192
193 tracing::info!("Connected to A2A server");
194
195 let mut stream = res.bytes_stream();
196 let mut buffer = String::new();
197
198 while let Some(chunk) = stream.next().await {
199 let chunk = chunk?;
200 buffer.push_str(&String::from_utf8_lossy(&chunk));
201
202 while let Some(pos) = buffer.find("\n\n") {
204 let event_str = buffer[..pos].to_string();
205 buffer = buffer[pos + 2..].to_string();
206
207 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
208 let data = data_line.trim_start_matches("data:").trim();
209 if data == "[DONE]" || data.is_empty() {
210 continue;
211 }
212
213 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
214 if let Some(id) = task.get("task").and_then(|t| t["id"].as_str())
215 .or_else(|| task["id"].as_str())
216 {
217 let mut proc = processing.lock().await;
218 if !proc.contains(id) {
219 proc.insert(id.to_string());
220 drop(proc);
221
222 let task_id = id.to_string();
223 let client = client.clone();
224 let server = server.to_string();
225 let token = token.clone();
226 let worker_id = worker_id.to_string();
227 let auto_approve = *auto_approve;
228 let processing_clone = processing.clone();
229
230 tokio::spawn(async move {
231 if let Err(e) = handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await {
232 tracing::error!("Task {} failed: {}", task_id, e);
233 }
234 processing_clone.lock().await.remove(&task_id);
235 });
236 }
237 }
238 }
239 }
240 }
241 }
242
243 Ok(())
244}
245
246async fn handle_task(
247 client: &Client,
248 server: &str,
249 token: &Option<String>,
250 worker_id: &str,
251 task: &serde_json::Value,
252 _auto_approve: AutoApprove,
253) -> Result<()> {
254 let task_id = task.get("task").and_then(|t| t["id"].as_str())
255 .or_else(|| task["id"].as_str())
256 .ok_or_else(|| anyhow::anyhow!("No task ID"))?;
257 let title = task.get("task").and_then(|t| t["title"].as_str())
258 .or_else(|| task["title"].as_str())
259 .unwrap_or("Untitled");
260
261 tracing::info!("Handling task: {} ({})", title, task_id);
262
263 let mut req = client.post(format!("{}/v1/worker/tasks/claim", server))
265 .header("X-Worker-ID", worker_id);
266 if let Some(t) = token {
267 req = req.bearer_auth(t);
268 }
269
270 let res = req
271 .json(&serde_json::json!({ "task_id": task_id }))
272 .send()
273 .await?;
274
275 if !res.status().is_success() {
276 let text = res.text().await?;
277 tracing::warn!("Failed to claim task: {}", text);
278 return Ok(());
279 }
280
281 tracing::info!("Claimed task: {}", task_id);
282
283 let session = Session::new().await?;
285 let prompt = task.get("task").and_then(|t| t["prompt"].as_str())
286 .or_else(|| task["prompt"].as_str())
287 .or_else(|| task.get("task").and_then(|t| t["description"].as_str()))
288 .or_else(|| task["description"].as_str())
289 .unwrap_or(title);
290
291 tracing::info!("Would execute prompt: {}", prompt);
293 let result = format!("Task {} processed (session: {})", task_id, session.id);
294
295 let mut req = client.post(format!("{}/v1/worker/tasks/release", server))
297 .header("X-Worker-ID", worker_id);
298 if let Some(t) = token {
299 req = req.bearer_auth(t);
300 }
301
302 req.json(&serde_json::json!({
303 "task_id": task_id,
304 "status": "completed",
305 "result": result,
306 }))
307 .send()
308 .await?;
309
310 tracing::info!("Task completed: {}", task_id);
311 Ok(())
312}
313
314use rand;