Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use std::collections::HashMap;
12use std::collections::HashSet;
13use std::sync::Arc;
14use tokio::sync::{Mutex, mpsc};
15use tokio::task::JoinHandle;
16
17/// Worker status for heartbeat
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19enum WorkerStatus {
20    Idle,
21    Processing,
22}
23
24impl WorkerStatus {
25    fn as_str(&self) -> &'static str {
26        match self {
27            WorkerStatus::Idle => "idle",
28            WorkerStatus::Processing => "processing",
29        }
30    }
31}
32
33/// Heartbeat state shared between the heartbeat task and the main worker
34#[derive(Clone)]
35struct HeartbeatState {
36    worker_id: String,
37    agent_name: String,
38    status: Arc<Mutex<WorkerStatus>>,
39    active_task_count: Arc<Mutex<usize>>,
40}
41
42impl HeartbeatState {
43    fn new(worker_id: String, agent_name: String) -> Self {
44        Self {
45            worker_id,
46            agent_name,
47            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
48            active_task_count: Arc::new(Mutex::new(0)),
49        }
50    }
51
52    async fn set_status(&self, status: WorkerStatus) {
53        *self.status.lock().await = status;
54    }
55
56    async fn set_task_count(&self, count: usize) {
57        *self.active_task_count.lock().await = count;
58    }
59}
60
61/// Run the A2A worker
62pub async fn run(args: A2aArgs) -> Result<()> {
63    let server = args.server.trim_end_matches('/');
64    let name = args
65        .name
66        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
67    let worker_id = generate_worker_id();
68
69    let codebases: Vec<String> = args
70        .codebases
71        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
72        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
73
74    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
75    tracing::info!("Server: {}", server);
76    tracing::info!("Codebases: {:?}", codebases);
77
78    let client = Client::new();
79    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
80
81    let auto_approve = match args.auto_approve.as_str() {
82        "all" => AutoApprove::All,
83        "safe" => AutoApprove::Safe,
84        _ => AutoApprove::None,
85    };
86
87    // Create heartbeat state
88    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
89
90    // Register worker
91    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
92
93    // Fetch pending tasks
94    fetch_pending_tasks(
95        &client,
96        server,
97        &args.token,
98        &worker_id,
99        &processing,
100        &auto_approve,
101    )
102    .await?;
103
104    // Connect to SSE stream
105    loop {
106        // Re-register worker on each reconnection to report updated models/capabilities
107        if let Err(e) =
108            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
109        {
110            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
111        }
112
113        // Start heartbeat task for this connection
114        let heartbeat_handle = start_heartbeat(
115            client.clone(),
116            server.to_string(),
117            args.token.clone(),
118            heartbeat_state.clone(),
119            processing.clone(),
120        );
121
122        match connect_stream(
123            &client,
124            server,
125            &args.token,
126            &worker_id,
127            &name,
128            &codebases,
129            &processing,
130            &auto_approve,
131        )
132        .await
133        {
134            Ok(()) => {
135                tracing::warn!("Stream ended, reconnecting...");
136            }
137            Err(e) => {
138                tracing::error!("Stream error: {}, reconnecting...", e);
139            }
140        }
141
142        // Cancel heartbeat on disconnection
143        heartbeat_handle.abort();
144        tracing::debug!("Heartbeat cancelled for reconnection");
145
146        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
147    }
148}
149
150fn generate_worker_id() -> String {
151    format!(
152        "wrk_{}_{:x}",
153        chrono::Utc::now().timestamp(),
154        rand::random::<u64>()
155    )
156}
157
158#[derive(Debug, Clone, Copy)]
159enum AutoApprove {
160    All,
161    Safe,
162    None,
163}
164
165/// Capabilities of the codetether-agent worker
166const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
167
168fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
169    task.get("task")
170        .and_then(|t| t.get(key))
171        .or_else(|| task.get(key))
172}
173
174fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
175    task_value(task, key).and_then(|v| v.as_str())
176}
177
178fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
179    task_value(task, "metadata")
180        .and_then(|m| m.as_object())
181        .cloned()
182        .unwrap_or_default()
183}
184
185fn model_ref_to_provider_model(model: &str) -> String {
186    // Convert "provider:model" to "provider/model" format, but only if
187    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
188    // contain colons as version separators and must NOT be converted.
189    if !model.contains('/') && model.contains(':') {
190        model.replacen(':', "/", 1)
191    } else {
192        model.to_string()
193    }
194}
195
196fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
197    match model_tier.unwrap_or("balanced") {
198        "fast" | "quick" => &[
199            "google",
200            "openai",
201            "moonshotai",
202            "zhipuai",
203            "anthropic",
204            "openrouter",
205            "novita",
206        ],
207        "heavy" | "deep" => &[
208            "anthropic",
209            "openai",
210            "google",
211            "moonshotai",
212            "zhipuai",
213            "openrouter",
214            "novita",
215        ],
216        _ => &[
217            "openai",
218            "anthropic",
219            "google",
220            "moonshotai",
221            "zhipuai",
222            "openrouter",
223            "novita",
224        ],
225    }
226}
227
228fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
229    for preferred in provider_preferences_for_tier(model_tier) {
230        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
231            return found;
232        }
233    }
234    if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
235        return found;
236    }
237    providers[0]
238}
239
240fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
241    match model_tier.unwrap_or("balanced") {
242        "fast" | "quick" => match provider {
243            "moonshotai" => "kimi-k2.5".to_string(),
244            "anthropic" => "claude-haiku-4-5".to_string(),
245            "openai" => "gpt-4o-mini".to_string(),
246            "google" => "gemini-2.5-flash".to_string(),
247            "zhipuai" => "glm-4.7".to_string(),
248            "openrouter" => "z-ai/glm-4.7".to_string(),
249            "novita" => "qwen/qwen3-coder-next".to_string(),
250            _ => "glm-4.7".to_string(),
251        },
252        "heavy" | "deep" => match provider {
253            "moonshotai" => "kimi-k2.5".to_string(),
254            "anthropic" => "claude-sonnet-4-20250514".to_string(),
255            "openai" => "o3".to_string(),
256            "google" => "gemini-2.5-pro".to_string(),
257            "zhipuai" => "glm-4.7".to_string(),
258            "openrouter" => "z-ai/glm-4.7".to_string(),
259            "novita" => "qwen/qwen3-coder-next".to_string(),
260            _ => "glm-4.7".to_string(),
261        },
262        _ => match provider {
263            "moonshotai" => "kimi-k2.5".to_string(),
264            "anthropic" => "claude-sonnet-4-20250514".to_string(),
265            "openai" => "gpt-4o".to_string(),
266            "google" => "gemini-2.5-pro".to_string(),
267            "zhipuai" => "glm-4.7".to_string(),
268            "openrouter" => "z-ai/glm-4.7".to_string(),
269            "novita" => "qwen/qwen3-coder-next".to_string(),
270            _ => "glm-4.7".to_string(),
271        },
272    }
273}
274
275fn is_swarm_agent(agent_type: &str) -> bool {
276    matches!(
277        agent_type.trim().to_ascii_lowercase().as_str(),
278        "swarm" | "parallel" | "multi-agent"
279    )
280}
281
282fn metadata_lookup<'a>(
283    metadata: &'a serde_json::Map<String, serde_json::Value>,
284    key: &str,
285) -> Option<&'a serde_json::Value> {
286    metadata
287        .get(key)
288        .or_else(|| {
289            metadata
290                .get("routing")
291                .and_then(|v| v.as_object())
292                .and_then(|obj| obj.get(key))
293        })
294        .or_else(|| {
295            metadata
296                .get("swarm")
297                .and_then(|v| v.as_object())
298                .and_then(|obj| obj.get(key))
299        })
300}
301
302fn metadata_str(
303    metadata: &serde_json::Map<String, serde_json::Value>,
304    keys: &[&str],
305) -> Option<String> {
306    for key in keys {
307        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
308            let trimmed = value.trim();
309            if !trimmed.is_empty() {
310                return Some(trimmed.to_string());
311            }
312        }
313    }
314    None
315}
316
317fn metadata_usize(
318    metadata: &serde_json::Map<String, serde_json::Value>,
319    keys: &[&str],
320) -> Option<usize> {
321    for key in keys {
322        if let Some(value) = metadata_lookup(metadata, key) {
323            if let Some(v) = value.as_u64() {
324                return usize::try_from(v).ok();
325            }
326            if let Some(v) = value.as_i64() {
327                if v >= 0 {
328                    return usize::try_from(v as u64).ok();
329                }
330            }
331            if let Some(v) = value.as_str() {
332                if let Ok(parsed) = v.trim().parse::<usize>() {
333                    return Some(parsed);
334                }
335            }
336        }
337    }
338    None
339}
340
341fn metadata_u64(
342    metadata: &serde_json::Map<String, serde_json::Value>,
343    keys: &[&str],
344) -> Option<u64> {
345    for key in keys {
346        if let Some(value) = metadata_lookup(metadata, key) {
347            if let Some(v) = value.as_u64() {
348                return Some(v);
349            }
350            if let Some(v) = value.as_i64() {
351                if v >= 0 {
352                    return Some(v as u64);
353                }
354            }
355            if let Some(v) = value.as_str() {
356                if let Ok(parsed) = v.trim().parse::<u64>() {
357                    return Some(parsed);
358                }
359            }
360        }
361    }
362    None
363}
364
365fn metadata_bool(
366    metadata: &serde_json::Map<String, serde_json::Value>,
367    keys: &[&str],
368) -> Option<bool> {
369    for key in keys {
370        if let Some(value) = metadata_lookup(metadata, key) {
371            if let Some(v) = value.as_bool() {
372                return Some(v);
373            }
374            if let Some(v) = value.as_str() {
375                match v.trim().to_ascii_lowercase().as_str() {
376                    "1" | "true" | "yes" | "on" => return Some(true),
377                    "0" | "false" | "no" | "off" => return Some(false),
378                    _ => {}
379                }
380            }
381        }
382    }
383    None
384}
385
386fn parse_swarm_strategy(
387    metadata: &serde_json::Map<String, serde_json::Value>,
388) -> DecompositionStrategy {
389    match metadata_str(
390        metadata,
391        &[
392            "decomposition_strategy",
393            "swarm_strategy",
394            "strategy",
395            "swarm_decomposition",
396        ],
397    )
398    .as_deref()
399    .map(|s| s.to_ascii_lowercase())
400    .as_deref()
401    {
402        Some("none") | Some("single") => DecompositionStrategy::None,
403        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
404        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
405        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
406        _ => DecompositionStrategy::Automatic,
407    }
408}
409
410async fn resolve_swarm_model(
411    explicit_model: Option<String>,
412    model_tier: Option<&str>,
413) -> Option<String> {
414    if let Some(model) = explicit_model {
415        if !model.trim().is_empty() {
416            return Some(model);
417        }
418    }
419
420    let registry = ProviderRegistry::from_vault().await.ok()?;
421    let providers = registry.list();
422    if providers.is_empty() {
423        return None;
424    }
425    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
426    let model = default_model_for_provider(provider, model_tier);
427    Some(format!("{}/{}", provider, model))
428}
429
430fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
431    match event {
432        SwarmEvent::Started {
433            task,
434            total_subtasks,
435        } => Some(format!(
436            "[swarm] started task={} planned_subtasks={}",
437            task, total_subtasks
438        )),
439        SwarmEvent::StageComplete {
440            stage,
441            completed,
442            failed,
443        } => Some(format!(
444            "[swarm] stage={} completed={} failed={}",
445            stage, completed, failed
446        )),
447        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
448            "[swarm] subtask id={} status={}",
449            &id.chars().take(8).collect::<String>(),
450            format!("{status:?}").to_ascii_lowercase()
451        )),
452        SwarmEvent::AgentToolCall {
453            subtask_id,
454            tool_name,
455        } => Some(format!(
456            "[swarm] subtask id={} tool={}",
457            &subtask_id.chars().take(8).collect::<String>(),
458            tool_name
459        )),
460        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
461            "[swarm] subtask id={} error={}",
462            &subtask_id.chars().take(8).collect::<String>(),
463            error
464        )),
465        SwarmEvent::Complete { success, stats } => Some(format!(
466            "[swarm] complete success={} subtasks={} speedup={:.2}",
467            success,
468            stats.subagents_completed + stats.subagents_failed,
469            stats.speedup_factor
470        )),
471        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
472        _ => None,
473    }
474}
475
476async fn register_worker(
477    client: &Client,
478    server: &str,
479    token: &Option<String>,
480    worker_id: &str,
481    name: &str,
482    codebases: &[String],
483) -> Result<()> {
484    // Load ProviderRegistry and collect available models
485    let models = match load_provider_models().await {
486        Ok(m) => m,
487        Err(e) => {
488            tracing::warn!(
489                "Failed to load provider models: {}, proceeding without model info",
490                e
491            );
492            HashMap::new()
493        }
494    };
495
496    // Register via the workers/register endpoint
497    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
498
499    if let Some(t) = token {
500        req = req.bearer_auth(t);
501    }
502
503    // Flatten models HashMap into array of {id, name, provider, provider_id} objects
504    // matching the format expected by the A2A server's /models and /workers endpoints
505    let models_array: Vec<serde_json::Value> = models
506        .iter()
507        .flat_map(|(provider, model_ids)| {
508            model_ids.iter().map(move |model_id| {
509                serde_json::json!({
510                    "id": format!("{}/{}", provider, model_id),
511                    "name": model_id,
512                    "provider": provider,
513                    "provider_id": provider,
514                })
515            })
516        })
517        .collect();
518
519    tracing::info!(
520        "Registering worker with {} models from {} providers",
521        models_array.len(),
522        models.len()
523    );
524
525    let hostname = std::env::var("HOSTNAME")
526        .or_else(|_| std::env::var("COMPUTERNAME"))
527        .unwrap_or_else(|_| "unknown".to_string());
528
529    let res = req
530        .json(&serde_json::json!({
531            "worker_id": worker_id,
532            "name": name,
533            "capabilities": WORKER_CAPABILITIES,
534            "hostname": hostname,
535            "models": models_array,
536            "codebases": codebases,
537        }))
538        .send()
539        .await?;
540
541    if res.status().is_success() {
542        tracing::info!("Worker registered successfully");
543    } else {
544        tracing::warn!("Failed to register worker: {}", res.status());
545    }
546
547    Ok(())
548}
549
550/// Load ProviderRegistry and collect all available models grouped by provider.
551/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
552async fn load_provider_models() -> Result<HashMap<String, Vec<String>>> {
553    // Try Vault first
554    let registry = match ProviderRegistry::from_vault().await {
555        Ok(r) if !r.list().is_empty() => {
556            tracing::info!("Loaded {} providers from Vault", r.list().len());
557            r
558        }
559        Ok(_) => {
560            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
561            fallback_registry().await?
562        }
563        Err(e) => {
564            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
565            fallback_registry().await?
566        }
567    };
568
569    let mut models_by_provider: HashMap<String, Vec<String>> = HashMap::new();
570
571    for provider_name in registry.list() {
572        if let Some(provider) = registry.get(provider_name) {
573            match provider.list_models().await {
574                Ok(models) => {
575                    let model_ids: Vec<String> = models.into_iter().map(|m| m.id).collect();
576                    if !model_ids.is_empty() {
577                        tracing::debug!("Provider {}: {} models", provider_name, model_ids.len());
578                        models_by_provider.insert(provider_name.to_string(), model_ids);
579                    }
580                }
581                Err(e) => {
582                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
583                }
584            }
585        }
586    }
587
588    // If we still have 0, try the models.dev catalog as last resort
589    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
590    // because the worker should advertise what it can handle. The server handles routing.
591    if models_by_provider.is_empty() {
592        tracing::info!(
593            "No authenticated providers found, fetching models.dev catalog (all providers)"
594        );
595        if let Ok(catalog) = crate::provider::models::ModelCatalog::fetch().await {
596            // Use all_providers_with_models() to get every provider+model from catalog
597            // regardless of API key availability (Vault may be down)
598            for (provider_id, provider_info) in catalog.all_providers() {
599                let model_ids: Vec<String> = provider_info
600                    .models
601                    .values()
602                    .map(|m| m.id.clone())
603                    .collect();
604                if !model_ids.is_empty() {
605                    tracing::debug!(
606                        "Catalog provider {}: {} models",
607                        provider_id,
608                        model_ids.len()
609                    );
610                    models_by_provider.insert(provider_id.clone(), model_ids);
611                }
612            }
613            tracing::info!(
614                "Loaded {} providers with {} total models from catalog",
615                models_by_provider.len(),
616                models_by_provider.values().map(|v| v.len()).sum::<usize>()
617            );
618        }
619    }
620
621    Ok(models_by_provider)
622}
623
624/// Fallback: build a ProviderRegistry from config file + environment variables
625async fn fallback_registry() -> Result<ProviderRegistry> {
626    let config = crate::config::Config::load().await.unwrap_or_default();
627    ProviderRegistry::from_config(&config).await
628}
629
630async fn fetch_pending_tasks(
631    client: &Client,
632    server: &str,
633    token: &Option<String>,
634    worker_id: &str,
635    processing: &Arc<Mutex<HashSet<String>>>,
636    auto_approve: &AutoApprove,
637) -> Result<()> {
638    tracing::info!("Checking for pending tasks...");
639
640    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
641    if let Some(t) = token {
642        req = req.bearer_auth(t);
643    }
644
645    let res = req.send().await?;
646    if !res.status().is_success() {
647        return Ok(());
648    }
649
650    let data: serde_json::Value = res.json().await?;
651    // Handle both plain array response and {tasks: [...]} wrapper
652    let tasks = if let Some(arr) = data.as_array() {
653        arr.clone()
654    } else {
655        data["tasks"].as_array().cloned().unwrap_or_default()
656    };
657
658    tracing::info!("Found {} pending task(s)", tasks.len());
659
660    for task in tasks {
661        if let Some(id) = task["id"].as_str() {
662            let mut proc = processing.lock().await;
663            if !proc.contains(id) {
664                proc.insert(id.to_string());
665                drop(proc);
666
667                let task_id = id.to_string();
668                let client = client.clone();
669                let server = server.to_string();
670                let token = token.clone();
671                let worker_id = worker_id.to_string();
672                let auto_approve = *auto_approve;
673                let processing = processing.clone();
674
675                tokio::spawn(async move {
676                    if let Err(e) =
677                        handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
678                    {
679                        tracing::error!("Task {} failed: {}", task_id, e);
680                    }
681                    processing.lock().await.remove(&task_id);
682                });
683            }
684        }
685    }
686
687    Ok(())
688}
689
690#[allow(clippy::too_many_arguments)]
691async fn connect_stream(
692    client: &Client,
693    server: &str,
694    token: &Option<String>,
695    worker_id: &str,
696    name: &str,
697    codebases: &[String],
698    processing: &Arc<Mutex<HashSet<String>>>,
699    auto_approve: &AutoApprove,
700) -> Result<()> {
701    let url = format!(
702        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
703        server,
704        urlencoding::encode(name),
705        urlencoding::encode(worker_id)
706    );
707
708    let mut req = client
709        .get(&url)
710        .header("Accept", "text/event-stream")
711        .header("X-Worker-ID", worker_id)
712        .header("X-Agent-Name", name)
713        .header("X-Codebases", codebases.join(","));
714
715    if let Some(t) = token {
716        req = req.bearer_auth(t);
717    }
718
719    let res = req.send().await?;
720    if !res.status().is_success() {
721        anyhow::bail!("Failed to connect: {}", res.status());
722    }
723
724    tracing::info!("Connected to A2A server");
725
726    let mut stream = res.bytes_stream();
727    let mut buffer = String::new();
728    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
729    poll_interval.tick().await; // consume the initial immediate tick
730
731    loop {
732        tokio::select! {
733            chunk = stream.next() => {
734                match chunk {
735                    Some(Ok(chunk)) => {
736                        buffer.push_str(&String::from_utf8_lossy(&chunk));
737
738                        // Process SSE events
739                        while let Some(pos) = buffer.find("\n\n") {
740                            let event_str = buffer[..pos].to_string();
741                            buffer = buffer[pos + 2..].to_string();
742
743                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
744                                let data = data_line.trim_start_matches("data:").trim();
745                                if data == "[DONE]" || data.is_empty() {
746                                    continue;
747                                }
748
749                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
750                                    spawn_task_handler(
751                                        &task, client, server, token, worker_id,
752                                        processing, auto_approve,
753                                    ).await;
754                                }
755                            }
756                        }
757                    }
758                    Some(Err(e)) => {
759                        return Err(e.into());
760                    }
761                    None => {
762                        // Stream ended
763                        return Ok(());
764                    }
765                }
766            }
767            _ = poll_interval.tick() => {
768                // Periodic poll for pending tasks the SSE stream may have missed
769                if let Err(e) = poll_pending_tasks(
770                    client, server, token, worker_id, processing, auto_approve,
771                ).await {
772                    tracing::warn!("Periodic task poll failed: {}", e);
773                }
774            }
775        }
776    }
777}
778
779async fn spawn_task_handler(
780    task: &serde_json::Value,
781    client: &Client,
782    server: &str,
783    token: &Option<String>,
784    worker_id: &str,
785    processing: &Arc<Mutex<HashSet<String>>>,
786    auto_approve: &AutoApprove,
787) {
788    if let Some(id) = task
789        .get("task")
790        .and_then(|t| t["id"].as_str())
791        .or_else(|| task["id"].as_str())
792    {
793        let mut proc = processing.lock().await;
794        if !proc.contains(id) {
795            proc.insert(id.to_string());
796            drop(proc);
797
798            let task_id = id.to_string();
799            let task = task.clone();
800            let client = client.clone();
801            let server = server.to_string();
802            let token = token.clone();
803            let worker_id = worker_id.to_string();
804            let auto_approve = *auto_approve;
805            let processing_clone = processing.clone();
806
807            tokio::spawn(async move {
808                if let Err(e) =
809                    handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
810                {
811                    tracing::error!("Task {} failed: {}", task_id, e);
812                }
813                processing_clone.lock().await.remove(&task_id);
814            });
815        }
816    }
817}
818
819async fn poll_pending_tasks(
820    client: &Client,
821    server: &str,
822    token: &Option<String>,
823    worker_id: &str,
824    processing: &Arc<Mutex<HashSet<String>>>,
825    auto_approve: &AutoApprove,
826) -> Result<()> {
827    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
828    if let Some(t) = token {
829        req = req.bearer_auth(t);
830    }
831
832    let res = req.send().await?;
833    if !res.status().is_success() {
834        return Ok(());
835    }
836
837    let data: serde_json::Value = res.json().await?;
838    let tasks = if let Some(arr) = data.as_array() {
839        arr.clone()
840    } else {
841        data["tasks"].as_array().cloned().unwrap_or_default()
842    };
843
844    if !tasks.is_empty() {
845        tracing::debug!("Poll found {} pending task(s)", tasks.len());
846    }
847
848    for task in &tasks {
849        spawn_task_handler(
850            task,
851            client,
852            server,
853            token,
854            worker_id,
855            processing,
856            auto_approve,
857        )
858        .await;
859    }
860
861    Ok(())
862}
863
864async fn handle_task(
865    client: &Client,
866    server: &str,
867    token: &Option<String>,
868    worker_id: &str,
869    task: &serde_json::Value,
870    auto_approve: AutoApprove,
871) -> Result<()> {
872    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
873    let title = task_str(task, "title").unwrap_or("Untitled");
874
875    tracing::info!("Handling task: {} ({})", title, task_id);
876
877    // Claim the task
878    let mut req = client
879        .post(format!("{}/v1/worker/tasks/claim", server))
880        .header("X-Worker-ID", worker_id);
881    if let Some(t) = token {
882        req = req.bearer_auth(t);
883    }
884
885    let res = req
886        .json(&serde_json::json!({ "task_id": task_id }))
887        .send()
888        .await?;
889
890    if !res.status().is_success() {
891        let text = res.text().await?;
892        tracing::warn!("Failed to claim task: {}", text);
893        return Ok(());
894    }
895
896    tracing::info!("Claimed task: {}", task_id);
897
898    let metadata = task_metadata(task);
899    let resume_session_id = metadata
900        .get("resume_session_id")
901        .and_then(|v| v.as_str())
902        .map(|s| s.trim().to_string())
903        .filter(|s| !s.is_empty());
904    let complexity_hint = metadata_str(&metadata, &["complexity"]);
905    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
906        .map(|s| s.to_ascii_lowercase())
907        .or_else(|| {
908            complexity_hint.as_ref().map(|complexity| {
909                match complexity.to_ascii_lowercase().as_str() {
910                    "quick" => "fast".to_string(),
911                    "deep" => "heavy".to_string(),
912                    _ => "balanced".to_string(),
913                }
914            })
915        });
916    let worker_personality = metadata_str(
917        &metadata,
918        &["worker_personality", "personality", "agent_personality"],
919    );
920    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
921    let raw_model = task_str(task, "model_ref")
922        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
923        .or_else(|| task_str(task, "model"))
924        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
925    let selected_model = raw_model.map(model_ref_to_provider_model);
926
927    // Resume existing session when requested; fall back to a fresh session if missing.
928    let mut session = if let Some(ref sid) = resume_session_id {
929        match Session::load(sid).await {
930            Ok(existing) => {
931                tracing::info!("Resuming session {} for task {}", sid, task_id);
932                existing
933            }
934            Err(e) => {
935                tracing::warn!(
936                    "Could not load session {} for task {} ({}), starting a new session",
937                    sid,
938                    task_id,
939                    e
940                );
941                Session::new().await?
942            }
943        }
944    } else {
945        Session::new().await?
946    };
947
948    let agent_type = task_str(task, "agent_type")
949        .or_else(|| task_str(task, "agent"))
950        .unwrap_or("build");
951    session.agent = agent_type.to_string();
952
953    if let Some(model) = selected_model.clone() {
954        session.metadata.model = Some(model);
955    }
956
957    let prompt = task_str(task, "prompt")
958        .or_else(|| task_str(task, "description"))
959        .unwrap_or(title);
960
961    tracing::info!("Executing prompt: {}", prompt);
962
963    // Set up output streaming to forward progress to the server
964    let stream_client = client.clone();
965    let stream_server = server.to_string();
966    let stream_token = token.clone();
967    let stream_worker_id = worker_id.to_string();
968    let stream_task_id = task_id.to_string();
969
970    let output_callback = move |output: String| {
971        let c = stream_client.clone();
972        let s = stream_server.clone();
973        let t = stream_token.clone();
974        let w = stream_worker_id.clone();
975        let tid = stream_task_id.clone();
976        tokio::spawn(async move {
977            let mut req = c
978                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
979                .header("X-Worker-ID", &w);
980            if let Some(tok) = &t {
981                req = req.bearer_auth(tok);
982            }
983            let _ = req
984                .json(&serde_json::json!({
985                    "worker_id": w,
986                    "output": output,
987                }))
988                .send()
989                .await;
990        });
991    };
992
993    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
994    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
995        match execute_swarm_with_policy(
996            &mut session,
997            prompt,
998            model_tier.as_deref(),
999            selected_model,
1000            &metadata,
1001            complexity_hint.as_deref(),
1002            worker_personality.as_deref(),
1003            target_agent_name.as_deref(),
1004            Some(&output_callback),
1005        )
1006        .await
1007        {
1008            Ok((session_result, true)) => {
1009                tracing::info!("Swarm task completed successfully: {}", task_id);
1010                (
1011                    "completed",
1012                    Some(session_result.text),
1013                    None,
1014                    Some(session_result.session_id),
1015                )
1016            }
1017            Ok((session_result, false)) => {
1018                tracing::warn!("Swarm task completed with failures: {}", task_id);
1019                (
1020                    "failed",
1021                    Some(session_result.text),
1022                    Some("Swarm execution completed with failures".to_string()),
1023                    Some(session_result.session_id),
1024                )
1025            }
1026            Err(e) => {
1027                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1028                ("failed", None, Some(format!("Error: {}", e)), None)
1029            }
1030        }
1031    } else {
1032        match execute_session_with_policy(
1033            &mut session,
1034            prompt,
1035            auto_approve,
1036            model_tier.as_deref(),
1037            Some(&output_callback),
1038        )
1039        .await
1040        {
1041            Ok(session_result) => {
1042                tracing::info!("Task completed successfully: {}", task_id);
1043                (
1044                    "completed",
1045                    Some(session_result.text),
1046                    None,
1047                    Some(session_result.session_id),
1048                )
1049            }
1050            Err(e) => {
1051                tracing::error!("Task failed: {} - {}", task_id, e);
1052                ("failed", None, Some(format!("Error: {}", e)), None)
1053            }
1054        }
1055    };
1056
1057    // Release the task with full details
1058    let mut req = client
1059        .post(format!("{}/v1/worker/tasks/release", server))
1060        .header("X-Worker-ID", worker_id);
1061    if let Some(t) = token {
1062        req = req.bearer_auth(t);
1063    }
1064
1065    req.json(&serde_json::json!({
1066        "task_id": task_id,
1067        "status": status,
1068        "result": result,
1069        "error": error,
1070        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1071    }))
1072    .send()
1073    .await?;
1074
1075    tracing::info!("Task released: {} with status: {}", task_id, status);
1076    Ok(())
1077}
1078
1079async fn execute_swarm_with_policy<F>(
1080    session: &mut Session,
1081    prompt: &str,
1082    model_tier: Option<&str>,
1083    explicit_model: Option<String>,
1084    metadata: &serde_json::Map<String, serde_json::Value>,
1085    complexity_hint: Option<&str>,
1086    worker_personality: Option<&str>,
1087    target_agent_name: Option<&str>,
1088    output_callback: Option<&F>,
1089) -> Result<(crate::session::SessionResult, bool)>
1090where
1091    F: Fn(String),
1092{
1093    use crate::provider::{ContentPart, Message, Role};
1094
1095    session.add_message(Message {
1096        role: Role::User,
1097        content: vec![ContentPart::Text {
1098            text: prompt.to_string(),
1099        }],
1100    });
1101
1102    if session.title.is_none() {
1103        session.generate_title().await?;
1104    }
1105
1106    let strategy = parse_swarm_strategy(metadata);
1107    let max_subagents = metadata_usize(
1108        metadata,
1109        &["swarm_max_subagents", "max_subagents", "subagents"],
1110    )
1111    .unwrap_or(10)
1112    .clamp(1, 100);
1113    let max_steps_per_subagent = metadata_usize(
1114        metadata,
1115        &[
1116            "swarm_max_steps_per_subagent",
1117            "max_steps_per_subagent",
1118            "max_steps",
1119        ],
1120    )
1121    .unwrap_or(50)
1122    .clamp(1, 200);
1123    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1124        .unwrap_or(600)
1125        .clamp(30, 3600);
1126    let parallel_enabled =
1127        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1128
1129    let model = resolve_swarm_model(explicit_model, model_tier).await;
1130    if let Some(ref selected_model) = model {
1131        session.metadata.model = Some(selected_model.clone());
1132    }
1133
1134    if let Some(cb) = output_callback {
1135        cb(format!(
1136            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1137            complexity_hint.unwrap_or("standard"),
1138            model_tier.unwrap_or("balanced"),
1139            worker_personality.unwrap_or("auto"),
1140            target_agent_name.unwrap_or("auto")
1141        ));
1142        cb(format!(
1143            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1144            strategy,
1145            max_subagents,
1146            max_steps_per_subagent,
1147            timeout_secs,
1148            model_tier.unwrap_or("balanced")
1149        ));
1150    }
1151
1152    let swarm_config = SwarmConfig {
1153        max_subagents,
1154        max_steps_per_subagent,
1155        subagent_timeout_secs: timeout_secs,
1156        parallel_enabled,
1157        model,
1158        working_dir: session
1159            .metadata
1160            .directory
1161            .as_ref()
1162            .map(|p| p.to_string_lossy().to_string()),
1163        ..Default::default()
1164    };
1165
1166    let swarm_result = if output_callback.is_some() {
1167        let (event_tx, mut event_rx) = mpsc::channel(256);
1168        let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1169        let prompt_owned = prompt.to_string();
1170        let mut exec_handle =
1171            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1172
1173        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1174
1175        while final_result.is_none() {
1176            tokio::select! {
1177                maybe_event = event_rx.recv() => {
1178                    if let Some(event) = maybe_event {
1179                        if let Some(cb) = output_callback {
1180                            if let Some(line) = format_swarm_event_for_output(&event) {
1181                                cb(line);
1182                            }
1183                        }
1184                    }
1185                }
1186                join_result = &mut exec_handle => {
1187                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1188                    final_result = Some(joined?);
1189                }
1190            }
1191        }
1192
1193        while let Ok(event) = event_rx.try_recv() {
1194            if let Some(cb) = output_callback {
1195                if let Some(line) = format_swarm_event_for_output(&event) {
1196                    cb(line);
1197                }
1198            }
1199        }
1200
1201        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1202    } else {
1203        SwarmExecutor::new(swarm_config)
1204            .execute(prompt, strategy)
1205            .await?
1206    };
1207
1208    let final_text = if swarm_result.result.trim().is_empty() {
1209        if swarm_result.success {
1210            "Swarm completed without textual output.".to_string()
1211        } else {
1212            "Swarm finished with failures and no textual output.".to_string()
1213        }
1214    } else {
1215        swarm_result.result.clone()
1216    };
1217
1218    session.add_message(Message {
1219        role: Role::Assistant,
1220        content: vec![ContentPart::Text {
1221            text: final_text.clone(),
1222        }],
1223    });
1224    session.save().await?;
1225
1226    Ok((
1227        crate::session::SessionResult {
1228            text: final_text,
1229            session_id: session.id.clone(),
1230        },
1231        swarm_result.success,
1232    ))
1233}
1234
1235/// Execute a session with the given auto-approve policy
1236/// Optionally streams output chunks via the callback
1237async fn execute_session_with_policy<F>(
1238    session: &mut Session,
1239    prompt: &str,
1240    auto_approve: AutoApprove,
1241    model_tier: Option<&str>,
1242    output_callback: Option<&F>,
1243) -> Result<crate::session::SessionResult>
1244where
1245    F: Fn(String),
1246{
1247    use crate::provider::{
1248        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1249    };
1250    use std::sync::Arc;
1251
1252    // Load provider registry from Vault
1253    let registry = ProviderRegistry::from_vault().await?;
1254    let providers = registry.list();
1255    tracing::info!("Available providers: {:?}", providers);
1256
1257    if providers.is_empty() {
1258        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1259    }
1260
1261    // Parse model string
1262    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1263        let (prov, model) = parse_model_string(model_str);
1264        if prov.is_some() {
1265            (prov.map(|s| s.to_string()), model.to_string())
1266        } else if providers.contains(&model) {
1267            (Some(model.to_string()), String::new())
1268        } else {
1269            (None, model.to_string())
1270        }
1271    } else {
1272        (None, String::new())
1273    };
1274
1275    let provider_slice = providers.as_slice();
1276    let provider_requested_but_unavailable = provider_name
1277        .as_deref()
1278        .map(|p| !providers.contains(&p))
1279        .unwrap_or(false);
1280
1281    // Determine which provider to use, preferring explicit request first, then model tier.
1282    let selected_provider = provider_name
1283        .as_deref()
1284        .filter(|p| providers.contains(p))
1285        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1286
1287    let provider = registry
1288        .get(selected_provider)
1289        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1290
1291    // Add user message
1292    session.add_message(Message {
1293        role: Role::User,
1294        content: vec![ContentPart::Text {
1295            text: prompt.to_string(),
1296        }],
1297    });
1298
1299    // Generate title
1300    if session.title.is_none() {
1301        session.generate_title().await?;
1302    }
1303
1304    // Determine model. If a specific provider was requested but not available,
1305    // ignore that model id and fall back to the tier-based default model.
1306    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1307        model_id
1308    } else {
1309        default_model_for_provider(selected_provider, model_tier)
1310    };
1311
1312    // Create tool registry with filtering based on auto-approve policy
1313    let tool_registry =
1314        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1315    let tool_definitions = tool_registry.definitions();
1316
1317    let temperature = if model.starts_with("kimi-k2") {
1318        Some(1.0)
1319    } else {
1320        Some(0.7)
1321    };
1322
1323    tracing::info!(
1324        "Using model: {} via provider: {} (tier: {:?})",
1325        model,
1326        selected_provider,
1327        model_tier
1328    );
1329    tracing::info!(
1330        "Available tools: {} (auto_approve: {:?})",
1331        tool_definitions.len(),
1332        auto_approve
1333    );
1334
1335    // Build system prompt
1336    let cwd = std::env::var("PWD")
1337        .map(std::path::PathBuf::from)
1338        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1339    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1340
1341    let mut final_output = String::new();
1342    let max_steps = 50;
1343
1344    for step in 1..=max_steps {
1345        tracing::info!(step = step, "Agent step starting");
1346
1347        // Build messages with system prompt first
1348        let mut messages = vec![Message {
1349            role: Role::System,
1350            content: vec![ContentPart::Text {
1351                text: system_prompt.clone(),
1352            }],
1353        }];
1354        messages.extend(session.messages.clone());
1355
1356        let request = CompletionRequest {
1357            messages,
1358            tools: tool_definitions.clone(),
1359            model: model.clone(),
1360            temperature,
1361            top_p: None,
1362            max_tokens: Some(8192),
1363            stop: Vec::new(),
1364        };
1365
1366        let response = provider.complete(request).await?;
1367
1368        crate::telemetry::TOKEN_USAGE.record_model_usage(
1369            &model,
1370            response.usage.prompt_tokens as u64,
1371            response.usage.completion_tokens as u64,
1372        );
1373
1374        // Extract tool calls
1375        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1376            .message
1377            .content
1378            .iter()
1379            .filter_map(|part| {
1380                if let ContentPart::ToolCall {
1381                    id,
1382                    name,
1383                    arguments,
1384                } = part
1385                {
1386                    let args: serde_json::Value =
1387                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1388                    Some((id.clone(), name.clone(), args))
1389                } else {
1390                    None
1391                }
1392            })
1393            .collect();
1394
1395        // Collect text output and stream it
1396        for part in &response.message.content {
1397            if let ContentPart::Text { text } = part {
1398                if !text.is_empty() {
1399                    final_output.push_str(text);
1400                    final_output.push('\n');
1401                    if let Some(cb) = output_callback {
1402                        cb(text.clone());
1403                    }
1404                }
1405            }
1406        }
1407
1408        // If no tool calls, we're done
1409        if tool_calls.is_empty() {
1410            session.add_message(response.message.clone());
1411            break;
1412        }
1413
1414        session.add_message(response.message.clone());
1415
1416        tracing::info!(
1417            step = step,
1418            num_tools = tool_calls.len(),
1419            "Executing tool calls"
1420        );
1421
1422        // Execute each tool call
1423        for (tool_id, tool_name, tool_input) in tool_calls {
1424            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1425
1426            // Stream tool start event
1427            if let Some(cb) = output_callback {
1428                cb(format!("[tool:start:{}]", tool_name));
1429            }
1430
1431            // Check if tool is allowed based on auto-approve policy
1432            if !is_tool_allowed(&tool_name, auto_approve) {
1433                let msg = format!(
1434                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1435                    tool_name, auto_approve
1436                );
1437                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1438                session.add_message(Message {
1439                    role: Role::Tool,
1440                    content: vec![ContentPart::ToolResult {
1441                        tool_call_id: tool_id,
1442                        content: msg,
1443                    }],
1444                });
1445                continue;
1446            }
1447
1448            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1449                let exec_result: Result<crate::tool::ToolResult> =
1450                    tool.execute(tool_input.clone()).await;
1451                match exec_result {
1452                    Ok(result) => {
1453                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1454                        if let Some(cb) = output_callback {
1455                            let status = if result.success { "ok" } else { "err" };
1456                            cb(format!(
1457                                "[tool:{}:{}] {}",
1458                                tool_name,
1459                                status,
1460                                &result.output[..result.output.len().min(500)]
1461                            ));
1462                        }
1463                        result.output
1464                    }
1465                    Err(e) => {
1466                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1467                        if let Some(cb) = output_callback {
1468                            cb(format!("[tool:{}:err] {}", tool_name, e));
1469                        }
1470                        format!("Error: {}", e)
1471                    }
1472                }
1473            } else {
1474                tracing::warn!(tool = %tool_name, "Tool not found");
1475                format!("Error: Unknown tool '{}'", tool_name)
1476            };
1477
1478            session.add_message(Message {
1479                role: Role::Tool,
1480                content: vec![ContentPart::ToolResult {
1481                    tool_call_id: tool_id,
1482                    content,
1483                }],
1484            });
1485        }
1486    }
1487
1488    session.save().await?;
1489
1490    Ok(crate::session::SessionResult {
1491        text: final_output.trim().to_string(),
1492        session_id: session.id.clone(),
1493    })
1494}
1495
1496/// Check if a tool is allowed based on the auto-approve policy
1497fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1498    match auto_approve {
1499        AutoApprove::All => true,
1500        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1501    }
1502}
1503
1504/// Check if a tool is considered "safe" (read-only)
1505fn is_safe_tool(tool_name: &str) -> bool {
1506    let safe_tools = [
1507        "read",
1508        "list",
1509        "glob",
1510        "grep",
1511        "codesearch",
1512        "lsp",
1513        "webfetch",
1514        "websearch",
1515        "todo_read",
1516        "skill",
1517    ];
1518    safe_tools.contains(&tool_name)
1519}
1520
1521/// Create a filtered tool registry based on the auto-approve policy
1522fn create_filtered_registry(
1523    provider: Arc<dyn crate::provider::Provider>,
1524    model: String,
1525    auto_approve: AutoApprove,
1526) -> crate::tool::ToolRegistry {
1527    use crate::tool::*;
1528
1529    let mut registry = ToolRegistry::new();
1530
1531    // Always add safe tools
1532    registry.register(Arc::new(file::ReadTool::new()));
1533    registry.register(Arc::new(file::ListTool::new()));
1534    registry.register(Arc::new(file::GlobTool::new()));
1535    registry.register(Arc::new(search::GrepTool::new()));
1536    registry.register(Arc::new(lsp::LspTool::new()));
1537    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1538    registry.register(Arc::new(websearch::WebSearchTool::new()));
1539    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1540    registry.register(Arc::new(todo::TodoReadTool::new()));
1541    registry.register(Arc::new(skill::SkillTool::new()));
1542
1543    // Add potentially dangerous tools only if auto_approve is All
1544    if matches!(auto_approve, AutoApprove::All) {
1545        registry.register(Arc::new(file::WriteTool::new()));
1546        registry.register(Arc::new(edit::EditTool::new()));
1547        registry.register(Arc::new(bash::BashTool::new()));
1548        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1549        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1550        registry.register(Arc::new(todo::TodoWriteTool::new()));
1551        registry.register(Arc::new(task::TaskTool::new()));
1552        registry.register(Arc::new(plan::PlanEnterTool::new()));
1553        registry.register(Arc::new(plan::PlanExitTool::new()));
1554        registry.register(Arc::new(rlm::RlmTool::new()));
1555        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1556        registry.register(Arc::new(prd::PrdTool::new()));
1557        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1558        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1559        registry.register(Arc::new(undo::UndoTool));
1560    }
1561
1562    registry.register(Arc::new(invalid::InvalidTool::new()));
1563
1564    registry
1565}
1566
1567/// Start the heartbeat background task
1568/// Returns a JoinHandle that can be used to cancel the heartbeat
1569fn start_heartbeat(
1570    client: Client,
1571    server: String,
1572    token: Option<String>,
1573    heartbeat_state: HeartbeatState,
1574    processing: Arc<Mutex<HashSet<String>>>,
1575) -> JoinHandle<()> {
1576    tokio::spawn(async move {
1577        let mut consecutive_failures = 0u32;
1578        const MAX_FAILURES: u32 = 3;
1579        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1580
1581        let mut interval =
1582            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1583        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1584
1585        loop {
1586            interval.tick().await;
1587
1588            // Update task count from processing set
1589            let active_count = processing.lock().await.len();
1590            heartbeat_state.set_task_count(active_count).await;
1591
1592            // Determine status based on active tasks
1593            let status = if active_count > 0 {
1594                WorkerStatus::Processing
1595            } else {
1596                WorkerStatus::Idle
1597            };
1598            heartbeat_state.set_status(status).await;
1599
1600            // Send heartbeat
1601            let url = format!(
1602                "{}/v1/opencode/workers/{}/heartbeat",
1603                server, heartbeat_state.worker_id
1604            );
1605            let mut req = client.post(&url);
1606
1607            if let Some(ref t) = token {
1608                req = req.bearer_auth(t);
1609            }
1610
1611            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1612            let payload = serde_json::json!({
1613                "worker_id": &heartbeat_state.worker_id,
1614                "agent_name": &heartbeat_state.agent_name,
1615                "status": status_str,
1616                "active_task_count": active_count,
1617            });
1618
1619            match req.json(&payload).send().await {
1620                Ok(res) => {
1621                    if res.status().is_success() {
1622                        consecutive_failures = 0;
1623                        tracing::debug!(
1624                            worker_id = %heartbeat_state.worker_id,
1625                            status = status_str,
1626                            active_tasks = active_count,
1627                            "Heartbeat sent successfully"
1628                        );
1629                    } else {
1630                        consecutive_failures += 1;
1631                        tracing::warn!(
1632                            worker_id = %heartbeat_state.worker_id,
1633                            status = %res.status(),
1634                            failures = consecutive_failures,
1635                            "Heartbeat failed"
1636                        );
1637                    }
1638                }
1639                Err(e) => {
1640                    consecutive_failures += 1;
1641                    tracing::warn!(
1642                        worker_id = %heartbeat_state.worker_id,
1643                        error = %e,
1644                        failures = consecutive_failures,
1645                        "Heartbeat request failed"
1646                    );
1647                }
1648            }
1649
1650            // Log error after 3 consecutive failures but do not terminate
1651            if consecutive_failures >= MAX_FAILURES {
1652                tracing::error!(
1653                    worker_id = %heartbeat_state.worker_id,
1654                    failures = consecutive_failures,
1655                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1656                    MAX_FAILURES
1657                );
1658                // Reset counter to avoid spamming error logs
1659                consecutive_failures = 0;
1660            }
1661        }
1662    })
1663}