Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{Mutex, mpsc};
17use tokio::task::JoinHandle;
18use tokio::time::Instant;
19
20/// Worker status for heartbeat
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum WorkerStatus {
23    Idle,
24    Processing,
25}
26
27impl WorkerStatus {
28    fn as_str(&self) -> &'static str {
29        match self {
30            WorkerStatus::Idle => "idle",
31            WorkerStatus::Processing => "processing",
32        }
33    }
34}
35
36/// Heartbeat state shared between the heartbeat task and the main worker
37#[derive(Clone)]
38struct HeartbeatState {
39    worker_id: String,
40    agent_name: String,
41    status: Arc<Mutex<WorkerStatus>>,
42    active_task_count: Arc<Mutex<usize>>,
43}
44
45impl HeartbeatState {
46    fn new(worker_id: String, agent_name: String) -> Self {
47        Self {
48            worker_id,
49            agent_name,
50            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
51            active_task_count: Arc::new(Mutex::new(0)),
52        }
53    }
54
55    async fn set_status(&self, status: WorkerStatus) {
56        *self.status.lock().await = status;
57    }
58
59    async fn set_task_count(&self, count: usize) {
60        *self.active_task_count.lock().await = count;
61    }
62}
63
64#[derive(Clone, Debug)]
65struct CognitionHeartbeatConfig {
66    enabled: bool,
67    source_base_url: String,
68    include_thought_summary: bool,
69    summary_max_chars: usize,
70    request_timeout_ms: u64,
71}
72
73impl CognitionHeartbeatConfig {
74    fn from_env() -> Self {
75        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
76            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
77            .trim_end_matches('/')
78            .to_string();
79
80        Self {
81            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
82            source_base_url,
83            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
84            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
85                .max(120),
86            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
87        }
88    }
89}
90
91#[derive(Debug, Deserialize)]
92struct CognitionStatusSnapshot {
93    running: bool,
94    #[serde(default)]
95    last_tick_at: Option<String>,
96    #[serde(default)]
97    active_persona_count: usize,
98    #[serde(default)]
99    events_buffered: usize,
100    #[serde(default)]
101    snapshots_buffered: usize,
102    #[serde(default)]
103    loop_interval_ms: u64,
104}
105
106#[derive(Debug, Deserialize)]
107struct CognitionLatestSnapshot {
108    generated_at: String,
109    summary: String,
110    #[serde(default)]
111    metadata: HashMap<String, serde_json::Value>,
112}
113
114/// Run the A2A worker
115pub async fn run(args: A2aArgs) -> Result<()> {
116    let server = args.server.trim_end_matches('/');
117    let name = args
118        .name
119        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
120    let worker_id = generate_worker_id();
121
122    let codebases: Vec<String> = args
123        .codebases
124        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
125        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
126
127    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
128    tracing::info!("Server: {}", server);
129    tracing::info!("Codebases: {:?}", codebases);
130
131    let client = Client::new();
132    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
133    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
134    if cognition_heartbeat.enabled {
135        tracing::info!(
136            source = %cognition_heartbeat.source_base_url,
137            include_thoughts = cognition_heartbeat.include_thought_summary,
138            max_chars = cognition_heartbeat.summary_max_chars,
139            timeout_ms = cognition_heartbeat.request_timeout_ms,
140            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
141        );
142    } else {
143        tracing::warn!(
144            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
145        );
146    }
147
148    let auto_approve = match args.auto_approve.as_str() {
149        "all" => AutoApprove::All,
150        "safe" => AutoApprove::Safe,
151        _ => AutoApprove::None,
152    };
153
154    // Create heartbeat state
155    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
156
157    // Register worker
158    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
159
160    // Fetch pending tasks
161    fetch_pending_tasks(
162        &client,
163        server,
164        &args.token,
165        &worker_id,
166        &processing,
167        &auto_approve,
168    )
169    .await?;
170
171    // Connect to SSE stream
172    loop {
173        // Re-register worker on each reconnection to report updated models/capabilities
174        if let Err(e) =
175            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
176        {
177            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
178        }
179
180        // Start heartbeat task for this connection
181        let heartbeat_handle = start_heartbeat(
182            client.clone(),
183            server.to_string(),
184            args.token.clone(),
185            heartbeat_state.clone(),
186            processing.clone(),
187            cognition_heartbeat.clone(),
188        );
189
190        match connect_stream(
191            &client,
192            server,
193            &args.token,
194            &worker_id,
195            &name,
196            &codebases,
197            &processing,
198            &auto_approve,
199        )
200        .await
201        {
202            Ok(()) => {
203                tracing::warn!("Stream ended, reconnecting...");
204            }
205            Err(e) => {
206                tracing::error!("Stream error: {}, reconnecting...", e);
207            }
208        }
209
210        // Cancel heartbeat on disconnection
211        heartbeat_handle.abort();
212        tracing::debug!("Heartbeat cancelled for reconnection");
213
214        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
215    }
216}
217
218fn generate_worker_id() -> String {
219    format!(
220        "wrk_{}_{:x}",
221        chrono::Utc::now().timestamp(),
222        rand::random::<u64>()
223    )
224}
225
226#[derive(Debug, Clone, Copy)]
227enum AutoApprove {
228    All,
229    Safe,
230    None,
231}
232
233/// Capabilities of the codetether-agent worker
234const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
235
236fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
237    task.get("task")
238        .and_then(|t| t.get(key))
239        .or_else(|| task.get(key))
240}
241
242fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
243    task_value(task, key).and_then(|v| v.as_str())
244}
245
246fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
247    task_value(task, "metadata")
248        .and_then(|m| m.as_object())
249        .cloned()
250        .unwrap_or_default()
251}
252
253fn model_ref_to_provider_model(model: &str) -> String {
254    // Convert "provider:model" to "provider/model" format, but only if
255    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
256    // contain colons as version separators and must NOT be converted.
257    if !model.contains('/') && model.contains(':') {
258        model.replacen(':', "/", 1)
259    } else {
260        model.to_string()
261    }
262}
263
264fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
265    match model_tier.unwrap_or("balanced") {
266        "fast" | "quick" => &[
267            "openai",
268            "github-copilot",
269            "moonshotai",
270            "zhipuai",
271            "openrouter",
272            "novita",
273            "google",
274            "anthropic",
275        ],
276        "heavy" | "deep" => &[
277            "anthropic",
278            "openai",
279            "github-copilot",
280            "moonshotai",
281            "zhipuai",
282            "openrouter",
283            "novita",
284            "google",
285        ],
286        _ => &[
287            "openai",
288            "github-copilot",
289            "anthropic",
290            "moonshotai",
291            "zhipuai",
292            "openrouter",
293            "novita",
294            "google",
295        ],
296    }
297}
298
299fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
300    for preferred in provider_preferences_for_tier(model_tier) {
301        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
302            return found;
303        }
304    }
305    if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
306        return found;
307    }
308    providers[0]
309}
310
311fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
312    match model_tier.unwrap_or("balanced") {
313        "fast" | "quick" => match provider {
314            "moonshotai" => "kimi-k2.5".to_string(),
315            "anthropic" => "claude-haiku-4-5".to_string(),
316            "openai" => "gpt-4o-mini".to_string(),
317            "google" => "gemini-2.5-flash".to_string(),
318            "zhipuai" => "glm-4.7".to_string(),
319            "openrouter" => "z-ai/glm-4.7".to_string(),
320            "novita" => "qwen/qwen3-coder-next".to_string(),
321            _ => "glm-4.7".to_string(),
322        },
323        "heavy" | "deep" => match provider {
324            "moonshotai" => "kimi-k2.5".to_string(),
325            "anthropic" => "claude-sonnet-4-20250514".to_string(),
326            "openai" => "o3".to_string(),
327            "google" => "gemini-2.5-pro".to_string(),
328            "zhipuai" => "glm-4.7".to_string(),
329            "openrouter" => "z-ai/glm-4.7".to_string(),
330            "novita" => "qwen/qwen3-coder-next".to_string(),
331            _ => "glm-4.7".to_string(),
332        },
333        _ => match provider {
334            "moonshotai" => "kimi-k2.5".to_string(),
335            "anthropic" => "claude-sonnet-4-20250514".to_string(),
336            "openai" => "gpt-4o".to_string(),
337            "google" => "gemini-2.5-pro".to_string(),
338            "zhipuai" => "glm-4.7".to_string(),
339            "openrouter" => "z-ai/glm-4.7".to_string(),
340            "novita" => "qwen/qwen3-coder-next".to_string(),
341            _ => "glm-4.7".to_string(),
342        },
343    }
344}
345
346fn is_swarm_agent(agent_type: &str) -> bool {
347    matches!(
348        agent_type.trim().to_ascii_lowercase().as_str(),
349        "swarm" | "parallel" | "multi-agent"
350    )
351}
352
353fn metadata_lookup<'a>(
354    metadata: &'a serde_json::Map<String, serde_json::Value>,
355    key: &str,
356) -> Option<&'a serde_json::Value> {
357    metadata
358        .get(key)
359        .or_else(|| {
360            metadata
361                .get("routing")
362                .and_then(|v| v.as_object())
363                .and_then(|obj| obj.get(key))
364        })
365        .or_else(|| {
366            metadata
367                .get("swarm")
368                .and_then(|v| v.as_object())
369                .and_then(|obj| obj.get(key))
370        })
371}
372
373fn metadata_str(
374    metadata: &serde_json::Map<String, serde_json::Value>,
375    keys: &[&str],
376) -> Option<String> {
377    for key in keys {
378        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
379            let trimmed = value.trim();
380            if !trimmed.is_empty() {
381                return Some(trimmed.to_string());
382            }
383        }
384    }
385    None
386}
387
388fn metadata_usize(
389    metadata: &serde_json::Map<String, serde_json::Value>,
390    keys: &[&str],
391) -> Option<usize> {
392    for key in keys {
393        if let Some(value) = metadata_lookup(metadata, key) {
394            if let Some(v) = value.as_u64() {
395                return usize::try_from(v).ok();
396            }
397            if let Some(v) = value.as_i64() {
398                if v >= 0 {
399                    return usize::try_from(v as u64).ok();
400                }
401            }
402            if let Some(v) = value.as_str() {
403                if let Ok(parsed) = v.trim().parse::<usize>() {
404                    return Some(parsed);
405                }
406            }
407        }
408    }
409    None
410}
411
412fn metadata_u64(
413    metadata: &serde_json::Map<String, serde_json::Value>,
414    keys: &[&str],
415) -> Option<u64> {
416    for key in keys {
417        if let Some(value) = metadata_lookup(metadata, key) {
418            if let Some(v) = value.as_u64() {
419                return Some(v);
420            }
421            if let Some(v) = value.as_i64() {
422                if v >= 0 {
423                    return Some(v as u64);
424                }
425            }
426            if let Some(v) = value.as_str() {
427                if let Ok(parsed) = v.trim().parse::<u64>() {
428                    return Some(parsed);
429                }
430            }
431        }
432    }
433    None
434}
435
436fn metadata_bool(
437    metadata: &serde_json::Map<String, serde_json::Value>,
438    keys: &[&str],
439) -> Option<bool> {
440    for key in keys {
441        if let Some(value) = metadata_lookup(metadata, key) {
442            if let Some(v) = value.as_bool() {
443                return Some(v);
444            }
445            if let Some(v) = value.as_str() {
446                match v.trim().to_ascii_lowercase().as_str() {
447                    "1" | "true" | "yes" | "on" => return Some(true),
448                    "0" | "false" | "no" | "off" => return Some(false),
449                    _ => {}
450                }
451            }
452        }
453    }
454    None
455}
456
457fn parse_swarm_strategy(
458    metadata: &serde_json::Map<String, serde_json::Value>,
459) -> DecompositionStrategy {
460    match metadata_str(
461        metadata,
462        &[
463            "decomposition_strategy",
464            "swarm_strategy",
465            "strategy",
466            "swarm_decomposition",
467        ],
468    )
469    .as_deref()
470    .map(|s| s.to_ascii_lowercase())
471    .as_deref()
472    {
473        Some("none") | Some("single") => DecompositionStrategy::None,
474        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
475        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
476        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
477        _ => DecompositionStrategy::Automatic,
478    }
479}
480
481async fn resolve_swarm_model(
482    explicit_model: Option<String>,
483    model_tier: Option<&str>,
484) -> Option<String> {
485    if let Some(model) = explicit_model {
486        if !model.trim().is_empty() {
487            return Some(model);
488        }
489    }
490
491    let registry = ProviderRegistry::from_vault().await.ok()?;
492    let providers = registry.list();
493    if providers.is_empty() {
494        return None;
495    }
496    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
497    let model = default_model_for_provider(provider, model_tier);
498    Some(format!("{}/{}", provider, model))
499}
500
501fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
502    match event {
503        SwarmEvent::Started {
504            task,
505            total_subtasks,
506        } => Some(format!(
507            "[swarm] started task={} planned_subtasks={}",
508            task, total_subtasks
509        )),
510        SwarmEvent::StageComplete {
511            stage,
512            completed,
513            failed,
514        } => Some(format!(
515            "[swarm] stage={} completed={} failed={}",
516            stage, completed, failed
517        )),
518        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
519            "[swarm] subtask id={} status={}",
520            &id.chars().take(8).collect::<String>(),
521            format!("{status:?}").to_ascii_lowercase()
522        )),
523        SwarmEvent::AgentToolCall {
524            subtask_id,
525            tool_name,
526        } => Some(format!(
527            "[swarm] subtask id={} tool={}",
528            &subtask_id.chars().take(8).collect::<String>(),
529            tool_name
530        )),
531        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
532            "[swarm] subtask id={} error={}",
533            &subtask_id.chars().take(8).collect::<String>(),
534            error
535        )),
536        SwarmEvent::Complete { success, stats } => Some(format!(
537            "[swarm] complete success={} subtasks={} speedup={:.2}",
538            success,
539            stats.subagents_completed + stats.subagents_failed,
540            stats.speedup_factor
541        )),
542        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
543        _ => None,
544    }
545}
546
547async fn register_worker(
548    client: &Client,
549    server: &str,
550    token: &Option<String>,
551    worker_id: &str,
552    name: &str,
553    codebases: &[String],
554) -> Result<()> {
555    // Load ProviderRegistry and collect available models
556    let models = match load_provider_models().await {
557        Ok(m) => m,
558        Err(e) => {
559            tracing::warn!(
560                "Failed to load provider models: {}, proceeding without model info",
561                e
562            );
563            HashMap::new()
564        }
565    };
566
567    // Register via the workers/register endpoint
568    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
569
570    if let Some(t) = token {
571        req = req.bearer_auth(t);
572    }
573
574    // Flatten models HashMap into array of model objects with pricing data
575    // matching the format expected by the A2A server's /models and /workers endpoints
576    let models_array: Vec<serde_json::Value> = models
577        .iter()
578        .flat_map(|(provider, model_infos)| {
579            model_infos.iter().map(move |m| {
580                let mut obj = serde_json::json!({
581                    "id": format!("{}/{}", provider, m.id),
582                    "name": &m.id,
583                    "provider": provider,
584                    "provider_id": provider,
585                });
586                if let Some(input_cost) = m.input_cost_per_million {
587                    obj["input_cost_per_million"] = serde_json::json!(input_cost);
588                }
589                if let Some(output_cost) = m.output_cost_per_million {
590                    obj["output_cost_per_million"] = serde_json::json!(output_cost);
591                }
592                obj
593            })
594        })
595        .collect();
596
597    tracing::info!(
598        "Registering worker with {} models from {} providers",
599        models_array.len(),
600        models.len()
601    );
602
603    let hostname = std::env::var("HOSTNAME")
604        .or_else(|_| std::env::var("COMPUTERNAME"))
605        .unwrap_or_else(|_| "unknown".to_string());
606
607    let res = req
608        .json(&serde_json::json!({
609            "worker_id": worker_id,
610            "name": name,
611            "capabilities": WORKER_CAPABILITIES,
612            "hostname": hostname,
613            "models": models_array,
614            "codebases": codebases,
615        }))
616        .send()
617        .await?;
618
619    if res.status().is_success() {
620        tracing::info!("Worker registered successfully");
621    } else {
622        tracing::warn!("Failed to register worker: {}", res.status());
623    }
624
625    Ok(())
626}
627
628/// Load ProviderRegistry and collect all available models grouped by provider.
629/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
630/// Returns ModelInfo structs (with pricing data when available).
631async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
632    // Try Vault first
633    let registry = match ProviderRegistry::from_vault().await {
634        Ok(r) if !r.list().is_empty() => {
635            tracing::info!("Loaded {} providers from Vault", r.list().len());
636            r
637        }
638        Ok(_) => {
639            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
640            fallback_registry().await?
641        }
642        Err(e) => {
643            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
644            fallback_registry().await?
645        }
646    };
647
648    // Fetch the models.dev catalog for pricing data enrichment
649    let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
650
651    // Map provider IDs to their catalog equivalents (some differ)
652    let catalog_alias = |pid: &str| -> String {
653        match pid {
654            "bedrock" => "amazon-bedrock".to_string(),
655            "novita" => "novita-ai".to_string(),
656            _ => pid.to_string(),
657        }
658    };
659
660    let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
661
662    for provider_name in registry.list() {
663        if let Some(provider) = registry.get(provider_name) {
664            match provider.list_models().await {
665                Ok(models) => {
666                    let enriched: Vec<crate::provider::ModelInfo> = models
667                        .into_iter()
668                        .map(|mut m| {
669                            // Enrich with catalog pricing if the provider didn't set it
670                            if m.input_cost_per_million.is_none()
671                                || m.output_cost_per_million.is_none()
672                            {
673                                if let Some(ref cat) = catalog {
674                                    let cat_pid = catalog_alias(provider_name);
675                                    if let Some(prov_info) = cat.get_provider(&cat_pid) {
676                                        // Try exact match first, then strip "us." prefix (bedrock uses us.vendor.model format)
677                                        let model_info =
678                                            prov_info.models.get(&m.id).or_else(|| {
679                                                m.id.strip_prefix("us.").and_then(|stripped| {
680                                                    prov_info.models.get(stripped)
681                                                })
682                                            });
683                                        if let Some(model_info) = model_info {
684                                            if let Some(ref cost) = model_info.cost {
685                                                if m.input_cost_per_million.is_none() {
686                                                    m.input_cost_per_million = Some(cost.input);
687                                                }
688                                                if m.output_cost_per_million.is_none() {
689                                                    m.output_cost_per_million = Some(cost.output);
690                                                }
691                                            }
692                                        }
693                                    }
694                                }
695                            }
696                            m
697                        })
698                        .collect();
699                    if !enriched.is_empty() {
700                        tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
701                        models_by_provider.insert(provider_name.to_string(), enriched);
702                    }
703                }
704                Err(e) => {
705                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
706                }
707            }
708        }
709    }
710
711    // If we still have 0, try the models.dev catalog as last resort
712    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
713    // because the worker should advertise what it can handle. The server handles routing.
714    if models_by_provider.is_empty() {
715        tracing::info!(
716            "No authenticated providers found, fetching models.dev catalog (all providers)"
717        );
718        if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
719            // Use all_providers_with_models() to get every provider+model from catalog
720            // regardless of API key availability (Vault may be down)
721            for (provider_id, provider_info) in cat.all_providers() {
722                let model_infos: Vec<crate::provider::ModelInfo> = provider_info
723                    .models
724                    .values()
725                    .map(|m| cat.to_model_info(m, provider_id))
726                    .collect();
727                if !model_infos.is_empty() {
728                    tracing::debug!(
729                        "Catalog provider {}: {} models",
730                        provider_id,
731                        model_infos.len()
732                    );
733                    models_by_provider.insert(provider_id.clone(), model_infos);
734                }
735            }
736            tracing::info!(
737                "Loaded {} providers with {} total models from catalog",
738                models_by_provider.len(),
739                models_by_provider.values().map(|v| v.len()).sum::<usize>()
740            );
741        }
742    }
743
744    Ok(models_by_provider)
745}
746
747/// Fallback: build a ProviderRegistry from config file + environment variables
748async fn fallback_registry() -> Result<ProviderRegistry> {
749    let config = crate::config::Config::load().await.unwrap_or_default();
750    ProviderRegistry::from_config(&config).await
751}
752
753async fn fetch_pending_tasks(
754    client: &Client,
755    server: &str,
756    token: &Option<String>,
757    worker_id: &str,
758    processing: &Arc<Mutex<HashSet<String>>>,
759    auto_approve: &AutoApprove,
760) -> Result<()> {
761    tracing::info!("Checking for pending tasks...");
762
763    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
764    if let Some(t) = token {
765        req = req.bearer_auth(t);
766    }
767
768    let res = req.send().await?;
769    if !res.status().is_success() {
770        return Ok(());
771    }
772
773    let data: serde_json::Value = res.json().await?;
774    // Handle both plain array response and {tasks: [...]} wrapper
775    let tasks = if let Some(arr) = data.as_array() {
776        arr.clone()
777    } else {
778        data["tasks"].as_array().cloned().unwrap_or_default()
779    };
780
781    tracing::info!("Found {} pending task(s)", tasks.len());
782
783    for task in tasks {
784        if let Some(id) = task["id"].as_str() {
785            let mut proc = processing.lock().await;
786            if !proc.contains(id) {
787                proc.insert(id.to_string());
788                drop(proc);
789
790                let task_id = id.to_string();
791                let client = client.clone();
792                let server = server.to_string();
793                let token = token.clone();
794                let worker_id = worker_id.to_string();
795                let auto_approve = *auto_approve;
796                let processing = processing.clone();
797
798                tokio::spawn(async move {
799                    if let Err(e) =
800                        handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
801                    {
802                        tracing::error!("Task {} failed: {}", task_id, e);
803                    }
804                    processing.lock().await.remove(&task_id);
805                });
806            }
807        }
808    }
809
810    Ok(())
811}
812
813#[allow(clippy::too_many_arguments)]
814async fn connect_stream(
815    client: &Client,
816    server: &str,
817    token: &Option<String>,
818    worker_id: &str,
819    name: &str,
820    codebases: &[String],
821    processing: &Arc<Mutex<HashSet<String>>>,
822    auto_approve: &AutoApprove,
823) -> Result<()> {
824    let url = format!(
825        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
826        server,
827        urlencoding::encode(name),
828        urlencoding::encode(worker_id)
829    );
830
831    let mut req = client
832        .get(&url)
833        .header("Accept", "text/event-stream")
834        .header("X-Worker-ID", worker_id)
835        .header("X-Agent-Name", name)
836        .header("X-Codebases", codebases.join(","));
837
838    if let Some(t) = token {
839        req = req.bearer_auth(t);
840    }
841
842    let res = req.send().await?;
843    if !res.status().is_success() {
844        anyhow::bail!("Failed to connect: {}", res.status());
845    }
846
847    tracing::info!("Connected to A2A server");
848
849    let mut stream = res.bytes_stream();
850    let mut buffer = String::new();
851    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
852    poll_interval.tick().await; // consume the initial immediate tick
853
854    loop {
855        tokio::select! {
856            chunk = stream.next() => {
857                match chunk {
858                    Some(Ok(chunk)) => {
859                        buffer.push_str(&String::from_utf8_lossy(&chunk));
860
861                        // Process SSE events
862                        while let Some(pos) = buffer.find("\n\n") {
863                            let event_str = buffer[..pos].to_string();
864                            buffer = buffer[pos + 2..].to_string();
865
866                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
867                                let data = data_line.trim_start_matches("data:").trim();
868                                if data == "[DONE]" || data.is_empty() {
869                                    continue;
870                                }
871
872                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
873                                    spawn_task_handler(
874                                        &task, client, server, token, worker_id,
875                                        processing, auto_approve,
876                                    ).await;
877                                }
878                            }
879                        }
880                    }
881                    Some(Err(e)) => {
882                        return Err(e.into());
883                    }
884                    None => {
885                        // Stream ended
886                        return Ok(());
887                    }
888                }
889            }
890            _ = poll_interval.tick() => {
891                // Periodic poll for pending tasks the SSE stream may have missed
892                if let Err(e) = poll_pending_tasks(
893                    client, server, token, worker_id, processing, auto_approve,
894                ).await {
895                    tracing::warn!("Periodic task poll failed: {}", e);
896                }
897            }
898        }
899    }
900}
901
902async fn spawn_task_handler(
903    task: &serde_json::Value,
904    client: &Client,
905    server: &str,
906    token: &Option<String>,
907    worker_id: &str,
908    processing: &Arc<Mutex<HashSet<String>>>,
909    auto_approve: &AutoApprove,
910) {
911    if let Some(id) = task
912        .get("task")
913        .and_then(|t| t["id"].as_str())
914        .or_else(|| task["id"].as_str())
915    {
916        let mut proc = processing.lock().await;
917        if !proc.contains(id) {
918            proc.insert(id.to_string());
919            drop(proc);
920
921            let task_id = id.to_string();
922            let task = task.clone();
923            let client = client.clone();
924            let server = server.to_string();
925            let token = token.clone();
926            let worker_id = worker_id.to_string();
927            let auto_approve = *auto_approve;
928            let processing_clone = processing.clone();
929
930            tokio::spawn(async move {
931                if let Err(e) =
932                    handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
933                {
934                    tracing::error!("Task {} failed: {}", task_id, e);
935                }
936                processing_clone.lock().await.remove(&task_id);
937            });
938        }
939    }
940}
941
942async fn poll_pending_tasks(
943    client: &Client,
944    server: &str,
945    token: &Option<String>,
946    worker_id: &str,
947    processing: &Arc<Mutex<HashSet<String>>>,
948    auto_approve: &AutoApprove,
949) -> Result<()> {
950    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
951    if let Some(t) = token {
952        req = req.bearer_auth(t);
953    }
954
955    let res = req.send().await?;
956    if !res.status().is_success() {
957        return Ok(());
958    }
959
960    let data: serde_json::Value = res.json().await?;
961    let tasks = if let Some(arr) = data.as_array() {
962        arr.clone()
963    } else {
964        data["tasks"].as_array().cloned().unwrap_or_default()
965    };
966
967    if !tasks.is_empty() {
968        tracing::debug!("Poll found {} pending task(s)", tasks.len());
969    }
970
971    for task in &tasks {
972        spawn_task_handler(
973            task,
974            client,
975            server,
976            token,
977            worker_id,
978            processing,
979            auto_approve,
980        )
981        .await;
982    }
983
984    Ok(())
985}
986
987async fn handle_task(
988    client: &Client,
989    server: &str,
990    token: &Option<String>,
991    worker_id: &str,
992    task: &serde_json::Value,
993    auto_approve: AutoApprove,
994) -> Result<()> {
995    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
996    let title = task_str(task, "title").unwrap_or("Untitled");
997
998    tracing::info!("Handling task: {} ({})", title, task_id);
999
1000    // Claim the task
1001    let mut req = client
1002        .post(format!("{}/v1/worker/tasks/claim", server))
1003        .header("X-Worker-ID", worker_id);
1004    if let Some(t) = token {
1005        req = req.bearer_auth(t);
1006    }
1007
1008    let res = req
1009        .json(&serde_json::json!({ "task_id": task_id }))
1010        .send()
1011        .await?;
1012
1013    if !res.status().is_success() {
1014        let status = res.status();
1015        let text = res.text().await?;
1016        if status == reqwest::StatusCode::CONFLICT {
1017            tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1018        } else {
1019            tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1020        }
1021        return Ok(());
1022    }
1023
1024    tracing::info!("Claimed task: {}", task_id);
1025
1026    let metadata = task_metadata(task);
1027    let resume_session_id = metadata
1028        .get("resume_session_id")
1029        .and_then(|v| v.as_str())
1030        .map(|s| s.trim().to_string())
1031        .filter(|s| !s.is_empty());
1032    let complexity_hint = metadata_str(&metadata, &["complexity"]);
1033    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1034        .map(|s| s.to_ascii_lowercase())
1035        .or_else(|| {
1036            complexity_hint.as_ref().map(|complexity| {
1037                match complexity.to_ascii_lowercase().as_str() {
1038                    "quick" => "fast".to_string(),
1039                    "deep" => "heavy".to_string(),
1040                    _ => "balanced".to_string(),
1041                }
1042            })
1043        });
1044    let worker_personality = metadata_str(
1045        &metadata,
1046        &["worker_personality", "personality", "agent_personality"],
1047    );
1048    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1049    let raw_model = task_str(task, "model_ref")
1050        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1051        .or_else(|| task_str(task, "model"))
1052        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1053    let selected_model = raw_model.map(model_ref_to_provider_model);
1054
1055    // Resume existing session when requested; fall back to a fresh session if missing.
1056    let mut session = if let Some(ref sid) = resume_session_id {
1057        match Session::load(sid).await {
1058            Ok(existing) => {
1059                tracing::info!("Resuming session {} for task {}", sid, task_id);
1060                existing
1061            }
1062            Err(e) => {
1063                tracing::warn!(
1064                    "Could not load session {} for task {} ({}), starting a new session",
1065                    sid,
1066                    task_id,
1067                    e
1068                );
1069                Session::new().await?
1070            }
1071        }
1072    } else {
1073        Session::new().await?
1074    };
1075
1076    let agent_type = task_str(task, "agent_type")
1077        .or_else(|| task_str(task, "agent"))
1078        .unwrap_or("build");
1079    session.agent = agent_type.to_string();
1080
1081    if let Some(model) = selected_model.clone() {
1082        session.metadata.model = Some(model);
1083    }
1084
1085    let prompt = task_str(task, "prompt")
1086        .or_else(|| task_str(task, "description"))
1087        .unwrap_or(title);
1088
1089    tracing::info!("Executing prompt: {}", prompt);
1090
1091    // Set up output streaming to forward progress to the server
1092    let stream_client = client.clone();
1093    let stream_server = server.to_string();
1094    let stream_token = token.clone();
1095    let stream_worker_id = worker_id.to_string();
1096    let stream_task_id = task_id.to_string();
1097
1098    let output_callback = move |output: String| {
1099        let c = stream_client.clone();
1100        let s = stream_server.clone();
1101        let t = stream_token.clone();
1102        let w = stream_worker_id.clone();
1103        let tid = stream_task_id.clone();
1104        tokio::spawn(async move {
1105            let mut req = c
1106                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1107                .header("X-Worker-ID", &w);
1108            if let Some(tok) = &t {
1109                req = req.bearer_auth(tok);
1110            }
1111            let _ = req
1112                .json(&serde_json::json!({
1113                    "worker_id": w,
1114                    "output": output,
1115                }))
1116                .send()
1117                .await;
1118        });
1119    };
1120
1121    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1122    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1123        match execute_swarm_with_policy(
1124            &mut session,
1125            prompt,
1126            model_tier.as_deref(),
1127            selected_model,
1128            &metadata,
1129            complexity_hint.as_deref(),
1130            worker_personality.as_deref(),
1131            target_agent_name.as_deref(),
1132            Some(&output_callback),
1133        )
1134        .await
1135        {
1136            Ok((session_result, true)) => {
1137                tracing::info!("Swarm task completed successfully: {}", task_id);
1138                (
1139                    "completed",
1140                    Some(session_result.text),
1141                    None,
1142                    Some(session_result.session_id),
1143                )
1144            }
1145            Ok((session_result, false)) => {
1146                tracing::warn!("Swarm task completed with failures: {}", task_id);
1147                (
1148                    "failed",
1149                    Some(session_result.text),
1150                    Some("Swarm execution completed with failures".to_string()),
1151                    Some(session_result.session_id),
1152                )
1153            }
1154            Err(e) => {
1155                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1156                ("failed", None, Some(format!("Error: {}", e)), None)
1157            }
1158        }
1159    } else {
1160        match execute_session_with_policy(
1161            &mut session,
1162            prompt,
1163            auto_approve,
1164            model_tier.as_deref(),
1165            Some(&output_callback),
1166        )
1167        .await
1168        {
1169            Ok(session_result) => {
1170                tracing::info!("Task completed successfully: {}", task_id);
1171                (
1172                    "completed",
1173                    Some(session_result.text),
1174                    None,
1175                    Some(session_result.session_id),
1176                )
1177            }
1178            Err(e) => {
1179                tracing::error!("Task failed: {} - {}", task_id, e);
1180                ("failed", None, Some(format!("Error: {}", e)), None)
1181            }
1182        }
1183    };
1184
1185    // Release the task with full details
1186    let mut req = client
1187        .post(format!("{}/v1/worker/tasks/release", server))
1188        .header("X-Worker-ID", worker_id);
1189    if let Some(t) = token {
1190        req = req.bearer_auth(t);
1191    }
1192
1193    req.json(&serde_json::json!({
1194        "task_id": task_id,
1195        "status": status,
1196        "result": result,
1197        "error": error,
1198        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1199    }))
1200    .send()
1201    .await?;
1202
1203    tracing::info!("Task released: {} with status: {}", task_id, status);
1204    Ok(())
1205}
1206
1207async fn execute_swarm_with_policy<F>(
1208    session: &mut Session,
1209    prompt: &str,
1210    model_tier: Option<&str>,
1211    explicit_model: Option<String>,
1212    metadata: &serde_json::Map<String, serde_json::Value>,
1213    complexity_hint: Option<&str>,
1214    worker_personality: Option<&str>,
1215    target_agent_name: Option<&str>,
1216    output_callback: Option<&F>,
1217) -> Result<(crate::session::SessionResult, bool)>
1218where
1219    F: Fn(String),
1220{
1221    use crate::provider::{ContentPart, Message, Role};
1222
1223    session.add_message(Message {
1224        role: Role::User,
1225        content: vec![ContentPart::Text {
1226            text: prompt.to_string(),
1227        }],
1228    });
1229
1230    if session.title.is_none() {
1231        session.generate_title().await?;
1232    }
1233
1234    let strategy = parse_swarm_strategy(metadata);
1235    let max_subagents = metadata_usize(
1236        metadata,
1237        &["swarm_max_subagents", "max_subagents", "subagents"],
1238    )
1239    .unwrap_or(10)
1240    .clamp(1, 100);
1241    let max_steps_per_subagent = metadata_usize(
1242        metadata,
1243        &[
1244            "swarm_max_steps_per_subagent",
1245            "max_steps_per_subagent",
1246            "max_steps",
1247        ],
1248    )
1249    .unwrap_or(50)
1250    .clamp(1, 200);
1251    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1252        .unwrap_or(600)
1253        .clamp(30, 3600);
1254    let parallel_enabled =
1255        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1256
1257    let model = resolve_swarm_model(explicit_model, model_tier).await;
1258    if let Some(ref selected_model) = model {
1259        session.metadata.model = Some(selected_model.clone());
1260    }
1261
1262    if let Some(cb) = output_callback {
1263        cb(format!(
1264            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1265            complexity_hint.unwrap_or("standard"),
1266            model_tier.unwrap_or("balanced"),
1267            worker_personality.unwrap_or("auto"),
1268            target_agent_name.unwrap_or("auto")
1269        ));
1270        cb(format!(
1271            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1272            strategy,
1273            max_subagents,
1274            max_steps_per_subagent,
1275            timeout_secs,
1276            model_tier.unwrap_or("balanced")
1277        ));
1278    }
1279
1280    let swarm_config = SwarmConfig {
1281        max_subagents,
1282        max_steps_per_subagent,
1283        subagent_timeout_secs: timeout_secs,
1284        parallel_enabled,
1285        model,
1286        working_dir: session
1287            .metadata
1288            .directory
1289            .as_ref()
1290            .map(|p| p.to_string_lossy().to_string()),
1291        ..Default::default()
1292    };
1293
1294    let swarm_result = if output_callback.is_some() {
1295        let (event_tx, mut event_rx) = mpsc::channel(256);
1296        let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1297        let prompt_owned = prompt.to_string();
1298        let mut exec_handle =
1299            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1300
1301        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1302
1303        while final_result.is_none() {
1304            tokio::select! {
1305                maybe_event = event_rx.recv() => {
1306                    if let Some(event) = maybe_event {
1307                        if let Some(cb) = output_callback {
1308                            if let Some(line) = format_swarm_event_for_output(&event) {
1309                                cb(line);
1310                            }
1311                        }
1312                    }
1313                }
1314                join_result = &mut exec_handle => {
1315                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1316                    final_result = Some(joined?);
1317                }
1318            }
1319        }
1320
1321        while let Ok(event) = event_rx.try_recv() {
1322            if let Some(cb) = output_callback {
1323                if let Some(line) = format_swarm_event_for_output(&event) {
1324                    cb(line);
1325                }
1326            }
1327        }
1328
1329        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1330    } else {
1331        SwarmExecutor::new(swarm_config)
1332            .execute(prompt, strategy)
1333            .await?
1334    };
1335
1336    let final_text = if swarm_result.result.trim().is_empty() {
1337        if swarm_result.success {
1338            "Swarm completed without textual output.".to_string()
1339        } else {
1340            "Swarm finished with failures and no textual output.".to_string()
1341        }
1342    } else {
1343        swarm_result.result.clone()
1344    };
1345
1346    session.add_message(Message {
1347        role: Role::Assistant,
1348        content: vec![ContentPart::Text {
1349            text: final_text.clone(),
1350        }],
1351    });
1352    session.save().await?;
1353
1354    Ok((
1355        crate::session::SessionResult {
1356            text: final_text,
1357            session_id: session.id.clone(),
1358        },
1359        swarm_result.success,
1360    ))
1361}
1362
1363/// Execute a session with the given auto-approve policy
1364/// Optionally streams output chunks via the callback
1365async fn execute_session_with_policy<F>(
1366    session: &mut Session,
1367    prompt: &str,
1368    auto_approve: AutoApprove,
1369    model_tier: Option<&str>,
1370    output_callback: Option<&F>,
1371) -> Result<crate::session::SessionResult>
1372where
1373    F: Fn(String),
1374{
1375    use crate::provider::{
1376        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1377    };
1378    use std::sync::Arc;
1379
1380    // Load provider registry from Vault
1381    let registry = ProviderRegistry::from_vault().await?;
1382    let providers = registry.list();
1383    tracing::info!("Available providers: {:?}", providers);
1384
1385    if providers.is_empty() {
1386        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1387    }
1388
1389    // Parse model string
1390    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1391        let (prov, model) = parse_model_string(model_str);
1392        if prov.is_some() {
1393            (prov.map(|s| s.to_string()), model.to_string())
1394        } else if providers.contains(&model) {
1395            (Some(model.to_string()), String::new())
1396        } else {
1397            (None, model.to_string())
1398        }
1399    } else {
1400        (None, String::new())
1401    };
1402
1403    let provider_slice = providers.as_slice();
1404    let provider_requested_but_unavailable = provider_name
1405        .as_deref()
1406        .map(|p| !providers.contains(&p))
1407        .unwrap_or(false);
1408
1409    // Determine which provider to use, preferring explicit request first, then model tier.
1410    let selected_provider = provider_name
1411        .as_deref()
1412        .filter(|p| providers.contains(p))
1413        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1414
1415    let provider = registry
1416        .get(selected_provider)
1417        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1418
1419    // Add user message
1420    session.add_message(Message {
1421        role: Role::User,
1422        content: vec![ContentPart::Text {
1423            text: prompt.to_string(),
1424        }],
1425    });
1426
1427    // Generate title
1428    if session.title.is_none() {
1429        session.generate_title().await?;
1430    }
1431
1432    // Determine model. If a specific provider was requested but not available,
1433    // ignore that model id and fall back to the tier-based default model.
1434    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1435        model_id
1436    } else {
1437        default_model_for_provider(selected_provider, model_tier)
1438    };
1439
1440    // Create tool registry with filtering based on auto-approve policy
1441    let tool_registry =
1442        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1443    let tool_definitions = tool_registry.definitions();
1444
1445    let temperature = if model.starts_with("kimi-k2") {
1446        Some(1.0)
1447    } else {
1448        Some(0.7)
1449    };
1450
1451    tracing::info!(
1452        "Using model: {} via provider: {} (tier: {:?})",
1453        model,
1454        selected_provider,
1455        model_tier
1456    );
1457    tracing::info!(
1458        "Available tools: {} (auto_approve: {:?})",
1459        tool_definitions.len(),
1460        auto_approve
1461    );
1462
1463    // Build system prompt
1464    let cwd = std::env::var("PWD")
1465        .map(std::path::PathBuf::from)
1466        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1467    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1468
1469    let mut final_output = String::new();
1470    let max_steps = 50;
1471
1472    for step in 1..=max_steps {
1473        tracing::info!(step = step, "Agent step starting");
1474
1475        // Build messages with system prompt first
1476        let mut messages = vec![Message {
1477            role: Role::System,
1478            content: vec![ContentPart::Text {
1479                text: system_prompt.clone(),
1480            }],
1481        }];
1482        messages.extend(session.messages.clone());
1483
1484        let request = CompletionRequest {
1485            messages,
1486            tools: tool_definitions.clone(),
1487            model: model.clone(),
1488            temperature,
1489            top_p: None,
1490            max_tokens: Some(8192),
1491            stop: Vec::new(),
1492        };
1493
1494        let response = provider.complete(request).await?;
1495
1496        crate::telemetry::TOKEN_USAGE.record_model_usage(
1497            &model,
1498            response.usage.prompt_tokens as u64,
1499            response.usage.completion_tokens as u64,
1500        );
1501
1502        // Extract tool calls
1503        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1504            .message
1505            .content
1506            .iter()
1507            .filter_map(|part| {
1508                if let ContentPart::ToolCall {
1509                    id,
1510                    name,
1511                    arguments,
1512                } = part
1513                {
1514                    let args: serde_json::Value =
1515                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1516                    Some((id.clone(), name.clone(), args))
1517                } else {
1518                    None
1519                }
1520            })
1521            .collect();
1522
1523        // Collect text output and stream it
1524        for part in &response.message.content {
1525            if let ContentPart::Text { text } = part {
1526                if !text.is_empty() {
1527                    final_output.push_str(text);
1528                    final_output.push('\n');
1529                    if let Some(cb) = output_callback {
1530                        cb(text.clone());
1531                    }
1532                }
1533            }
1534        }
1535
1536        // If no tool calls, we're done
1537        if tool_calls.is_empty() {
1538            session.add_message(response.message.clone());
1539            break;
1540        }
1541
1542        session.add_message(response.message.clone());
1543
1544        tracing::info!(
1545            step = step,
1546            num_tools = tool_calls.len(),
1547            "Executing tool calls"
1548        );
1549
1550        // Execute each tool call
1551        for (tool_id, tool_name, tool_input) in tool_calls {
1552            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1553
1554            // Stream tool start event
1555            if let Some(cb) = output_callback {
1556                cb(format!("[tool:start:{}]", tool_name));
1557            }
1558
1559            // Check if tool is allowed based on auto-approve policy
1560            if !is_tool_allowed(&tool_name, auto_approve) {
1561                let msg = format!(
1562                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1563                    tool_name, auto_approve
1564                );
1565                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1566                session.add_message(Message {
1567                    role: Role::Tool,
1568                    content: vec![ContentPart::ToolResult {
1569                        tool_call_id: tool_id,
1570                        content: msg,
1571                    }],
1572                });
1573                continue;
1574            }
1575
1576            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1577                let exec_result: Result<crate::tool::ToolResult> =
1578                    tool.execute(tool_input.clone()).await;
1579                match exec_result {
1580                    Ok(result) => {
1581                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1582                        if let Some(cb) = output_callback {
1583                            let status = if result.success { "ok" } else { "err" };
1584                            cb(format!(
1585                                "[tool:{}:{}] {}",
1586                                tool_name,
1587                                status,
1588                                &result.output[..result.output.len().min(500)]
1589                            ));
1590                        }
1591                        result.output
1592                    }
1593                    Err(e) => {
1594                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1595                        if let Some(cb) = output_callback {
1596                            cb(format!("[tool:{}:err] {}", tool_name, e));
1597                        }
1598                        format!("Error: {}", e)
1599                    }
1600                }
1601            } else {
1602                tracing::warn!(tool = %tool_name, "Tool not found");
1603                format!("Error: Unknown tool '{}'", tool_name)
1604            };
1605
1606            session.add_message(Message {
1607                role: Role::Tool,
1608                content: vec![ContentPart::ToolResult {
1609                    tool_call_id: tool_id,
1610                    content,
1611                }],
1612            });
1613        }
1614    }
1615
1616    session.save().await?;
1617
1618    Ok(crate::session::SessionResult {
1619        text: final_output.trim().to_string(),
1620        session_id: session.id.clone(),
1621    })
1622}
1623
1624/// Check if a tool is allowed based on the auto-approve policy
1625fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1626    match auto_approve {
1627        AutoApprove::All => true,
1628        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1629    }
1630}
1631
1632/// Check if a tool is considered "safe" (read-only)
1633fn is_safe_tool(tool_name: &str) -> bool {
1634    let safe_tools = [
1635        "read",
1636        "list",
1637        "glob",
1638        "grep",
1639        "codesearch",
1640        "lsp",
1641        "webfetch",
1642        "websearch",
1643        "todo_read",
1644        "skill",
1645    ];
1646    safe_tools.contains(&tool_name)
1647}
1648
1649/// Create a filtered tool registry based on the auto-approve policy
1650fn create_filtered_registry(
1651    provider: Arc<dyn crate::provider::Provider>,
1652    model: String,
1653    auto_approve: AutoApprove,
1654) -> crate::tool::ToolRegistry {
1655    use crate::tool::*;
1656
1657    let mut registry = ToolRegistry::new();
1658
1659    // Always add safe tools
1660    registry.register(Arc::new(file::ReadTool::new()));
1661    registry.register(Arc::new(file::ListTool::new()));
1662    registry.register(Arc::new(file::GlobTool::new()));
1663    registry.register(Arc::new(search::GrepTool::new()));
1664    registry.register(Arc::new(lsp::LspTool::new()));
1665    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1666    registry.register(Arc::new(websearch::WebSearchTool::new()));
1667    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1668    registry.register(Arc::new(todo::TodoReadTool::new()));
1669    registry.register(Arc::new(skill::SkillTool::new()));
1670
1671    // Add potentially dangerous tools only if auto_approve is All
1672    if matches!(auto_approve, AutoApprove::All) {
1673        registry.register(Arc::new(file::WriteTool::new()));
1674        registry.register(Arc::new(edit::EditTool::new()));
1675        registry.register(Arc::new(bash::BashTool::new()));
1676        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1677        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1678        registry.register(Arc::new(todo::TodoWriteTool::new()));
1679        registry.register(Arc::new(task::TaskTool::new()));
1680        registry.register(Arc::new(plan::PlanEnterTool::new()));
1681        registry.register(Arc::new(plan::PlanExitTool::new()));
1682        registry.register(Arc::new(rlm::RlmTool::new()));
1683        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1684        registry.register(Arc::new(prd::PrdTool::new()));
1685        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1686        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1687        registry.register(Arc::new(undo::UndoTool));
1688    }
1689
1690    registry.register(Arc::new(invalid::InvalidTool::new()));
1691
1692    registry
1693}
1694
1695/// Start the heartbeat background task
1696/// Returns a JoinHandle that can be used to cancel the heartbeat
1697fn start_heartbeat(
1698    client: Client,
1699    server: String,
1700    token: Option<String>,
1701    heartbeat_state: HeartbeatState,
1702    processing: Arc<Mutex<HashSet<String>>>,
1703    cognition_config: CognitionHeartbeatConfig,
1704) -> JoinHandle<()> {
1705    tokio::spawn(async move {
1706        let mut consecutive_failures = 0u32;
1707        const MAX_FAILURES: u32 = 3;
1708        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1709        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1710        let mut cognition_payload_disabled_until: Option<Instant> = None;
1711
1712        let mut interval =
1713            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1714        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1715
1716        loop {
1717            interval.tick().await;
1718
1719            // Update task count from processing set
1720            let active_count = processing.lock().await.len();
1721            heartbeat_state.set_task_count(active_count).await;
1722
1723            // Determine status based on active tasks
1724            let status = if active_count > 0 {
1725                WorkerStatus::Processing
1726            } else {
1727                WorkerStatus::Idle
1728            };
1729            heartbeat_state.set_status(status).await;
1730
1731            // Send heartbeat
1732            let url = format!(
1733                "{}/v1/opencode/workers/{}/heartbeat",
1734                server, heartbeat_state.worker_id
1735            );
1736            let mut req = client.post(&url);
1737
1738            if let Some(ref t) = token {
1739                req = req.bearer_auth(t);
1740            }
1741
1742            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1743            let base_payload = serde_json::json!({
1744                "worker_id": &heartbeat_state.worker_id,
1745                "agent_name": &heartbeat_state.agent_name,
1746                "status": status_str,
1747                "active_task_count": active_count,
1748            });
1749            let mut payload = base_payload.clone();
1750            let mut included_cognition_payload = false;
1751            let cognition_payload_allowed = cognition_payload_disabled_until
1752                .map(|until| Instant::now() >= until)
1753                .unwrap_or(true);
1754
1755            if cognition_config.enabled
1756                && cognition_payload_allowed
1757                && let Some(cognition_payload) =
1758                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1759                && let Some(obj) = payload.as_object_mut()
1760            {
1761                obj.insert("cognition".to_string(), cognition_payload);
1762                included_cognition_payload = true;
1763            }
1764
1765            match req.json(&payload).send().await {
1766                Ok(res) => {
1767                    if res.status().is_success() {
1768                        consecutive_failures = 0;
1769                        tracing::debug!(
1770                            worker_id = %heartbeat_state.worker_id,
1771                            status = status_str,
1772                            active_tasks = active_count,
1773                            "Heartbeat sent successfully"
1774                        );
1775                    } else if included_cognition_payload && res.status().is_client_error() {
1776                        tracing::warn!(
1777                            worker_id = %heartbeat_state.worker_id,
1778                            status = %res.status(),
1779                            "Heartbeat cognition payload rejected, retrying without cognition payload"
1780                        );
1781
1782                        let mut retry_req = client.post(&url);
1783                        if let Some(ref t) = token {
1784                            retry_req = retry_req.bearer_auth(t);
1785                        }
1786
1787                        match retry_req.json(&base_payload).send().await {
1788                            Ok(retry_res) if retry_res.status().is_success() => {
1789                                cognition_payload_disabled_until = Some(
1790                                    Instant::now()
1791                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1792                                );
1793                                consecutive_failures = 0;
1794                                tracing::warn!(
1795                                    worker_id = %heartbeat_state.worker_id,
1796                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1797                                    "Paused cognition heartbeat payload after schema rejection"
1798                                );
1799                            }
1800                            Ok(retry_res) => {
1801                                consecutive_failures += 1;
1802                                tracing::warn!(
1803                                    worker_id = %heartbeat_state.worker_id,
1804                                    status = %retry_res.status(),
1805                                    failures = consecutive_failures,
1806                                    "Heartbeat failed even after retry without cognition payload"
1807                                );
1808                            }
1809                            Err(e) => {
1810                                consecutive_failures += 1;
1811                                tracing::warn!(
1812                                    worker_id = %heartbeat_state.worker_id,
1813                                    error = %e,
1814                                    failures = consecutive_failures,
1815                                    "Heartbeat retry without cognition payload failed"
1816                                );
1817                            }
1818                        }
1819                    } else {
1820                        consecutive_failures += 1;
1821                        tracing::warn!(
1822                            worker_id = %heartbeat_state.worker_id,
1823                            status = %res.status(),
1824                            failures = consecutive_failures,
1825                            "Heartbeat failed"
1826                        );
1827                    }
1828                }
1829                Err(e) => {
1830                    consecutive_failures += 1;
1831                    tracing::warn!(
1832                        worker_id = %heartbeat_state.worker_id,
1833                        error = %e,
1834                        failures = consecutive_failures,
1835                        "Heartbeat request failed"
1836                    );
1837                }
1838            }
1839
1840            // Log error after 3 consecutive failures but do not terminate
1841            if consecutive_failures >= MAX_FAILURES {
1842                tracing::error!(
1843                    worker_id = %heartbeat_state.worker_id,
1844                    failures = consecutive_failures,
1845                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1846                    MAX_FAILURES
1847                );
1848                // Reset counter to avoid spamming error logs
1849                consecutive_failures = 0;
1850            }
1851        }
1852    })
1853}
1854
1855async fn fetch_cognition_heartbeat_payload(
1856    client: &Client,
1857    config: &CognitionHeartbeatConfig,
1858) -> Option<serde_json::Value> {
1859    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1860    let status_res = tokio::time::timeout(
1861        Duration::from_millis(config.request_timeout_ms),
1862        client.get(status_url).send(),
1863    )
1864    .await
1865    .ok()?
1866    .ok()?;
1867
1868    if !status_res.status().is_success() {
1869        return None;
1870    }
1871
1872    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1873    let mut payload = serde_json::json!({
1874        "running": status.running,
1875        "last_tick_at": status.last_tick_at,
1876        "active_persona_count": status.active_persona_count,
1877        "events_buffered": status.events_buffered,
1878        "snapshots_buffered": status.snapshots_buffered,
1879        "loop_interval_ms": status.loop_interval_ms,
1880    });
1881
1882    if config.include_thought_summary {
1883        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1884        let snapshot_res = tokio::time::timeout(
1885            Duration::from_millis(config.request_timeout_ms),
1886            client.get(snapshot_url).send(),
1887        )
1888        .await
1889        .ok()
1890        .and_then(Result::ok);
1891
1892        if let Some(snapshot_res) = snapshot_res
1893            && snapshot_res.status().is_success()
1894            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1895            && let Some(obj) = payload.as_object_mut()
1896        {
1897            obj.insert(
1898                "latest_snapshot_at".to_string(),
1899                serde_json::Value::String(snapshot.generated_at),
1900            );
1901            obj.insert(
1902                "latest_thought".to_string(),
1903                serde_json::Value::String(trim_for_heartbeat(
1904                    &snapshot.summary,
1905                    config.summary_max_chars,
1906                )),
1907            );
1908            if let Some(model) = snapshot
1909                .metadata
1910                .get("model")
1911                .and_then(serde_json::Value::as_str)
1912            {
1913                obj.insert(
1914                    "latest_thought_model".to_string(),
1915                    serde_json::Value::String(model.to_string()),
1916                );
1917            }
1918            if let Some(source) = snapshot
1919                .metadata
1920                .get("source")
1921                .and_then(serde_json::Value::as_str)
1922            {
1923                obj.insert(
1924                    "latest_thought_source".to_string(),
1925                    serde_json::Value::String(source.to_string()),
1926                );
1927            }
1928        }
1929    }
1930
1931    Some(payload)
1932}
1933
1934fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1935    if input.chars().count() <= max_chars {
1936        return input.trim().to_string();
1937    }
1938
1939    let mut trimmed = String::with_capacity(max_chars + 3);
1940    for ch in input.chars().take(max_chars) {
1941        trimmed.push(ch);
1942    }
1943    trimmed.push_str("...");
1944    trimmed.trim().to_string()
1945}
1946
1947fn env_bool(name: &str, default: bool) -> bool {
1948    std::env::var(name)
1949        .ok()
1950        .and_then(|v| match v.to_ascii_lowercase().as_str() {
1951            "1" | "true" | "yes" | "on" => Some(true),
1952            "0" | "false" | "no" | "off" => Some(false),
1953            _ => None,
1954        })
1955        .unwrap_or(default)
1956}
1957
1958fn env_usize(name: &str, default: usize) -> usize {
1959    std::env::var(name)
1960        .ok()
1961        .and_then(|v| v.parse::<usize>().ok())
1962        .unwrap_or(default)
1963}
1964
1965fn env_u64(name: &str, default: u64) -> u64 {
1966    std::env::var(name)
1967        .ok()
1968        .and_then(|v| v.parse::<u64>().ok())
1969        .unwrap_or(default)
1970}