Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{Mutex, mpsc};
17use tokio::task::JoinHandle;
18use tokio::time::Instant;
19
20/// Worker status for heartbeat
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum WorkerStatus {
23    Idle,
24    Processing,
25}
26
27impl WorkerStatus {
28    fn as_str(&self) -> &'static str {
29        match self {
30            WorkerStatus::Idle => "idle",
31            WorkerStatus::Processing => "processing",
32        }
33    }
34}
35
36/// Heartbeat state shared between the heartbeat task and the main worker
37#[derive(Clone)]
38struct HeartbeatState {
39    worker_id: String,
40    agent_name: String,
41    status: Arc<Mutex<WorkerStatus>>,
42    active_task_count: Arc<Mutex<usize>>,
43}
44
45impl HeartbeatState {
46    fn new(worker_id: String, agent_name: String) -> Self {
47        Self {
48            worker_id,
49            agent_name,
50            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
51            active_task_count: Arc::new(Mutex::new(0)),
52        }
53    }
54
55    async fn set_status(&self, status: WorkerStatus) {
56        *self.status.lock().await = status;
57    }
58
59    async fn set_task_count(&self, count: usize) {
60        *self.active_task_count.lock().await = count;
61    }
62}
63
64#[derive(Clone, Debug)]
65struct CognitionHeartbeatConfig {
66    enabled: bool,
67    source_base_url: String,
68    include_thought_summary: bool,
69    summary_max_chars: usize,
70    request_timeout_ms: u64,
71}
72
73impl CognitionHeartbeatConfig {
74    fn from_env() -> Self {
75        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
76            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
77            .trim_end_matches('/')
78            .to_string();
79
80        Self {
81            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
82            source_base_url,
83            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
84            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
85                .max(120),
86            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
87        }
88    }
89}
90
91#[derive(Debug, Deserialize)]
92struct CognitionStatusSnapshot {
93    running: bool,
94    #[serde(default)]
95    last_tick_at: Option<String>,
96    #[serde(default)]
97    active_persona_count: usize,
98    #[serde(default)]
99    events_buffered: usize,
100    #[serde(default)]
101    snapshots_buffered: usize,
102    #[serde(default)]
103    loop_interval_ms: u64,
104}
105
106#[derive(Debug, Deserialize)]
107struct CognitionLatestSnapshot {
108    generated_at: String,
109    summary: String,
110    #[serde(default)]
111    metadata: HashMap<String, serde_json::Value>,
112}
113
114/// Run the A2A worker
115pub async fn run(args: A2aArgs) -> Result<()> {
116    let server = args.server.trim_end_matches('/');
117    let name = args
118        .name
119        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
120    let worker_id = generate_worker_id();
121
122    let codebases: Vec<String> = args
123        .codebases
124        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
125        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
126
127    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
128    tracing::info!("Server: {}", server);
129    tracing::info!("Codebases: {:?}", codebases);
130
131    let client = Client::new();
132    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
133    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
134    if cognition_heartbeat.enabled {
135        tracing::info!(
136            source = %cognition_heartbeat.source_base_url,
137            include_thoughts = cognition_heartbeat.include_thought_summary,
138            max_chars = cognition_heartbeat.summary_max_chars,
139            timeout_ms = cognition_heartbeat.request_timeout_ms,
140            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
141        );
142    } else {
143        tracing::warn!(
144            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
145        );
146    }
147
148    let auto_approve = match args.auto_approve.as_str() {
149        "all" => AutoApprove::All,
150        "safe" => AutoApprove::Safe,
151        _ => AutoApprove::None,
152    };
153
154    // Create heartbeat state
155    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
156
157    // Register worker
158    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
159
160    // Fetch pending tasks
161    fetch_pending_tasks(
162        &client,
163        server,
164        &args.token,
165        &worker_id,
166        &processing,
167        &auto_approve,
168    )
169    .await?;
170
171    // Connect to SSE stream
172    loop {
173        // Re-register worker on each reconnection to report updated models/capabilities
174        if let Err(e) =
175            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
176        {
177            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
178        }
179
180        // Start heartbeat task for this connection
181        let heartbeat_handle = start_heartbeat(
182            client.clone(),
183            server.to_string(),
184            args.token.clone(),
185            heartbeat_state.clone(),
186            processing.clone(),
187            cognition_heartbeat.clone(),
188        );
189
190        match connect_stream(
191            &client,
192            server,
193            &args.token,
194            &worker_id,
195            &name,
196            &codebases,
197            &processing,
198            &auto_approve,
199        )
200        .await
201        {
202            Ok(()) => {
203                tracing::warn!("Stream ended, reconnecting...");
204            }
205            Err(e) => {
206                tracing::error!("Stream error: {}, reconnecting...", e);
207            }
208        }
209
210        // Cancel heartbeat on disconnection
211        heartbeat_handle.abort();
212        tracing::debug!("Heartbeat cancelled for reconnection");
213
214        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
215    }
216}
217
218fn generate_worker_id() -> String {
219    format!(
220        "wrk_{}_{:x}",
221        chrono::Utc::now().timestamp(),
222        rand::random::<u64>()
223    )
224}
225
226#[derive(Debug, Clone, Copy)]
227enum AutoApprove {
228    All,
229    Safe,
230    None,
231}
232
233/// Default A2A server URL when none is configured
234pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
235
236/// Capabilities of the codetether-agent worker
237const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
238
239fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
240    task.get("task")
241        .and_then(|t| t.get(key))
242        .or_else(|| task.get(key))
243}
244
245fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
246    task_value(task, key).and_then(|v| v.as_str())
247}
248
249fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
250    task_value(task, "metadata")
251        .and_then(|m| m.as_object())
252        .cloned()
253        .unwrap_or_default()
254}
255
256fn model_ref_to_provider_model(model: &str) -> String {
257    // Convert "provider:model" to "provider/model" format, but only if
258    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
259    // contain colons as version separators and must NOT be converted.
260    if !model.contains('/') && model.contains(':') {
261        model.replacen(':', "/", 1)
262    } else {
263        model.to_string()
264    }
265}
266
267fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
268    match model_tier.unwrap_or("balanced") {
269        "fast" | "quick" => &[
270            "openai",
271            "github-copilot",
272            "moonshotai",
273            "zhipuai",
274            "openrouter",
275            "novita",
276            "google",
277            "anthropic",
278        ],
279        "heavy" | "deep" => &[
280            "anthropic",
281            "openai",
282            "github-copilot",
283            "moonshotai",
284            "zhipuai",
285            "openrouter",
286            "novita",
287            "google",
288        ],
289        _ => &[
290            "openai",
291            "github-copilot",
292            "anthropic",
293            "moonshotai",
294            "zhipuai",
295            "openrouter",
296            "novita",
297            "google",
298        ],
299    }
300}
301
302fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
303    for preferred in provider_preferences_for_tier(model_tier) {
304        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
305            return found;
306        }
307    }
308    if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
309        return found;
310    }
311    providers[0]
312}
313
314fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
315    match model_tier.unwrap_or("balanced") {
316        "fast" | "quick" => match provider {
317            "moonshotai" => "kimi-k2.5".to_string(),
318            "anthropic" => "claude-haiku-4-5".to_string(),
319            "openai" => "gpt-4o-mini".to_string(),
320            "google" => "gemini-2.5-flash".to_string(),
321            "zhipuai" => "glm-4.7".to_string(),
322            "openrouter" => "z-ai/glm-4.7".to_string(),
323            "novita" => "qwen/qwen3-coder-next".to_string(),
324            _ => "glm-4.7".to_string(),
325        },
326        "heavy" | "deep" => match provider {
327            "moonshotai" => "kimi-k2.5".to_string(),
328            "anthropic" => "claude-sonnet-4-20250514".to_string(),
329            "openai" => "o3".to_string(),
330            "google" => "gemini-2.5-pro".to_string(),
331            "zhipuai" => "glm-4.7".to_string(),
332            "openrouter" => "z-ai/glm-4.7".to_string(),
333            "novita" => "qwen/qwen3-coder-next".to_string(),
334            _ => "glm-4.7".to_string(),
335        },
336        _ => match provider {
337            "moonshotai" => "kimi-k2.5".to_string(),
338            "anthropic" => "claude-sonnet-4-20250514".to_string(),
339            "openai" => "gpt-4o".to_string(),
340            "google" => "gemini-2.5-pro".to_string(),
341            "zhipuai" => "glm-4.7".to_string(),
342            "openrouter" => "z-ai/glm-4.7".to_string(),
343            "novita" => "qwen/qwen3-coder-next".to_string(),
344            _ => "glm-4.7".to_string(),
345        },
346    }
347}
348
349fn is_swarm_agent(agent_type: &str) -> bool {
350    matches!(
351        agent_type.trim().to_ascii_lowercase().as_str(),
352        "swarm" | "parallel" | "multi-agent"
353    )
354}
355
356fn metadata_lookup<'a>(
357    metadata: &'a serde_json::Map<String, serde_json::Value>,
358    key: &str,
359) -> Option<&'a serde_json::Value> {
360    metadata
361        .get(key)
362        .or_else(|| {
363            metadata
364                .get("routing")
365                .and_then(|v| v.as_object())
366                .and_then(|obj| obj.get(key))
367        })
368        .or_else(|| {
369            metadata
370                .get("swarm")
371                .and_then(|v| v.as_object())
372                .and_then(|obj| obj.get(key))
373        })
374}
375
376fn metadata_str(
377    metadata: &serde_json::Map<String, serde_json::Value>,
378    keys: &[&str],
379) -> Option<String> {
380    for key in keys {
381        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
382            let trimmed = value.trim();
383            if !trimmed.is_empty() {
384                return Some(trimmed.to_string());
385            }
386        }
387    }
388    None
389}
390
391fn metadata_usize(
392    metadata: &serde_json::Map<String, serde_json::Value>,
393    keys: &[&str],
394) -> Option<usize> {
395    for key in keys {
396        if let Some(value) = metadata_lookup(metadata, key) {
397            if let Some(v) = value.as_u64() {
398                return usize::try_from(v).ok();
399            }
400            if let Some(v) = value.as_i64() {
401                if v >= 0 {
402                    return usize::try_from(v as u64).ok();
403                }
404            }
405            if let Some(v) = value.as_str() {
406                if let Ok(parsed) = v.trim().parse::<usize>() {
407                    return Some(parsed);
408                }
409            }
410        }
411    }
412    None
413}
414
415fn metadata_u64(
416    metadata: &serde_json::Map<String, serde_json::Value>,
417    keys: &[&str],
418) -> Option<u64> {
419    for key in keys {
420        if let Some(value) = metadata_lookup(metadata, key) {
421            if let Some(v) = value.as_u64() {
422                return Some(v);
423            }
424            if let Some(v) = value.as_i64() {
425                if v >= 0 {
426                    return Some(v as u64);
427                }
428            }
429            if let Some(v) = value.as_str() {
430                if let Ok(parsed) = v.trim().parse::<u64>() {
431                    return Some(parsed);
432                }
433            }
434        }
435    }
436    None
437}
438
439fn metadata_bool(
440    metadata: &serde_json::Map<String, serde_json::Value>,
441    keys: &[&str],
442) -> Option<bool> {
443    for key in keys {
444        if let Some(value) = metadata_lookup(metadata, key) {
445            if let Some(v) = value.as_bool() {
446                return Some(v);
447            }
448            if let Some(v) = value.as_str() {
449                match v.trim().to_ascii_lowercase().as_str() {
450                    "1" | "true" | "yes" | "on" => return Some(true),
451                    "0" | "false" | "no" | "off" => return Some(false),
452                    _ => {}
453                }
454            }
455        }
456    }
457    None
458}
459
460fn parse_swarm_strategy(
461    metadata: &serde_json::Map<String, serde_json::Value>,
462) -> DecompositionStrategy {
463    match metadata_str(
464        metadata,
465        &[
466            "decomposition_strategy",
467            "swarm_strategy",
468            "strategy",
469            "swarm_decomposition",
470        ],
471    )
472    .as_deref()
473    .map(|s| s.to_ascii_lowercase())
474    .as_deref()
475    {
476        Some("none") | Some("single") => DecompositionStrategy::None,
477        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
478        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
479        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
480        _ => DecompositionStrategy::Automatic,
481    }
482}
483
484async fn resolve_swarm_model(
485    explicit_model: Option<String>,
486    model_tier: Option<&str>,
487) -> Option<String> {
488    if let Some(model) = explicit_model {
489        if !model.trim().is_empty() {
490            return Some(model);
491        }
492    }
493
494    let registry = ProviderRegistry::from_vault().await.ok()?;
495    let providers = registry.list();
496    if providers.is_empty() {
497        return None;
498    }
499    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
500    let model = default_model_for_provider(provider, model_tier);
501    Some(format!("{}/{}", provider, model))
502}
503
504fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
505    match event {
506        SwarmEvent::Started {
507            task,
508            total_subtasks,
509        } => Some(format!(
510            "[swarm] started task={} planned_subtasks={}",
511            task, total_subtasks
512        )),
513        SwarmEvent::StageComplete {
514            stage,
515            completed,
516            failed,
517        } => Some(format!(
518            "[swarm] stage={} completed={} failed={}",
519            stage, completed, failed
520        )),
521        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
522            "[swarm] subtask id={} status={}",
523            &id.chars().take(8).collect::<String>(),
524            format!("{status:?}").to_ascii_lowercase()
525        )),
526        SwarmEvent::AgentToolCall {
527            subtask_id,
528            tool_name,
529        } => Some(format!(
530            "[swarm] subtask id={} tool={}",
531            &subtask_id.chars().take(8).collect::<String>(),
532            tool_name
533        )),
534        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
535            "[swarm] subtask id={} error={}",
536            &subtask_id.chars().take(8).collect::<String>(),
537            error
538        )),
539        SwarmEvent::Complete { success, stats } => Some(format!(
540            "[swarm] complete success={} subtasks={} speedup={:.2}",
541            success,
542            stats.subagents_completed + stats.subagents_failed,
543            stats.speedup_factor
544        )),
545        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
546        _ => None,
547    }
548}
549
550async fn register_worker(
551    client: &Client,
552    server: &str,
553    token: &Option<String>,
554    worker_id: &str,
555    name: &str,
556    codebases: &[String],
557) -> Result<()> {
558    // Load ProviderRegistry and collect available models
559    let models = match load_provider_models().await {
560        Ok(m) => m,
561        Err(e) => {
562            tracing::warn!(
563                "Failed to load provider models: {}, proceeding without model info",
564                e
565            );
566            HashMap::new()
567        }
568    };
569
570    // Register via the workers/register endpoint
571    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
572
573    if let Some(t) = token {
574        req = req.bearer_auth(t);
575    }
576
577    // Flatten models HashMap into array of model objects with pricing data
578    // matching the format expected by the A2A server's /models and /workers endpoints
579    let models_array: Vec<serde_json::Value> = models
580        .iter()
581        .flat_map(|(provider, model_infos)| {
582            model_infos.iter().map(move |m| {
583                let mut obj = serde_json::json!({
584                    "id": format!("{}/{}", provider, m.id),
585                    "name": &m.id,
586                    "provider": provider,
587                    "provider_id": provider,
588                });
589                if let Some(input_cost) = m.input_cost_per_million {
590                    obj["input_cost_per_million"] = serde_json::json!(input_cost);
591                }
592                if let Some(output_cost) = m.output_cost_per_million {
593                    obj["output_cost_per_million"] = serde_json::json!(output_cost);
594                }
595                obj
596            })
597        })
598        .collect();
599
600    tracing::info!(
601        "Registering worker with {} models from {} providers",
602        models_array.len(),
603        models.len()
604    );
605
606    let hostname = std::env::var("HOSTNAME")
607        .or_else(|_| std::env::var("COMPUTERNAME"))
608        .unwrap_or_else(|_| "unknown".to_string());
609
610    let res = req
611        .json(&serde_json::json!({
612            "worker_id": worker_id,
613            "name": name,
614            "capabilities": WORKER_CAPABILITIES,
615            "hostname": hostname,
616            "models": models_array,
617            "codebases": codebases,
618        }))
619        .send()
620        .await?;
621
622    if res.status().is_success() {
623        tracing::info!("Worker registered successfully");
624    } else {
625        tracing::warn!("Failed to register worker: {}", res.status());
626    }
627
628    Ok(())
629}
630
631/// Load ProviderRegistry and collect all available models grouped by provider.
632/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
633/// Returns ModelInfo structs (with pricing data when available).
634async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
635    // Try Vault first
636    let registry = match ProviderRegistry::from_vault().await {
637        Ok(r) if !r.list().is_empty() => {
638            tracing::info!("Loaded {} providers from Vault", r.list().len());
639            r
640        }
641        Ok(_) => {
642            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
643            fallback_registry().await?
644        }
645        Err(e) => {
646            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
647            fallback_registry().await?
648        }
649    };
650
651    // Fetch the models.dev catalog for pricing data enrichment
652    let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
653
654    // Map provider IDs to their catalog equivalents (some differ)
655    let catalog_alias = |pid: &str| -> String {
656        match pid {
657            "bedrock" => "amazon-bedrock".to_string(),
658            "novita" => "novita-ai".to_string(),
659            _ => pid.to_string(),
660        }
661    };
662
663    let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
664
665    for provider_name in registry.list() {
666        if let Some(provider) = registry.get(provider_name) {
667            match provider.list_models().await {
668                Ok(models) => {
669                    let enriched: Vec<crate::provider::ModelInfo> = models
670                        .into_iter()
671                        .map(|mut m| {
672                            // Enrich with catalog pricing if the provider didn't set it
673                            if m.input_cost_per_million.is_none()
674                                || m.output_cost_per_million.is_none()
675                            {
676                                if let Some(ref cat) = catalog {
677                                    let cat_pid = catalog_alias(provider_name);
678                                    if let Some(prov_info) = cat.get_provider(&cat_pid) {
679                                        // Try exact match first, then strip "us." prefix (bedrock uses us.vendor.model format)
680                                        let model_info =
681                                            prov_info.models.get(&m.id).or_else(|| {
682                                                m.id.strip_prefix("us.").and_then(|stripped| {
683                                                    prov_info.models.get(stripped)
684                                                })
685                                            });
686                                        if let Some(model_info) = model_info {
687                                            if let Some(ref cost) = model_info.cost {
688                                                if m.input_cost_per_million.is_none() {
689                                                    m.input_cost_per_million = Some(cost.input);
690                                                }
691                                                if m.output_cost_per_million.is_none() {
692                                                    m.output_cost_per_million = Some(cost.output);
693                                                }
694                                            }
695                                        }
696                                    }
697                                }
698                            }
699                            m
700                        })
701                        .collect();
702                    if !enriched.is_empty() {
703                        tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
704                        models_by_provider.insert(provider_name.to_string(), enriched);
705                    }
706                }
707                Err(e) => {
708                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
709                }
710            }
711        }
712    }
713
714    // If we still have 0, try the models.dev catalog as last resort
715    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
716    // because the worker should advertise what it can handle. The server handles routing.
717    if models_by_provider.is_empty() {
718        tracing::info!(
719            "No authenticated providers found, fetching models.dev catalog (all providers)"
720        );
721        if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
722            // Use all_providers_with_models() to get every provider+model from catalog
723            // regardless of API key availability (Vault may be down)
724            for (provider_id, provider_info) in cat.all_providers() {
725                let model_infos: Vec<crate::provider::ModelInfo> = provider_info
726                    .models
727                    .values()
728                    .map(|m| cat.to_model_info(m, provider_id))
729                    .collect();
730                if !model_infos.is_empty() {
731                    tracing::debug!(
732                        "Catalog provider {}: {} models",
733                        provider_id,
734                        model_infos.len()
735                    );
736                    models_by_provider.insert(provider_id.clone(), model_infos);
737                }
738            }
739            tracing::info!(
740                "Loaded {} providers with {} total models from catalog",
741                models_by_provider.len(),
742                models_by_provider.values().map(|v| v.len()).sum::<usize>()
743            );
744        }
745    }
746
747    Ok(models_by_provider)
748}
749
750/// Fallback: build a ProviderRegistry from config file + environment variables
751async fn fallback_registry() -> Result<ProviderRegistry> {
752    let config = crate::config::Config::load().await.unwrap_or_default();
753    ProviderRegistry::from_config(&config).await
754}
755
756async fn fetch_pending_tasks(
757    client: &Client,
758    server: &str,
759    token: &Option<String>,
760    worker_id: &str,
761    processing: &Arc<Mutex<HashSet<String>>>,
762    auto_approve: &AutoApprove,
763) -> Result<()> {
764    tracing::info!("Checking for pending tasks...");
765
766    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
767    if let Some(t) = token {
768        req = req.bearer_auth(t);
769    }
770
771    let res = req.send().await?;
772    if !res.status().is_success() {
773        return Ok(());
774    }
775
776    let data: serde_json::Value = res.json().await?;
777    // Handle both plain array response and {tasks: [...]} wrapper
778    let tasks = if let Some(arr) = data.as_array() {
779        arr.clone()
780    } else {
781        data["tasks"].as_array().cloned().unwrap_or_default()
782    };
783
784    tracing::info!("Found {} pending task(s)", tasks.len());
785
786    for task in tasks {
787        if let Some(id) = task["id"].as_str() {
788            let mut proc = processing.lock().await;
789            if !proc.contains(id) {
790                proc.insert(id.to_string());
791                drop(proc);
792
793                let task_id = id.to_string();
794                let client = client.clone();
795                let server = server.to_string();
796                let token = token.clone();
797                let worker_id = worker_id.to_string();
798                let auto_approve = *auto_approve;
799                let processing = processing.clone();
800
801                tokio::spawn(async move {
802                    if let Err(e) =
803                        handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
804                    {
805                        tracing::error!("Task {} failed: {}", task_id, e);
806                    }
807                    processing.lock().await.remove(&task_id);
808                });
809            }
810        }
811    }
812
813    Ok(())
814}
815
816#[allow(clippy::too_many_arguments)]
817async fn connect_stream(
818    client: &Client,
819    server: &str,
820    token: &Option<String>,
821    worker_id: &str,
822    name: &str,
823    codebases: &[String],
824    processing: &Arc<Mutex<HashSet<String>>>,
825    auto_approve: &AutoApprove,
826) -> Result<()> {
827    let url = format!(
828        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
829        server,
830        urlencoding::encode(name),
831        urlencoding::encode(worker_id)
832    );
833
834    let mut req = client
835        .get(&url)
836        .header("Accept", "text/event-stream")
837        .header("X-Worker-ID", worker_id)
838        .header("X-Agent-Name", name)
839        .header("X-Codebases", codebases.join(","));
840
841    if let Some(t) = token {
842        req = req.bearer_auth(t);
843    }
844
845    let res = req.send().await?;
846    if !res.status().is_success() {
847        anyhow::bail!("Failed to connect: {}", res.status());
848    }
849
850    tracing::info!("Connected to A2A server");
851
852    let mut stream = res.bytes_stream();
853    let mut buffer = String::new();
854    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
855    poll_interval.tick().await; // consume the initial immediate tick
856
857    loop {
858        tokio::select! {
859            chunk = stream.next() => {
860                match chunk {
861                    Some(Ok(chunk)) => {
862                        buffer.push_str(&String::from_utf8_lossy(&chunk));
863
864                        // Process SSE events
865                        while let Some(pos) = buffer.find("\n\n") {
866                            let event_str = buffer[..pos].to_string();
867                            buffer = buffer[pos + 2..].to_string();
868
869                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
870                                let data = data_line.trim_start_matches("data:").trim();
871                                if data == "[DONE]" || data.is_empty() {
872                                    continue;
873                                }
874
875                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
876                                    spawn_task_handler(
877                                        &task, client, server, token, worker_id,
878                                        processing, auto_approve,
879                                    ).await;
880                                }
881                            }
882                        }
883                    }
884                    Some(Err(e)) => {
885                        return Err(e.into());
886                    }
887                    None => {
888                        // Stream ended
889                        return Ok(());
890                    }
891                }
892            }
893            _ = poll_interval.tick() => {
894                // Periodic poll for pending tasks the SSE stream may have missed
895                if let Err(e) = poll_pending_tasks(
896                    client, server, token, worker_id, processing, auto_approve,
897                ).await {
898                    tracing::warn!("Periodic task poll failed: {}", e);
899                }
900            }
901        }
902    }
903}
904
905async fn spawn_task_handler(
906    task: &serde_json::Value,
907    client: &Client,
908    server: &str,
909    token: &Option<String>,
910    worker_id: &str,
911    processing: &Arc<Mutex<HashSet<String>>>,
912    auto_approve: &AutoApprove,
913) {
914    if let Some(id) = task
915        .get("task")
916        .and_then(|t| t["id"].as_str())
917        .or_else(|| task["id"].as_str())
918    {
919        let mut proc = processing.lock().await;
920        if !proc.contains(id) {
921            proc.insert(id.to_string());
922            drop(proc);
923
924            let task_id = id.to_string();
925            let task = task.clone();
926            let client = client.clone();
927            let server = server.to_string();
928            let token = token.clone();
929            let worker_id = worker_id.to_string();
930            let auto_approve = *auto_approve;
931            let processing_clone = processing.clone();
932
933            tokio::spawn(async move {
934                if let Err(e) =
935                    handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
936                {
937                    tracing::error!("Task {} failed: {}", task_id, e);
938                }
939                processing_clone.lock().await.remove(&task_id);
940            });
941        }
942    }
943}
944
945async fn poll_pending_tasks(
946    client: &Client,
947    server: &str,
948    token: &Option<String>,
949    worker_id: &str,
950    processing: &Arc<Mutex<HashSet<String>>>,
951    auto_approve: &AutoApprove,
952) -> Result<()> {
953    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
954    if let Some(t) = token {
955        req = req.bearer_auth(t);
956    }
957
958    let res = req.send().await?;
959    if !res.status().is_success() {
960        return Ok(());
961    }
962
963    let data: serde_json::Value = res.json().await?;
964    let tasks = if let Some(arr) = data.as_array() {
965        arr.clone()
966    } else {
967        data["tasks"].as_array().cloned().unwrap_or_default()
968    };
969
970    if !tasks.is_empty() {
971        tracing::debug!("Poll found {} pending task(s)", tasks.len());
972    }
973
974    for task in &tasks {
975        spawn_task_handler(
976            task,
977            client,
978            server,
979            token,
980            worker_id,
981            processing,
982            auto_approve,
983        )
984        .await;
985    }
986
987    Ok(())
988}
989
990async fn handle_task(
991    client: &Client,
992    server: &str,
993    token: &Option<String>,
994    worker_id: &str,
995    task: &serde_json::Value,
996    auto_approve: AutoApprove,
997) -> Result<()> {
998    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
999    let title = task_str(task, "title").unwrap_or("Untitled");
1000
1001    tracing::info!("Handling task: {} ({})", title, task_id);
1002
1003    // Claim the task
1004    let mut req = client
1005        .post(format!("{}/v1/worker/tasks/claim", server))
1006        .header("X-Worker-ID", worker_id);
1007    if let Some(t) = token {
1008        req = req.bearer_auth(t);
1009    }
1010
1011    let res = req
1012        .json(&serde_json::json!({ "task_id": task_id }))
1013        .send()
1014        .await?;
1015
1016    if !res.status().is_success() {
1017        let status = res.status();
1018        let text = res.text().await?;
1019        if status == reqwest::StatusCode::CONFLICT {
1020            tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1021        } else {
1022            tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1023        }
1024        return Ok(());
1025    }
1026
1027    tracing::info!("Claimed task: {}", task_id);
1028
1029    let metadata = task_metadata(task);
1030    let resume_session_id = metadata
1031        .get("resume_session_id")
1032        .and_then(|v| v.as_str())
1033        .map(|s| s.trim().to_string())
1034        .filter(|s| !s.is_empty());
1035    let complexity_hint = metadata_str(&metadata, &["complexity"]);
1036    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1037        .map(|s| s.to_ascii_lowercase())
1038        .or_else(|| {
1039            complexity_hint.as_ref().map(|complexity| {
1040                match complexity.to_ascii_lowercase().as_str() {
1041                    "quick" => "fast".to_string(),
1042                    "deep" => "heavy".to_string(),
1043                    _ => "balanced".to_string(),
1044                }
1045            })
1046        });
1047    let worker_personality = metadata_str(
1048        &metadata,
1049        &["worker_personality", "personality", "agent_personality"],
1050    );
1051    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1052    let raw_model = task_str(task, "model_ref")
1053        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1054        .or_else(|| task_str(task, "model"))
1055        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1056    let selected_model = raw_model.map(model_ref_to_provider_model);
1057
1058    // Resume existing session when requested; fall back to a fresh session if missing.
1059    let mut session = if let Some(ref sid) = resume_session_id {
1060        match Session::load(sid).await {
1061            Ok(existing) => {
1062                tracing::info!("Resuming session {} for task {}", sid, task_id);
1063                existing
1064            }
1065            Err(e) => {
1066                tracing::warn!(
1067                    "Could not load session {} for task {} ({}), starting a new session",
1068                    sid,
1069                    task_id,
1070                    e
1071                );
1072                Session::new().await?
1073            }
1074        }
1075    } else {
1076        Session::new().await?
1077    };
1078
1079    let agent_type = task_str(task, "agent_type")
1080        .or_else(|| task_str(task, "agent"))
1081        .unwrap_or("build");
1082    session.agent = agent_type.to_string();
1083
1084    if let Some(model) = selected_model.clone() {
1085        session.metadata.model = Some(model);
1086    }
1087
1088    let prompt = task_str(task, "prompt")
1089        .or_else(|| task_str(task, "description"))
1090        .unwrap_or(title);
1091
1092    tracing::info!("Executing prompt: {}", prompt);
1093
1094    // Set up output streaming to forward progress to the server
1095    let stream_client = client.clone();
1096    let stream_server = server.to_string();
1097    let stream_token = token.clone();
1098    let stream_worker_id = worker_id.to_string();
1099    let stream_task_id = task_id.to_string();
1100
1101    let output_callback = move |output: String| {
1102        let c = stream_client.clone();
1103        let s = stream_server.clone();
1104        let t = stream_token.clone();
1105        let w = stream_worker_id.clone();
1106        let tid = stream_task_id.clone();
1107        tokio::spawn(async move {
1108            let mut req = c
1109                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1110                .header("X-Worker-ID", &w);
1111            if let Some(tok) = &t {
1112                req = req.bearer_auth(tok);
1113            }
1114            let _ = req
1115                .json(&serde_json::json!({
1116                    "worker_id": w,
1117                    "output": output,
1118                }))
1119                .send()
1120                .await;
1121        });
1122    };
1123
1124    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1125    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1126        match execute_swarm_with_policy(
1127            &mut session,
1128            prompt,
1129            model_tier.as_deref(),
1130            selected_model,
1131            &metadata,
1132            complexity_hint.as_deref(),
1133            worker_personality.as_deref(),
1134            target_agent_name.as_deref(),
1135            Some(&output_callback),
1136        )
1137        .await
1138        {
1139            Ok((session_result, true)) => {
1140                tracing::info!("Swarm task completed successfully: {}", task_id);
1141                (
1142                    "completed",
1143                    Some(session_result.text),
1144                    None,
1145                    Some(session_result.session_id),
1146                )
1147            }
1148            Ok((session_result, false)) => {
1149                tracing::warn!("Swarm task completed with failures: {}", task_id);
1150                (
1151                    "failed",
1152                    Some(session_result.text),
1153                    Some("Swarm execution completed with failures".to_string()),
1154                    Some(session_result.session_id),
1155                )
1156            }
1157            Err(e) => {
1158                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1159                ("failed", None, Some(format!("Error: {}", e)), None)
1160            }
1161        }
1162    } else {
1163        match execute_session_with_policy(
1164            &mut session,
1165            prompt,
1166            auto_approve,
1167            model_tier.as_deref(),
1168            Some(&output_callback),
1169        )
1170        .await
1171        {
1172            Ok(session_result) => {
1173                tracing::info!("Task completed successfully: {}", task_id);
1174                (
1175                    "completed",
1176                    Some(session_result.text),
1177                    None,
1178                    Some(session_result.session_id),
1179                )
1180            }
1181            Err(e) => {
1182                tracing::error!("Task failed: {} - {}", task_id, e);
1183                ("failed", None, Some(format!("Error: {}", e)), None)
1184            }
1185        }
1186    };
1187
1188    // Release the task with full details
1189    let mut req = client
1190        .post(format!("{}/v1/worker/tasks/release", server))
1191        .header("X-Worker-ID", worker_id);
1192    if let Some(t) = token {
1193        req = req.bearer_auth(t);
1194    }
1195
1196    req.json(&serde_json::json!({
1197        "task_id": task_id,
1198        "status": status,
1199        "result": result,
1200        "error": error,
1201        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1202    }))
1203    .send()
1204    .await?;
1205
1206    tracing::info!("Task released: {} with status: {}", task_id, status);
1207    Ok(())
1208}
1209
1210async fn execute_swarm_with_policy<F>(
1211    session: &mut Session,
1212    prompt: &str,
1213    model_tier: Option<&str>,
1214    explicit_model: Option<String>,
1215    metadata: &serde_json::Map<String, serde_json::Value>,
1216    complexity_hint: Option<&str>,
1217    worker_personality: Option<&str>,
1218    target_agent_name: Option<&str>,
1219    output_callback: Option<&F>,
1220) -> Result<(crate::session::SessionResult, bool)>
1221where
1222    F: Fn(String),
1223{
1224    use crate::provider::{ContentPart, Message, Role};
1225
1226    session.add_message(Message {
1227        role: Role::User,
1228        content: vec![ContentPart::Text {
1229            text: prompt.to_string(),
1230        }],
1231    });
1232
1233    if session.title.is_none() {
1234        session.generate_title().await?;
1235    }
1236
1237    let strategy = parse_swarm_strategy(metadata);
1238    let max_subagents = metadata_usize(
1239        metadata,
1240        &["swarm_max_subagents", "max_subagents", "subagents"],
1241    )
1242    .unwrap_or(10)
1243    .clamp(1, 100);
1244    let max_steps_per_subagent = metadata_usize(
1245        metadata,
1246        &[
1247            "swarm_max_steps_per_subagent",
1248            "max_steps_per_subagent",
1249            "max_steps",
1250        ],
1251    )
1252    .unwrap_or(50)
1253    .clamp(1, 200);
1254    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1255        .unwrap_or(600)
1256        .clamp(30, 3600);
1257    let parallel_enabled =
1258        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1259
1260    let model = resolve_swarm_model(explicit_model, model_tier).await;
1261    if let Some(ref selected_model) = model {
1262        session.metadata.model = Some(selected_model.clone());
1263    }
1264
1265    if let Some(cb) = output_callback {
1266        cb(format!(
1267            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1268            complexity_hint.unwrap_or("standard"),
1269            model_tier.unwrap_or("balanced"),
1270            worker_personality.unwrap_or("auto"),
1271            target_agent_name.unwrap_or("auto")
1272        ));
1273        cb(format!(
1274            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1275            strategy,
1276            max_subagents,
1277            max_steps_per_subagent,
1278            timeout_secs,
1279            model_tier.unwrap_or("balanced")
1280        ));
1281    }
1282
1283    let swarm_config = SwarmConfig {
1284        max_subagents,
1285        max_steps_per_subagent,
1286        subagent_timeout_secs: timeout_secs,
1287        parallel_enabled,
1288        model,
1289        working_dir: session
1290            .metadata
1291            .directory
1292            .as_ref()
1293            .map(|p| p.to_string_lossy().to_string()),
1294        ..Default::default()
1295    };
1296
1297    let swarm_result = if output_callback.is_some() {
1298        let (event_tx, mut event_rx) = mpsc::channel(256);
1299        let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1300        let prompt_owned = prompt.to_string();
1301        let mut exec_handle =
1302            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1303
1304        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1305
1306        while final_result.is_none() {
1307            tokio::select! {
1308                maybe_event = event_rx.recv() => {
1309                    if let Some(event) = maybe_event {
1310                        if let Some(cb) = output_callback {
1311                            if let Some(line) = format_swarm_event_for_output(&event) {
1312                                cb(line);
1313                            }
1314                        }
1315                    }
1316                }
1317                join_result = &mut exec_handle => {
1318                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1319                    final_result = Some(joined?);
1320                }
1321            }
1322        }
1323
1324        while let Ok(event) = event_rx.try_recv() {
1325            if let Some(cb) = output_callback {
1326                if let Some(line) = format_swarm_event_for_output(&event) {
1327                    cb(line);
1328                }
1329            }
1330        }
1331
1332        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1333    } else {
1334        SwarmExecutor::new(swarm_config)
1335            .execute(prompt, strategy)
1336            .await?
1337    };
1338
1339    let final_text = if swarm_result.result.trim().is_empty() {
1340        if swarm_result.success {
1341            "Swarm completed without textual output.".to_string()
1342        } else {
1343            "Swarm finished with failures and no textual output.".to_string()
1344        }
1345    } else {
1346        swarm_result.result.clone()
1347    };
1348
1349    session.add_message(Message {
1350        role: Role::Assistant,
1351        content: vec![ContentPart::Text {
1352            text: final_text.clone(),
1353        }],
1354    });
1355    session.save().await?;
1356
1357    Ok((
1358        crate::session::SessionResult {
1359            text: final_text,
1360            session_id: session.id.clone(),
1361        },
1362        swarm_result.success,
1363    ))
1364}
1365
1366/// Execute a session with the given auto-approve policy
1367/// Optionally streams output chunks via the callback
1368async fn execute_session_with_policy<F>(
1369    session: &mut Session,
1370    prompt: &str,
1371    auto_approve: AutoApprove,
1372    model_tier: Option<&str>,
1373    output_callback: Option<&F>,
1374) -> Result<crate::session::SessionResult>
1375where
1376    F: Fn(String),
1377{
1378    use crate::provider::{
1379        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1380    };
1381    use std::sync::Arc;
1382
1383    // Load provider registry from Vault
1384    let registry = ProviderRegistry::from_vault().await?;
1385    let providers = registry.list();
1386    tracing::info!("Available providers: {:?}", providers);
1387
1388    if providers.is_empty() {
1389        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1390    }
1391
1392    // Parse model string
1393    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1394        let (prov, model) = parse_model_string(model_str);
1395        if prov.is_some() {
1396            (prov.map(|s| s.to_string()), model.to_string())
1397        } else if providers.contains(&model) {
1398            (Some(model.to_string()), String::new())
1399        } else {
1400            (None, model.to_string())
1401        }
1402    } else {
1403        (None, String::new())
1404    };
1405
1406    let provider_slice = providers.as_slice();
1407    let provider_requested_but_unavailable = provider_name
1408        .as_deref()
1409        .map(|p| !providers.contains(&p))
1410        .unwrap_or(false);
1411
1412    // Determine which provider to use, preferring explicit request first, then model tier.
1413    let selected_provider = provider_name
1414        .as_deref()
1415        .filter(|p| providers.contains(p))
1416        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1417
1418    let provider = registry
1419        .get(selected_provider)
1420        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1421
1422    // Add user message
1423    session.add_message(Message {
1424        role: Role::User,
1425        content: vec![ContentPart::Text {
1426            text: prompt.to_string(),
1427        }],
1428    });
1429
1430    // Generate title
1431    if session.title.is_none() {
1432        session.generate_title().await?;
1433    }
1434
1435    // Determine model. If a specific provider was requested but not available,
1436    // ignore that model id and fall back to the tier-based default model.
1437    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1438        model_id
1439    } else {
1440        default_model_for_provider(selected_provider, model_tier)
1441    };
1442
1443    // Create tool registry with filtering based on auto-approve policy
1444    let tool_registry =
1445        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1446    let tool_definitions = tool_registry.definitions();
1447
1448    let temperature = if model.starts_with("kimi-k2") {
1449        Some(1.0)
1450    } else {
1451        Some(0.7)
1452    };
1453
1454    tracing::info!(
1455        "Using model: {} via provider: {} (tier: {:?})",
1456        model,
1457        selected_provider,
1458        model_tier
1459    );
1460    tracing::info!(
1461        "Available tools: {} (auto_approve: {:?})",
1462        tool_definitions.len(),
1463        auto_approve
1464    );
1465
1466    // Build system prompt
1467    let cwd = std::env::var("PWD")
1468        .map(std::path::PathBuf::from)
1469        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1470    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1471
1472    let mut final_output = String::new();
1473    let max_steps = 50;
1474
1475    for step in 1..=max_steps {
1476        tracing::info!(step = step, "Agent step starting");
1477
1478        // Build messages with system prompt first
1479        let mut messages = vec![Message {
1480            role: Role::System,
1481            content: vec![ContentPart::Text {
1482                text: system_prompt.clone(),
1483            }],
1484        }];
1485        messages.extend(session.messages.clone());
1486
1487        let request = CompletionRequest {
1488            messages,
1489            tools: tool_definitions.clone(),
1490            model: model.clone(),
1491            temperature,
1492            top_p: None,
1493            max_tokens: Some(8192),
1494            stop: Vec::new(),
1495        };
1496
1497        let response = provider.complete(request).await?;
1498
1499        crate::telemetry::TOKEN_USAGE.record_model_usage(
1500            &model,
1501            response.usage.prompt_tokens as u64,
1502            response.usage.completion_tokens as u64,
1503        );
1504
1505        // Extract tool calls
1506        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1507            .message
1508            .content
1509            .iter()
1510            .filter_map(|part| {
1511                if let ContentPart::ToolCall {
1512                    id,
1513                    name,
1514                    arguments,
1515                } = part
1516                {
1517                    let args: serde_json::Value =
1518                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1519                    Some((id.clone(), name.clone(), args))
1520                } else {
1521                    None
1522                }
1523            })
1524            .collect();
1525
1526        // Collect text output and stream it
1527        for part in &response.message.content {
1528            if let ContentPart::Text { text } = part {
1529                if !text.is_empty() {
1530                    final_output.push_str(text);
1531                    final_output.push('\n');
1532                    if let Some(cb) = output_callback {
1533                        cb(text.clone());
1534                    }
1535                }
1536            }
1537        }
1538
1539        // If no tool calls, we're done
1540        if tool_calls.is_empty() {
1541            session.add_message(response.message.clone());
1542            break;
1543        }
1544
1545        session.add_message(response.message.clone());
1546
1547        tracing::info!(
1548            step = step,
1549            num_tools = tool_calls.len(),
1550            "Executing tool calls"
1551        );
1552
1553        // Execute each tool call
1554        for (tool_id, tool_name, tool_input) in tool_calls {
1555            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1556
1557            // Stream tool start event
1558            if let Some(cb) = output_callback {
1559                cb(format!("[tool:start:{}]", tool_name));
1560            }
1561
1562            // Check if tool is allowed based on auto-approve policy
1563            if !is_tool_allowed(&tool_name, auto_approve) {
1564                let msg = format!(
1565                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1566                    tool_name, auto_approve
1567                );
1568                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1569                session.add_message(Message {
1570                    role: Role::Tool,
1571                    content: vec![ContentPart::ToolResult {
1572                        tool_call_id: tool_id,
1573                        content: msg,
1574                    }],
1575                });
1576                continue;
1577            }
1578
1579            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1580                let exec_result: Result<crate::tool::ToolResult> =
1581                    tool.execute(tool_input.clone()).await;
1582                match exec_result {
1583                    Ok(result) => {
1584                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1585                        if let Some(cb) = output_callback {
1586                            let status = if result.success { "ok" } else { "err" };
1587                            cb(format!(
1588                                "[tool:{}:{}] {}",
1589                                tool_name,
1590                                status,
1591                                &result.output[..result.output.len().min(500)]
1592                            ));
1593                        }
1594                        result.output
1595                    }
1596                    Err(e) => {
1597                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1598                        if let Some(cb) = output_callback {
1599                            cb(format!("[tool:{}:err] {}", tool_name, e));
1600                        }
1601                        format!("Error: {}", e)
1602                    }
1603                }
1604            } else {
1605                tracing::warn!(tool = %tool_name, "Tool not found");
1606                format!("Error: Unknown tool '{}'", tool_name)
1607            };
1608
1609            session.add_message(Message {
1610                role: Role::Tool,
1611                content: vec![ContentPart::ToolResult {
1612                    tool_call_id: tool_id,
1613                    content,
1614                }],
1615            });
1616        }
1617    }
1618
1619    session.save().await?;
1620
1621    Ok(crate::session::SessionResult {
1622        text: final_output.trim().to_string(),
1623        session_id: session.id.clone(),
1624    })
1625}
1626
1627/// Check if a tool is allowed based on the auto-approve policy
1628fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1629    match auto_approve {
1630        AutoApprove::All => true,
1631        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1632    }
1633}
1634
1635/// Check if a tool is considered "safe" (read-only)
1636fn is_safe_tool(tool_name: &str) -> bool {
1637    let safe_tools = [
1638        "read",
1639        "list",
1640        "glob",
1641        "grep",
1642        "codesearch",
1643        "lsp",
1644        "webfetch",
1645        "websearch",
1646        "todo_read",
1647        "skill",
1648    ];
1649    safe_tools.contains(&tool_name)
1650}
1651
1652/// Create a filtered tool registry based on the auto-approve policy
1653fn create_filtered_registry(
1654    provider: Arc<dyn crate::provider::Provider>,
1655    model: String,
1656    auto_approve: AutoApprove,
1657) -> crate::tool::ToolRegistry {
1658    use crate::tool::*;
1659
1660    let mut registry = ToolRegistry::new();
1661
1662    // Always add safe tools
1663    registry.register(Arc::new(file::ReadTool::new()));
1664    registry.register(Arc::new(file::ListTool::new()));
1665    registry.register(Arc::new(file::GlobTool::new()));
1666    registry.register(Arc::new(search::GrepTool::new()));
1667    registry.register(Arc::new(lsp::LspTool::new()));
1668    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1669    registry.register(Arc::new(websearch::WebSearchTool::new()));
1670    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1671    registry.register(Arc::new(todo::TodoReadTool::new()));
1672    registry.register(Arc::new(skill::SkillTool::new()));
1673
1674    // Add potentially dangerous tools only if auto_approve is All
1675    if matches!(auto_approve, AutoApprove::All) {
1676        registry.register(Arc::new(file::WriteTool::new()));
1677        registry.register(Arc::new(edit::EditTool::new()));
1678        registry.register(Arc::new(bash::BashTool::new()));
1679        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1680        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1681        registry.register(Arc::new(todo::TodoWriteTool::new()));
1682        registry.register(Arc::new(task::TaskTool::new()));
1683        registry.register(Arc::new(plan::PlanEnterTool::new()));
1684        registry.register(Arc::new(plan::PlanExitTool::new()));
1685        registry.register(Arc::new(rlm::RlmTool::new()));
1686        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1687        registry.register(Arc::new(prd::PrdTool::new()));
1688        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1689        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1690        registry.register(Arc::new(undo::UndoTool));
1691        registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1692    }
1693
1694    registry.register(Arc::new(invalid::InvalidTool::new()));
1695
1696    registry
1697}
1698
1699/// Start the heartbeat background task
1700/// Returns a JoinHandle that can be used to cancel the heartbeat
1701fn start_heartbeat(
1702    client: Client,
1703    server: String,
1704    token: Option<String>,
1705    heartbeat_state: HeartbeatState,
1706    processing: Arc<Mutex<HashSet<String>>>,
1707    cognition_config: CognitionHeartbeatConfig,
1708) -> JoinHandle<()> {
1709    tokio::spawn(async move {
1710        let mut consecutive_failures = 0u32;
1711        const MAX_FAILURES: u32 = 3;
1712        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1713        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1714        let mut cognition_payload_disabled_until: Option<Instant> = None;
1715
1716        let mut interval =
1717            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1718        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1719
1720        loop {
1721            interval.tick().await;
1722
1723            // Update task count from processing set
1724            let active_count = processing.lock().await.len();
1725            heartbeat_state.set_task_count(active_count).await;
1726
1727            // Determine status based on active tasks
1728            let status = if active_count > 0 {
1729                WorkerStatus::Processing
1730            } else {
1731                WorkerStatus::Idle
1732            };
1733            heartbeat_state.set_status(status).await;
1734
1735            // Send heartbeat
1736            let url = format!(
1737                "{}/v1/opencode/workers/{}/heartbeat",
1738                server, heartbeat_state.worker_id
1739            );
1740            let mut req = client.post(&url);
1741
1742            if let Some(ref t) = token {
1743                req = req.bearer_auth(t);
1744            }
1745
1746            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1747            let base_payload = serde_json::json!({
1748                "worker_id": &heartbeat_state.worker_id,
1749                "agent_name": &heartbeat_state.agent_name,
1750                "status": status_str,
1751                "active_task_count": active_count,
1752            });
1753            let mut payload = base_payload.clone();
1754            let mut included_cognition_payload = false;
1755            let cognition_payload_allowed = cognition_payload_disabled_until
1756                .map(|until| Instant::now() >= until)
1757                .unwrap_or(true);
1758
1759            if cognition_config.enabled
1760                && cognition_payload_allowed
1761                && let Some(cognition_payload) =
1762                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1763                && let Some(obj) = payload.as_object_mut()
1764            {
1765                obj.insert("cognition".to_string(), cognition_payload);
1766                included_cognition_payload = true;
1767            }
1768
1769            match req.json(&payload).send().await {
1770                Ok(res) => {
1771                    if res.status().is_success() {
1772                        consecutive_failures = 0;
1773                        tracing::debug!(
1774                            worker_id = %heartbeat_state.worker_id,
1775                            status = status_str,
1776                            active_tasks = active_count,
1777                            "Heartbeat sent successfully"
1778                        );
1779                    } else if included_cognition_payload && res.status().is_client_error() {
1780                        tracing::warn!(
1781                            worker_id = %heartbeat_state.worker_id,
1782                            status = %res.status(),
1783                            "Heartbeat cognition payload rejected, retrying without cognition payload"
1784                        );
1785
1786                        let mut retry_req = client.post(&url);
1787                        if let Some(ref t) = token {
1788                            retry_req = retry_req.bearer_auth(t);
1789                        }
1790
1791                        match retry_req.json(&base_payload).send().await {
1792                            Ok(retry_res) if retry_res.status().is_success() => {
1793                                cognition_payload_disabled_until = Some(
1794                                    Instant::now()
1795                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1796                                );
1797                                consecutive_failures = 0;
1798                                tracing::warn!(
1799                                    worker_id = %heartbeat_state.worker_id,
1800                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1801                                    "Paused cognition heartbeat payload after schema rejection"
1802                                );
1803                            }
1804                            Ok(retry_res) => {
1805                                consecutive_failures += 1;
1806                                tracing::warn!(
1807                                    worker_id = %heartbeat_state.worker_id,
1808                                    status = %retry_res.status(),
1809                                    failures = consecutive_failures,
1810                                    "Heartbeat failed even after retry without cognition payload"
1811                                );
1812                            }
1813                            Err(e) => {
1814                                consecutive_failures += 1;
1815                                tracing::warn!(
1816                                    worker_id = %heartbeat_state.worker_id,
1817                                    error = %e,
1818                                    failures = consecutive_failures,
1819                                    "Heartbeat retry without cognition payload failed"
1820                                );
1821                            }
1822                        }
1823                    } else {
1824                        consecutive_failures += 1;
1825                        tracing::warn!(
1826                            worker_id = %heartbeat_state.worker_id,
1827                            status = %res.status(),
1828                            failures = consecutive_failures,
1829                            "Heartbeat failed"
1830                        );
1831                    }
1832                }
1833                Err(e) => {
1834                    consecutive_failures += 1;
1835                    tracing::warn!(
1836                        worker_id = %heartbeat_state.worker_id,
1837                        error = %e,
1838                        failures = consecutive_failures,
1839                        "Heartbeat request failed"
1840                    );
1841                }
1842            }
1843
1844            // Log error after 3 consecutive failures but do not terminate
1845            if consecutive_failures >= MAX_FAILURES {
1846                tracing::error!(
1847                    worker_id = %heartbeat_state.worker_id,
1848                    failures = consecutive_failures,
1849                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1850                    MAX_FAILURES
1851                );
1852                // Reset counter to avoid spamming error logs
1853                consecutive_failures = 0;
1854            }
1855        }
1856    })
1857}
1858
1859async fn fetch_cognition_heartbeat_payload(
1860    client: &Client,
1861    config: &CognitionHeartbeatConfig,
1862) -> Option<serde_json::Value> {
1863    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1864    let status_res = tokio::time::timeout(
1865        Duration::from_millis(config.request_timeout_ms),
1866        client.get(status_url).send(),
1867    )
1868    .await
1869    .ok()?
1870    .ok()?;
1871
1872    if !status_res.status().is_success() {
1873        return None;
1874    }
1875
1876    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1877    let mut payload = serde_json::json!({
1878        "running": status.running,
1879        "last_tick_at": status.last_tick_at,
1880        "active_persona_count": status.active_persona_count,
1881        "events_buffered": status.events_buffered,
1882        "snapshots_buffered": status.snapshots_buffered,
1883        "loop_interval_ms": status.loop_interval_ms,
1884    });
1885
1886    if config.include_thought_summary {
1887        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1888        let snapshot_res = tokio::time::timeout(
1889            Duration::from_millis(config.request_timeout_ms),
1890            client.get(snapshot_url).send(),
1891        )
1892        .await
1893        .ok()
1894        .and_then(Result::ok);
1895
1896        if let Some(snapshot_res) = snapshot_res
1897            && snapshot_res.status().is_success()
1898            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1899            && let Some(obj) = payload.as_object_mut()
1900        {
1901            obj.insert(
1902                "latest_snapshot_at".to_string(),
1903                serde_json::Value::String(snapshot.generated_at),
1904            );
1905            obj.insert(
1906                "latest_thought".to_string(),
1907                serde_json::Value::String(trim_for_heartbeat(
1908                    &snapshot.summary,
1909                    config.summary_max_chars,
1910                )),
1911            );
1912            if let Some(model) = snapshot
1913                .metadata
1914                .get("model")
1915                .and_then(serde_json::Value::as_str)
1916            {
1917                obj.insert(
1918                    "latest_thought_model".to_string(),
1919                    serde_json::Value::String(model.to_string()),
1920                );
1921            }
1922            if let Some(source) = snapshot
1923                .metadata
1924                .get("source")
1925                .and_then(serde_json::Value::as_str)
1926            {
1927                obj.insert(
1928                    "latest_thought_source".to_string(),
1929                    serde_json::Value::String(source.to_string()),
1930                );
1931            }
1932        }
1933    }
1934
1935    Some(payload)
1936}
1937
1938fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1939    if input.chars().count() <= max_chars {
1940        return input.trim().to_string();
1941    }
1942
1943    let mut trimmed = String::with_capacity(max_chars + 3);
1944    for ch in input.chars().take(max_chars) {
1945        trimmed.push(ch);
1946    }
1947    trimmed.push_str("...");
1948    trimmed.trim().to_string()
1949}
1950
1951fn env_bool(name: &str, default: bool) -> bool {
1952    std::env::var(name)
1953        .ok()
1954        .and_then(|v| match v.to_ascii_lowercase().as_str() {
1955            "1" | "true" | "yes" | "on" => Some(true),
1956            "0" | "false" | "no" | "off" => Some(false),
1957            _ => None,
1958        })
1959        .unwrap_or(default)
1960}
1961
1962fn env_usize(name: &str, default: usize) -> usize {
1963    std::env::var(name)
1964        .ok()
1965        .and_then(|v| v.parse::<usize>().ok())
1966        .unwrap_or(default)
1967}
1968
1969fn env_u64(name: &str, default: u64) -> u64 {
1970    std::env::var(name)
1971        .ok()
1972        .and_then(|v| v.parse::<u64>().ok())
1973        .unwrap_or(default)
1974}