Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::bus::AgentBus;
4use crate::cli::A2aArgs;
5use crate::provider::ProviderRegistry;
6use crate::session::Session;
7use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
8use crate::tui::swarm_view::SwarmEvent;
9use anyhow::Result;
10use futures::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{Mutex, mpsc};
18use tokio::task::JoinHandle;
19use tokio::time::Instant;
20
21/// Worker status for heartbeat
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum WorkerStatus {
24    Idle,
25    Processing,
26}
27
28impl WorkerStatus {
29    fn as_str(&self) -> &'static str {
30        match self {
31            WorkerStatus::Idle => "idle",
32            WorkerStatus::Processing => "processing",
33        }
34    }
35}
36
37/// Heartbeat state shared between the heartbeat task and the main worker
38#[derive(Clone)]
39struct HeartbeatState {
40    worker_id: String,
41    agent_name: String,
42    status: Arc<Mutex<WorkerStatus>>,
43    active_task_count: Arc<Mutex<usize>>,
44}
45
46impl HeartbeatState {
47    fn new(worker_id: String, agent_name: String) -> Self {
48        Self {
49            worker_id,
50            agent_name,
51            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
52            active_task_count: Arc::new(Mutex::new(0)),
53        }
54    }
55
56    async fn set_status(&self, status: WorkerStatus) {
57        *self.status.lock().await = status;
58    }
59
60    async fn set_task_count(&self, count: usize) {
61        *self.active_task_count.lock().await = count;
62    }
63}
64
65#[derive(Clone, Debug)]
66struct CognitionHeartbeatConfig {
67    enabled: bool,
68    source_base_url: String,
69    include_thought_summary: bool,
70    summary_max_chars: usize,
71    request_timeout_ms: u64,
72}
73
74impl CognitionHeartbeatConfig {
75    fn from_env() -> Self {
76        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
77            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
78            .trim_end_matches('/')
79            .to_string();
80
81        Self {
82            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
83            source_base_url,
84            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
85            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
86                .max(120),
87            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
88        }
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct CognitionStatusSnapshot {
94    running: bool,
95    #[serde(default)]
96    last_tick_at: Option<String>,
97    #[serde(default)]
98    active_persona_count: usize,
99    #[serde(default)]
100    events_buffered: usize,
101    #[serde(default)]
102    snapshots_buffered: usize,
103    #[serde(default)]
104    loop_interval_ms: u64,
105}
106
107#[derive(Debug, Deserialize)]
108struct CognitionLatestSnapshot {
109    generated_at: String,
110    summary: String,
111    #[serde(default)]
112    metadata: HashMap<String, serde_json::Value>,
113}
114
115// Run the A2A worker
116pub async fn run(args: A2aArgs) -> Result<()> {
117    let server = args.server.trim_end_matches('/');
118    let name = args
119        .name
120        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
121    let worker_id = generate_worker_id();
122
123    let codebases: Vec<String> = args
124        .codebases
125        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
126        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
127
128    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
129    tracing::info!("Server: {}", server);
130    tracing::info!("Codebases: {:?}", codebases);
131
132    let client = Client::new();
133    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
134    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
135    if cognition_heartbeat.enabled {
136        tracing::info!(
137            source = %cognition_heartbeat.source_base_url,
138            include_thoughts = cognition_heartbeat.include_thought_summary,
139            max_chars = cognition_heartbeat.summary_max_chars,
140            timeout_ms = cognition_heartbeat.request_timeout_ms,
141            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
142        );
143    } else {
144        tracing::warn!(
145            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
146        );
147    }
148
149    let auto_approve = match args.auto_approve.as_str() {
150        "all" => AutoApprove::All,
151        "safe" => AutoApprove::Safe,
152        _ => AutoApprove::None,
153    };
154
155    // Create heartbeat state
156    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
157
158    // Create agent bus for in-process sub-agent communication
159    let bus = AgentBus::new().into_arc();
160    {
161        let handle = bus.handle(&worker_id);
162        handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
163    }
164
165    // Register worker
166    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
167
168    // Fetch pending tasks
169    fetch_pending_tasks(
170        &client,
171        server,
172        &args.token,
173        &worker_id,
174        &processing,
175        &auto_approve,
176        &bus,
177    )
178    .await?;
179
180    // Connect to SSE stream
181    loop {
182        // Re-register worker on each reconnection to report updated models/capabilities
183        if let Err(e) =
184            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
185        {
186            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
187        }
188
189        // Start heartbeat task for this connection
190        let heartbeat_handle = start_heartbeat(
191            client.clone(),
192            server.to_string(),
193            args.token.clone(),
194            heartbeat_state.clone(),
195            processing.clone(),
196            cognition_heartbeat.clone(),
197        );
198
199        match connect_stream(
200            &client,
201            server,
202            &args.token,
203            &worker_id,
204            &name,
205            &codebases,
206            &processing,
207            &auto_approve,
208            &bus,
209        )
210        .await
211        {
212            Ok(()) => {
213                tracing::warn!("Stream ended, reconnecting...");
214            }
215            Err(e) => {
216                tracing::error!("Stream error: {}, reconnecting...", e);
217            }
218        }
219
220        // Cancel heartbeat on disconnection
221        heartbeat_handle.abort();
222        tracing::debug!("Heartbeat cancelled for reconnection");
223
224        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
225    }
226}
227
228fn generate_worker_id() -> String {
229    format!(
230        "wrk_{}_{:x}",
231        chrono::Utc::now().timestamp(),
232        rand::random::<u64>()
233    )
234}
235
236#[derive(Debug, Clone, Copy)]
237enum AutoApprove {
238    All,
239    Safe,
240    None,
241}
242
243/// Default A2A server URL when none is configured
244pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
245
246/// Capabilities of the codetether-agent worker
247const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
248
249fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
250    task.get("task")
251        .and_then(|t| t.get(key))
252        .or_else(|| task.get(key))
253}
254
255fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
256    task_value(task, key).and_then(|v| v.as_str())
257}
258
259fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
260    task_value(task, "metadata")
261        .and_then(|m| m.as_object())
262        .cloned()
263        .unwrap_or_default()
264}
265
266fn model_ref_to_provider_model(model: &str) -> String {
267    // Convert "provider:model" to "provider/model" format, but only if
268    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
269    // contain colons as version separators and must NOT be converted.
270    if !model.contains('/') && model.contains(':') {
271        model.replacen(':', "/", 1)
272    } else {
273        model.to_string()
274    }
275}
276
277fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
278    match model_tier.unwrap_or("balanced") {
279        "fast" | "quick" => &[
280            "zai",
281            "openai",
282            "github-copilot",
283            "moonshotai",
284            "openrouter",
285            "novita",
286            "google",
287            "anthropic",
288        ],
289        "heavy" | "deep" => &[
290            "zai",
291            "anthropic",
292            "openai",
293            "github-copilot",
294            "moonshotai",
295            "openrouter",
296            "novita",
297            "google",
298        ],
299        _ => &[
300            "zai",
301            "openai",
302            "github-copilot",
303            "anthropic",
304            "moonshotai",
305            "openrouter",
306            "novita",
307            "google",
308        ],
309    }
310}
311
312fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
313    for preferred in provider_preferences_for_tier(model_tier) {
314        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
315            return found;
316        }
317    }
318    if let Some(found) = providers.iter().copied().find(|p| *p == "zai") {
319        return found;
320    }
321    providers[0]
322}
323
324fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
325    match model_tier.unwrap_or("balanced") {
326        "fast" | "quick" => match provider {
327            "moonshotai" => "kimi-k2.5".to_string(),
328            "anthropic" => "claude-haiku-4-5".to_string(),
329            "openai" => "gpt-4o-mini".to_string(),
330            "google" => "gemini-2.5-flash".to_string(),
331            "zhipuai" | "zai" => "glm-5".to_string(),
332            "openrouter" => "z-ai/glm-5".to_string(),
333            "novita" => "qwen/qwen3-coder-next".to_string(),
334            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
335            _ => "glm-5".to_string(),
336        },
337        "heavy" | "deep" => match provider {
338            "moonshotai" => "kimi-k2.5".to_string(),
339            "anthropic" => "claude-sonnet-4-20250514".to_string(),
340            "openai" => "o3".to_string(),
341            "google" => "gemini-2.5-pro".to_string(),
342            "zhipuai" | "zai" => "glm-5".to_string(),
343            "openrouter" => "z-ai/glm-5".to_string(),
344            "novita" => "qwen/qwen3-coder-next".to_string(),
345            "bedrock" => "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
346            _ => "glm-5".to_string(),
347        },
348        _ => match provider {
349            "moonshotai" => "kimi-k2.5".to_string(),
350            "anthropic" => "claude-sonnet-4-20250514".to_string(),
351            "openai" => "gpt-4o".to_string(),
352            "google" => "gemini-2.5-pro".to_string(),
353            "zhipuai" | "zai" => "glm-5".to_string(),
354            "openrouter" => "z-ai/glm-5".to_string(),
355            "novita" => "qwen/qwen3-coder-next".to_string(),
356            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
357            _ => "glm-5".to_string(),
358        },
359    }
360}
361
362fn is_swarm_agent(agent_type: &str) -> bool {
363    matches!(
364        agent_type.trim().to_ascii_lowercase().as_str(),
365        "swarm" | "parallel" | "multi-agent"
366    )
367}
368
369fn metadata_lookup<'a>(
370    metadata: &'a serde_json::Map<String, serde_json::Value>,
371    key: &str,
372) -> Option<&'a serde_json::Value> {
373    metadata
374        .get(key)
375        .or_else(|| {
376            metadata
377                .get("routing")
378                .and_then(|v| v.as_object())
379                .and_then(|obj| obj.get(key))
380        })
381        .or_else(|| {
382            metadata
383                .get("swarm")
384                .and_then(|v| v.as_object())
385                .and_then(|obj| obj.get(key))
386        })
387}
388
389fn metadata_str(
390    metadata: &serde_json::Map<String, serde_json::Value>,
391    keys: &[&str],
392) -> Option<String> {
393    for key in keys {
394        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
395            let trimmed = value.trim();
396            if !trimmed.is_empty() {
397                return Some(trimmed.to_string());
398            }
399        }
400    }
401    None
402}
403
404fn metadata_usize(
405    metadata: &serde_json::Map<String, serde_json::Value>,
406    keys: &[&str],
407) -> Option<usize> {
408    for key in keys {
409        if let Some(value) = metadata_lookup(metadata, key) {
410            if let Some(v) = value.as_u64() {
411                return usize::try_from(v).ok();
412            }
413            if let Some(v) = value.as_i64() {
414                if v >= 0 {
415                    return usize::try_from(v as u64).ok();
416                }
417            }
418            if let Some(v) = value.as_str() {
419                if let Ok(parsed) = v.trim().parse::<usize>() {
420                    return Some(parsed);
421                }
422            }
423        }
424    }
425    None
426}
427
428fn metadata_u64(
429    metadata: &serde_json::Map<String, serde_json::Value>,
430    keys: &[&str],
431) -> Option<u64> {
432    for key in keys {
433        if let Some(value) = metadata_lookup(metadata, key) {
434            if let Some(v) = value.as_u64() {
435                return Some(v);
436            }
437            if let Some(v) = value.as_i64() {
438                if v >= 0 {
439                    return Some(v as u64);
440                }
441            }
442            if let Some(v) = value.as_str() {
443                if let Ok(parsed) = v.trim().parse::<u64>() {
444                    return Some(parsed);
445                }
446            }
447        }
448    }
449    None
450}
451
452fn metadata_bool(
453    metadata: &serde_json::Map<String, serde_json::Value>,
454    keys: &[&str],
455) -> Option<bool> {
456    for key in keys {
457        if let Some(value) = metadata_lookup(metadata, key) {
458            if let Some(v) = value.as_bool() {
459                return Some(v);
460            }
461            if let Some(v) = value.as_str() {
462                match v.trim().to_ascii_lowercase().as_str() {
463                    "1" | "true" | "yes" | "on" => return Some(true),
464                    "0" | "false" | "no" | "off" => return Some(false),
465                    _ => {}
466                }
467            }
468        }
469    }
470    None
471}
472
473fn parse_swarm_strategy(
474    metadata: &serde_json::Map<String, serde_json::Value>,
475) -> DecompositionStrategy {
476    match metadata_str(
477        metadata,
478        &[
479            "decomposition_strategy",
480            "swarm_strategy",
481            "strategy",
482            "swarm_decomposition",
483        ],
484    )
485    .as_deref()
486    .map(|s| s.to_ascii_lowercase())
487    .as_deref()
488    {
489        Some("none") | Some("single") => DecompositionStrategy::None,
490        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
491        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
492        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
493        _ => DecompositionStrategy::Automatic,
494    }
495}
496
497async fn resolve_swarm_model(
498    explicit_model: Option<String>,
499    model_tier: Option<&str>,
500) -> Option<String> {
501    if let Some(model) = explicit_model {
502        if !model.trim().is_empty() {
503            return Some(model);
504        }
505    }
506
507    let registry = ProviderRegistry::from_vault().await.ok()?;
508    let providers = registry.list();
509    if providers.is_empty() {
510        return None;
511    }
512    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
513    let model = default_model_for_provider(provider, model_tier);
514    Some(format!("{}/{}", provider, model))
515}
516
517fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
518    match event {
519        SwarmEvent::Started {
520            task,
521            total_subtasks,
522        } => Some(format!(
523            "[swarm] started task={} planned_subtasks={}",
524            task, total_subtasks
525        )),
526        SwarmEvent::StageComplete {
527            stage,
528            completed,
529            failed,
530        } => Some(format!(
531            "[swarm] stage={} completed={} failed={}",
532            stage, completed, failed
533        )),
534        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
535            "[swarm] subtask id={} status={}",
536            &id.chars().take(8).collect::<String>(),
537            format!("{status:?}").to_ascii_lowercase()
538        )),
539        SwarmEvent::AgentToolCall {
540            subtask_id,
541            tool_name,
542        } => Some(format!(
543            "[swarm] subtask id={} tool={}",
544            &subtask_id.chars().take(8).collect::<String>(),
545            tool_name
546        )),
547        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
548            "[swarm] subtask id={} error={}",
549            &subtask_id.chars().take(8).collect::<String>(),
550            error
551        )),
552        SwarmEvent::Complete { success, stats } => Some(format!(
553            "[swarm] complete success={} subtasks={} speedup={:.2}",
554            success,
555            stats.subagents_completed + stats.subagents_failed,
556            stats.speedup_factor
557        )),
558        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
559        _ => None,
560    }
561}
562
563async fn register_worker(
564    client: &Client,
565    server: &str,
566    token: &Option<String>,
567    worker_id: &str,
568    name: &str,
569    codebases: &[String],
570) -> Result<()> {
571    // Load ProviderRegistry and collect available models
572    let models = match load_provider_models().await {
573        Ok(m) => m,
574        Err(e) => {
575            tracing::warn!(
576                "Failed to load provider models: {}, proceeding without model info",
577                e
578            );
579            HashMap::new()
580        }
581    };
582
583    // Register via the workers/register endpoint
584    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
585
586    if let Some(t) = token {
587        req = req.bearer_auth(t);
588    }
589
590    // Flatten models HashMap into array of model objects with pricing data
591    // matching the format expected by the A2A server's /models and /workers endpoints
592    let models_array: Vec<serde_json::Value> = models
593        .iter()
594        .flat_map(|(provider, model_infos)| {
595            model_infos.iter().map(move |m| {
596                let mut obj = serde_json::json!({
597                    "id": format!("{}/{}", provider, m.id),
598                    "name": &m.id,
599                    "provider": provider,
600                    "provider_id": provider,
601                });
602                if let Some(input_cost) = m.input_cost_per_million {
603                    obj["input_cost_per_million"] = serde_json::json!(input_cost);
604                }
605                if let Some(output_cost) = m.output_cost_per_million {
606                    obj["output_cost_per_million"] = serde_json::json!(output_cost);
607                }
608                obj
609            })
610        })
611        .collect();
612
613    tracing::info!(
614        "Registering worker with {} models from {} providers",
615        models_array.len(),
616        models.len()
617    );
618
619    let hostname = std::env::var("HOSTNAME")
620        .or_else(|_| std::env::var("COMPUTERNAME"))
621        .unwrap_or_else(|_| "unknown".to_string());
622
623    let res = req
624        .json(&serde_json::json!({
625            "worker_id": worker_id,
626            "name": name,
627            "capabilities": WORKER_CAPABILITIES,
628            "hostname": hostname,
629            "models": models_array,
630            "codebases": codebases,
631        }))
632        .send()
633        .await?;
634
635    if res.status().is_success() {
636        tracing::info!("Worker registered successfully");
637    } else {
638        tracing::warn!("Failed to register worker: {}", res.status());
639    }
640
641    Ok(())
642}
643
644/// Load ProviderRegistry and collect all available models grouped by provider.
645/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
646/// Returns ModelInfo structs (with pricing data when available).
647async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
648    // Try Vault first
649    let registry = match ProviderRegistry::from_vault().await {
650        Ok(r) if !r.list().is_empty() => {
651            tracing::info!("Loaded {} providers from Vault", r.list().len());
652            r
653        }
654        Ok(_) => {
655            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
656            fallback_registry().await?
657        }
658        Err(e) => {
659            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
660            fallback_registry().await?
661        }
662    };
663
664    // Fetch the models.dev catalog for pricing data enrichment
665    let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
666
667    // Map provider IDs to their catalog equivalents (some differ)
668    let catalog_alias = |pid: &str| -> String {
669        match pid {
670            "bedrock" => "amazon-bedrock".to_string(),
671            "novita" => "novita-ai".to_string(),
672            _ => pid.to_string(),
673        }
674    };
675
676    let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
677
678    for provider_name in registry.list() {
679        if let Some(provider) = registry.get(provider_name) {
680            match provider.list_models().await {
681                Ok(models) => {
682                    let enriched: Vec<crate::provider::ModelInfo> = models
683                        .into_iter()
684                        .map(|mut m| {
685                            // Enrich with catalog pricing if the provider didn't set it
686                            if m.input_cost_per_million.is_none()
687                                || m.output_cost_per_million.is_none()
688                            {
689                                if let Some(ref cat) = catalog {
690                                    let cat_pid = catalog_alias(provider_name);
691                                    if let Some(prov_info) = cat.get_provider(&cat_pid) {
692                                        // Try exact match first, then strip "us." prefix (bedrock uses us.vendor.model format)
693                                        let model_info =
694                                            prov_info.models.get(&m.id).or_else(|| {
695                                                m.id.strip_prefix("us.").and_then(|stripped| {
696                                                    prov_info.models.get(stripped)
697                                                })
698                                            });
699                                        if let Some(model_info) = model_info {
700                                            if let Some(ref cost) = model_info.cost {
701                                                if m.input_cost_per_million.is_none() {
702                                                    m.input_cost_per_million = Some(cost.input);
703                                                }
704                                                if m.output_cost_per_million.is_none() {
705                                                    m.output_cost_per_million = Some(cost.output);
706                                                }
707                                            }
708                                        }
709                                    }
710                                }
711                            }
712                            m
713                        })
714                        .collect();
715                    if !enriched.is_empty() {
716                        tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
717                        models_by_provider.insert(provider_name.to_string(), enriched);
718                    }
719                }
720                Err(e) => {
721                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
722                }
723            }
724        }
725    }
726
727    // If we still have 0, try the models.dev catalog as last resort
728    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
729    // because the worker should advertise what it can handle. The server handles routing.
730    if models_by_provider.is_empty() {
731        tracing::info!(
732            "No authenticated providers found, fetching models.dev catalog (all providers)"
733        );
734        if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
735            // Use all_providers_with_models() to get every provider+model from catalog
736            // regardless of API key availability (Vault may be down)
737            for (provider_id, provider_info) in cat.all_providers() {
738                let model_infos: Vec<crate::provider::ModelInfo> = provider_info
739                    .models
740                    .values()
741                    .map(|m| cat.to_model_info(m, provider_id))
742                    .collect();
743                if !model_infos.is_empty() {
744                    tracing::debug!(
745                        "Catalog provider {}: {} models",
746                        provider_id,
747                        model_infos.len()
748                    );
749                    models_by_provider.insert(provider_id.clone(), model_infos);
750                }
751            }
752            tracing::info!(
753                "Loaded {} providers with {} total models from catalog",
754                models_by_provider.len(),
755                models_by_provider.values().map(|v| v.len()).sum::<usize>()
756            );
757        }
758    }
759
760    Ok(models_by_provider)
761}
762
763/// Fallback: build a ProviderRegistry from config file + environment variables
764async fn fallback_registry() -> Result<ProviderRegistry> {
765    let config = crate::config::Config::load().await.unwrap_or_default();
766    ProviderRegistry::from_config(&config).await
767}
768
769async fn fetch_pending_tasks(
770    client: &Client,
771    server: &str,
772    token: &Option<String>,
773    worker_id: &str,
774    processing: &Arc<Mutex<HashSet<String>>>,
775    auto_approve: &AutoApprove,
776    bus: &Arc<AgentBus>,
777) -> Result<()> {
778    tracing::info!("Checking for pending tasks...");
779
780    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
781    if let Some(t) = token {
782        req = req.bearer_auth(t);
783    }
784
785    let res = req.send().await?;
786    if !res.status().is_success() {
787        return Ok(());
788    }
789
790    let data: serde_json::Value = res.json().await?;
791    // Handle both plain array response and {tasks: [...]} wrapper
792    let tasks = if let Some(arr) = data.as_array() {
793        arr.clone()
794    } else {
795        data["tasks"].as_array().cloned().unwrap_or_default()
796    };
797
798    tracing::info!("Found {} pending task(s)", tasks.len());
799
800    for task in tasks {
801        if let Some(id) = task["id"].as_str() {
802            let mut proc = processing.lock().await;
803            if !proc.contains(id) {
804                proc.insert(id.to_string());
805                drop(proc);
806
807                let task_id = id.to_string();
808                let client = client.clone();
809                let server = server.to_string();
810                let token = token.clone();
811                let worker_id = worker_id.to_string();
812                let auto_approve = *auto_approve;
813                let processing = processing.clone();
814                let bus = bus.clone();
815
816                tokio::spawn(async move {
817                    if let Err(e) = handle_task(
818                        &client,
819                        &server,
820                        &token,
821                        &worker_id,
822                        &task,
823                        auto_approve,
824                        &bus,
825                    )
826                    .await
827                    {
828                        tracing::error!("Task {} failed: {}", task_id, e);
829                    }
830                    processing.lock().await.remove(&task_id);
831                });
832            }
833        }
834    }
835
836    Ok(())
837}
838
839#[allow(clippy::too_many_arguments)]
840async fn connect_stream(
841    client: &Client,
842    server: &str,
843    token: &Option<String>,
844    worker_id: &str,
845    name: &str,
846    codebases: &[String],
847    processing: &Arc<Mutex<HashSet<String>>>,
848    auto_approve: &AutoApprove,
849    bus: &Arc<AgentBus>,
850) -> Result<()> {
851    let url = format!(
852        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
853        server,
854        urlencoding::encode(name),
855        urlencoding::encode(worker_id)
856    );
857
858    let mut req = client
859        .get(&url)
860        .header("Accept", "text/event-stream")
861        .header("X-Worker-ID", worker_id)
862        .header("X-Agent-Name", name)
863        .header("X-Codebases", codebases.join(","));
864
865    if let Some(t) = token {
866        req = req.bearer_auth(t);
867    }
868
869    let res = req.send().await?;
870    if !res.status().is_success() {
871        anyhow::bail!("Failed to connect: {}", res.status());
872    }
873
874    tracing::info!("Connected to A2A server");
875
876    let mut stream = res.bytes_stream();
877    let mut buffer = String::new();
878    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
879    poll_interval.tick().await; // consume the initial immediate tick
880
881    loop {
882        tokio::select! {
883            chunk = stream.next() => {
884                match chunk {
885                    Some(Ok(chunk)) => {
886                        buffer.push_str(&String::from_utf8_lossy(&chunk));
887
888                        // Process SSE events
889                        while let Some(pos) = buffer.find("\n\n") {
890                            let event_str = buffer[..pos].to_string();
891                            buffer = buffer[pos + 2..].to_string();
892
893                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
894                                let data = data_line.trim_start_matches("data:").trim();
895                                if data == "[DONE]" || data.is_empty() {
896                                    continue;
897                                }
898
899                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
900                                    spawn_task_handler(
901                                        &task, client, server, token, worker_id,
902                                        processing, auto_approve, bus,
903                                    ).await;
904                                }
905                            }
906                        }
907                    }
908                    Some(Err(e)) => {
909                        return Err(e.into());
910                    }
911                    None => {
912                        // Stream ended
913                        return Ok(());
914                    }
915                }
916            }
917            _ = poll_interval.tick() => {
918                // Periodic poll for pending tasks the SSE stream may have missed
919                if let Err(e) = poll_pending_tasks(
920                    client, server, token, worker_id, processing, auto_approve, bus,
921                ).await {
922                    tracing::warn!("Periodic task poll failed: {}", e);
923                }
924            }
925        }
926    }
927}
928
929async fn spawn_task_handler(
930    task: &serde_json::Value,
931    client: &Client,
932    server: &str,
933    token: &Option<String>,
934    worker_id: &str,
935    processing: &Arc<Mutex<HashSet<String>>>,
936    auto_approve: &AutoApprove,
937    bus: &Arc<AgentBus>,
938) {
939    if let Some(id) = task
940        .get("task")
941        .and_then(|t| t["id"].as_str())
942        .or_else(|| task["id"].as_str())
943    {
944        let mut proc = processing.lock().await;
945        if !proc.contains(id) {
946            proc.insert(id.to_string());
947            drop(proc);
948
949            let task_id = id.to_string();
950            let task = task.clone();
951            let client = client.clone();
952            let server = server.to_string();
953            let token = token.clone();
954            let worker_id = worker_id.to_string();
955            let auto_approve = *auto_approve;
956            let processing_clone = processing.clone();
957            let bus = bus.clone();
958
959            tokio::spawn(async move {
960                if let Err(e) = handle_task(
961                    &client,
962                    &server,
963                    &token,
964                    &worker_id,
965                    &task,
966                    auto_approve,
967                    &bus,
968                )
969                .await
970                {
971                    tracing::error!("Task {} failed: {}", task_id, e);
972                }
973                processing_clone.lock().await.remove(&task_id);
974            });
975        }
976    }
977}
978
979async fn poll_pending_tasks(
980    client: &Client,
981    server: &str,
982    token: &Option<String>,
983    worker_id: &str,
984    processing: &Arc<Mutex<HashSet<String>>>,
985    auto_approve: &AutoApprove,
986    bus: &Arc<AgentBus>,
987) -> Result<()> {
988    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
989    if let Some(t) = token {
990        req = req.bearer_auth(t);
991    }
992
993    let res = req.send().await?;
994    if !res.status().is_success() {
995        return Ok(());
996    }
997
998    let data: serde_json::Value = res.json().await?;
999    let tasks = if let Some(arr) = data.as_array() {
1000        arr.clone()
1001    } else {
1002        data["tasks"].as_array().cloned().unwrap_or_default()
1003    };
1004
1005    if !tasks.is_empty() {
1006        tracing::debug!("Poll found {} pending task(s)", tasks.len());
1007    }
1008
1009    for task in &tasks {
1010        spawn_task_handler(
1011            task,
1012            client,
1013            server,
1014            token,
1015            worker_id,
1016            processing,
1017            auto_approve,
1018            bus,
1019        )
1020        .await;
1021    }
1022
1023    Ok(())
1024}
1025
1026async fn handle_task(
1027    client: &Client,
1028    server: &str,
1029    token: &Option<String>,
1030    worker_id: &str,
1031    task: &serde_json::Value,
1032    auto_approve: AutoApprove,
1033    bus: &Arc<AgentBus>,
1034) -> Result<()> {
1035    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
1036    let title = task_str(task, "title").unwrap_or("Untitled");
1037
1038    tracing::info!("Handling task: {} ({})", title, task_id);
1039
1040    // Claim the task
1041    let mut req = client
1042        .post(format!("{}/v1/worker/tasks/claim", server))
1043        .header("X-Worker-ID", worker_id);
1044    if let Some(t) = token {
1045        req = req.bearer_auth(t);
1046    }
1047
1048    let res = req
1049        .json(&serde_json::json!({ "task_id": task_id }))
1050        .send()
1051        .await?;
1052
1053    if !res.status().is_success() {
1054        let status = res.status();
1055        let text = res.text().await?;
1056        if status == reqwest::StatusCode::CONFLICT {
1057            tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1058        } else {
1059            tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1060        }
1061        return Ok(());
1062    }
1063
1064    tracing::info!("Claimed task: {}", task_id);
1065
1066    let metadata = task_metadata(task);
1067    let resume_session_id = metadata
1068        .get("resume_session_id")
1069        .and_then(|v| v.as_str())
1070        .map(|s| s.trim().to_string())
1071        .filter(|s| !s.is_empty());
1072    let complexity_hint = metadata_str(&metadata, &["complexity"]);
1073    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1074        .map(|s| s.to_ascii_lowercase())
1075        .or_else(|| {
1076            complexity_hint.as_ref().map(|complexity| {
1077                match complexity.to_ascii_lowercase().as_str() {
1078                    "quick" => "fast".to_string(),
1079                    "deep" => "heavy".to_string(),
1080                    _ => "balanced".to_string(),
1081                }
1082            })
1083        });
1084    let worker_personality = metadata_str(
1085        &metadata,
1086        &["worker_personality", "personality", "agent_personality"],
1087    );
1088    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1089    let raw_model = task_str(task, "model_ref")
1090        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1091        .or_else(|| task_str(task, "model"))
1092        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1093    let selected_model = raw_model.map(model_ref_to_provider_model);
1094
1095    // Resume existing session when requested; fall back to a fresh session if missing.
1096    let mut session = if let Some(ref sid) = resume_session_id {
1097        match Session::load(sid).await {
1098            Ok(existing) => {
1099                tracing::info!("Resuming session {} for task {}", sid, task_id);
1100                existing
1101            }
1102            Err(e) => {
1103                tracing::warn!(
1104                    "Could not load session {} for task {} ({}), starting a new session",
1105                    sid,
1106                    task_id,
1107                    e
1108                );
1109                Session::new().await?
1110            }
1111        }
1112    } else {
1113        Session::new().await?
1114    };
1115
1116    let agent_type = task_str(task, "agent_type")
1117        .or_else(|| task_str(task, "agent"))
1118        .unwrap_or("build");
1119    session.agent = agent_type.to_string();
1120
1121    if let Some(model) = selected_model.clone() {
1122        session.metadata.model = Some(model);
1123    }
1124
1125    let prompt = task_str(task, "prompt")
1126        .or_else(|| task_str(task, "description"))
1127        .unwrap_or(title);
1128
1129    tracing::info!("Executing prompt: {}", prompt);
1130
1131    // Set up output streaming to forward progress to the server
1132    let stream_client = client.clone();
1133    let stream_server = server.to_string();
1134    let stream_token = token.clone();
1135    let stream_worker_id = worker_id.to_string();
1136    let stream_task_id = task_id.to_string();
1137
1138    let output_callback = move |output: String| {
1139        let c = stream_client.clone();
1140        let s = stream_server.clone();
1141        let t = stream_token.clone();
1142        let w = stream_worker_id.clone();
1143        let tid = stream_task_id.clone();
1144        tokio::spawn(async move {
1145            let mut req = c
1146                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1147                .header("X-Worker-ID", &w);
1148            if let Some(tok) = &t {
1149                req = req.bearer_auth(tok);
1150            }
1151            let _ = req
1152                .json(&serde_json::json!({
1153                    "worker_id": w,
1154                    "output": output,
1155                }))
1156                .send()
1157                .await;
1158        });
1159    };
1160
1161    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1162    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1163        match execute_swarm_with_policy(
1164            &mut session,
1165            prompt,
1166            model_tier.as_deref(),
1167            selected_model,
1168            &metadata,
1169            complexity_hint.as_deref(),
1170            worker_personality.as_deref(),
1171            target_agent_name.as_deref(),
1172            Some(bus),
1173            Some(&output_callback),
1174        )
1175        .await
1176        {
1177            Ok((session_result, true)) => {
1178                tracing::info!("Swarm task completed successfully: {}", task_id);
1179                (
1180                    "completed",
1181                    Some(session_result.text),
1182                    None,
1183                    Some(session_result.session_id),
1184                )
1185            }
1186            Ok((session_result, false)) => {
1187                tracing::warn!("Swarm task completed with failures: {}", task_id);
1188                (
1189                    "failed",
1190                    Some(session_result.text),
1191                    Some("Swarm execution completed with failures".to_string()),
1192                    Some(session_result.session_id),
1193                )
1194            }
1195            Err(e) => {
1196                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1197                ("failed", None, Some(format!("Error: {}", e)), None)
1198            }
1199        }
1200    } else {
1201        match execute_session_with_policy(
1202            &mut session,
1203            prompt,
1204            auto_approve,
1205            model_tier.as_deref(),
1206            Some(&output_callback),
1207        )
1208        .await
1209        {
1210            Ok(session_result) => {
1211                tracing::info!("Task completed successfully: {}", task_id);
1212                (
1213                    "completed",
1214                    Some(session_result.text),
1215                    None,
1216                    Some(session_result.session_id),
1217                )
1218            }
1219            Err(e) => {
1220                tracing::error!("Task failed: {} - {}", task_id, e);
1221                ("failed", None, Some(format!("Error: {}", e)), None)
1222            }
1223        }
1224    };
1225
1226    // Release the task with full details
1227    let mut req = client
1228        .post(format!("{}/v1/worker/tasks/release", server))
1229        .header("X-Worker-ID", worker_id);
1230    if let Some(t) = token {
1231        req = req.bearer_auth(t);
1232    }
1233
1234    req.json(&serde_json::json!({
1235        "task_id": task_id,
1236        "status": status,
1237        "result": result,
1238        "error": error,
1239        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1240    }))
1241    .send()
1242    .await?;
1243
1244    tracing::info!("Task released: {} with status: {}", task_id, status);
1245    Ok(())
1246}
1247
1248async fn execute_swarm_with_policy<F>(
1249    session: &mut Session,
1250    prompt: &str,
1251    model_tier: Option<&str>,
1252    explicit_model: Option<String>,
1253    metadata: &serde_json::Map<String, serde_json::Value>,
1254    complexity_hint: Option<&str>,
1255    worker_personality: Option<&str>,
1256    target_agent_name: Option<&str>,
1257    bus: Option<&Arc<AgentBus>>,
1258    output_callback: Option<&F>,
1259) -> Result<(crate::session::SessionResult, bool)>
1260where
1261    F: Fn(String),
1262{
1263    use crate::provider::{ContentPart, Message, Role};
1264
1265    session.add_message(Message {
1266        role: Role::User,
1267        content: vec![ContentPart::Text {
1268            text: prompt.to_string(),
1269        }],
1270    });
1271
1272    if session.title.is_none() {
1273        session.generate_title().await?;
1274    }
1275
1276    let strategy = parse_swarm_strategy(metadata);
1277    let max_subagents = metadata_usize(
1278        metadata,
1279        &["swarm_max_subagents", "max_subagents", "subagents"],
1280    )
1281    .unwrap_or(10)
1282    .clamp(1, 100);
1283    let max_steps_per_subagent = metadata_usize(
1284        metadata,
1285        &[
1286            "swarm_max_steps_per_subagent",
1287            "max_steps_per_subagent",
1288            "max_steps",
1289        ],
1290    )
1291    .unwrap_or(50)
1292    .clamp(1, 200);
1293    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1294        .unwrap_or(600)
1295        .clamp(30, 3600);
1296    let parallel_enabled =
1297        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1298
1299    let model = resolve_swarm_model(explicit_model, model_tier).await;
1300    if let Some(ref selected_model) = model {
1301        session.metadata.model = Some(selected_model.clone());
1302    }
1303
1304    if let Some(cb) = output_callback {
1305        cb(format!(
1306            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1307            complexity_hint.unwrap_or("standard"),
1308            model_tier.unwrap_or("balanced"),
1309            worker_personality.unwrap_or("auto"),
1310            target_agent_name.unwrap_or("auto")
1311        ));
1312        cb(format!(
1313            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1314            strategy,
1315            max_subagents,
1316            max_steps_per_subagent,
1317            timeout_secs,
1318            model_tier.unwrap_or("balanced")
1319        ));
1320    }
1321
1322    let swarm_config = SwarmConfig {
1323        max_subagents,
1324        max_steps_per_subagent,
1325        subagent_timeout_secs: timeout_secs,
1326        parallel_enabled,
1327        model,
1328        working_dir: session
1329            .metadata
1330            .directory
1331            .as_ref()
1332            .map(|p| p.to_string_lossy().to_string()),
1333        ..Default::default()
1334    };
1335
1336    let swarm_result = if output_callback.is_some() {
1337        let (event_tx, mut event_rx) = mpsc::channel(256);
1338        let mut executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1339        if let Some(bus) = bus {
1340            executor = executor.with_bus(Arc::clone(bus));
1341        }
1342        let prompt_owned = prompt.to_string();
1343        let mut exec_handle =
1344            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1345
1346        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1347
1348        while final_result.is_none() {
1349            tokio::select! {
1350                maybe_event = event_rx.recv() => {
1351                    if let Some(event) = maybe_event {
1352                        if let Some(cb) = output_callback {
1353                            if let Some(line) = format_swarm_event_for_output(&event) {
1354                                cb(line);
1355                            }
1356                        }
1357                    }
1358                }
1359                join_result = &mut exec_handle => {
1360                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1361                    final_result = Some(joined?);
1362                }
1363            }
1364        }
1365
1366        while let Ok(event) = event_rx.try_recv() {
1367            if let Some(cb) = output_callback {
1368                if let Some(line) = format_swarm_event_for_output(&event) {
1369                    cb(line);
1370                }
1371            }
1372        }
1373
1374        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1375    } else {
1376        let mut executor = SwarmExecutor::new(swarm_config);
1377        if let Some(bus) = bus {
1378            executor = executor.with_bus(Arc::clone(bus));
1379        }
1380        executor.execute(prompt, strategy).await?
1381    };
1382
1383    let final_text = if swarm_result.result.trim().is_empty() {
1384        if swarm_result.success {
1385            "Swarm completed without textual output.".to_string()
1386        } else {
1387            "Swarm finished with failures and no textual output.".to_string()
1388        }
1389    } else {
1390        swarm_result.result.clone()
1391    };
1392
1393    session.add_message(Message {
1394        role: Role::Assistant,
1395        content: vec![ContentPart::Text {
1396            text: final_text.clone(),
1397        }],
1398    });
1399    session.save().await?;
1400
1401    Ok((
1402        crate::session::SessionResult {
1403            text: final_text,
1404            session_id: session.id.clone(),
1405        },
1406        swarm_result.success,
1407    ))
1408}
1409
1410/// Execute a session with the given auto-approve policy
1411/// Optionally streams output chunks via the callback
1412async fn execute_session_with_policy<F>(
1413    session: &mut Session,
1414    prompt: &str,
1415    auto_approve: AutoApprove,
1416    model_tier: Option<&str>,
1417    output_callback: Option<&F>,
1418) -> Result<crate::session::SessionResult>
1419where
1420    F: Fn(String),
1421{
1422    use crate::provider::{
1423        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1424    };
1425    use std::sync::Arc;
1426
1427    // Load provider registry from Vault
1428    let registry = ProviderRegistry::from_vault().await?;
1429    let providers = registry.list();
1430    tracing::info!("Available providers: {:?}", providers);
1431
1432    if providers.is_empty() {
1433        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1434    }
1435
1436    // Parse model string
1437    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1438        let (prov, model) = parse_model_string(model_str);
1439        let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
1440        if prov.is_some() {
1441            (prov.map(|s| s.to_string()), model.to_string())
1442        } else if providers.contains(&model) {
1443            (Some(model.to_string()), String::new())
1444        } else {
1445            (None, model.to_string())
1446        }
1447    } else {
1448        (None, String::new())
1449    };
1450
1451    let provider_slice = providers.as_slice();
1452    let provider_requested_but_unavailable = provider_name
1453        .as_deref()
1454        .map(|p| !providers.contains(&p))
1455        .unwrap_or(false);
1456
1457    // Determine which provider to use, preferring explicit request first, then model tier.
1458    let selected_provider = provider_name
1459        .as_deref()
1460        .filter(|p| providers.contains(p))
1461        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1462
1463    let provider = registry
1464        .get(selected_provider)
1465        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1466
1467    // Add user message
1468    session.add_message(Message {
1469        role: Role::User,
1470        content: vec![ContentPart::Text {
1471            text: prompt.to_string(),
1472        }],
1473    });
1474
1475    // Generate title
1476    if session.title.is_none() {
1477        session.generate_title().await?;
1478    }
1479
1480    // Determine model. If a specific provider was requested but not available,
1481    // ignore that model id and fall back to the tier-based default model.
1482    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1483        model_id
1484    } else {
1485        default_model_for_provider(selected_provider, model_tier)
1486    };
1487
1488    // Create tool registry with filtering based on auto-approve policy
1489    let tool_registry =
1490        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1491    let tool_definitions = tool_registry.definitions();
1492
1493    let temperature = if model.contains("kimi-k2") {
1494        Some(1.0)
1495    } else {
1496        Some(0.7)
1497    };
1498
1499    tracing::info!(
1500        "Using model: {} via provider: {} (tier: {:?})",
1501        model,
1502        selected_provider,
1503        model_tier
1504    );
1505    tracing::info!(
1506        "Available tools: {} (auto_approve: {:?})",
1507        tool_definitions.len(),
1508        auto_approve
1509    );
1510
1511    // Build system prompt
1512    let cwd = std::env::var("PWD")
1513        .map(std::path::PathBuf::from)
1514        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1515    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1516
1517    let mut final_output = String::new();
1518    let max_steps = 50;
1519
1520    for step in 1..=max_steps {
1521        tracing::info!(step = step, "Agent step starting");
1522
1523        // Build messages with system prompt first
1524        let mut messages = vec![Message {
1525            role: Role::System,
1526            content: vec![ContentPart::Text {
1527                text: system_prompt.clone(),
1528            }],
1529        }];
1530        messages.extend(session.messages.clone());
1531
1532        let request = CompletionRequest {
1533            messages,
1534            tools: tool_definitions.clone(),
1535            model: model.clone(),
1536            temperature,
1537            top_p: None,
1538            max_tokens: Some(8192),
1539            stop: Vec::new(),
1540        };
1541
1542        let response = provider.complete(request).await?;
1543
1544        crate::telemetry::TOKEN_USAGE.record_model_usage(
1545            &model,
1546            response.usage.prompt_tokens as u64,
1547            response.usage.completion_tokens as u64,
1548        );
1549
1550        // Extract tool calls
1551        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1552            .message
1553            .content
1554            .iter()
1555            .filter_map(|part| {
1556                if let ContentPart::ToolCall {
1557                    id,
1558                    name,
1559                    arguments,
1560                } = part
1561                {
1562                    let args: serde_json::Value =
1563                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1564                    Some((id.clone(), name.clone(), args))
1565                } else {
1566                    None
1567                }
1568            })
1569            .collect();
1570
1571        // Collect text output and stream it
1572        for part in &response.message.content {
1573            if let ContentPart::Text { text } = part {
1574                if !text.is_empty() {
1575                    final_output.push_str(text);
1576                    final_output.push('\n');
1577                    if let Some(cb) = output_callback {
1578                        cb(text.clone());
1579                    }
1580                }
1581            }
1582        }
1583
1584        // If no tool calls, we're done
1585        if tool_calls.is_empty() {
1586            session.add_message(response.message.clone());
1587            break;
1588        }
1589
1590        session.add_message(response.message.clone());
1591
1592        tracing::info!(
1593            step = step,
1594            num_tools = tool_calls.len(),
1595            "Executing tool calls"
1596        );
1597
1598        // Execute each tool call
1599        for (tool_id, tool_name, tool_input) in tool_calls {
1600            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1601
1602            // Stream tool start event
1603            if let Some(cb) = output_callback {
1604                cb(format!("[tool:start:{}]", tool_name));
1605            }
1606
1607            // Check if tool is allowed based on auto-approve policy
1608            if !is_tool_allowed(&tool_name, auto_approve) {
1609                let msg = format!(
1610                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1611                    tool_name, auto_approve
1612                );
1613                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1614                session.add_message(Message {
1615                    role: Role::Tool,
1616                    content: vec![ContentPart::ToolResult {
1617                        tool_call_id: tool_id,
1618                        content: msg,
1619                    }],
1620                });
1621                continue;
1622            }
1623
1624            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1625                let exec_result: Result<crate::tool::ToolResult> =
1626                    tool.execute(tool_input.clone()).await;
1627                match exec_result {
1628                    Ok(result) => {
1629                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1630                        if let Some(cb) = output_callback {
1631                            let status = if result.success { "ok" } else { "err" };
1632                            cb(format!(
1633                                "[tool:{}:{}] {}",
1634                                tool_name,
1635                                status,
1636                                &result.output[..result.output.len().min(500)]
1637                            ));
1638                        }
1639                        result.output
1640                    }
1641                    Err(e) => {
1642                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1643                        if let Some(cb) = output_callback {
1644                            cb(format!("[tool:{}:err] {}", tool_name, e));
1645                        }
1646                        format!("Error: {}", e)
1647                    }
1648                }
1649            } else {
1650                tracing::warn!(tool = %tool_name, "Tool not found");
1651                format!("Error: Unknown tool '{}'", tool_name)
1652            };
1653
1654            session.add_message(Message {
1655                role: Role::Tool,
1656                content: vec![ContentPart::ToolResult {
1657                    tool_call_id: tool_id,
1658                    content,
1659                }],
1660            });
1661        }
1662    }
1663
1664    session.save().await?;
1665
1666    Ok(crate::session::SessionResult {
1667        text: final_output.trim().to_string(),
1668        session_id: session.id.clone(),
1669    })
1670}
1671
1672/// Check if a tool is allowed based on the auto-approve policy
1673fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1674    match auto_approve {
1675        AutoApprove::All => true,
1676        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1677    }
1678}
1679
1680/// Check if a tool is considered "safe" (read-only)
1681fn is_safe_tool(tool_name: &str) -> bool {
1682    let safe_tools = [
1683        "read",
1684        "list",
1685        "glob",
1686        "grep",
1687        "codesearch",
1688        "lsp",
1689        "webfetch",
1690        "websearch",
1691        "todo_read",
1692        "skill",
1693    ];
1694    safe_tools.contains(&tool_name)
1695}
1696
1697/// Create a filtered tool registry based on the auto-approve policy
1698fn create_filtered_registry(
1699    provider: Arc<dyn crate::provider::Provider>,
1700    model: String,
1701    auto_approve: AutoApprove,
1702) -> crate::tool::ToolRegistry {
1703    use crate::tool::*;
1704
1705    let mut registry = ToolRegistry::new();
1706
1707    // Always add safe tools
1708    registry.register(Arc::new(file::ReadTool::new()));
1709    registry.register(Arc::new(file::ListTool::new()));
1710    registry.register(Arc::new(file::GlobTool::new()));
1711    registry.register(Arc::new(search::GrepTool::new()));
1712    registry.register(Arc::new(lsp::LspTool::new()));
1713    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1714    registry.register(Arc::new(websearch::WebSearchTool::new()));
1715    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1716    registry.register(Arc::new(todo::TodoReadTool::new()));
1717    registry.register(Arc::new(skill::SkillTool::new()));
1718
1719    // Add potentially dangerous tools only if auto_approve is All
1720    if matches!(auto_approve, AutoApprove::All) {
1721        registry.register(Arc::new(file::WriteTool::new()));
1722        registry.register(Arc::new(edit::EditTool::new()));
1723        registry.register(Arc::new(bash::BashTool::new()));
1724        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1725        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1726        registry.register(Arc::new(todo::TodoWriteTool::new()));
1727        registry.register(Arc::new(task::TaskTool::new()));
1728        registry.register(Arc::new(plan::PlanEnterTool::new()));
1729        registry.register(Arc::new(plan::PlanExitTool::new()));
1730        registry.register(Arc::new(rlm::RlmTool::new()));
1731        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1732        registry.register(Arc::new(prd::PrdTool::new()));
1733        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1734        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1735        registry.register(Arc::new(undo::UndoTool));
1736        registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1737    }
1738
1739    registry.register(Arc::new(invalid::InvalidTool::new()));
1740
1741    registry
1742}
1743
1744/// Start the heartbeat background task
1745/// Returns a JoinHandle that can be used to cancel the heartbeat
1746fn start_heartbeat(
1747    client: Client,
1748    server: String,
1749    token: Option<String>,
1750    heartbeat_state: HeartbeatState,
1751    processing: Arc<Mutex<HashSet<String>>>,
1752    cognition_config: CognitionHeartbeatConfig,
1753) -> JoinHandle<()> {
1754    tokio::spawn(async move {
1755        let mut consecutive_failures = 0u32;
1756        const MAX_FAILURES: u32 = 3;
1757        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1758        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1759        let mut cognition_payload_disabled_until: Option<Instant> = None;
1760
1761        let mut interval =
1762            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1763        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1764
1765        loop {
1766            interval.tick().await;
1767
1768            // Update task count from processing set
1769            let active_count = processing.lock().await.len();
1770            heartbeat_state.set_task_count(active_count).await;
1771
1772            // Determine status based on active tasks
1773            let status = if active_count > 0 {
1774                WorkerStatus::Processing
1775            } else {
1776                WorkerStatus::Idle
1777            };
1778            heartbeat_state.set_status(status).await;
1779
1780            // Send heartbeat
1781            let url = format!(
1782                "{}/v1/opencode/workers/{}/heartbeat",
1783                server, heartbeat_state.worker_id
1784            );
1785            let mut req = client.post(&url);
1786
1787            if let Some(ref t) = token {
1788                req = req.bearer_auth(t);
1789            }
1790
1791            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1792            let base_payload = serde_json::json!({
1793                "worker_id": &heartbeat_state.worker_id,
1794                "agent_name": &heartbeat_state.agent_name,
1795                "status": status_str,
1796                "active_task_count": active_count,
1797            });
1798            let mut payload = base_payload.clone();
1799            let mut included_cognition_payload = false;
1800            let cognition_payload_allowed = cognition_payload_disabled_until
1801                .map(|until| Instant::now() >= until)
1802                .unwrap_or(true);
1803
1804            if cognition_config.enabled
1805                && cognition_payload_allowed
1806                && let Some(cognition_payload) =
1807                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1808                && let Some(obj) = payload.as_object_mut()
1809            {
1810                obj.insert("cognition".to_string(), cognition_payload);
1811                included_cognition_payload = true;
1812            }
1813
1814            match req.json(&payload).send().await {
1815                Ok(res) => {
1816                    if res.status().is_success() {
1817                        consecutive_failures = 0;
1818                        tracing::debug!(
1819                            worker_id = %heartbeat_state.worker_id,
1820                            status = status_str,
1821                            active_tasks = active_count,
1822                            "Heartbeat sent successfully"
1823                        );
1824                    } else if included_cognition_payload && res.status().is_client_error() {
1825                        tracing::warn!(
1826                            worker_id = %heartbeat_state.worker_id,
1827                            status = %res.status(),
1828                            "Heartbeat cognition payload rejected, retrying without cognition payload"
1829                        );
1830
1831                        let mut retry_req = client.post(&url);
1832                        if let Some(ref t) = token {
1833                            retry_req = retry_req.bearer_auth(t);
1834                        }
1835
1836                        match retry_req.json(&base_payload).send().await {
1837                            Ok(retry_res) if retry_res.status().is_success() => {
1838                                cognition_payload_disabled_until = Some(
1839                                    Instant::now()
1840                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1841                                );
1842                                consecutive_failures = 0;
1843                                tracing::warn!(
1844                                    worker_id = %heartbeat_state.worker_id,
1845                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1846                                    "Paused cognition heartbeat payload after schema rejection"
1847                                );
1848                            }
1849                            Ok(retry_res) => {
1850                                consecutive_failures += 1;
1851                                tracing::warn!(
1852                                    worker_id = %heartbeat_state.worker_id,
1853                                    status = %retry_res.status(),
1854                                    failures = consecutive_failures,
1855                                    "Heartbeat failed even after retry without cognition payload"
1856                                );
1857                            }
1858                            Err(e) => {
1859                                consecutive_failures += 1;
1860                                tracing::warn!(
1861                                    worker_id = %heartbeat_state.worker_id,
1862                                    error = %e,
1863                                    failures = consecutive_failures,
1864                                    "Heartbeat retry without cognition payload failed"
1865                                );
1866                            }
1867                        }
1868                    } else {
1869                        consecutive_failures += 1;
1870                        tracing::warn!(
1871                            worker_id = %heartbeat_state.worker_id,
1872                            status = %res.status(),
1873                            failures = consecutive_failures,
1874                            "Heartbeat failed"
1875                        );
1876                    }
1877                }
1878                Err(e) => {
1879                    consecutive_failures += 1;
1880                    tracing::warn!(
1881                        worker_id = %heartbeat_state.worker_id,
1882                        error = %e,
1883                        failures = consecutive_failures,
1884                        "Heartbeat request failed"
1885                    );
1886                }
1887            }
1888
1889            // Log error after 3 consecutive failures but do not terminate
1890            if consecutive_failures >= MAX_FAILURES {
1891                tracing::error!(
1892                    worker_id = %heartbeat_state.worker_id,
1893                    failures = consecutive_failures,
1894                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1895                    MAX_FAILURES
1896                );
1897                // Reset counter to avoid spamming error logs
1898                consecutive_failures = 0;
1899            }
1900        }
1901    })
1902}
1903
1904async fn fetch_cognition_heartbeat_payload(
1905    client: &Client,
1906    config: &CognitionHeartbeatConfig,
1907) -> Option<serde_json::Value> {
1908    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1909    let status_res = tokio::time::timeout(
1910        Duration::from_millis(config.request_timeout_ms),
1911        client.get(status_url).send(),
1912    )
1913    .await
1914    .ok()?
1915    .ok()?;
1916
1917    if !status_res.status().is_success() {
1918        return None;
1919    }
1920
1921    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1922    let mut payload = serde_json::json!({
1923        "running": status.running,
1924        "last_tick_at": status.last_tick_at,
1925        "active_persona_count": status.active_persona_count,
1926        "events_buffered": status.events_buffered,
1927        "snapshots_buffered": status.snapshots_buffered,
1928        "loop_interval_ms": status.loop_interval_ms,
1929    });
1930
1931    if config.include_thought_summary {
1932        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1933        let snapshot_res = tokio::time::timeout(
1934            Duration::from_millis(config.request_timeout_ms),
1935            client.get(snapshot_url).send(),
1936        )
1937        .await
1938        .ok()
1939        .and_then(Result::ok);
1940
1941        if let Some(snapshot_res) = snapshot_res
1942            && snapshot_res.status().is_success()
1943            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1944            && let Some(obj) = payload.as_object_mut()
1945        {
1946            obj.insert(
1947                "latest_snapshot_at".to_string(),
1948                serde_json::Value::String(snapshot.generated_at),
1949            );
1950            obj.insert(
1951                "latest_thought".to_string(),
1952                serde_json::Value::String(trim_for_heartbeat(
1953                    &snapshot.summary,
1954                    config.summary_max_chars,
1955                )),
1956            );
1957            if let Some(model) = snapshot
1958                .metadata
1959                .get("model")
1960                .and_then(serde_json::Value::as_str)
1961            {
1962                obj.insert(
1963                    "latest_thought_model".to_string(),
1964                    serde_json::Value::String(model.to_string()),
1965                );
1966            }
1967            if let Some(source) = snapshot
1968                .metadata
1969                .get("source")
1970                .and_then(serde_json::Value::as_str)
1971            {
1972                obj.insert(
1973                    "latest_thought_source".to_string(),
1974                    serde_json::Value::String(source.to_string()),
1975                );
1976            }
1977        }
1978    }
1979
1980    Some(payload)
1981}
1982
1983fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1984    if input.chars().count() <= max_chars {
1985        return input.trim().to_string();
1986    }
1987
1988    let mut trimmed = String::with_capacity(max_chars + 3);
1989    for ch in input.chars().take(max_chars) {
1990        trimmed.push(ch);
1991    }
1992    trimmed.push_str("...");
1993    trimmed.trim().to_string()
1994}
1995
1996fn env_bool(name: &str, default: bool) -> bool {
1997    std::env::var(name)
1998        .ok()
1999        .and_then(|v| match v.to_ascii_lowercase().as_str() {
2000            "1" | "true" | "yes" | "on" => Some(true),
2001            "0" | "false" | "no" | "off" => Some(false),
2002            _ => None,
2003        })
2004        .unwrap_or(default)
2005}
2006
2007fn env_usize(name: &str, default: usize) -> usize {
2008    std::env::var(name)
2009        .ok()
2010        .and_then(|v| v.parse::<usize>().ok())
2011        .unwrap_or(default)
2012}
2013
2014fn env_u64(name: &str, default: u64) -> u64 {
2015    std::env::var(name)
2016        .ok()
2017        .and_then(|v| v.parse::<u64>().ok())
2018        .unwrap_or(default)
2019}