Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::bus::AgentBus;
4use crate::cli::A2aArgs;
5use crate::provider::ProviderRegistry;
6use crate::session::Session;
7use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
8use crate::tui::swarm_view::SwarmEvent;
9use anyhow::Result;
10use futures::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{Mutex, mpsc};
18use tokio::task::JoinHandle;
19use tokio::time::Instant;
20
21/// Worker status for heartbeat
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum WorkerStatus {
24    Idle,
25    Processing,
26}
27
28impl WorkerStatus {
29    fn as_str(&self) -> &'static str {
30        match self {
31            WorkerStatus::Idle => "idle",
32            WorkerStatus::Processing => "processing",
33        }
34    }
35}
36
37/// Heartbeat state shared between the heartbeat task and the main worker
38#[derive(Clone)]
39struct HeartbeatState {
40    worker_id: String,
41    agent_name: String,
42    status: Arc<Mutex<WorkerStatus>>,
43    active_task_count: Arc<Mutex<usize>>,
44}
45
46impl HeartbeatState {
47    fn new(worker_id: String, agent_name: String) -> Self {
48        Self {
49            worker_id,
50            agent_name,
51            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
52            active_task_count: Arc::new(Mutex::new(0)),
53        }
54    }
55
56    async fn set_status(&self, status: WorkerStatus) {
57        *self.status.lock().await = status;
58    }
59
60    async fn set_task_count(&self, count: usize) {
61        *self.active_task_count.lock().await = count;
62    }
63}
64
65#[derive(Clone, Debug)]
66struct CognitionHeartbeatConfig {
67    enabled: bool,
68    source_base_url: String,
69    include_thought_summary: bool,
70    summary_max_chars: usize,
71    request_timeout_ms: u64,
72}
73
74impl CognitionHeartbeatConfig {
75    fn from_env() -> Self {
76        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
77            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
78            .trim_end_matches('/')
79            .to_string();
80
81        Self {
82            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
83            source_base_url,
84            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
85            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
86                .max(120),
87            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
88        }
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct CognitionStatusSnapshot {
94    running: bool,
95    #[serde(default)]
96    last_tick_at: Option<String>,
97    #[serde(default)]
98    active_persona_count: usize,
99    #[serde(default)]
100    events_buffered: usize,
101    #[serde(default)]
102    snapshots_buffered: usize,
103    #[serde(default)]
104    loop_interval_ms: u64,
105}
106
107#[derive(Debug, Deserialize)]
108struct CognitionLatestSnapshot {
109    generated_at: String,
110    summary: String,
111    #[serde(default)]
112    metadata: HashMap<String, serde_json::Value>,
113}
114
115// Run the A2A worker
116pub async fn run(args: A2aArgs) -> Result<()> {
117    let server = args.server.trim_end_matches('/');
118    let name = args
119        .name
120        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
121    let worker_id = generate_worker_id();
122
123    let codebases: Vec<String> = args
124        .codebases
125        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
126        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
127
128    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
129    tracing::info!("Server: {}", server);
130    tracing::info!("Codebases: {:?}", codebases);
131
132    let client = Client::new();
133    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
134    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
135    if cognition_heartbeat.enabled {
136        tracing::info!(
137            source = %cognition_heartbeat.source_base_url,
138            include_thoughts = cognition_heartbeat.include_thought_summary,
139            max_chars = cognition_heartbeat.summary_max_chars,
140            timeout_ms = cognition_heartbeat.request_timeout_ms,
141            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
142        );
143    } else {
144        tracing::warn!(
145            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
146        );
147    }
148
149    let auto_approve = match args.auto_approve.as_str() {
150        "all" => AutoApprove::All,
151        "safe" => AutoApprove::Safe,
152        _ => AutoApprove::None,
153    };
154
155    // Create heartbeat state
156    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
157
158    // Create agent bus for in-process sub-agent communication
159    let bus = AgentBus::new().into_arc();
160    {
161        let handle = bus.handle(&worker_id);
162        handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
163    }
164
165    // Register worker
166    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
167
168    // Fetch pending tasks
169    fetch_pending_tasks(
170        &client,
171        server,
172        &args.token,
173        &worker_id,
174        &processing,
175        &auto_approve,
176        &bus,
177    )
178    .await?;
179
180    // Connect to SSE stream
181    loop {
182        // Re-register worker on each reconnection to report updated models/capabilities
183        if let Err(e) =
184            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
185        {
186            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
187        }
188
189        // Start heartbeat task for this connection
190        let heartbeat_handle = start_heartbeat(
191            client.clone(),
192            server.to_string(),
193            args.token.clone(),
194            heartbeat_state.clone(),
195            processing.clone(),
196            cognition_heartbeat.clone(),
197        );
198
199        match connect_stream(
200            &client,
201            server,
202            &args.token,
203            &worker_id,
204            &name,
205            &codebases,
206            &processing,
207            &auto_approve,
208            &bus,
209        )
210        .await
211        {
212            Ok(()) => {
213                tracing::warn!("Stream ended, reconnecting...");
214            }
215            Err(e) => {
216                tracing::error!("Stream error: {}, reconnecting...", e);
217            }
218        }
219
220        // Cancel heartbeat on disconnection
221        heartbeat_handle.abort();
222        tracing::debug!("Heartbeat cancelled for reconnection");
223
224        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
225    }
226}
227
228fn generate_worker_id() -> String {
229    format!(
230        "wrk_{}_{:x}",
231        chrono::Utc::now().timestamp(),
232        rand::random::<u64>()
233    )
234}
235
236#[derive(Debug, Clone, Copy)]
237enum AutoApprove {
238    All,
239    Safe,
240    None,
241}
242
243/// Default A2A server URL when none is configured
244pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
245
246/// Capabilities of the codetether-agent worker
247const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
248
249fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
250    task.get("task")
251        .and_then(|t| t.get(key))
252        .or_else(|| task.get(key))
253}
254
255fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
256    task_value(task, key).and_then(|v| v.as_str())
257}
258
259fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
260    task_value(task, "metadata")
261        .and_then(|m| m.as_object())
262        .cloned()
263        .unwrap_or_default()
264}
265
266fn model_ref_to_provider_model(model: &str) -> String {
267    // Convert "provider:model" to "provider/model" format, but only if
268    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
269    // contain colons as version separators and must NOT be converted.
270    if !model.contains('/') && model.contains(':') {
271        model.replacen(':', "/", 1)
272    } else {
273        model.to_string()
274    }
275}
276
277fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
278    match model_tier.unwrap_or("balanced") {
279        "fast" | "quick" => &[
280            "zai",
281            "openai",
282            "github-copilot",
283            "moonshotai",
284            "openrouter",
285            "novita",
286            "google",
287            "anthropic",
288        ],
289        "heavy" | "deep" => &[
290            "zai",
291            "anthropic",
292            "openai",
293            "github-copilot",
294            "moonshotai",
295            "openrouter",
296            "novita",
297            "google",
298        ],
299        _ => &[
300            "zai",
301            "openai",
302            "github-copilot",
303            "anthropic",
304            "moonshotai",
305            "openrouter",
306            "novita",
307            "google",
308        ],
309    }
310}
311
312fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
313    for preferred in provider_preferences_for_tier(model_tier) {
314        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
315            return found;
316        }
317    }
318    if let Some(found) = providers.iter().copied().find(|p| *p == "zai") {
319        return found;
320    }
321    providers[0]
322}
323
324fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
325    match model_tier.unwrap_or("balanced") {
326        "fast" | "quick" => match provider {
327            "moonshotai" => "kimi-k2.5".to_string(),
328            "anthropic" => "claude-haiku-4-5".to_string(),
329            "openai" => "gpt-4o-mini".to_string(),
330            "google" => "gemini-2.5-flash".to_string(),
331            "zhipuai" | "zai" => "glm-5".to_string(),
332            "openrouter" => "z-ai/glm-5".to_string(),
333            "novita" => "qwen/qwen3-coder-next".to_string(),
334            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
335            _ => "glm-5".to_string(),
336        },
337        "heavy" | "deep" => match provider {
338            "moonshotai" => "kimi-k2.5".to_string(),
339            "anthropic" => "claude-sonnet-4-20250514".to_string(),
340            "openai" => "o3".to_string(),
341            "google" => "gemini-2.5-pro".to_string(),
342            "zhipuai" | "zai" => "glm-5".to_string(),
343            "openrouter" => "z-ai/glm-5".to_string(),
344            "novita" => "qwen/qwen3-coder-next".to_string(),
345            "bedrock" => "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
346            _ => "glm-5".to_string(),
347        },
348        _ => match provider {
349            "moonshotai" => "kimi-k2.5".to_string(),
350            "anthropic" => "claude-sonnet-4-20250514".to_string(),
351            "openai" => "gpt-4o".to_string(),
352            "google" => "gemini-2.5-pro".to_string(),
353            "zhipuai" | "zai" => "glm-5".to_string(),
354            "openrouter" => "z-ai/glm-5".to_string(),
355            "novita" => "qwen/qwen3-coder-next".to_string(),
356            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
357            _ => "glm-5".to_string(),
358        },
359    }
360}
361
362fn prefers_temperature_one(model: &str) -> bool {
363    let normalized = model.to_ascii_lowercase();
364    normalized.contains("kimi-k2") || normalized.contains("glm-") || normalized.contains("minimax")
365}
366
367fn is_swarm_agent(agent_type: &str) -> bool {
368    matches!(
369        agent_type.trim().to_ascii_lowercase().as_str(),
370        "swarm" | "parallel" | "multi-agent"
371    )
372}
373
374fn metadata_lookup<'a>(
375    metadata: &'a serde_json::Map<String, serde_json::Value>,
376    key: &str,
377) -> Option<&'a serde_json::Value> {
378    metadata
379        .get(key)
380        .or_else(|| {
381            metadata
382                .get("routing")
383                .and_then(|v| v.as_object())
384                .and_then(|obj| obj.get(key))
385        })
386        .or_else(|| {
387            metadata
388                .get("swarm")
389                .and_then(|v| v.as_object())
390                .and_then(|obj| obj.get(key))
391        })
392}
393
394fn metadata_str(
395    metadata: &serde_json::Map<String, serde_json::Value>,
396    keys: &[&str],
397) -> Option<String> {
398    for key in keys {
399        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
400            let trimmed = value.trim();
401            if !trimmed.is_empty() {
402                return Some(trimmed.to_string());
403            }
404        }
405    }
406    None
407}
408
409fn metadata_usize(
410    metadata: &serde_json::Map<String, serde_json::Value>,
411    keys: &[&str],
412) -> Option<usize> {
413    for key in keys {
414        if let Some(value) = metadata_lookup(metadata, key) {
415            if let Some(v) = value.as_u64() {
416                return usize::try_from(v).ok();
417            }
418            if let Some(v) = value.as_i64() {
419                if v >= 0 {
420                    return usize::try_from(v as u64).ok();
421                }
422            }
423            if let Some(v) = value.as_str() {
424                if let Ok(parsed) = v.trim().parse::<usize>() {
425                    return Some(parsed);
426                }
427            }
428        }
429    }
430    None
431}
432
433fn metadata_u64(
434    metadata: &serde_json::Map<String, serde_json::Value>,
435    keys: &[&str],
436) -> Option<u64> {
437    for key in keys {
438        if let Some(value) = metadata_lookup(metadata, key) {
439            if let Some(v) = value.as_u64() {
440                return Some(v);
441            }
442            if let Some(v) = value.as_i64() {
443                if v >= 0 {
444                    return Some(v as u64);
445                }
446            }
447            if let Some(v) = value.as_str() {
448                if let Ok(parsed) = v.trim().parse::<u64>() {
449                    return Some(parsed);
450                }
451            }
452        }
453    }
454    None
455}
456
457fn metadata_bool(
458    metadata: &serde_json::Map<String, serde_json::Value>,
459    keys: &[&str],
460) -> Option<bool> {
461    for key in keys {
462        if let Some(value) = metadata_lookup(metadata, key) {
463            if let Some(v) = value.as_bool() {
464                return Some(v);
465            }
466            if let Some(v) = value.as_str() {
467                match v.trim().to_ascii_lowercase().as_str() {
468                    "1" | "true" | "yes" | "on" => return Some(true),
469                    "0" | "false" | "no" | "off" => return Some(false),
470                    _ => {}
471                }
472            }
473        }
474    }
475    None
476}
477
478fn parse_swarm_strategy(
479    metadata: &serde_json::Map<String, serde_json::Value>,
480) -> DecompositionStrategy {
481    match metadata_str(
482        metadata,
483        &[
484            "decomposition_strategy",
485            "swarm_strategy",
486            "strategy",
487            "swarm_decomposition",
488        ],
489    )
490    .as_deref()
491    .map(|s| s.to_ascii_lowercase())
492    .as_deref()
493    {
494        Some("none") | Some("single") => DecompositionStrategy::None,
495        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
496        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
497        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
498        _ => DecompositionStrategy::Automatic,
499    }
500}
501
502async fn resolve_swarm_model(
503    explicit_model: Option<String>,
504    model_tier: Option<&str>,
505) -> Option<String> {
506    if let Some(model) = explicit_model {
507        if !model.trim().is_empty() {
508            return Some(model);
509        }
510    }
511
512    let registry = ProviderRegistry::from_vault().await.ok()?;
513    let providers = registry.list();
514    if providers.is_empty() {
515        return None;
516    }
517    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
518    let model = default_model_for_provider(provider, model_tier);
519    Some(format!("{}/{}", provider, model))
520}
521
522fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
523    match event {
524        SwarmEvent::Started {
525            task,
526            total_subtasks,
527        } => Some(format!(
528            "[swarm] started task={} planned_subtasks={}",
529            task, total_subtasks
530        )),
531        SwarmEvent::StageComplete {
532            stage,
533            completed,
534            failed,
535        } => Some(format!(
536            "[swarm] stage={} completed={} failed={}",
537            stage, completed, failed
538        )),
539        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
540            "[swarm] subtask id={} status={}",
541            &id.chars().take(8).collect::<String>(),
542            format!("{status:?}").to_ascii_lowercase()
543        )),
544        SwarmEvent::AgentToolCall {
545            subtask_id,
546            tool_name,
547        } => Some(format!(
548            "[swarm] subtask id={} tool={}",
549            &subtask_id.chars().take(8).collect::<String>(),
550            tool_name
551        )),
552        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
553            "[swarm] subtask id={} error={}",
554            &subtask_id.chars().take(8).collect::<String>(),
555            error
556        )),
557        SwarmEvent::Complete { success, stats } => Some(format!(
558            "[swarm] complete success={} subtasks={} speedup={:.2}",
559            success,
560            stats.subagents_completed + stats.subagents_failed,
561            stats.speedup_factor
562        )),
563        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
564        _ => None,
565    }
566}
567
568async fn register_worker(
569    client: &Client,
570    server: &str,
571    token: &Option<String>,
572    worker_id: &str,
573    name: &str,
574    codebases: &[String],
575) -> Result<()> {
576    // Load ProviderRegistry and collect available models
577    let models = match load_provider_models().await {
578        Ok(m) => m,
579        Err(e) => {
580            tracing::warn!(
581                "Failed to load provider models: {}, proceeding without model info",
582                e
583            );
584            HashMap::new()
585        }
586    };
587
588    // Register via the workers/register endpoint
589    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
590
591    if let Some(t) = token {
592        req = req.bearer_auth(t);
593    }
594
595    // Flatten models HashMap into array of model objects with pricing data
596    // matching the format expected by the A2A server's /models and /workers endpoints
597    let models_array: Vec<serde_json::Value> = models
598        .iter()
599        .flat_map(|(provider, model_infos)| {
600            model_infos.iter().map(move |m| {
601                let mut obj = serde_json::json!({
602                    "id": format!("{}/{}", provider, m.id),
603                    "name": &m.id,
604                    "provider": provider,
605                    "provider_id": provider,
606                });
607                if let Some(input_cost) = m.input_cost_per_million {
608                    obj["input_cost_per_million"] = serde_json::json!(input_cost);
609                }
610                if let Some(output_cost) = m.output_cost_per_million {
611                    obj["output_cost_per_million"] = serde_json::json!(output_cost);
612                }
613                obj
614            })
615        })
616        .collect();
617
618    tracing::info!(
619        "Registering worker with {} models from {} providers",
620        models_array.len(),
621        models.len()
622    );
623
624    let hostname = std::env::var("HOSTNAME")
625        .or_else(|_| std::env::var("COMPUTERNAME"))
626        .unwrap_or_else(|_| "unknown".to_string());
627
628    let res = req
629        .json(&serde_json::json!({
630            "worker_id": worker_id,
631            "name": name,
632            "capabilities": WORKER_CAPABILITIES,
633            "hostname": hostname,
634            "models": models_array,
635            "codebases": codebases,
636        }))
637        .send()
638        .await?;
639
640    if res.status().is_success() {
641        tracing::info!("Worker registered successfully");
642    } else {
643        tracing::warn!("Failed to register worker: {}", res.status());
644    }
645
646    Ok(())
647}
648
649/// Load ProviderRegistry and collect all available models grouped by provider.
650/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
651/// Returns ModelInfo structs (with pricing data when available).
652async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
653    // Try Vault first
654    let registry = match ProviderRegistry::from_vault().await {
655        Ok(r) if !r.list().is_empty() => {
656            tracing::info!("Loaded {} providers from Vault", r.list().len());
657            r
658        }
659        Ok(_) => {
660            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
661            fallback_registry().await?
662        }
663        Err(e) => {
664            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
665            fallback_registry().await?
666        }
667    };
668
669    // Fetch the models.dev catalog for pricing data enrichment
670    let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
671
672    // Map provider IDs to their catalog equivalents (some differ)
673    let catalog_alias = |pid: &str| -> String {
674        match pid {
675            "bedrock" => "amazon-bedrock".to_string(),
676            "novita" => "novita-ai".to_string(),
677            _ => pid.to_string(),
678        }
679    };
680
681    let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
682
683    for provider_name in registry.list() {
684        if let Some(provider) = registry.get(provider_name) {
685            match provider.list_models().await {
686                Ok(models) => {
687                    let enriched: Vec<crate::provider::ModelInfo> = models
688                        .into_iter()
689                        .map(|mut m| {
690                            // Enrich with catalog pricing if the provider didn't set it
691                            if m.input_cost_per_million.is_none()
692                                || m.output_cost_per_million.is_none()
693                            {
694                                if let Some(ref cat) = catalog {
695                                    let cat_pid = catalog_alias(provider_name);
696                                    if let Some(prov_info) = cat.get_provider(&cat_pid) {
697                                        // Try exact match first, then strip "us." prefix (bedrock uses us.vendor.model format)
698                                        let model_info =
699                                            prov_info.models.get(&m.id).or_else(|| {
700                                                m.id.strip_prefix("us.").and_then(|stripped| {
701                                                    prov_info.models.get(stripped)
702                                                })
703                                            });
704                                        if let Some(model_info) = model_info {
705                                            if let Some(ref cost) = model_info.cost {
706                                                if m.input_cost_per_million.is_none() {
707                                                    m.input_cost_per_million = Some(cost.input);
708                                                }
709                                                if m.output_cost_per_million.is_none() {
710                                                    m.output_cost_per_million = Some(cost.output);
711                                                }
712                                            }
713                                        }
714                                    }
715                                }
716                            }
717                            m
718                        })
719                        .collect();
720                    if !enriched.is_empty() {
721                        tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
722                        models_by_provider.insert(provider_name.to_string(), enriched);
723                    }
724                }
725                Err(e) => {
726                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
727                }
728            }
729        }
730    }
731
732    // If we still have 0, try the models.dev catalog as last resort
733    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
734    // because the worker should advertise what it can handle. The server handles routing.
735    if models_by_provider.is_empty() {
736        tracing::info!(
737            "No authenticated providers found, fetching models.dev catalog (all providers)"
738        );
739        if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
740            // Use all_providers_with_models() to get every provider+model from catalog
741            // regardless of API key availability (Vault may be down)
742            for (provider_id, provider_info) in cat.all_providers() {
743                let model_infos: Vec<crate::provider::ModelInfo> = provider_info
744                    .models
745                    .values()
746                    .map(|m| cat.to_model_info(m, provider_id))
747                    .collect();
748                if !model_infos.is_empty() {
749                    tracing::debug!(
750                        "Catalog provider {}: {} models",
751                        provider_id,
752                        model_infos.len()
753                    );
754                    models_by_provider.insert(provider_id.clone(), model_infos);
755                }
756            }
757            tracing::info!(
758                "Loaded {} providers with {} total models from catalog",
759                models_by_provider.len(),
760                models_by_provider.values().map(|v| v.len()).sum::<usize>()
761            );
762        }
763    }
764
765    Ok(models_by_provider)
766}
767
768/// Fallback: build a ProviderRegistry from config file + environment variables
769async fn fallback_registry() -> Result<ProviderRegistry> {
770    let config = crate::config::Config::load().await.unwrap_or_default();
771    ProviderRegistry::from_config(&config).await
772}
773
774async fn fetch_pending_tasks(
775    client: &Client,
776    server: &str,
777    token: &Option<String>,
778    worker_id: &str,
779    processing: &Arc<Mutex<HashSet<String>>>,
780    auto_approve: &AutoApprove,
781    bus: &Arc<AgentBus>,
782) -> Result<()> {
783    tracing::info!("Checking for pending tasks...");
784
785    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
786    if let Some(t) = token {
787        req = req.bearer_auth(t);
788    }
789
790    let res = req.send().await?;
791    if !res.status().is_success() {
792        return Ok(());
793    }
794
795    let data: serde_json::Value = res.json().await?;
796    // Handle both plain array response and {tasks: [...]} wrapper
797    let tasks = if let Some(arr) = data.as_array() {
798        arr.clone()
799    } else {
800        data["tasks"].as_array().cloned().unwrap_or_default()
801    };
802
803    tracing::info!("Found {} pending task(s)", tasks.len());
804
805    for task in tasks {
806        if let Some(id) = task["id"].as_str() {
807            let mut proc = processing.lock().await;
808            if !proc.contains(id) {
809                proc.insert(id.to_string());
810                drop(proc);
811
812                let task_id = id.to_string();
813                let client = client.clone();
814                let server = server.to_string();
815                let token = token.clone();
816                let worker_id = worker_id.to_string();
817                let auto_approve = *auto_approve;
818                let processing = processing.clone();
819                let bus = bus.clone();
820
821                tokio::spawn(async move {
822                    if let Err(e) = handle_task(
823                        &client,
824                        &server,
825                        &token,
826                        &worker_id,
827                        &task,
828                        auto_approve,
829                        &bus,
830                    )
831                    .await
832                    {
833                        tracing::error!("Task {} failed: {}", task_id, e);
834                    }
835                    processing.lock().await.remove(&task_id);
836                });
837            }
838        }
839    }
840
841    Ok(())
842}
843
844#[allow(clippy::too_many_arguments)]
845async fn connect_stream(
846    client: &Client,
847    server: &str,
848    token: &Option<String>,
849    worker_id: &str,
850    name: &str,
851    codebases: &[String],
852    processing: &Arc<Mutex<HashSet<String>>>,
853    auto_approve: &AutoApprove,
854    bus: &Arc<AgentBus>,
855) -> Result<()> {
856    let url = format!(
857        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
858        server,
859        urlencoding::encode(name),
860        urlencoding::encode(worker_id)
861    );
862
863    let mut req = client
864        .get(&url)
865        .header("Accept", "text/event-stream")
866        .header("X-Worker-ID", worker_id)
867        .header("X-Agent-Name", name)
868        .header("X-Codebases", codebases.join(","));
869
870    if let Some(t) = token {
871        req = req.bearer_auth(t);
872    }
873
874    let res = req.send().await?;
875    if !res.status().is_success() {
876        anyhow::bail!("Failed to connect: {}", res.status());
877    }
878
879    tracing::info!("Connected to A2A server");
880
881    let mut stream = res.bytes_stream();
882    let mut buffer = String::new();
883    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
884    poll_interval.tick().await; // consume the initial immediate tick
885
886    loop {
887        tokio::select! {
888            chunk = stream.next() => {
889                match chunk {
890                    Some(Ok(chunk)) => {
891                        buffer.push_str(&String::from_utf8_lossy(&chunk));
892
893                        // Process SSE events
894                        while let Some(pos) = buffer.find("\n\n") {
895                            let event_str = buffer[..pos].to_string();
896                            buffer = buffer[pos + 2..].to_string();
897
898                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
899                                let data = data_line.trim_start_matches("data:").trim();
900                                if data == "[DONE]" || data.is_empty() {
901                                    continue;
902                                }
903
904                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
905                                    spawn_task_handler(
906                                        &task, client, server, token, worker_id,
907                                        processing, auto_approve, bus,
908                                    ).await;
909                                }
910                            }
911                        }
912                    }
913                    Some(Err(e)) => {
914                        return Err(e.into());
915                    }
916                    None => {
917                        // Stream ended
918                        return Ok(());
919                    }
920                }
921            }
922            _ = poll_interval.tick() => {
923                // Periodic poll for pending tasks the SSE stream may have missed
924                if let Err(e) = poll_pending_tasks(
925                    client, server, token, worker_id, processing, auto_approve, bus,
926                ).await {
927                    tracing::warn!("Periodic task poll failed: {}", e);
928                }
929            }
930        }
931    }
932}
933
934async fn spawn_task_handler(
935    task: &serde_json::Value,
936    client: &Client,
937    server: &str,
938    token: &Option<String>,
939    worker_id: &str,
940    processing: &Arc<Mutex<HashSet<String>>>,
941    auto_approve: &AutoApprove,
942    bus: &Arc<AgentBus>,
943) {
944    if let Some(id) = task
945        .get("task")
946        .and_then(|t| t["id"].as_str())
947        .or_else(|| task["id"].as_str())
948    {
949        let mut proc = processing.lock().await;
950        if !proc.contains(id) {
951            proc.insert(id.to_string());
952            drop(proc);
953
954            let task_id = id.to_string();
955            let task = task.clone();
956            let client = client.clone();
957            let server = server.to_string();
958            let token = token.clone();
959            let worker_id = worker_id.to_string();
960            let auto_approve = *auto_approve;
961            let processing_clone = processing.clone();
962            let bus = bus.clone();
963
964            tokio::spawn(async move {
965                if let Err(e) = handle_task(
966                    &client,
967                    &server,
968                    &token,
969                    &worker_id,
970                    &task,
971                    auto_approve,
972                    &bus,
973                )
974                .await
975                {
976                    tracing::error!("Task {} failed: {}", task_id, e);
977                }
978                processing_clone.lock().await.remove(&task_id);
979            });
980        }
981    }
982}
983
984async fn poll_pending_tasks(
985    client: &Client,
986    server: &str,
987    token: &Option<String>,
988    worker_id: &str,
989    processing: &Arc<Mutex<HashSet<String>>>,
990    auto_approve: &AutoApprove,
991    bus: &Arc<AgentBus>,
992) -> Result<()> {
993    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
994    if let Some(t) = token {
995        req = req.bearer_auth(t);
996    }
997
998    let res = req.send().await?;
999    if !res.status().is_success() {
1000        return Ok(());
1001    }
1002
1003    let data: serde_json::Value = res.json().await?;
1004    let tasks = if let Some(arr) = data.as_array() {
1005        arr.clone()
1006    } else {
1007        data["tasks"].as_array().cloned().unwrap_or_default()
1008    };
1009
1010    if !tasks.is_empty() {
1011        tracing::debug!("Poll found {} pending task(s)", tasks.len());
1012    }
1013
1014    for task in &tasks {
1015        spawn_task_handler(
1016            task,
1017            client,
1018            server,
1019            token,
1020            worker_id,
1021            processing,
1022            auto_approve,
1023            bus,
1024        )
1025        .await;
1026    }
1027
1028    Ok(())
1029}
1030
1031async fn handle_task(
1032    client: &Client,
1033    server: &str,
1034    token: &Option<String>,
1035    worker_id: &str,
1036    task: &serde_json::Value,
1037    auto_approve: AutoApprove,
1038    bus: &Arc<AgentBus>,
1039) -> Result<()> {
1040    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
1041    let title = task_str(task, "title").unwrap_or("Untitled");
1042
1043    tracing::info!("Handling task: {} ({})", title, task_id);
1044
1045    // Claim the task
1046    let mut req = client
1047        .post(format!("{}/v1/worker/tasks/claim", server))
1048        .header("X-Worker-ID", worker_id);
1049    if let Some(t) = token {
1050        req = req.bearer_auth(t);
1051    }
1052
1053    let res = req
1054        .json(&serde_json::json!({ "task_id": task_id }))
1055        .send()
1056        .await?;
1057
1058    if !res.status().is_success() {
1059        let status = res.status();
1060        let text = res.text().await?;
1061        if status == reqwest::StatusCode::CONFLICT {
1062            tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1063        } else {
1064            tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1065        }
1066        return Ok(());
1067    }
1068
1069    tracing::info!("Claimed task: {}", task_id);
1070
1071    let metadata = task_metadata(task);
1072    let resume_session_id = metadata
1073        .get("resume_session_id")
1074        .and_then(|v| v.as_str())
1075        .map(|s| s.trim().to_string())
1076        .filter(|s| !s.is_empty());
1077    let complexity_hint = metadata_str(&metadata, &["complexity"]);
1078    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1079        .map(|s| s.to_ascii_lowercase())
1080        .or_else(|| {
1081            complexity_hint.as_ref().map(|complexity| {
1082                match complexity.to_ascii_lowercase().as_str() {
1083                    "quick" => "fast".to_string(),
1084                    "deep" => "heavy".to_string(),
1085                    _ => "balanced".to_string(),
1086                }
1087            })
1088        });
1089    let worker_personality = metadata_str(
1090        &metadata,
1091        &["worker_personality", "personality", "agent_personality"],
1092    );
1093    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1094    let raw_model = task_str(task, "model_ref")
1095        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1096        .or_else(|| task_str(task, "model"))
1097        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1098    let selected_model = raw_model.map(model_ref_to_provider_model);
1099
1100    // Resume existing session when requested; fall back to a fresh session if missing.
1101    let mut session = if let Some(ref sid) = resume_session_id {
1102        match Session::load(sid).await {
1103            Ok(existing) => {
1104                tracing::info!("Resuming session {} for task {}", sid, task_id);
1105                existing
1106            }
1107            Err(e) => {
1108                tracing::warn!(
1109                    "Could not load session {} for task {} ({}), starting a new session",
1110                    sid,
1111                    task_id,
1112                    e
1113                );
1114                Session::new().await?
1115            }
1116        }
1117    } else {
1118        Session::new().await?
1119    };
1120
1121    let agent_type = task_str(task, "agent_type")
1122        .or_else(|| task_str(task, "agent"))
1123        .unwrap_or("build");
1124    session.agent = agent_type.to_string();
1125
1126    if let Some(model) = selected_model.clone() {
1127        session.metadata.model = Some(model);
1128    }
1129
1130    let prompt = task_str(task, "prompt")
1131        .or_else(|| task_str(task, "description"))
1132        .unwrap_or(title);
1133
1134    tracing::info!("Executing prompt: {}", prompt);
1135
1136    // Set up output streaming to forward progress to the server
1137    let stream_client = client.clone();
1138    let stream_server = server.to_string();
1139    let stream_token = token.clone();
1140    let stream_worker_id = worker_id.to_string();
1141    let stream_task_id = task_id.to_string();
1142
1143    let output_callback = move |output: String| {
1144        let c = stream_client.clone();
1145        let s = stream_server.clone();
1146        let t = stream_token.clone();
1147        let w = stream_worker_id.clone();
1148        let tid = stream_task_id.clone();
1149        tokio::spawn(async move {
1150            let mut req = c
1151                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1152                .header("X-Worker-ID", &w);
1153            if let Some(tok) = &t {
1154                req = req.bearer_auth(tok);
1155            }
1156            let _ = req
1157                .json(&serde_json::json!({
1158                    "worker_id": w,
1159                    "output": output,
1160                }))
1161                .send()
1162                .await;
1163        });
1164    };
1165
1166    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1167    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1168        match execute_swarm_with_policy(
1169            &mut session,
1170            prompt,
1171            model_tier.as_deref(),
1172            selected_model,
1173            &metadata,
1174            complexity_hint.as_deref(),
1175            worker_personality.as_deref(),
1176            target_agent_name.as_deref(),
1177            Some(bus),
1178            Some(&output_callback),
1179        )
1180        .await
1181        {
1182            Ok((session_result, true)) => {
1183                tracing::info!("Swarm task completed successfully: {}", task_id);
1184                (
1185                    "completed",
1186                    Some(session_result.text),
1187                    None,
1188                    Some(session_result.session_id),
1189                )
1190            }
1191            Ok((session_result, false)) => {
1192                tracing::warn!("Swarm task completed with failures: {}", task_id);
1193                (
1194                    "failed",
1195                    Some(session_result.text),
1196                    Some("Swarm execution completed with failures".to_string()),
1197                    Some(session_result.session_id),
1198                )
1199            }
1200            Err(e) => {
1201                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1202                ("failed", None, Some(format!("Error: {}", e)), None)
1203            }
1204        }
1205    } else {
1206        match execute_session_with_policy(
1207            &mut session,
1208            prompt,
1209            auto_approve,
1210            model_tier.as_deref(),
1211            Some(&output_callback),
1212        )
1213        .await
1214        {
1215            Ok(session_result) => {
1216                tracing::info!("Task completed successfully: {}", task_id);
1217                (
1218                    "completed",
1219                    Some(session_result.text),
1220                    None,
1221                    Some(session_result.session_id),
1222                )
1223            }
1224            Err(e) => {
1225                tracing::error!("Task failed: {} - {}", task_id, e);
1226                ("failed", None, Some(format!("Error: {}", e)), None)
1227            }
1228        }
1229    };
1230
1231    // Release the task with full details
1232    let mut req = client
1233        .post(format!("{}/v1/worker/tasks/release", server))
1234        .header("X-Worker-ID", worker_id);
1235    if let Some(t) = token {
1236        req = req.bearer_auth(t);
1237    }
1238
1239    req.json(&serde_json::json!({
1240        "task_id": task_id,
1241        "status": status,
1242        "result": result,
1243        "error": error,
1244        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1245    }))
1246    .send()
1247    .await?;
1248
1249    tracing::info!("Task released: {} with status: {}", task_id, status);
1250    Ok(())
1251}
1252
1253async fn execute_swarm_with_policy<F>(
1254    session: &mut Session,
1255    prompt: &str,
1256    model_tier: Option<&str>,
1257    explicit_model: Option<String>,
1258    metadata: &serde_json::Map<String, serde_json::Value>,
1259    complexity_hint: Option<&str>,
1260    worker_personality: Option<&str>,
1261    target_agent_name: Option<&str>,
1262    bus: Option<&Arc<AgentBus>>,
1263    output_callback: Option<&F>,
1264) -> Result<(crate::session::SessionResult, bool)>
1265where
1266    F: Fn(String),
1267{
1268    use crate::provider::{ContentPart, Message, Role};
1269
1270    session.add_message(Message {
1271        role: Role::User,
1272        content: vec![ContentPart::Text {
1273            text: prompt.to_string(),
1274        }],
1275    });
1276
1277    if session.title.is_none() {
1278        session.generate_title().await?;
1279    }
1280
1281    let strategy = parse_swarm_strategy(metadata);
1282    let max_subagents = metadata_usize(
1283        metadata,
1284        &["swarm_max_subagents", "max_subagents", "subagents"],
1285    )
1286    .unwrap_or(10)
1287    .clamp(1, 100);
1288    let max_steps_per_subagent = metadata_usize(
1289        metadata,
1290        &[
1291            "swarm_max_steps_per_subagent",
1292            "max_steps_per_subagent",
1293            "max_steps",
1294        ],
1295    )
1296    .unwrap_or(50)
1297    .clamp(1, 200);
1298    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1299        .unwrap_or(600)
1300        .clamp(30, 3600);
1301    let parallel_enabled =
1302        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1303
1304    let model = resolve_swarm_model(explicit_model, model_tier).await;
1305    if let Some(ref selected_model) = model {
1306        session.metadata.model = Some(selected_model.clone());
1307    }
1308
1309    if let Some(cb) = output_callback {
1310        cb(format!(
1311            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1312            complexity_hint.unwrap_or("standard"),
1313            model_tier.unwrap_or("balanced"),
1314            worker_personality.unwrap_or("auto"),
1315            target_agent_name.unwrap_or("auto")
1316        ));
1317        cb(format!(
1318            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1319            strategy,
1320            max_subagents,
1321            max_steps_per_subagent,
1322            timeout_secs,
1323            model_tier.unwrap_or("balanced")
1324        ));
1325    }
1326
1327    let swarm_config = SwarmConfig {
1328        max_subagents,
1329        max_steps_per_subagent,
1330        subagent_timeout_secs: timeout_secs,
1331        parallel_enabled,
1332        model,
1333        working_dir: session
1334            .metadata
1335            .directory
1336            .as_ref()
1337            .map(|p| p.to_string_lossy().to_string()),
1338        ..Default::default()
1339    };
1340
1341    let swarm_result = if output_callback.is_some() {
1342        let (event_tx, mut event_rx) = mpsc::channel(256);
1343        let mut executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1344        if let Some(bus) = bus {
1345            executor = executor.with_bus(Arc::clone(bus));
1346        }
1347        let prompt_owned = prompt.to_string();
1348        let mut exec_handle =
1349            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1350
1351        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1352
1353        while final_result.is_none() {
1354            tokio::select! {
1355                maybe_event = event_rx.recv() => {
1356                    if let Some(event) = maybe_event {
1357                        if let Some(cb) = output_callback {
1358                            if let Some(line) = format_swarm_event_for_output(&event) {
1359                                cb(line);
1360                            }
1361                        }
1362                    }
1363                }
1364                join_result = &mut exec_handle => {
1365                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1366                    final_result = Some(joined?);
1367                }
1368            }
1369        }
1370
1371        while let Ok(event) = event_rx.try_recv() {
1372            if let Some(cb) = output_callback {
1373                if let Some(line) = format_swarm_event_for_output(&event) {
1374                    cb(line);
1375                }
1376            }
1377        }
1378
1379        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1380    } else {
1381        let mut executor = SwarmExecutor::new(swarm_config);
1382        if let Some(bus) = bus {
1383            executor = executor.with_bus(Arc::clone(bus));
1384        }
1385        executor.execute(prompt, strategy).await?
1386    };
1387
1388    let final_text = if swarm_result.result.trim().is_empty() {
1389        if swarm_result.success {
1390            "Swarm completed without textual output.".to_string()
1391        } else {
1392            "Swarm finished with failures and no textual output.".to_string()
1393        }
1394    } else {
1395        swarm_result.result.clone()
1396    };
1397
1398    session.add_message(Message {
1399        role: Role::Assistant,
1400        content: vec![ContentPart::Text {
1401            text: final_text.clone(),
1402        }],
1403    });
1404    session.save().await?;
1405
1406    Ok((
1407        crate::session::SessionResult {
1408            text: final_text,
1409            session_id: session.id.clone(),
1410        },
1411        swarm_result.success,
1412    ))
1413}
1414
1415/// Execute a session with the given auto-approve policy
1416/// Optionally streams output chunks via the callback
1417async fn execute_session_with_policy<F>(
1418    session: &mut Session,
1419    prompt: &str,
1420    auto_approve: AutoApprove,
1421    model_tier: Option<&str>,
1422    output_callback: Option<&F>,
1423) -> Result<crate::session::SessionResult>
1424where
1425    F: Fn(String),
1426{
1427    use crate::provider::{
1428        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1429    };
1430    use std::sync::Arc;
1431
1432    // Load provider registry from Vault
1433    let registry = ProviderRegistry::from_vault().await?;
1434    let providers = registry.list();
1435    tracing::info!("Available providers: {:?}", providers);
1436
1437    if providers.is_empty() {
1438        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1439    }
1440
1441    // Parse model string
1442    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1443        let (prov, model) = parse_model_string(model_str);
1444        let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
1445        if prov.is_some() {
1446            (prov.map(|s| s.to_string()), model.to_string())
1447        } else if providers.contains(&model) {
1448            (Some(model.to_string()), String::new())
1449        } else {
1450            (None, model.to_string())
1451        }
1452    } else {
1453        (None, String::new())
1454    };
1455
1456    let provider_slice = providers.as_slice();
1457    let provider_requested_but_unavailable = provider_name
1458        .as_deref()
1459        .map(|p| !providers.contains(&p))
1460        .unwrap_or(false);
1461
1462    // Determine which provider to use, preferring explicit request first, then model tier.
1463    let selected_provider = provider_name
1464        .as_deref()
1465        .filter(|p| providers.contains(p))
1466        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1467
1468    let provider = registry
1469        .get(selected_provider)
1470        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1471
1472    // Add user message
1473    session.add_message(Message {
1474        role: Role::User,
1475        content: vec![ContentPart::Text {
1476            text: prompt.to_string(),
1477        }],
1478    });
1479
1480    // Generate title
1481    if session.title.is_none() {
1482        session.generate_title().await?;
1483    }
1484
1485    // Determine model. If a specific provider was requested but not available,
1486    // ignore that model id and fall back to the tier-based default model.
1487    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1488        model_id
1489    } else {
1490        default_model_for_provider(selected_provider, model_tier)
1491    };
1492
1493    // Create tool registry with filtering based on auto-approve policy
1494    let tool_registry =
1495        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1496    let tool_definitions = tool_registry.definitions();
1497
1498    let temperature = if prefers_temperature_one(&model) {
1499        Some(1.0)
1500    } else {
1501        Some(0.7)
1502    };
1503
1504    tracing::info!(
1505        "Using model: {} via provider: {} (tier: {:?})",
1506        model,
1507        selected_provider,
1508        model_tier
1509    );
1510    tracing::info!(
1511        "Available tools: {} (auto_approve: {:?})",
1512        tool_definitions.len(),
1513        auto_approve
1514    );
1515
1516    // Build system prompt
1517    let cwd = std::env::var("PWD")
1518        .map(std::path::PathBuf::from)
1519        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1520    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1521
1522    let mut final_output = String::new();
1523    let max_steps = 50;
1524
1525    for step in 1..=max_steps {
1526        tracing::info!(step = step, "Agent step starting");
1527
1528        // Build messages with system prompt first
1529        let mut messages = vec![Message {
1530            role: Role::System,
1531            content: vec![ContentPart::Text {
1532                text: system_prompt.clone(),
1533            }],
1534        }];
1535        messages.extend(session.messages.clone());
1536
1537        let request = CompletionRequest {
1538            messages,
1539            tools: tool_definitions.clone(),
1540            model: model.clone(),
1541            temperature,
1542            top_p: None,
1543            max_tokens: Some(8192),
1544            stop: Vec::new(),
1545        };
1546
1547        let response = provider.complete(request).await?;
1548
1549        crate::telemetry::TOKEN_USAGE.record_model_usage(
1550            &model,
1551            response.usage.prompt_tokens as u64,
1552            response.usage.completion_tokens as u64,
1553        );
1554
1555        // Extract tool calls
1556        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1557            .message
1558            .content
1559            .iter()
1560            .filter_map(|part| {
1561                if let ContentPart::ToolCall {
1562                    id,
1563                    name,
1564                    arguments,
1565                } = part
1566                {
1567                    let args: serde_json::Value =
1568                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1569                    Some((id.clone(), name.clone(), args))
1570                } else {
1571                    None
1572                }
1573            })
1574            .collect();
1575
1576        // Collect text output and stream it
1577        for part in &response.message.content {
1578            if let ContentPart::Text { text } = part {
1579                if !text.is_empty() {
1580                    final_output.push_str(text);
1581                    final_output.push('\n');
1582                    if let Some(cb) = output_callback {
1583                        cb(text.clone());
1584                    }
1585                }
1586            }
1587        }
1588
1589        // If no tool calls, we're done
1590        if tool_calls.is_empty() {
1591            session.add_message(response.message.clone());
1592            break;
1593        }
1594
1595        session.add_message(response.message.clone());
1596
1597        tracing::info!(
1598            step = step,
1599            num_tools = tool_calls.len(),
1600            "Executing tool calls"
1601        );
1602
1603        // Execute each tool call
1604        for (tool_id, tool_name, tool_input) in tool_calls {
1605            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1606
1607            // Stream tool start event
1608            if let Some(cb) = output_callback {
1609                cb(format!("[tool:start:{}]", tool_name));
1610            }
1611
1612            // Check if tool is allowed based on auto-approve policy
1613            if !is_tool_allowed(&tool_name, auto_approve) {
1614                let msg = format!(
1615                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1616                    tool_name, auto_approve
1617                );
1618                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1619                session.add_message(Message {
1620                    role: Role::Tool,
1621                    content: vec![ContentPart::ToolResult {
1622                        tool_call_id: tool_id,
1623                        content: msg,
1624                    }],
1625                });
1626                continue;
1627            }
1628
1629            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1630                let exec_result: Result<crate::tool::ToolResult> =
1631                    tool.execute(tool_input.clone()).await;
1632                match exec_result {
1633                    Ok(result) => {
1634                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1635                        if let Some(cb) = output_callback {
1636                            let status = if result.success { "ok" } else { "err" };
1637                            cb(format!(
1638                                "[tool:{}:{}] {}",
1639                                tool_name,
1640                                status,
1641                                &result.output[..result.output.len().min(500)]
1642                            ));
1643                        }
1644                        result.output
1645                    }
1646                    Err(e) => {
1647                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1648                        if let Some(cb) = output_callback {
1649                            cb(format!("[tool:{}:err] {}", tool_name, e));
1650                        }
1651                        format!("Error: {}", e)
1652                    }
1653                }
1654            } else {
1655                tracing::warn!(tool = %tool_name, "Tool not found");
1656                format!("Error: Unknown tool '{}'", tool_name)
1657            };
1658
1659            session.add_message(Message {
1660                role: Role::Tool,
1661                content: vec![ContentPart::ToolResult {
1662                    tool_call_id: tool_id,
1663                    content,
1664                }],
1665            });
1666        }
1667    }
1668
1669    session.save().await?;
1670
1671    Ok(crate::session::SessionResult {
1672        text: final_output.trim().to_string(),
1673        session_id: session.id.clone(),
1674    })
1675}
1676
1677/// Check if a tool is allowed based on the auto-approve policy
1678fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1679    match auto_approve {
1680        AutoApprove::All => true,
1681        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1682    }
1683}
1684
1685/// Check if a tool is considered "safe" (read-only)
1686fn is_safe_tool(tool_name: &str) -> bool {
1687    let safe_tools = [
1688        "read",
1689        "list",
1690        "glob",
1691        "grep",
1692        "codesearch",
1693        "lsp",
1694        "webfetch",
1695        "websearch",
1696        "todo_read",
1697        "skill",
1698    ];
1699    safe_tools.contains(&tool_name)
1700}
1701
1702/// Create a filtered tool registry based on the auto-approve policy
1703fn create_filtered_registry(
1704    provider: Arc<dyn crate::provider::Provider>,
1705    model: String,
1706    auto_approve: AutoApprove,
1707) -> crate::tool::ToolRegistry {
1708    use crate::tool::*;
1709
1710    let mut registry = ToolRegistry::new();
1711
1712    // Always add safe tools
1713    registry.register(Arc::new(file::ReadTool::new()));
1714    registry.register(Arc::new(file::ListTool::new()));
1715    registry.register(Arc::new(file::GlobTool::new()));
1716    registry.register(Arc::new(search::GrepTool::new()));
1717    registry.register(Arc::new(lsp::LspTool::new()));
1718    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1719    registry.register(Arc::new(websearch::WebSearchTool::new()));
1720    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1721    registry.register(Arc::new(todo::TodoReadTool::new()));
1722    registry.register(Arc::new(skill::SkillTool::new()));
1723
1724    // Add potentially dangerous tools only if auto_approve is All
1725    if matches!(auto_approve, AutoApprove::All) {
1726        registry.register(Arc::new(file::WriteTool::new()));
1727        registry.register(Arc::new(edit::EditTool::new()));
1728        registry.register(Arc::new(bash::BashTool::new()));
1729        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1730        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1731        registry.register(Arc::new(todo::TodoWriteTool::new()));
1732        registry.register(Arc::new(task::TaskTool::new()));
1733        registry.register(Arc::new(plan::PlanEnterTool::new()));
1734        registry.register(Arc::new(plan::PlanExitTool::new()));
1735        registry.register(Arc::new(rlm::RlmTool::new()));
1736        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1737        registry.register(Arc::new(prd::PrdTool::new()));
1738        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1739        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1740        registry.register(Arc::new(undo::UndoTool));
1741        registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1742    }
1743
1744    registry.register(Arc::new(invalid::InvalidTool::new()));
1745
1746    registry
1747}
1748
1749/// Start the heartbeat background task
1750/// Returns a JoinHandle that can be used to cancel the heartbeat
1751fn start_heartbeat(
1752    client: Client,
1753    server: String,
1754    token: Option<String>,
1755    heartbeat_state: HeartbeatState,
1756    processing: Arc<Mutex<HashSet<String>>>,
1757    cognition_config: CognitionHeartbeatConfig,
1758) -> JoinHandle<()> {
1759    tokio::spawn(async move {
1760        let mut consecutive_failures = 0u32;
1761        const MAX_FAILURES: u32 = 3;
1762        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1763        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1764        let mut cognition_payload_disabled_until: Option<Instant> = None;
1765
1766        let mut interval =
1767            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1768        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1769
1770        loop {
1771            interval.tick().await;
1772
1773            // Update task count from processing set
1774            let active_count = processing.lock().await.len();
1775            heartbeat_state.set_task_count(active_count).await;
1776
1777            // Determine status based on active tasks
1778            let status = if active_count > 0 {
1779                WorkerStatus::Processing
1780            } else {
1781                WorkerStatus::Idle
1782            };
1783            heartbeat_state.set_status(status).await;
1784
1785            // Send heartbeat
1786            let url = format!(
1787                "{}/v1/opencode/workers/{}/heartbeat",
1788                server, heartbeat_state.worker_id
1789            );
1790            let mut req = client.post(&url);
1791
1792            if let Some(ref t) = token {
1793                req = req.bearer_auth(t);
1794            }
1795
1796            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1797            let base_payload = serde_json::json!({
1798                "worker_id": &heartbeat_state.worker_id,
1799                "agent_name": &heartbeat_state.agent_name,
1800                "status": status_str,
1801                "active_task_count": active_count,
1802            });
1803            let mut payload = base_payload.clone();
1804            let mut included_cognition_payload = false;
1805            let cognition_payload_allowed = cognition_payload_disabled_until
1806                .map(|until| Instant::now() >= until)
1807                .unwrap_or(true);
1808
1809            if cognition_config.enabled
1810                && cognition_payload_allowed
1811                && let Some(cognition_payload) =
1812                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1813                && let Some(obj) = payload.as_object_mut()
1814            {
1815                obj.insert("cognition".to_string(), cognition_payload);
1816                included_cognition_payload = true;
1817            }
1818
1819            match req.json(&payload).send().await {
1820                Ok(res) => {
1821                    if res.status().is_success() {
1822                        consecutive_failures = 0;
1823                        tracing::debug!(
1824                            worker_id = %heartbeat_state.worker_id,
1825                            status = status_str,
1826                            active_tasks = active_count,
1827                            "Heartbeat sent successfully"
1828                        );
1829                    } else if included_cognition_payload && res.status().is_client_error() {
1830                        tracing::warn!(
1831                            worker_id = %heartbeat_state.worker_id,
1832                            status = %res.status(),
1833                            "Heartbeat cognition payload rejected, retrying without cognition payload"
1834                        );
1835
1836                        let mut retry_req = client.post(&url);
1837                        if let Some(ref t) = token {
1838                            retry_req = retry_req.bearer_auth(t);
1839                        }
1840
1841                        match retry_req.json(&base_payload).send().await {
1842                            Ok(retry_res) if retry_res.status().is_success() => {
1843                                cognition_payload_disabled_until = Some(
1844                                    Instant::now()
1845                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1846                                );
1847                                consecutive_failures = 0;
1848                                tracing::warn!(
1849                                    worker_id = %heartbeat_state.worker_id,
1850                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1851                                    "Paused cognition heartbeat payload after schema rejection"
1852                                );
1853                            }
1854                            Ok(retry_res) => {
1855                                consecutive_failures += 1;
1856                                tracing::warn!(
1857                                    worker_id = %heartbeat_state.worker_id,
1858                                    status = %retry_res.status(),
1859                                    failures = consecutive_failures,
1860                                    "Heartbeat failed even after retry without cognition payload"
1861                                );
1862                            }
1863                            Err(e) => {
1864                                consecutive_failures += 1;
1865                                tracing::warn!(
1866                                    worker_id = %heartbeat_state.worker_id,
1867                                    error = %e,
1868                                    failures = consecutive_failures,
1869                                    "Heartbeat retry without cognition payload failed"
1870                                );
1871                            }
1872                        }
1873                    } else {
1874                        consecutive_failures += 1;
1875                        tracing::warn!(
1876                            worker_id = %heartbeat_state.worker_id,
1877                            status = %res.status(),
1878                            failures = consecutive_failures,
1879                            "Heartbeat failed"
1880                        );
1881                    }
1882                }
1883                Err(e) => {
1884                    consecutive_failures += 1;
1885                    tracing::warn!(
1886                        worker_id = %heartbeat_state.worker_id,
1887                        error = %e,
1888                        failures = consecutive_failures,
1889                        "Heartbeat request failed"
1890                    );
1891                }
1892            }
1893
1894            // Log error after 3 consecutive failures but do not terminate
1895            if consecutive_failures >= MAX_FAILURES {
1896                tracing::error!(
1897                    worker_id = %heartbeat_state.worker_id,
1898                    failures = consecutive_failures,
1899                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1900                    MAX_FAILURES
1901                );
1902                // Reset counter to avoid spamming error logs
1903                consecutive_failures = 0;
1904            }
1905        }
1906    })
1907}
1908
1909async fn fetch_cognition_heartbeat_payload(
1910    client: &Client,
1911    config: &CognitionHeartbeatConfig,
1912) -> Option<serde_json::Value> {
1913    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1914    let status_res = tokio::time::timeout(
1915        Duration::from_millis(config.request_timeout_ms),
1916        client.get(status_url).send(),
1917    )
1918    .await
1919    .ok()?
1920    .ok()?;
1921
1922    if !status_res.status().is_success() {
1923        return None;
1924    }
1925
1926    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1927    let mut payload = serde_json::json!({
1928        "running": status.running,
1929        "last_tick_at": status.last_tick_at,
1930        "active_persona_count": status.active_persona_count,
1931        "events_buffered": status.events_buffered,
1932        "snapshots_buffered": status.snapshots_buffered,
1933        "loop_interval_ms": status.loop_interval_ms,
1934    });
1935
1936    if config.include_thought_summary {
1937        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1938        let snapshot_res = tokio::time::timeout(
1939            Duration::from_millis(config.request_timeout_ms),
1940            client.get(snapshot_url).send(),
1941        )
1942        .await
1943        .ok()
1944        .and_then(Result::ok);
1945
1946        if let Some(snapshot_res) = snapshot_res
1947            && snapshot_res.status().is_success()
1948            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1949            && let Some(obj) = payload.as_object_mut()
1950        {
1951            obj.insert(
1952                "latest_snapshot_at".to_string(),
1953                serde_json::Value::String(snapshot.generated_at),
1954            );
1955            obj.insert(
1956                "latest_thought".to_string(),
1957                serde_json::Value::String(trim_for_heartbeat(
1958                    &snapshot.summary,
1959                    config.summary_max_chars,
1960                )),
1961            );
1962            if let Some(model) = snapshot
1963                .metadata
1964                .get("model")
1965                .and_then(serde_json::Value::as_str)
1966            {
1967                obj.insert(
1968                    "latest_thought_model".to_string(),
1969                    serde_json::Value::String(model.to_string()),
1970                );
1971            }
1972            if let Some(source) = snapshot
1973                .metadata
1974                .get("source")
1975                .and_then(serde_json::Value::as_str)
1976            {
1977                obj.insert(
1978                    "latest_thought_source".to_string(),
1979                    serde_json::Value::String(source.to_string()),
1980                );
1981            }
1982        }
1983    }
1984
1985    Some(payload)
1986}
1987
1988fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1989    if input.chars().count() <= max_chars {
1990        return input.trim().to_string();
1991    }
1992
1993    let mut trimmed = String::with_capacity(max_chars + 3);
1994    for ch in input.chars().take(max_chars) {
1995        trimmed.push(ch);
1996    }
1997    trimmed.push_str("...");
1998    trimmed.trim().to_string()
1999}
2000
2001fn env_bool(name: &str, default: bool) -> bool {
2002    std::env::var(name)
2003        .ok()
2004        .and_then(|v| match v.to_ascii_lowercase().as_str() {
2005            "1" | "true" | "yes" | "on" => Some(true),
2006            "0" | "false" | "no" | "off" => Some(false),
2007            _ => None,
2008        })
2009        .unwrap_or(default)
2010}
2011
2012fn env_usize(name: &str, default: usize) -> usize {
2013    std::env::var(name)
2014        .ok()
2015        .and_then(|v| v.parse::<usize>().ok())
2016        .unwrap_or(default)
2017}
2018
2019fn env_u64(name: &str, default: u64) -> u64 {
2020    std::env::var(name)
2021        .ok()
2022        .and_then(|v| v.parse::<u64>().ok())
2023        .unwrap_or(default)
2024}