Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::bus::AgentBus;
4use crate::cli::A2aArgs;
5use crate::provider::ProviderRegistry;
6use crate::session::Session;
7use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
8use crate::tui::swarm_view::SwarmEvent;
9use anyhow::Result;
10use futures::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{Mutex, mpsc};
18use tokio::task::JoinHandle;
19use tokio::time::Instant;
20
21/// Worker status for heartbeat
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum WorkerStatus {
24    Idle,
25    Processing,
26}
27
28impl WorkerStatus {
29    fn as_str(&self) -> &'static str {
30        match self {
31            WorkerStatus::Idle => "idle",
32            WorkerStatus::Processing => "processing",
33        }
34    }
35}
36
37/// Heartbeat state shared between the heartbeat task and the main worker
38#[derive(Clone)]
39struct HeartbeatState {
40    worker_id: String,
41    agent_name: String,
42    status: Arc<Mutex<WorkerStatus>>,
43    active_task_count: Arc<Mutex<usize>>,
44}
45
46impl HeartbeatState {
47    fn new(worker_id: String, agent_name: String) -> Self {
48        Self {
49            worker_id,
50            agent_name,
51            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
52            active_task_count: Arc::new(Mutex::new(0)),
53        }
54    }
55
56    async fn set_status(&self, status: WorkerStatus) {
57        *self.status.lock().await = status;
58    }
59
60    async fn set_task_count(&self, count: usize) {
61        *self.active_task_count.lock().await = count;
62    }
63}
64
65#[derive(Clone, Debug)]
66struct CognitionHeartbeatConfig {
67    enabled: bool,
68    source_base_url: String,
69    include_thought_summary: bool,
70    summary_max_chars: usize,
71    request_timeout_ms: u64,
72}
73
74impl CognitionHeartbeatConfig {
75    fn from_env() -> Self {
76        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
77            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
78            .trim_end_matches('/')
79            .to_string();
80
81        Self {
82            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
83            source_base_url,
84            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
85            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
86                .max(120),
87            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
88        }
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct CognitionStatusSnapshot {
94    running: bool,
95    #[serde(default)]
96    last_tick_at: Option<String>,
97    #[serde(default)]
98    active_persona_count: usize,
99    #[serde(default)]
100    events_buffered: usize,
101    #[serde(default)]
102    snapshots_buffered: usize,
103    #[serde(default)]
104    loop_interval_ms: u64,
105}
106
107#[derive(Debug, Deserialize)]
108struct CognitionLatestSnapshot {
109    generated_at: String,
110    summary: String,
111    #[serde(default)]
112    metadata: HashMap<String, serde_json::Value>,
113}
114
115// Run the A2A worker
116pub async fn run(args: A2aArgs) -> Result<()> {
117    let server = args.server.trim_end_matches('/');
118    let name = args
119        .name
120        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
121    let worker_id = generate_worker_id();
122
123    let codebases: Vec<String> = args
124        .codebases
125        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
126        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
127
128    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
129    tracing::info!("Server: {}", server);
130    tracing::info!("Codebases: {:?}", codebases);
131
132    let client = Client::new();
133    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
134    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
135    if cognition_heartbeat.enabled {
136        tracing::info!(
137            source = %cognition_heartbeat.source_base_url,
138            include_thoughts = cognition_heartbeat.include_thought_summary,
139            max_chars = cognition_heartbeat.summary_max_chars,
140            timeout_ms = cognition_heartbeat.request_timeout_ms,
141            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
142        );
143    } else {
144        tracing::warn!(
145            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
146        );
147    }
148
149    let auto_approve = match args.auto_approve.as_str() {
150        "all" => AutoApprove::All,
151        "safe" => AutoApprove::Safe,
152        _ => AutoApprove::None,
153    };
154
155    // Create heartbeat state
156    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
157
158    // Create agent bus for in-process sub-agent communication
159    let bus = AgentBus::new().into_arc();
160    {
161        let handle = bus.handle(&worker_id);
162        handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
163    }
164
165    // Register worker
166    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
167
168    // Fetch pending tasks
169    fetch_pending_tasks(
170        &client,
171        server,
172        &args.token,
173        &worker_id,
174        &processing,
175        &auto_approve,
176        &bus,
177    )
178    .await?;
179
180    // Connect to SSE stream
181    loop {
182        // Re-register worker on each reconnection to report updated models/capabilities
183        if let Err(e) =
184            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
185        {
186            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
187        }
188
189        // Start heartbeat task for this connection
190        let heartbeat_handle = start_heartbeat(
191            client.clone(),
192            server.to_string(),
193            args.token.clone(),
194            heartbeat_state.clone(),
195            processing.clone(),
196            cognition_heartbeat.clone(),
197        );
198
199        match connect_stream(
200            &client,
201            server,
202            &args.token,
203            &worker_id,
204            &name,
205            &codebases,
206            &processing,
207            &auto_approve,
208            &bus,
209        )
210        .await
211        {
212            Ok(()) => {
213                tracing::warn!("Stream ended, reconnecting...");
214            }
215            Err(e) => {
216                tracing::error!("Stream error: {}, reconnecting...", e);
217            }
218        }
219
220        // Cancel heartbeat on disconnection
221        heartbeat_handle.abort();
222        tracing::debug!("Heartbeat cancelled for reconnection");
223
224        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
225    }
226}
227
228fn generate_worker_id() -> String {
229    format!(
230        "wrk_{}_{:x}",
231        chrono::Utc::now().timestamp(),
232        rand::random::<u64>()
233    )
234}
235
236#[derive(Debug, Clone, Copy)]
237enum AutoApprove {
238    All,
239    Safe,
240    None,
241}
242
243/// Default A2A server URL when none is configured
244pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
245
246/// Capabilities of the codetether-agent worker
247const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
248
249fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
250    task.get("task")
251        .and_then(|t| t.get(key))
252        .or_else(|| task.get(key))
253}
254
255fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
256    task_value(task, key).and_then(|v| v.as_str())
257}
258
259fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
260    task_value(task, "metadata")
261        .and_then(|m| m.as_object())
262        .cloned()
263        .unwrap_or_default()
264}
265
266fn model_ref_to_provider_model(model: &str) -> String {
267    // Convert "provider:model" to "provider/model" format, but only if
268    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
269    // contain colons as version separators and must NOT be converted.
270    if !model.contains('/') && model.contains(':') {
271        model.replacen(':', "/", 1)
272    } else {
273        model.to_string()
274    }
275}
276
277fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
278    match model_tier.unwrap_or("balanced") {
279        "fast" | "quick" => &[
280            "zai",
281            "openai",
282            "github-copilot",
283            "moonshotai",
284            "openrouter",
285            "novita",
286            "google",
287            "anthropic",
288        ],
289        "heavy" | "deep" => &[
290            "zai",
291            "anthropic",
292            "openai",
293            "github-copilot",
294            "moonshotai",
295            "openrouter",
296            "novita",
297            "google",
298        ],
299        _ => &[
300            "zai",
301            "openai",
302            "github-copilot",
303            "anthropic",
304            "moonshotai",
305            "openrouter",
306            "novita",
307            "google",
308        ],
309    }
310}
311
312fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
313    for preferred in provider_preferences_for_tier(model_tier) {
314        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
315            return found;
316        }
317    }
318    if let Some(found) = providers.iter().copied().find(|p| *p == "zai") {
319        return found;
320    }
321    providers[0]
322}
323
324fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
325    match model_tier.unwrap_or("balanced") {
326        "fast" | "quick" => match provider {
327            "moonshotai" => "kimi-k2.5".to_string(),
328            "anthropic" => "claude-haiku-4-5".to_string(),
329            "openai" => "gpt-4o-mini".to_string(),
330            "google" => "gemini-2.5-flash".to_string(),
331            "zhipuai" | "zai" => "glm-5".to_string(),
332            "openrouter" => "z-ai/glm-5".to_string(),
333            "novita" => "qwen/qwen3-coder-next".to_string(),
334            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
335            _ => "glm-5".to_string(),
336        },
337        "heavy" | "deep" => match provider {
338            "moonshotai" => "kimi-k2.5".to_string(),
339            "anthropic" => "claude-sonnet-4-20250514".to_string(),
340            "openai" => "o3".to_string(),
341            "google" => "gemini-2.5-pro".to_string(),
342            "zhipuai" | "zai" => "glm-5".to_string(),
343            "openrouter" => "z-ai/glm-5".to_string(),
344            "novita" => "qwen/qwen3-coder-next".to_string(),
345            "bedrock" => "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
346            _ => "glm-5".to_string(),
347        },
348        _ => match provider {
349            "moonshotai" => "kimi-k2.5".to_string(),
350            "anthropic" => "claude-sonnet-4-20250514".to_string(),
351            "openai" => "gpt-4o".to_string(),
352            "google" => "gemini-2.5-pro".to_string(),
353            "zhipuai" | "zai" => "glm-5".to_string(),
354            "openrouter" => "z-ai/glm-5".to_string(),
355            "novita" => "qwen/qwen3-coder-next".to_string(),
356            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
357            _ => "glm-5".to_string(),
358        },
359    }
360}
361
362fn prefers_temperature_one(model: &str) -> bool {
363    let normalized = model.to_ascii_lowercase();
364    normalized.contains("kimi-k2")
365        || normalized.contains("glm-")
366        || normalized.contains("minimax")
367}
368
369fn is_swarm_agent(agent_type: &str) -> bool {
370    matches!(
371        agent_type.trim().to_ascii_lowercase().as_str(),
372        "swarm" | "parallel" | "multi-agent"
373    )
374}
375
376fn metadata_lookup<'a>(
377    metadata: &'a serde_json::Map<String, serde_json::Value>,
378    key: &str,
379) -> Option<&'a serde_json::Value> {
380    metadata
381        .get(key)
382        .or_else(|| {
383            metadata
384                .get("routing")
385                .and_then(|v| v.as_object())
386                .and_then(|obj| obj.get(key))
387        })
388        .or_else(|| {
389            metadata
390                .get("swarm")
391                .and_then(|v| v.as_object())
392                .and_then(|obj| obj.get(key))
393        })
394}
395
396fn metadata_str(
397    metadata: &serde_json::Map<String, serde_json::Value>,
398    keys: &[&str],
399) -> Option<String> {
400    for key in keys {
401        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
402            let trimmed = value.trim();
403            if !trimmed.is_empty() {
404                return Some(trimmed.to_string());
405            }
406        }
407    }
408    None
409}
410
411fn metadata_usize(
412    metadata: &serde_json::Map<String, serde_json::Value>,
413    keys: &[&str],
414) -> Option<usize> {
415    for key in keys {
416        if let Some(value) = metadata_lookup(metadata, key) {
417            if let Some(v) = value.as_u64() {
418                return usize::try_from(v).ok();
419            }
420            if let Some(v) = value.as_i64() {
421                if v >= 0 {
422                    return usize::try_from(v as u64).ok();
423                }
424            }
425            if let Some(v) = value.as_str() {
426                if let Ok(parsed) = v.trim().parse::<usize>() {
427                    return Some(parsed);
428                }
429            }
430        }
431    }
432    None
433}
434
435fn metadata_u64(
436    metadata: &serde_json::Map<String, serde_json::Value>,
437    keys: &[&str],
438) -> Option<u64> {
439    for key in keys {
440        if let Some(value) = metadata_lookup(metadata, key) {
441            if let Some(v) = value.as_u64() {
442                return Some(v);
443            }
444            if let Some(v) = value.as_i64() {
445                if v >= 0 {
446                    return Some(v as u64);
447                }
448            }
449            if let Some(v) = value.as_str() {
450                if let Ok(parsed) = v.trim().parse::<u64>() {
451                    return Some(parsed);
452                }
453            }
454        }
455    }
456    None
457}
458
459fn metadata_bool(
460    metadata: &serde_json::Map<String, serde_json::Value>,
461    keys: &[&str],
462) -> Option<bool> {
463    for key in keys {
464        if let Some(value) = metadata_lookup(metadata, key) {
465            if let Some(v) = value.as_bool() {
466                return Some(v);
467            }
468            if let Some(v) = value.as_str() {
469                match v.trim().to_ascii_lowercase().as_str() {
470                    "1" | "true" | "yes" | "on" => return Some(true),
471                    "0" | "false" | "no" | "off" => return Some(false),
472                    _ => {}
473                }
474            }
475        }
476    }
477    None
478}
479
480fn parse_swarm_strategy(
481    metadata: &serde_json::Map<String, serde_json::Value>,
482) -> DecompositionStrategy {
483    match metadata_str(
484        metadata,
485        &[
486            "decomposition_strategy",
487            "swarm_strategy",
488            "strategy",
489            "swarm_decomposition",
490        ],
491    )
492    .as_deref()
493    .map(|s| s.to_ascii_lowercase())
494    .as_deref()
495    {
496        Some("none") | Some("single") => DecompositionStrategy::None,
497        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
498        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
499        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
500        _ => DecompositionStrategy::Automatic,
501    }
502}
503
504async fn resolve_swarm_model(
505    explicit_model: Option<String>,
506    model_tier: Option<&str>,
507) -> Option<String> {
508    if let Some(model) = explicit_model {
509        if !model.trim().is_empty() {
510            return Some(model);
511        }
512    }
513
514    let registry = ProviderRegistry::from_vault().await.ok()?;
515    let providers = registry.list();
516    if providers.is_empty() {
517        return None;
518    }
519    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
520    let model = default_model_for_provider(provider, model_tier);
521    Some(format!("{}/{}", provider, model))
522}
523
524fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
525    match event {
526        SwarmEvent::Started {
527            task,
528            total_subtasks,
529        } => Some(format!(
530            "[swarm] started task={} planned_subtasks={}",
531            task, total_subtasks
532        )),
533        SwarmEvent::StageComplete {
534            stage,
535            completed,
536            failed,
537        } => Some(format!(
538            "[swarm] stage={} completed={} failed={}",
539            stage, completed, failed
540        )),
541        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
542            "[swarm] subtask id={} status={}",
543            &id.chars().take(8).collect::<String>(),
544            format!("{status:?}").to_ascii_lowercase()
545        )),
546        SwarmEvent::AgentToolCall {
547            subtask_id,
548            tool_name,
549        } => Some(format!(
550            "[swarm] subtask id={} tool={}",
551            &subtask_id.chars().take(8).collect::<String>(),
552            tool_name
553        )),
554        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
555            "[swarm] subtask id={} error={}",
556            &subtask_id.chars().take(8).collect::<String>(),
557            error
558        )),
559        SwarmEvent::Complete { success, stats } => Some(format!(
560            "[swarm] complete success={} subtasks={} speedup={:.2}",
561            success,
562            stats.subagents_completed + stats.subagents_failed,
563            stats.speedup_factor
564        )),
565        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
566        _ => None,
567    }
568}
569
570async fn register_worker(
571    client: &Client,
572    server: &str,
573    token: &Option<String>,
574    worker_id: &str,
575    name: &str,
576    codebases: &[String],
577) -> Result<()> {
578    // Load ProviderRegistry and collect available models
579    let models = match load_provider_models().await {
580        Ok(m) => m,
581        Err(e) => {
582            tracing::warn!(
583                "Failed to load provider models: {}, proceeding without model info",
584                e
585            );
586            HashMap::new()
587        }
588    };
589
590    // Register via the workers/register endpoint
591    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
592
593    if let Some(t) = token {
594        req = req.bearer_auth(t);
595    }
596
597    // Flatten models HashMap into array of model objects with pricing data
598    // matching the format expected by the A2A server's /models and /workers endpoints
599    let models_array: Vec<serde_json::Value> = models
600        .iter()
601        .flat_map(|(provider, model_infos)| {
602            model_infos.iter().map(move |m| {
603                let mut obj = serde_json::json!({
604                    "id": format!("{}/{}", provider, m.id),
605                    "name": &m.id,
606                    "provider": provider,
607                    "provider_id": provider,
608                });
609                if let Some(input_cost) = m.input_cost_per_million {
610                    obj["input_cost_per_million"] = serde_json::json!(input_cost);
611                }
612                if let Some(output_cost) = m.output_cost_per_million {
613                    obj["output_cost_per_million"] = serde_json::json!(output_cost);
614                }
615                obj
616            })
617        })
618        .collect();
619
620    tracing::info!(
621        "Registering worker with {} models from {} providers",
622        models_array.len(),
623        models.len()
624    );
625
626    let hostname = std::env::var("HOSTNAME")
627        .or_else(|_| std::env::var("COMPUTERNAME"))
628        .unwrap_or_else(|_| "unknown".to_string());
629
630    let res = req
631        .json(&serde_json::json!({
632            "worker_id": worker_id,
633            "name": name,
634            "capabilities": WORKER_CAPABILITIES,
635            "hostname": hostname,
636            "models": models_array,
637            "codebases": codebases,
638        }))
639        .send()
640        .await?;
641
642    if res.status().is_success() {
643        tracing::info!("Worker registered successfully");
644    } else {
645        tracing::warn!("Failed to register worker: {}", res.status());
646    }
647
648    Ok(())
649}
650
651/// Load ProviderRegistry and collect all available models grouped by provider.
652/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
653/// Returns ModelInfo structs (with pricing data when available).
654async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
655    // Try Vault first
656    let registry = match ProviderRegistry::from_vault().await {
657        Ok(r) if !r.list().is_empty() => {
658            tracing::info!("Loaded {} providers from Vault", r.list().len());
659            r
660        }
661        Ok(_) => {
662            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
663            fallback_registry().await?
664        }
665        Err(e) => {
666            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
667            fallback_registry().await?
668        }
669    };
670
671    // Fetch the models.dev catalog for pricing data enrichment
672    let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
673
674    // Map provider IDs to their catalog equivalents (some differ)
675    let catalog_alias = |pid: &str| -> String {
676        match pid {
677            "bedrock" => "amazon-bedrock".to_string(),
678            "novita" => "novita-ai".to_string(),
679            _ => pid.to_string(),
680        }
681    };
682
683    let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
684
685    for provider_name in registry.list() {
686        if let Some(provider) = registry.get(provider_name) {
687            match provider.list_models().await {
688                Ok(models) => {
689                    let enriched: Vec<crate::provider::ModelInfo> = models
690                        .into_iter()
691                        .map(|mut m| {
692                            // Enrich with catalog pricing if the provider didn't set it
693                            if m.input_cost_per_million.is_none()
694                                || m.output_cost_per_million.is_none()
695                            {
696                                if let Some(ref cat) = catalog {
697                                    let cat_pid = catalog_alias(provider_name);
698                                    if let Some(prov_info) = cat.get_provider(&cat_pid) {
699                                        // Try exact match first, then strip "us." prefix (bedrock uses us.vendor.model format)
700                                        let model_info =
701                                            prov_info.models.get(&m.id).or_else(|| {
702                                                m.id.strip_prefix("us.").and_then(|stripped| {
703                                                    prov_info.models.get(stripped)
704                                                })
705                                            });
706                                        if let Some(model_info) = model_info {
707                                            if let Some(ref cost) = model_info.cost {
708                                                if m.input_cost_per_million.is_none() {
709                                                    m.input_cost_per_million = Some(cost.input);
710                                                }
711                                                if m.output_cost_per_million.is_none() {
712                                                    m.output_cost_per_million = Some(cost.output);
713                                                }
714                                            }
715                                        }
716                                    }
717                                }
718                            }
719                            m
720                        })
721                        .collect();
722                    if !enriched.is_empty() {
723                        tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
724                        models_by_provider.insert(provider_name.to_string(), enriched);
725                    }
726                }
727                Err(e) => {
728                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
729                }
730            }
731        }
732    }
733
734    // If we still have 0, try the models.dev catalog as last resort
735    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
736    // because the worker should advertise what it can handle. The server handles routing.
737    if models_by_provider.is_empty() {
738        tracing::info!(
739            "No authenticated providers found, fetching models.dev catalog (all providers)"
740        );
741        if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
742            // Use all_providers_with_models() to get every provider+model from catalog
743            // regardless of API key availability (Vault may be down)
744            for (provider_id, provider_info) in cat.all_providers() {
745                let model_infos: Vec<crate::provider::ModelInfo> = provider_info
746                    .models
747                    .values()
748                    .map(|m| cat.to_model_info(m, provider_id))
749                    .collect();
750                if !model_infos.is_empty() {
751                    tracing::debug!(
752                        "Catalog provider {}: {} models",
753                        provider_id,
754                        model_infos.len()
755                    );
756                    models_by_provider.insert(provider_id.clone(), model_infos);
757                }
758            }
759            tracing::info!(
760                "Loaded {} providers with {} total models from catalog",
761                models_by_provider.len(),
762                models_by_provider.values().map(|v| v.len()).sum::<usize>()
763            );
764        }
765    }
766
767    Ok(models_by_provider)
768}
769
770/// Fallback: build a ProviderRegistry from config file + environment variables
771async fn fallback_registry() -> Result<ProviderRegistry> {
772    let config = crate::config::Config::load().await.unwrap_or_default();
773    ProviderRegistry::from_config(&config).await
774}
775
776async fn fetch_pending_tasks(
777    client: &Client,
778    server: &str,
779    token: &Option<String>,
780    worker_id: &str,
781    processing: &Arc<Mutex<HashSet<String>>>,
782    auto_approve: &AutoApprove,
783    bus: &Arc<AgentBus>,
784) -> Result<()> {
785    tracing::info!("Checking for pending tasks...");
786
787    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
788    if let Some(t) = token {
789        req = req.bearer_auth(t);
790    }
791
792    let res = req.send().await?;
793    if !res.status().is_success() {
794        return Ok(());
795    }
796
797    let data: serde_json::Value = res.json().await?;
798    // Handle both plain array response and {tasks: [...]} wrapper
799    let tasks = if let Some(arr) = data.as_array() {
800        arr.clone()
801    } else {
802        data["tasks"].as_array().cloned().unwrap_or_default()
803    };
804
805    tracing::info!("Found {} pending task(s)", tasks.len());
806
807    for task in tasks {
808        if let Some(id) = task["id"].as_str() {
809            let mut proc = processing.lock().await;
810            if !proc.contains(id) {
811                proc.insert(id.to_string());
812                drop(proc);
813
814                let task_id = id.to_string();
815                let client = client.clone();
816                let server = server.to_string();
817                let token = token.clone();
818                let worker_id = worker_id.to_string();
819                let auto_approve = *auto_approve;
820                let processing = processing.clone();
821                let bus = bus.clone();
822
823                tokio::spawn(async move {
824                    if let Err(e) = handle_task(
825                        &client,
826                        &server,
827                        &token,
828                        &worker_id,
829                        &task,
830                        auto_approve,
831                        &bus,
832                    )
833                    .await
834                    {
835                        tracing::error!("Task {} failed: {}", task_id, e);
836                    }
837                    processing.lock().await.remove(&task_id);
838                });
839            }
840        }
841    }
842
843    Ok(())
844}
845
846#[allow(clippy::too_many_arguments)]
847async fn connect_stream(
848    client: &Client,
849    server: &str,
850    token: &Option<String>,
851    worker_id: &str,
852    name: &str,
853    codebases: &[String],
854    processing: &Arc<Mutex<HashSet<String>>>,
855    auto_approve: &AutoApprove,
856    bus: &Arc<AgentBus>,
857) -> Result<()> {
858    let url = format!(
859        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
860        server,
861        urlencoding::encode(name),
862        urlencoding::encode(worker_id)
863    );
864
865    let mut req = client
866        .get(&url)
867        .header("Accept", "text/event-stream")
868        .header("X-Worker-ID", worker_id)
869        .header("X-Agent-Name", name)
870        .header("X-Codebases", codebases.join(","));
871
872    if let Some(t) = token {
873        req = req.bearer_auth(t);
874    }
875
876    let res = req.send().await?;
877    if !res.status().is_success() {
878        anyhow::bail!("Failed to connect: {}", res.status());
879    }
880
881    tracing::info!("Connected to A2A server");
882
883    let mut stream = res.bytes_stream();
884    let mut buffer = String::new();
885    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
886    poll_interval.tick().await; // consume the initial immediate tick
887
888    loop {
889        tokio::select! {
890            chunk = stream.next() => {
891                match chunk {
892                    Some(Ok(chunk)) => {
893                        buffer.push_str(&String::from_utf8_lossy(&chunk));
894
895                        // Process SSE events
896                        while let Some(pos) = buffer.find("\n\n") {
897                            let event_str = buffer[..pos].to_string();
898                            buffer = buffer[pos + 2..].to_string();
899
900                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
901                                let data = data_line.trim_start_matches("data:").trim();
902                                if data == "[DONE]" || data.is_empty() {
903                                    continue;
904                                }
905
906                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
907                                    spawn_task_handler(
908                                        &task, client, server, token, worker_id,
909                                        processing, auto_approve, bus,
910                                    ).await;
911                                }
912                            }
913                        }
914                    }
915                    Some(Err(e)) => {
916                        return Err(e.into());
917                    }
918                    None => {
919                        // Stream ended
920                        return Ok(());
921                    }
922                }
923            }
924            _ = poll_interval.tick() => {
925                // Periodic poll for pending tasks the SSE stream may have missed
926                if let Err(e) = poll_pending_tasks(
927                    client, server, token, worker_id, processing, auto_approve, bus,
928                ).await {
929                    tracing::warn!("Periodic task poll failed: {}", e);
930                }
931            }
932        }
933    }
934}
935
936async fn spawn_task_handler(
937    task: &serde_json::Value,
938    client: &Client,
939    server: &str,
940    token: &Option<String>,
941    worker_id: &str,
942    processing: &Arc<Mutex<HashSet<String>>>,
943    auto_approve: &AutoApprove,
944    bus: &Arc<AgentBus>,
945) {
946    if let Some(id) = task
947        .get("task")
948        .and_then(|t| t["id"].as_str())
949        .or_else(|| task["id"].as_str())
950    {
951        let mut proc = processing.lock().await;
952        if !proc.contains(id) {
953            proc.insert(id.to_string());
954            drop(proc);
955
956            let task_id = id.to_string();
957            let task = task.clone();
958            let client = client.clone();
959            let server = server.to_string();
960            let token = token.clone();
961            let worker_id = worker_id.to_string();
962            let auto_approve = *auto_approve;
963            let processing_clone = processing.clone();
964            let bus = bus.clone();
965
966            tokio::spawn(async move {
967                if let Err(e) = handle_task(
968                    &client,
969                    &server,
970                    &token,
971                    &worker_id,
972                    &task,
973                    auto_approve,
974                    &bus,
975                )
976                .await
977                {
978                    tracing::error!("Task {} failed: {}", task_id, e);
979                }
980                processing_clone.lock().await.remove(&task_id);
981            });
982        }
983    }
984}
985
986async fn poll_pending_tasks(
987    client: &Client,
988    server: &str,
989    token: &Option<String>,
990    worker_id: &str,
991    processing: &Arc<Mutex<HashSet<String>>>,
992    auto_approve: &AutoApprove,
993    bus: &Arc<AgentBus>,
994) -> Result<()> {
995    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
996    if let Some(t) = token {
997        req = req.bearer_auth(t);
998    }
999
1000    let res = req.send().await?;
1001    if !res.status().is_success() {
1002        return Ok(());
1003    }
1004
1005    let data: serde_json::Value = res.json().await?;
1006    let tasks = if let Some(arr) = data.as_array() {
1007        arr.clone()
1008    } else {
1009        data["tasks"].as_array().cloned().unwrap_or_default()
1010    };
1011
1012    if !tasks.is_empty() {
1013        tracing::debug!("Poll found {} pending task(s)", tasks.len());
1014    }
1015
1016    for task in &tasks {
1017        spawn_task_handler(
1018            task,
1019            client,
1020            server,
1021            token,
1022            worker_id,
1023            processing,
1024            auto_approve,
1025            bus,
1026        )
1027        .await;
1028    }
1029
1030    Ok(())
1031}
1032
1033async fn handle_task(
1034    client: &Client,
1035    server: &str,
1036    token: &Option<String>,
1037    worker_id: &str,
1038    task: &serde_json::Value,
1039    auto_approve: AutoApprove,
1040    bus: &Arc<AgentBus>,
1041) -> Result<()> {
1042    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
1043    let title = task_str(task, "title").unwrap_or("Untitled");
1044
1045    tracing::info!("Handling task: {} ({})", title, task_id);
1046
1047    // Claim the task
1048    let mut req = client
1049        .post(format!("{}/v1/worker/tasks/claim", server))
1050        .header("X-Worker-ID", worker_id);
1051    if let Some(t) = token {
1052        req = req.bearer_auth(t);
1053    }
1054
1055    let res = req
1056        .json(&serde_json::json!({ "task_id": task_id }))
1057        .send()
1058        .await?;
1059
1060    if !res.status().is_success() {
1061        let status = res.status();
1062        let text = res.text().await?;
1063        if status == reqwest::StatusCode::CONFLICT {
1064            tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1065        } else {
1066            tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1067        }
1068        return Ok(());
1069    }
1070
1071    tracing::info!("Claimed task: {}", task_id);
1072
1073    let metadata = task_metadata(task);
1074    let resume_session_id = metadata
1075        .get("resume_session_id")
1076        .and_then(|v| v.as_str())
1077        .map(|s| s.trim().to_string())
1078        .filter(|s| !s.is_empty());
1079    let complexity_hint = metadata_str(&metadata, &["complexity"]);
1080    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1081        .map(|s| s.to_ascii_lowercase())
1082        .or_else(|| {
1083            complexity_hint.as_ref().map(|complexity| {
1084                match complexity.to_ascii_lowercase().as_str() {
1085                    "quick" => "fast".to_string(),
1086                    "deep" => "heavy".to_string(),
1087                    _ => "balanced".to_string(),
1088                }
1089            })
1090        });
1091    let worker_personality = metadata_str(
1092        &metadata,
1093        &["worker_personality", "personality", "agent_personality"],
1094    );
1095    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1096    let raw_model = task_str(task, "model_ref")
1097        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1098        .or_else(|| task_str(task, "model"))
1099        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1100    let selected_model = raw_model.map(model_ref_to_provider_model);
1101
1102    // Resume existing session when requested; fall back to a fresh session if missing.
1103    let mut session = if let Some(ref sid) = resume_session_id {
1104        match Session::load(sid).await {
1105            Ok(existing) => {
1106                tracing::info!("Resuming session {} for task {}", sid, task_id);
1107                existing
1108            }
1109            Err(e) => {
1110                tracing::warn!(
1111                    "Could not load session {} for task {} ({}), starting a new session",
1112                    sid,
1113                    task_id,
1114                    e
1115                );
1116                Session::new().await?
1117            }
1118        }
1119    } else {
1120        Session::new().await?
1121    };
1122
1123    let agent_type = task_str(task, "agent_type")
1124        .or_else(|| task_str(task, "agent"))
1125        .unwrap_or("build");
1126    session.agent = agent_type.to_string();
1127
1128    if let Some(model) = selected_model.clone() {
1129        session.metadata.model = Some(model);
1130    }
1131
1132    let prompt = task_str(task, "prompt")
1133        .or_else(|| task_str(task, "description"))
1134        .unwrap_or(title);
1135
1136    tracing::info!("Executing prompt: {}", prompt);
1137
1138    // Set up output streaming to forward progress to the server
1139    let stream_client = client.clone();
1140    let stream_server = server.to_string();
1141    let stream_token = token.clone();
1142    let stream_worker_id = worker_id.to_string();
1143    let stream_task_id = task_id.to_string();
1144
1145    let output_callback = move |output: String| {
1146        let c = stream_client.clone();
1147        let s = stream_server.clone();
1148        let t = stream_token.clone();
1149        let w = stream_worker_id.clone();
1150        let tid = stream_task_id.clone();
1151        tokio::spawn(async move {
1152            let mut req = c
1153                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1154                .header("X-Worker-ID", &w);
1155            if let Some(tok) = &t {
1156                req = req.bearer_auth(tok);
1157            }
1158            let _ = req
1159                .json(&serde_json::json!({
1160                    "worker_id": w,
1161                    "output": output,
1162                }))
1163                .send()
1164                .await;
1165        });
1166    };
1167
1168    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1169    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1170        match execute_swarm_with_policy(
1171            &mut session,
1172            prompt,
1173            model_tier.as_deref(),
1174            selected_model,
1175            &metadata,
1176            complexity_hint.as_deref(),
1177            worker_personality.as_deref(),
1178            target_agent_name.as_deref(),
1179            Some(bus),
1180            Some(&output_callback),
1181        )
1182        .await
1183        {
1184            Ok((session_result, true)) => {
1185                tracing::info!("Swarm task completed successfully: {}", task_id);
1186                (
1187                    "completed",
1188                    Some(session_result.text),
1189                    None,
1190                    Some(session_result.session_id),
1191                )
1192            }
1193            Ok((session_result, false)) => {
1194                tracing::warn!("Swarm task completed with failures: {}", task_id);
1195                (
1196                    "failed",
1197                    Some(session_result.text),
1198                    Some("Swarm execution completed with failures".to_string()),
1199                    Some(session_result.session_id),
1200                )
1201            }
1202            Err(e) => {
1203                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1204                ("failed", None, Some(format!("Error: {}", e)), None)
1205            }
1206        }
1207    } else {
1208        match execute_session_with_policy(
1209            &mut session,
1210            prompt,
1211            auto_approve,
1212            model_tier.as_deref(),
1213            Some(&output_callback),
1214        )
1215        .await
1216        {
1217            Ok(session_result) => {
1218                tracing::info!("Task completed successfully: {}", task_id);
1219                (
1220                    "completed",
1221                    Some(session_result.text),
1222                    None,
1223                    Some(session_result.session_id),
1224                )
1225            }
1226            Err(e) => {
1227                tracing::error!("Task failed: {} - {}", task_id, e);
1228                ("failed", None, Some(format!("Error: {}", e)), None)
1229            }
1230        }
1231    };
1232
1233    // Release the task with full details
1234    let mut req = client
1235        .post(format!("{}/v1/worker/tasks/release", server))
1236        .header("X-Worker-ID", worker_id);
1237    if let Some(t) = token {
1238        req = req.bearer_auth(t);
1239    }
1240
1241    req.json(&serde_json::json!({
1242        "task_id": task_id,
1243        "status": status,
1244        "result": result,
1245        "error": error,
1246        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1247    }))
1248    .send()
1249    .await?;
1250
1251    tracing::info!("Task released: {} with status: {}", task_id, status);
1252    Ok(())
1253}
1254
1255async fn execute_swarm_with_policy<F>(
1256    session: &mut Session,
1257    prompt: &str,
1258    model_tier: Option<&str>,
1259    explicit_model: Option<String>,
1260    metadata: &serde_json::Map<String, serde_json::Value>,
1261    complexity_hint: Option<&str>,
1262    worker_personality: Option<&str>,
1263    target_agent_name: Option<&str>,
1264    bus: Option<&Arc<AgentBus>>,
1265    output_callback: Option<&F>,
1266) -> Result<(crate::session::SessionResult, bool)>
1267where
1268    F: Fn(String),
1269{
1270    use crate::provider::{ContentPart, Message, Role};
1271
1272    session.add_message(Message {
1273        role: Role::User,
1274        content: vec![ContentPart::Text {
1275            text: prompt.to_string(),
1276        }],
1277    });
1278
1279    if session.title.is_none() {
1280        session.generate_title().await?;
1281    }
1282
1283    let strategy = parse_swarm_strategy(metadata);
1284    let max_subagents = metadata_usize(
1285        metadata,
1286        &["swarm_max_subagents", "max_subagents", "subagents"],
1287    )
1288    .unwrap_or(10)
1289    .clamp(1, 100);
1290    let max_steps_per_subagent = metadata_usize(
1291        metadata,
1292        &[
1293            "swarm_max_steps_per_subagent",
1294            "max_steps_per_subagent",
1295            "max_steps",
1296        ],
1297    )
1298    .unwrap_or(50)
1299    .clamp(1, 200);
1300    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1301        .unwrap_or(600)
1302        .clamp(30, 3600);
1303    let parallel_enabled =
1304        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1305
1306    let model = resolve_swarm_model(explicit_model, model_tier).await;
1307    if let Some(ref selected_model) = model {
1308        session.metadata.model = Some(selected_model.clone());
1309    }
1310
1311    if let Some(cb) = output_callback {
1312        cb(format!(
1313            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1314            complexity_hint.unwrap_or("standard"),
1315            model_tier.unwrap_or("balanced"),
1316            worker_personality.unwrap_or("auto"),
1317            target_agent_name.unwrap_or("auto")
1318        ));
1319        cb(format!(
1320            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1321            strategy,
1322            max_subagents,
1323            max_steps_per_subagent,
1324            timeout_secs,
1325            model_tier.unwrap_or("balanced")
1326        ));
1327    }
1328
1329    let swarm_config = SwarmConfig {
1330        max_subagents,
1331        max_steps_per_subagent,
1332        subagent_timeout_secs: timeout_secs,
1333        parallel_enabled,
1334        model,
1335        working_dir: session
1336            .metadata
1337            .directory
1338            .as_ref()
1339            .map(|p| p.to_string_lossy().to_string()),
1340        ..Default::default()
1341    };
1342
1343    let swarm_result = if output_callback.is_some() {
1344        let (event_tx, mut event_rx) = mpsc::channel(256);
1345        let mut executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1346        if let Some(bus) = bus {
1347            executor = executor.with_bus(Arc::clone(bus));
1348        }
1349        let prompt_owned = prompt.to_string();
1350        let mut exec_handle =
1351            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1352
1353        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1354
1355        while final_result.is_none() {
1356            tokio::select! {
1357                maybe_event = event_rx.recv() => {
1358                    if let Some(event) = maybe_event {
1359                        if let Some(cb) = output_callback {
1360                            if let Some(line) = format_swarm_event_for_output(&event) {
1361                                cb(line);
1362                            }
1363                        }
1364                    }
1365                }
1366                join_result = &mut exec_handle => {
1367                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1368                    final_result = Some(joined?);
1369                }
1370            }
1371        }
1372
1373        while let Ok(event) = event_rx.try_recv() {
1374            if let Some(cb) = output_callback {
1375                if let Some(line) = format_swarm_event_for_output(&event) {
1376                    cb(line);
1377                }
1378            }
1379        }
1380
1381        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1382    } else {
1383        let mut executor = SwarmExecutor::new(swarm_config);
1384        if let Some(bus) = bus {
1385            executor = executor.with_bus(Arc::clone(bus));
1386        }
1387        executor.execute(prompt, strategy).await?
1388    };
1389
1390    let final_text = if swarm_result.result.trim().is_empty() {
1391        if swarm_result.success {
1392            "Swarm completed without textual output.".to_string()
1393        } else {
1394            "Swarm finished with failures and no textual output.".to_string()
1395        }
1396    } else {
1397        swarm_result.result.clone()
1398    };
1399
1400    session.add_message(Message {
1401        role: Role::Assistant,
1402        content: vec![ContentPart::Text {
1403            text: final_text.clone(),
1404        }],
1405    });
1406    session.save().await?;
1407
1408    Ok((
1409        crate::session::SessionResult {
1410            text: final_text,
1411            session_id: session.id.clone(),
1412        },
1413        swarm_result.success,
1414    ))
1415}
1416
1417/// Execute a session with the given auto-approve policy
1418/// Optionally streams output chunks via the callback
1419async fn execute_session_with_policy<F>(
1420    session: &mut Session,
1421    prompt: &str,
1422    auto_approve: AutoApprove,
1423    model_tier: Option<&str>,
1424    output_callback: Option<&F>,
1425) -> Result<crate::session::SessionResult>
1426where
1427    F: Fn(String),
1428{
1429    use crate::provider::{
1430        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1431    };
1432    use std::sync::Arc;
1433
1434    // Load provider registry from Vault
1435    let registry = ProviderRegistry::from_vault().await?;
1436    let providers = registry.list();
1437    tracing::info!("Available providers: {:?}", providers);
1438
1439    if providers.is_empty() {
1440        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1441    }
1442
1443    // Parse model string
1444    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1445        let (prov, model) = parse_model_string(model_str);
1446        let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
1447        if prov.is_some() {
1448            (prov.map(|s| s.to_string()), model.to_string())
1449        } else if providers.contains(&model) {
1450            (Some(model.to_string()), String::new())
1451        } else {
1452            (None, model.to_string())
1453        }
1454    } else {
1455        (None, String::new())
1456    };
1457
1458    let provider_slice = providers.as_slice();
1459    let provider_requested_but_unavailable = provider_name
1460        .as_deref()
1461        .map(|p| !providers.contains(&p))
1462        .unwrap_or(false);
1463
1464    // Determine which provider to use, preferring explicit request first, then model tier.
1465    let selected_provider = provider_name
1466        .as_deref()
1467        .filter(|p| providers.contains(p))
1468        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1469
1470    let provider = registry
1471        .get(selected_provider)
1472        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1473
1474    // Add user message
1475    session.add_message(Message {
1476        role: Role::User,
1477        content: vec![ContentPart::Text {
1478            text: prompt.to_string(),
1479        }],
1480    });
1481
1482    // Generate title
1483    if session.title.is_none() {
1484        session.generate_title().await?;
1485    }
1486
1487    // Determine model. If a specific provider was requested but not available,
1488    // ignore that model id and fall back to the tier-based default model.
1489    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1490        model_id
1491    } else {
1492        default_model_for_provider(selected_provider, model_tier)
1493    };
1494
1495    // Create tool registry with filtering based on auto-approve policy
1496    let tool_registry =
1497        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1498    let tool_definitions = tool_registry.definitions();
1499
1500    let temperature = if prefers_temperature_one(&model) {
1501        Some(1.0)
1502    } else {
1503        Some(0.7)
1504    };
1505
1506    tracing::info!(
1507        "Using model: {} via provider: {} (tier: {:?})",
1508        model,
1509        selected_provider,
1510        model_tier
1511    );
1512    tracing::info!(
1513        "Available tools: {} (auto_approve: {:?})",
1514        tool_definitions.len(),
1515        auto_approve
1516    );
1517
1518    // Build system prompt
1519    let cwd = std::env::var("PWD")
1520        .map(std::path::PathBuf::from)
1521        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1522    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1523
1524    let mut final_output = String::new();
1525    let max_steps = 50;
1526
1527    for step in 1..=max_steps {
1528        tracing::info!(step = step, "Agent step starting");
1529
1530        // Build messages with system prompt first
1531        let mut messages = vec![Message {
1532            role: Role::System,
1533            content: vec![ContentPart::Text {
1534                text: system_prompt.clone(),
1535            }],
1536        }];
1537        messages.extend(session.messages.clone());
1538
1539        let request = CompletionRequest {
1540            messages,
1541            tools: tool_definitions.clone(),
1542            model: model.clone(),
1543            temperature,
1544            top_p: None,
1545            max_tokens: Some(8192),
1546            stop: Vec::new(),
1547        };
1548
1549        let response = provider.complete(request).await?;
1550
1551        crate::telemetry::TOKEN_USAGE.record_model_usage(
1552            &model,
1553            response.usage.prompt_tokens as u64,
1554            response.usage.completion_tokens as u64,
1555        );
1556
1557        // Extract tool calls
1558        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1559            .message
1560            .content
1561            .iter()
1562            .filter_map(|part| {
1563                if let ContentPart::ToolCall {
1564                    id,
1565                    name,
1566                    arguments,
1567                } = part
1568                {
1569                    let args: serde_json::Value =
1570                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1571                    Some((id.clone(), name.clone(), args))
1572                } else {
1573                    None
1574                }
1575            })
1576            .collect();
1577
1578        // Collect text output and stream it
1579        for part in &response.message.content {
1580            if let ContentPart::Text { text } = part {
1581                if !text.is_empty() {
1582                    final_output.push_str(text);
1583                    final_output.push('\n');
1584                    if let Some(cb) = output_callback {
1585                        cb(text.clone());
1586                    }
1587                }
1588            }
1589        }
1590
1591        // If no tool calls, we're done
1592        if tool_calls.is_empty() {
1593            session.add_message(response.message.clone());
1594            break;
1595        }
1596
1597        session.add_message(response.message.clone());
1598
1599        tracing::info!(
1600            step = step,
1601            num_tools = tool_calls.len(),
1602            "Executing tool calls"
1603        );
1604
1605        // Execute each tool call
1606        for (tool_id, tool_name, tool_input) in tool_calls {
1607            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1608
1609            // Stream tool start event
1610            if let Some(cb) = output_callback {
1611                cb(format!("[tool:start:{}]", tool_name));
1612            }
1613
1614            // Check if tool is allowed based on auto-approve policy
1615            if !is_tool_allowed(&tool_name, auto_approve) {
1616                let msg = format!(
1617                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1618                    tool_name, auto_approve
1619                );
1620                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1621                session.add_message(Message {
1622                    role: Role::Tool,
1623                    content: vec![ContentPart::ToolResult {
1624                        tool_call_id: tool_id,
1625                        content: msg,
1626                    }],
1627                });
1628                continue;
1629            }
1630
1631            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1632                let exec_result: Result<crate::tool::ToolResult> =
1633                    tool.execute(tool_input.clone()).await;
1634                match exec_result {
1635                    Ok(result) => {
1636                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1637                        if let Some(cb) = output_callback {
1638                            let status = if result.success { "ok" } else { "err" };
1639                            cb(format!(
1640                                "[tool:{}:{}] {}",
1641                                tool_name,
1642                                status,
1643                                &result.output[..result.output.len().min(500)]
1644                            ));
1645                        }
1646                        result.output
1647                    }
1648                    Err(e) => {
1649                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1650                        if let Some(cb) = output_callback {
1651                            cb(format!("[tool:{}:err] {}", tool_name, e));
1652                        }
1653                        format!("Error: {}", e)
1654                    }
1655                }
1656            } else {
1657                tracing::warn!(tool = %tool_name, "Tool not found");
1658                format!("Error: Unknown tool '{}'", tool_name)
1659            };
1660
1661            session.add_message(Message {
1662                role: Role::Tool,
1663                content: vec![ContentPart::ToolResult {
1664                    tool_call_id: tool_id,
1665                    content,
1666                }],
1667            });
1668        }
1669    }
1670
1671    session.save().await?;
1672
1673    Ok(crate::session::SessionResult {
1674        text: final_output.trim().to_string(),
1675        session_id: session.id.clone(),
1676    })
1677}
1678
1679/// Check if a tool is allowed based on the auto-approve policy
1680fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1681    match auto_approve {
1682        AutoApprove::All => true,
1683        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1684    }
1685}
1686
1687/// Check if a tool is considered "safe" (read-only)
1688fn is_safe_tool(tool_name: &str) -> bool {
1689    let safe_tools = [
1690        "read",
1691        "list",
1692        "glob",
1693        "grep",
1694        "codesearch",
1695        "lsp",
1696        "webfetch",
1697        "websearch",
1698        "todo_read",
1699        "skill",
1700    ];
1701    safe_tools.contains(&tool_name)
1702}
1703
1704/// Create a filtered tool registry based on the auto-approve policy
1705fn create_filtered_registry(
1706    provider: Arc<dyn crate::provider::Provider>,
1707    model: String,
1708    auto_approve: AutoApprove,
1709) -> crate::tool::ToolRegistry {
1710    use crate::tool::*;
1711
1712    let mut registry = ToolRegistry::new();
1713
1714    // Always add safe tools
1715    registry.register(Arc::new(file::ReadTool::new()));
1716    registry.register(Arc::new(file::ListTool::new()));
1717    registry.register(Arc::new(file::GlobTool::new()));
1718    registry.register(Arc::new(search::GrepTool::new()));
1719    registry.register(Arc::new(lsp::LspTool::new()));
1720    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1721    registry.register(Arc::new(websearch::WebSearchTool::new()));
1722    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1723    registry.register(Arc::new(todo::TodoReadTool::new()));
1724    registry.register(Arc::new(skill::SkillTool::new()));
1725
1726    // Add potentially dangerous tools only if auto_approve is All
1727    if matches!(auto_approve, AutoApprove::All) {
1728        registry.register(Arc::new(file::WriteTool::new()));
1729        registry.register(Arc::new(edit::EditTool::new()));
1730        registry.register(Arc::new(bash::BashTool::new()));
1731        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1732        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1733        registry.register(Arc::new(todo::TodoWriteTool::new()));
1734        registry.register(Arc::new(task::TaskTool::new()));
1735        registry.register(Arc::new(plan::PlanEnterTool::new()));
1736        registry.register(Arc::new(plan::PlanExitTool::new()));
1737        registry.register(Arc::new(rlm::RlmTool::new()));
1738        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1739        registry.register(Arc::new(prd::PrdTool::new()));
1740        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1741        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1742        registry.register(Arc::new(undo::UndoTool));
1743        registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1744    }
1745
1746    registry.register(Arc::new(invalid::InvalidTool::new()));
1747
1748    registry
1749}
1750
1751/// Start the heartbeat background task
1752/// Returns a JoinHandle that can be used to cancel the heartbeat
1753fn start_heartbeat(
1754    client: Client,
1755    server: String,
1756    token: Option<String>,
1757    heartbeat_state: HeartbeatState,
1758    processing: Arc<Mutex<HashSet<String>>>,
1759    cognition_config: CognitionHeartbeatConfig,
1760) -> JoinHandle<()> {
1761    tokio::spawn(async move {
1762        let mut consecutive_failures = 0u32;
1763        const MAX_FAILURES: u32 = 3;
1764        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1765        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1766        let mut cognition_payload_disabled_until: Option<Instant> = None;
1767
1768        let mut interval =
1769            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1770        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1771
1772        loop {
1773            interval.tick().await;
1774
1775            // Update task count from processing set
1776            let active_count = processing.lock().await.len();
1777            heartbeat_state.set_task_count(active_count).await;
1778
1779            // Determine status based on active tasks
1780            let status = if active_count > 0 {
1781                WorkerStatus::Processing
1782            } else {
1783                WorkerStatus::Idle
1784            };
1785            heartbeat_state.set_status(status).await;
1786
1787            // Send heartbeat
1788            let url = format!(
1789                "{}/v1/opencode/workers/{}/heartbeat",
1790                server, heartbeat_state.worker_id
1791            );
1792            let mut req = client.post(&url);
1793
1794            if let Some(ref t) = token {
1795                req = req.bearer_auth(t);
1796            }
1797
1798            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1799            let base_payload = serde_json::json!({
1800                "worker_id": &heartbeat_state.worker_id,
1801                "agent_name": &heartbeat_state.agent_name,
1802                "status": status_str,
1803                "active_task_count": active_count,
1804            });
1805            let mut payload = base_payload.clone();
1806            let mut included_cognition_payload = false;
1807            let cognition_payload_allowed = cognition_payload_disabled_until
1808                .map(|until| Instant::now() >= until)
1809                .unwrap_or(true);
1810
1811            if cognition_config.enabled
1812                && cognition_payload_allowed
1813                && let Some(cognition_payload) =
1814                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1815                && let Some(obj) = payload.as_object_mut()
1816            {
1817                obj.insert("cognition".to_string(), cognition_payload);
1818                included_cognition_payload = true;
1819            }
1820
1821            match req.json(&payload).send().await {
1822                Ok(res) => {
1823                    if res.status().is_success() {
1824                        consecutive_failures = 0;
1825                        tracing::debug!(
1826                            worker_id = %heartbeat_state.worker_id,
1827                            status = status_str,
1828                            active_tasks = active_count,
1829                            "Heartbeat sent successfully"
1830                        );
1831                    } else if included_cognition_payload && res.status().is_client_error() {
1832                        tracing::warn!(
1833                            worker_id = %heartbeat_state.worker_id,
1834                            status = %res.status(),
1835                            "Heartbeat cognition payload rejected, retrying without cognition payload"
1836                        );
1837
1838                        let mut retry_req = client.post(&url);
1839                        if let Some(ref t) = token {
1840                            retry_req = retry_req.bearer_auth(t);
1841                        }
1842
1843                        match retry_req.json(&base_payload).send().await {
1844                            Ok(retry_res) if retry_res.status().is_success() => {
1845                                cognition_payload_disabled_until = Some(
1846                                    Instant::now()
1847                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1848                                );
1849                                consecutive_failures = 0;
1850                                tracing::warn!(
1851                                    worker_id = %heartbeat_state.worker_id,
1852                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1853                                    "Paused cognition heartbeat payload after schema rejection"
1854                                );
1855                            }
1856                            Ok(retry_res) => {
1857                                consecutive_failures += 1;
1858                                tracing::warn!(
1859                                    worker_id = %heartbeat_state.worker_id,
1860                                    status = %retry_res.status(),
1861                                    failures = consecutive_failures,
1862                                    "Heartbeat failed even after retry without cognition payload"
1863                                );
1864                            }
1865                            Err(e) => {
1866                                consecutive_failures += 1;
1867                                tracing::warn!(
1868                                    worker_id = %heartbeat_state.worker_id,
1869                                    error = %e,
1870                                    failures = consecutive_failures,
1871                                    "Heartbeat retry without cognition payload failed"
1872                                );
1873                            }
1874                        }
1875                    } else {
1876                        consecutive_failures += 1;
1877                        tracing::warn!(
1878                            worker_id = %heartbeat_state.worker_id,
1879                            status = %res.status(),
1880                            failures = consecutive_failures,
1881                            "Heartbeat failed"
1882                        );
1883                    }
1884                }
1885                Err(e) => {
1886                    consecutive_failures += 1;
1887                    tracing::warn!(
1888                        worker_id = %heartbeat_state.worker_id,
1889                        error = %e,
1890                        failures = consecutive_failures,
1891                        "Heartbeat request failed"
1892                    );
1893                }
1894            }
1895
1896            // Log error after 3 consecutive failures but do not terminate
1897            if consecutive_failures >= MAX_FAILURES {
1898                tracing::error!(
1899                    worker_id = %heartbeat_state.worker_id,
1900                    failures = consecutive_failures,
1901                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1902                    MAX_FAILURES
1903                );
1904                // Reset counter to avoid spamming error logs
1905                consecutive_failures = 0;
1906            }
1907        }
1908    })
1909}
1910
1911async fn fetch_cognition_heartbeat_payload(
1912    client: &Client,
1913    config: &CognitionHeartbeatConfig,
1914) -> Option<serde_json::Value> {
1915    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1916    let status_res = tokio::time::timeout(
1917        Duration::from_millis(config.request_timeout_ms),
1918        client.get(status_url).send(),
1919    )
1920    .await
1921    .ok()?
1922    .ok()?;
1923
1924    if !status_res.status().is_success() {
1925        return None;
1926    }
1927
1928    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1929    let mut payload = serde_json::json!({
1930        "running": status.running,
1931        "last_tick_at": status.last_tick_at,
1932        "active_persona_count": status.active_persona_count,
1933        "events_buffered": status.events_buffered,
1934        "snapshots_buffered": status.snapshots_buffered,
1935        "loop_interval_ms": status.loop_interval_ms,
1936    });
1937
1938    if config.include_thought_summary {
1939        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1940        let snapshot_res = tokio::time::timeout(
1941            Duration::from_millis(config.request_timeout_ms),
1942            client.get(snapshot_url).send(),
1943        )
1944        .await
1945        .ok()
1946        .and_then(Result::ok);
1947
1948        if let Some(snapshot_res) = snapshot_res
1949            && snapshot_res.status().is_success()
1950            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1951            && let Some(obj) = payload.as_object_mut()
1952        {
1953            obj.insert(
1954                "latest_snapshot_at".to_string(),
1955                serde_json::Value::String(snapshot.generated_at),
1956            );
1957            obj.insert(
1958                "latest_thought".to_string(),
1959                serde_json::Value::String(trim_for_heartbeat(
1960                    &snapshot.summary,
1961                    config.summary_max_chars,
1962                )),
1963            );
1964            if let Some(model) = snapshot
1965                .metadata
1966                .get("model")
1967                .and_then(serde_json::Value::as_str)
1968            {
1969                obj.insert(
1970                    "latest_thought_model".to_string(),
1971                    serde_json::Value::String(model.to_string()),
1972                );
1973            }
1974            if let Some(source) = snapshot
1975                .metadata
1976                .get("source")
1977                .and_then(serde_json::Value::as_str)
1978            {
1979                obj.insert(
1980                    "latest_thought_source".to_string(),
1981                    serde_json::Value::String(source.to_string()),
1982                );
1983            }
1984        }
1985    }
1986
1987    Some(payload)
1988}
1989
1990fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1991    if input.chars().count() <= max_chars {
1992        return input.trim().to_string();
1993    }
1994
1995    let mut trimmed = String::with_capacity(max_chars + 3);
1996    for ch in input.chars().take(max_chars) {
1997        trimmed.push(ch);
1998    }
1999    trimmed.push_str("...");
2000    trimmed.trim().to_string()
2001}
2002
2003fn env_bool(name: &str, default: bool) -> bool {
2004    std::env::var(name)
2005        .ok()
2006        .and_then(|v| match v.to_ascii_lowercase().as_str() {
2007            "1" | "true" | "yes" | "on" => Some(true),
2008            "0" | "false" | "no" | "off" => Some(false),
2009            _ => None,
2010        })
2011        .unwrap_or(default)
2012}
2013
2014fn env_usize(name: &str, default: usize) -> usize {
2015    std::env::var(name)
2016        .ok()
2017        .and_then(|v| v.parse::<usize>().ok())
2018        .unwrap_or(default)
2019}
2020
2021fn env_u64(name: &str, default: u64) -> u64 {
2022    std::env::var(name)
2023        .ok()
2024        .and_then(|v| v.parse::<u64>().ok())
2025        .unwrap_or(default)
2026}