Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::bus::AgentBus;
4use crate::cli::A2aArgs;
5use crate::provider::ProviderRegistry;
6use crate::session::Session;
7use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
8use crate::tui::swarm_view::SwarmEvent;
9use anyhow::Result;
10use futures::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{Mutex, mpsc};
18use tokio::task::JoinHandle;
19use tokio::time::Instant;
20
21/// Worker status for heartbeat
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum WorkerStatus {
24    Idle,
25    Processing,
26}
27
28impl WorkerStatus {
29    pub fn as_str(&self) -> &'static str {
30        match self {
31            WorkerStatus::Idle => "idle",
32            WorkerStatus::Processing => "processing",
33        }
34    }
35}
36
37/// Heartbeat state shared between the heartbeat task and the main worker
38#[derive(Clone)]
39pub struct HeartbeatState {
40    worker_id: String,
41    pub agent_name: String,
42    pub status: Arc<Mutex<WorkerStatus>>,
43    pub active_task_count: Arc<Mutex<usize>>,
44}
45
46impl HeartbeatState {
47    pub fn new(worker_id: String, agent_name: String) -> Self {
48        Self {
49            worker_id,
50            agent_name,
51            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
52            active_task_count: Arc::new(Mutex::new(0)),
53        }
54    }
55
56    pub async fn set_status(&self, status: WorkerStatus) {
57        *self.status.lock().await = status;
58    }
59
60    pub async fn set_task_count(&self, count: usize) {
61        *self.active_task_count.lock().await = count;
62    }
63}
64
65#[derive(Clone, Debug)]
66struct CognitionHeartbeatConfig {
67    enabled: bool,
68    source_base_url: String,
69    include_thought_summary: bool,
70    summary_max_chars: usize,
71    request_timeout_ms: u64,
72}
73
74impl CognitionHeartbeatConfig {
75    fn from_env() -> Self {
76        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
77            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
78            .trim_end_matches('/')
79            .to_string();
80
81        Self {
82            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
83            source_base_url,
84            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
85            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
86                .max(120),
87            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
88        }
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct CognitionStatusSnapshot {
94    running: bool,
95    #[serde(default)]
96    last_tick_at: Option<String>,
97    #[serde(default)]
98    active_persona_count: usize,
99    #[serde(default)]
100    events_buffered: usize,
101    #[serde(default)]
102    snapshots_buffered: usize,
103    #[serde(default)]
104    loop_interval_ms: u64,
105}
106
107#[derive(Debug, Deserialize)]
108struct CognitionLatestSnapshot {
109    generated_at: String,
110    summary: String,
111    #[serde(default)]
112    metadata: HashMap<String, serde_json::Value>,
113}
114
115// Run the A2A worker
116pub async fn run(args: A2aArgs) -> Result<()> {
117    let server = args.server.trim_end_matches('/');
118    let name = args
119        .name
120        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
121    let worker_id = generate_worker_id();
122
123    let codebases: Vec<String> = args
124        .codebases
125        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
126        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
127
128    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
129    tracing::info!("Server: {}", server);
130    tracing::info!("Codebases: {:?}", codebases);
131
132    let client = Client::new();
133    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
134    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
135    if cognition_heartbeat.enabled {
136        tracing::info!(
137            source = %cognition_heartbeat.source_base_url,
138            include_thoughts = cognition_heartbeat.include_thought_summary,
139            max_chars = cognition_heartbeat.summary_max_chars,
140            timeout_ms = cognition_heartbeat.request_timeout_ms,
141            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
142        );
143    } else {
144        tracing::warn!(
145            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
146        );
147    }
148
149    let auto_approve = match args.auto_approve.as_str() {
150        "all" => AutoApprove::All,
151        "safe" => AutoApprove::Safe,
152        _ => AutoApprove::None,
153    };
154
155    // Create heartbeat state
156    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
157
158    // Create agent bus for in-process sub-agent communication
159    let bus = AgentBus::new().into_arc();
160    {
161        let handle = bus.handle(&worker_id);
162        handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
163    }
164
165    // Register worker
166    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
167
168    // Fetch pending tasks
169    fetch_pending_tasks(
170        &client,
171        server,
172        &args.token,
173        &worker_id,
174        &processing,
175        &auto_approve,
176        &bus,
177    )
178    .await?;
179
180    // Connect to SSE stream
181    loop {
182        // Re-register worker on each reconnection to report updated models/capabilities
183        if let Err(e) =
184            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
185        {
186            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
187        }
188
189        // Start heartbeat task for this connection
190        let heartbeat_handle = start_heartbeat(
191            client.clone(),
192            server.to_string(),
193            args.token.clone(),
194            heartbeat_state.clone(),
195            processing.clone(),
196            cognition_heartbeat.clone(),
197        );
198
199        match connect_stream(
200            &client,
201            server,
202            &args.token,
203            &worker_id,
204            &name,
205            &codebases,
206            &processing,
207            &auto_approve,
208            &bus,
209            None, // No task notification channel in simple run mode
210        )
211        .await
212        {
213            Ok(()) => {
214                tracing::warn!("Stream ended, reconnecting...");
215            }
216            Err(e) => {
217                tracing::error!("Stream error: {}, reconnecting...", e);
218            }
219        }
220
221        // Cancel heartbeat on disconnection
222        heartbeat_handle.abort();
223        tracing::debug!("Heartbeat cancelled for reconnection");
224
225        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
226    }
227}
228
229/// Run the A2A worker with shared state for HTTP server integration
230/// This variant accepts a WorkerServerState to communicate with the HTTP server
231pub async fn run_with_state(
232    args: A2aArgs,
233    server_state: crate::worker_server::WorkerServerState,
234) -> Result<()> {
235    let server = args.server.trim_end_matches('/');
236    let name = args
237        .name
238        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
239    let worker_id = generate_worker_id();
240
241    // Share worker_id with HTTP server
242    server_state.set_worker_id(worker_id.clone()).await;
243
244    let codebases: Vec<String> = args
245        .codebases
246        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
247        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
248
249    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
250    tracing::info!("Server: {}", server);
251    tracing::info!("Codebases: {:?}", codebases);
252
253    let client = Client::new();
254    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
255    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
256    if cognition_heartbeat.enabled {
257        tracing::info!(
258            source = %cognition_heartbeat.source_base_url,
259            include_thoughts = cognition_heartbeat.include_thought_summary,
260            max_chars = cognition_heartbeat.summary_max_chars,
261            timeout_ms = cognition_heartbeat.request_timeout_ms,
262            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
263        );
264    } else {
265        tracing::warn!(
266            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
267        );
268    }
269
270    let auto_approve = match args.auto_approve.as_str() {
271        "all" => AutoApprove::All,
272        "safe" => AutoApprove::Safe,
273        _ => AutoApprove::None,
274    };
275
276    // Create heartbeat state
277    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
278
279    // Share heartbeat state with HTTP server
280    server_state
281        .set_heartbeat_state(Arc::new(heartbeat_state.clone()))
282        .await;
283
284    // Create agent bus for in-process sub-agent communication
285    let bus = AgentBus::new().into_arc();
286    {
287        let handle = bus.handle(&worker_id);
288        handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
289    }
290
291    // Register worker
292    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
293
294    // Mark as connected
295    server_state.set_connected(true).await;
296
297    // Fetch pending tasks before entering reconnection loop
298    fetch_pending_tasks(
299        &client,
300        server,
301        &args.token,
302        &worker_id,
303        &processing,
304        &auto_approve,
305        &bus,
306    )
307    .await?;
308
309    // Connect to SSE stream
310    loop {
311        // Create task notification channel for CloudEvent-triggered task execution
312        // Recreate on each reconnection since the receiver is moved into connect_stream
313        let (task_notify_tx, task_notify_rx) = mpsc::channel::<String>(32);
314        server_state
315            .set_task_notification_channel(task_notify_tx)
316            .await;
317
318        // Mark as connected on each reconnection
319        server_state.set_connected(true).await;
320
321        // Re-register worker on each reconnection to report updated models/capabilities
322        if let Err(e) =
323            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
324        {
325            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
326        }
327
328        // Start heartbeat task for this connection
329        let heartbeat_handle = start_heartbeat(
330            client.clone(),
331            server.to_string(),
332            args.token.clone(),
333            heartbeat_state.clone(),
334            processing.clone(),
335            cognition_heartbeat.clone(),
336        );
337
338        match connect_stream(
339            &client,
340            server,
341            &args.token,
342            &worker_id,
343            &name,
344            &codebases,
345            &processing,
346            &auto_approve,
347            &bus,
348            Some(task_notify_rx),
349        )
350        .await
351        {
352            Ok(()) => {
353                tracing::warn!("Stream ended, reconnecting...");
354            }
355            Err(e) => {
356                tracing::error!("Stream error: {}, reconnecting...", e);
357            }
358        }
359
360        // Mark as disconnected
361        server_state.set_connected(false).await;
362
363        // Cancel heartbeat on disconnection
364        heartbeat_handle.abort();
365        tracing::debug!("Heartbeat cancelled for reconnection");
366
367        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
368    }
369}
370
371fn generate_worker_id() -> String {
372    format!(
373        "wrk_{}_{:x}",
374        chrono::Utc::now().timestamp(),
375        rand::random::<u64>()
376    )
377}
378
379#[derive(Debug, Clone, Copy)]
380enum AutoApprove {
381    All,
382    Safe,
383    None,
384}
385
386/// Default A2A server URL when none is configured
387pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
388
389/// Capabilities of the codetether-agent worker
390const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
391
392fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
393    task.get("task")
394        .and_then(|t| t.get(key))
395        .or_else(|| task.get(key))
396}
397
398fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
399    task_value(task, key).and_then(|v| v.as_str())
400}
401
402fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
403    task_value(task, "metadata")
404        .and_then(|m| m.as_object())
405        .cloned()
406        .unwrap_or_default()
407}
408
409fn model_ref_to_provider_model(model: &str) -> String {
410    // Convert "provider:model" to "provider/model" format, but only if
411    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
412    // contain colons as version separators and must NOT be converted.
413    if !model.contains('/') && model.contains(':') {
414        model.replacen(':', "/", 1)
415    } else {
416        model.to_string()
417    }
418}
419
420fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
421    match model_tier.unwrap_or("balanced") {
422        "fast" | "quick" => &[
423            "zai",
424            "openai",
425            "github-copilot",
426            "moonshotai",
427            "openrouter",
428            "novita",
429            "google",
430            "anthropic",
431        ],
432        "heavy" | "deep" => &[
433            "zai",
434            "anthropic",
435            "openai",
436            "github-copilot",
437            "moonshotai",
438            "openrouter",
439            "novita",
440            "google",
441        ],
442        _ => &[
443            "zai",
444            "openai",
445            "github-copilot",
446            "anthropic",
447            "moonshotai",
448            "openrouter",
449            "novita",
450            "google",
451        ],
452    }
453}
454
455fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
456    for preferred in provider_preferences_for_tier(model_tier) {
457        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
458            return found;
459        }
460    }
461    if let Some(found) = providers.iter().copied().find(|p| *p == "zai") {
462        return found;
463    }
464    providers[0]
465}
466
467fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
468    match model_tier.unwrap_or("balanced") {
469        "fast" | "quick" => match provider {
470            "moonshotai" => "kimi-k2.5".to_string(),
471            "anthropic" => "claude-haiku-4-5".to_string(),
472            "openai" => "gpt-4o-mini".to_string(),
473            "google" => "gemini-2.5-flash".to_string(),
474            "zhipuai" | "zai" => "glm-5".to_string(),
475            "openrouter" => "z-ai/glm-5".to_string(),
476            "novita" => "qwen/qwen3-coder-next".to_string(),
477            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
478            _ => "glm-5".to_string(),
479        },
480        "heavy" | "deep" => match provider {
481            "moonshotai" => "kimi-k2.5".to_string(),
482            "anthropic" => "claude-sonnet-4-20250514".to_string(),
483            "openai" => "o3".to_string(),
484            "google" => "gemini-2.5-pro".to_string(),
485            "zhipuai" | "zai" => "glm-5".to_string(),
486            "openrouter" => "z-ai/glm-5".to_string(),
487            "novita" => "qwen/qwen3-coder-next".to_string(),
488            "bedrock" => "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
489            _ => "glm-5".to_string(),
490        },
491        _ => match provider {
492            "moonshotai" => "kimi-k2.5".to_string(),
493            "anthropic" => "claude-sonnet-4-20250514".to_string(),
494            "openai" => "gpt-4o".to_string(),
495            "google" => "gemini-2.5-pro".to_string(),
496            "zhipuai" | "zai" => "glm-5".to_string(),
497            "openrouter" => "z-ai/glm-5".to_string(),
498            "novita" => "qwen/qwen3-coder-next".to_string(),
499            "bedrock" => "amazon.nova-lite-v1:0".to_string(),
500            _ => "glm-5".to_string(),
501        },
502    }
503}
504
505fn prefers_temperature_one(model: &str) -> bool {
506    let normalized = model.to_ascii_lowercase();
507    normalized.contains("kimi-k2") || normalized.contains("glm-") || normalized.contains("minimax")
508}
509
510fn is_swarm_agent(agent_type: &str) -> bool {
511    matches!(
512        agent_type.trim().to_ascii_lowercase().as_str(),
513        "swarm" | "parallel" | "multi-agent"
514    )
515}
516
517fn metadata_lookup<'a>(
518    metadata: &'a serde_json::Map<String, serde_json::Value>,
519    key: &str,
520) -> Option<&'a serde_json::Value> {
521    metadata
522        .get(key)
523        .or_else(|| {
524            metadata
525                .get("routing")
526                .and_then(|v| v.as_object())
527                .and_then(|obj| obj.get(key))
528        })
529        .or_else(|| {
530            metadata
531                .get("swarm")
532                .and_then(|v| v.as_object())
533                .and_then(|obj| obj.get(key))
534        })
535}
536
537fn metadata_str(
538    metadata: &serde_json::Map<String, serde_json::Value>,
539    keys: &[&str],
540) -> Option<String> {
541    for key in keys {
542        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
543            let trimmed = value.trim();
544            if !trimmed.is_empty() {
545                return Some(trimmed.to_string());
546            }
547        }
548    }
549    None
550}
551
552fn metadata_usize(
553    metadata: &serde_json::Map<String, serde_json::Value>,
554    keys: &[&str],
555) -> Option<usize> {
556    for key in keys {
557        if let Some(value) = metadata_lookup(metadata, key) {
558            if let Some(v) = value.as_u64() {
559                return usize::try_from(v).ok();
560            }
561            if let Some(v) = value.as_i64() {
562                if v >= 0 {
563                    return usize::try_from(v as u64).ok();
564                }
565            }
566            if let Some(v) = value.as_str() {
567                if let Ok(parsed) = v.trim().parse::<usize>() {
568                    return Some(parsed);
569                }
570            }
571        }
572    }
573    None
574}
575
576fn metadata_u64(
577    metadata: &serde_json::Map<String, serde_json::Value>,
578    keys: &[&str],
579) -> Option<u64> {
580    for key in keys {
581        if let Some(value) = metadata_lookup(metadata, key) {
582            if let Some(v) = value.as_u64() {
583                return Some(v);
584            }
585            if let Some(v) = value.as_i64() {
586                if v >= 0 {
587                    return Some(v as u64);
588                }
589            }
590            if let Some(v) = value.as_str() {
591                if let Ok(parsed) = v.trim().parse::<u64>() {
592                    return Some(parsed);
593                }
594            }
595        }
596    }
597    None
598}
599
600fn metadata_bool(
601    metadata: &serde_json::Map<String, serde_json::Value>,
602    keys: &[&str],
603) -> Option<bool> {
604    for key in keys {
605        if let Some(value) = metadata_lookup(metadata, key) {
606            if let Some(v) = value.as_bool() {
607                return Some(v);
608            }
609            if let Some(v) = value.as_str() {
610                match v.trim().to_ascii_lowercase().as_str() {
611                    "1" | "true" | "yes" | "on" => return Some(true),
612                    "0" | "false" | "no" | "off" => return Some(false),
613                    _ => {}
614                }
615            }
616        }
617    }
618    None
619}
620
621fn parse_swarm_strategy(
622    metadata: &serde_json::Map<String, serde_json::Value>,
623) -> DecompositionStrategy {
624    match metadata_str(
625        metadata,
626        &[
627            "decomposition_strategy",
628            "swarm_strategy",
629            "strategy",
630            "swarm_decomposition",
631        ],
632    )
633    .as_deref()
634    .map(|s| s.to_ascii_lowercase())
635    .as_deref()
636    {
637        Some("none") | Some("single") => DecompositionStrategy::None,
638        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
639        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
640        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
641        _ => DecompositionStrategy::Automatic,
642    }
643}
644
645async fn resolve_swarm_model(
646    explicit_model: Option<String>,
647    model_tier: Option<&str>,
648) -> Option<String> {
649    if let Some(model) = explicit_model {
650        if !model.trim().is_empty() {
651            return Some(model);
652        }
653    }
654
655    let registry = ProviderRegistry::from_vault().await.ok()?;
656    let providers = registry.list();
657    if providers.is_empty() {
658        return None;
659    }
660    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
661    let model = default_model_for_provider(provider, model_tier);
662    Some(format!("{}/{}", provider, model))
663}
664
665fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
666    match event {
667        SwarmEvent::Started {
668            task,
669            total_subtasks,
670        } => Some(format!(
671            "[swarm] started task={} planned_subtasks={}",
672            task, total_subtasks
673        )),
674        SwarmEvent::StageComplete {
675            stage,
676            completed,
677            failed,
678        } => Some(format!(
679            "[swarm] stage={} completed={} failed={}",
680            stage, completed, failed
681        )),
682        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
683            "[swarm] subtask id={} status={}",
684            &id.chars().take(8).collect::<String>(),
685            format!("{status:?}").to_ascii_lowercase()
686        )),
687        SwarmEvent::AgentToolCall {
688            subtask_id,
689            tool_name,
690        } => Some(format!(
691            "[swarm] subtask id={} tool={}",
692            &subtask_id.chars().take(8).collect::<String>(),
693            tool_name
694        )),
695        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
696            "[swarm] subtask id={} error={}",
697            &subtask_id.chars().take(8).collect::<String>(),
698            error
699        )),
700        SwarmEvent::Complete { success, stats } => Some(format!(
701            "[swarm] complete success={} subtasks={} speedup={:.2}",
702            success,
703            stats.subagents_completed + stats.subagents_failed,
704            stats.speedup_factor
705        )),
706        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
707        _ => None,
708    }
709}
710
711async fn register_worker(
712    client: &Client,
713    server: &str,
714    token: &Option<String>,
715    worker_id: &str,
716    name: &str,
717    codebases: &[String],
718) -> Result<()> {
719    // Load ProviderRegistry and collect available models
720    let models = match load_provider_models().await {
721        Ok(m) => m,
722        Err(e) => {
723            tracing::warn!(
724                "Failed to load provider models: {}, proceeding without model info",
725                e
726            );
727            HashMap::new()
728        }
729    };
730
731    // Register via the workers/register endpoint
732    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
733
734    if let Some(t) = token {
735        req = req.bearer_auth(t);
736    }
737
738    // Flatten models HashMap into array of model objects with pricing data
739    // matching the format expected by the A2A server's /models and /workers endpoints
740    let models_array: Vec<serde_json::Value> = models
741        .iter()
742        .flat_map(|(provider, model_infos)| {
743            model_infos.iter().map(move |m| {
744                let mut obj = serde_json::json!({
745                    "id": format!("{}/{}", provider, m.id),
746                    "name": &m.id,
747                    "provider": provider,
748                    "provider_id": provider,
749                });
750                if let Some(input_cost) = m.input_cost_per_million {
751                    obj["input_cost_per_million"] = serde_json::json!(input_cost);
752                }
753                if let Some(output_cost) = m.output_cost_per_million {
754                    obj["output_cost_per_million"] = serde_json::json!(output_cost);
755                }
756                obj
757            })
758        })
759        .collect();
760
761    tracing::info!(
762        "Registering worker with {} models from {} providers",
763        models_array.len(),
764        models.len()
765    );
766
767    let hostname = std::env::var("HOSTNAME")
768        .or_else(|_| std::env::var("COMPUTERNAME"))
769        .unwrap_or_else(|_| "unknown".to_string());
770
771    let res = req
772        .json(&serde_json::json!({
773            "worker_id": worker_id,
774            "name": name,
775            "capabilities": WORKER_CAPABILITIES,
776            "hostname": hostname,
777            "models": models_array,
778            "codebases": codebases,
779        }))
780        .send()
781        .await?;
782
783    if res.status().is_success() {
784        tracing::info!("Worker registered successfully");
785    } else {
786        tracing::warn!("Failed to register worker: {}", res.status());
787    }
788
789    Ok(())
790}
791
792/// Load ProviderRegistry and collect all available models grouped by provider.
793/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
794/// Returns ModelInfo structs (with pricing data when available).
795async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
796    // Try Vault first
797    let registry = match ProviderRegistry::from_vault().await {
798        Ok(r) if !r.list().is_empty() => {
799            tracing::info!("Loaded {} providers from Vault", r.list().len());
800            r
801        }
802        Ok(_) => {
803            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
804            fallback_registry().await?
805        }
806        Err(e) => {
807            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
808            fallback_registry().await?
809        }
810    };
811
812    // Fetch the models.dev catalog for pricing data enrichment
813    let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
814
815    // Map provider IDs to their catalog equivalents (some differ)
816    let catalog_alias = |pid: &str| -> String {
817        match pid {
818            "bedrock" => "amazon-bedrock".to_string(),
819            "novita" => "novita-ai".to_string(),
820            _ => pid.to_string(),
821        }
822    };
823
824    let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
825
826    for provider_name in registry.list() {
827        if let Some(provider) = registry.get(provider_name) {
828            match provider.list_models().await {
829                Ok(models) => {
830                    let enriched: Vec<crate::provider::ModelInfo> = models
831                        .into_iter()
832                        .map(|mut m| {
833                            // Enrich with catalog pricing if the provider didn't set it
834                            if m.input_cost_per_million.is_none()
835                                || m.output_cost_per_million.is_none()
836                            {
837                                if let Some(ref cat) = catalog {
838                                    let cat_pid = catalog_alias(provider_name);
839                                    if let Some(prov_info) = cat.get_provider(&cat_pid) {
840                                        // Try exact match first, then strip "us." prefix (bedrock uses us.vendor.model format)
841                                        let model_info =
842                                            prov_info.models.get(&m.id).or_else(|| {
843                                                m.id.strip_prefix("us.").and_then(|stripped| {
844                                                    prov_info.models.get(stripped)
845                                                })
846                                            });
847                                        if let Some(model_info) = model_info {
848                                            if let Some(ref cost) = model_info.cost {
849                                                if m.input_cost_per_million.is_none() {
850                                                    m.input_cost_per_million = Some(cost.input);
851                                                }
852                                                if m.output_cost_per_million.is_none() {
853                                                    m.output_cost_per_million = Some(cost.output);
854                                                }
855                                            }
856                                        }
857                                    }
858                                }
859                            }
860                            m
861                        })
862                        .collect();
863                    if !enriched.is_empty() {
864                        tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
865                        models_by_provider.insert(provider_name.to_string(), enriched);
866                    }
867                }
868                Err(e) => {
869                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
870                }
871            }
872        }
873    }
874
875    // If we still have 0, try the models.dev catalog as last resort
876    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
877    // because the worker should advertise what it can handle. The server handles routing.
878    if models_by_provider.is_empty() {
879        tracing::info!(
880            "No authenticated providers found, fetching models.dev catalog (all providers)"
881        );
882        if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
883            // Use all_providers_with_models() to get every provider+model from catalog
884            // regardless of API key availability (Vault may be down)
885            for (provider_id, provider_info) in cat.all_providers() {
886                let model_infos: Vec<crate::provider::ModelInfo> = provider_info
887                    .models
888                    .values()
889                    .map(|m| cat.to_model_info(m, provider_id))
890                    .collect();
891                if !model_infos.is_empty() {
892                    tracing::debug!(
893                        "Catalog provider {}: {} models",
894                        provider_id,
895                        model_infos.len()
896                    );
897                    models_by_provider.insert(provider_id.clone(), model_infos);
898                }
899            }
900            tracing::info!(
901                "Loaded {} providers with {} total models from catalog",
902                models_by_provider.len(),
903                models_by_provider.values().map(|v| v.len()).sum::<usize>()
904            );
905        }
906    }
907
908    Ok(models_by_provider)
909}
910
911/// Fallback: build a ProviderRegistry from config file + environment variables
912async fn fallback_registry() -> Result<ProviderRegistry> {
913    let config = crate::config::Config::load().await.unwrap_or_default();
914    ProviderRegistry::from_config(&config).await
915}
916
917async fn fetch_pending_tasks(
918    client: &Client,
919    server: &str,
920    token: &Option<String>,
921    worker_id: &str,
922    processing: &Arc<Mutex<HashSet<String>>>,
923    auto_approve: &AutoApprove,
924    bus: &Arc<AgentBus>,
925) -> Result<()> {
926    tracing::info!("Checking for pending tasks...");
927
928    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
929    if let Some(t) = token {
930        req = req.bearer_auth(t);
931    }
932
933    let res = req.send().await?;
934    if !res.status().is_success() {
935        return Ok(());
936    }
937
938    let data: serde_json::Value = res.json().await?;
939    // Handle both plain array response and {tasks: [...]} wrapper
940    let tasks = if let Some(arr) = data.as_array() {
941        arr.clone()
942    } else {
943        data["tasks"].as_array().cloned().unwrap_or_default()
944    };
945
946    tracing::info!("Found {} pending task(s)", tasks.len());
947
948    for task in tasks {
949        if let Some(id) = task["id"].as_str() {
950            let mut proc = processing.lock().await;
951            if !proc.contains(id) {
952                proc.insert(id.to_string());
953                drop(proc);
954
955                let task_id = id.to_string();
956                let client = client.clone();
957                let server = server.to_string();
958                let token = token.clone();
959                let worker_id = worker_id.to_string();
960                let auto_approve = *auto_approve;
961                let processing = processing.clone();
962                let bus = bus.clone();
963
964                tokio::spawn(async move {
965                    if let Err(e) = handle_task(
966                        &client,
967                        &server,
968                        &token,
969                        &worker_id,
970                        &task,
971                        auto_approve,
972                        &bus,
973                    )
974                    .await
975                    {
976                        tracing::error!("Task {} failed: {}", task_id, e);
977                    }
978                    processing.lock().await.remove(&task_id);
979                });
980            }
981        }
982    }
983
984    Ok(())
985}
986
987#[allow(clippy::too_many_arguments)]
988async fn connect_stream(
989    client: &Client,
990    server: &str,
991    token: &Option<String>,
992    worker_id: &str,
993    name: &str,
994    codebases: &[String],
995    processing: &Arc<Mutex<HashSet<String>>>,
996    auto_approve: &AutoApprove,
997    bus: &Arc<AgentBus>,
998    task_notify_rx: Option<mpsc::Receiver<String>>,
999) -> Result<()> {
1000    let url = format!(
1001        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
1002        server,
1003        urlencoding::encode(name),
1004        urlencoding::encode(worker_id)
1005    );
1006
1007    let mut req = client
1008        .get(&url)
1009        .header("Accept", "text/event-stream")
1010        .header("X-Worker-ID", worker_id)
1011        .header("X-Agent-Name", name)
1012        .header("X-Codebases", codebases.join(","));
1013
1014    if let Some(t) = token {
1015        req = req.bearer_auth(t);
1016    }
1017
1018    let res = req.send().await?;
1019    if !res.status().is_success() {
1020        anyhow::bail!("Failed to connect: {}", res.status());
1021    }
1022
1023    tracing::info!("Connected to A2A server");
1024
1025    let mut stream = res.bytes_stream();
1026    let mut buffer = String::new();
1027    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
1028    poll_interval.tick().await; // consume the initial immediate tick
1029
1030    // Pin the optional receiver so we can use it in the loop
1031    let mut task_notify_rx = task_notify_rx;
1032
1033    loop {
1034        tokio::select! {
1035            // Handle task notification from CloudEvent (Knative Eventing)
1036            // Only if the channel was provided (i.e., running with HTTP server)
1037            task_id = async {
1038                if let Some(ref mut rx) = task_notify_rx {
1039                    rx.recv().await
1040                } else {
1041                    // Never ready when None - use pending to skip this branch
1042                    futures::future::pending().await
1043                }
1044            } => {
1045                if let Some(task_id) = task_id {
1046                    tracing::info!("Received task notification via CloudEvent: {}", task_id);
1047                    // Immediately poll for and process this task
1048                    if let Err(e) = poll_pending_tasks(
1049                        client, server, token, worker_id, processing, auto_approve, bus,
1050                    ).await {
1051                        tracing::warn!("Task notification poll failed: {}", e);
1052                    }
1053                }
1054            }
1055            chunk = stream.next() => {
1056                match chunk {
1057                    Some(Ok(chunk)) => {
1058                        buffer.push_str(&String::from_utf8_lossy(&chunk));
1059
1060                        // Process SSE events
1061                        while let Some(pos) = buffer.find("\n\n") {
1062                            let event_str = buffer[..pos].to_string();
1063                            buffer = buffer[pos + 2..].to_string();
1064
1065                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
1066                                let data = data_line.trim_start_matches("data:").trim();
1067                                if data == "[DONE]" || data.is_empty() {
1068                                    continue;
1069                                }
1070
1071                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
1072                                    spawn_task_handler(
1073                                        &task, client, server, token, worker_id,
1074                                        processing, auto_approve, bus,
1075                                    ).await;
1076                                }
1077                            }
1078                        }
1079                    }
1080                    Some(Err(e)) => {
1081                        return Err(e.into());
1082                    }
1083                    None => {
1084                        // Stream ended
1085                        return Ok(());
1086                    }
1087                }
1088            }
1089            _ = poll_interval.tick() => {
1090                // Periodic poll for pending tasks the SSE stream may have missed
1091                if let Err(e) = poll_pending_tasks(
1092                    client, server, token, worker_id, processing, auto_approve, bus,
1093                ).await {
1094                    tracing::warn!("Periodic task poll failed: {}", e);
1095                }
1096            }
1097        }
1098    }
1099}
1100
1101async fn spawn_task_handler(
1102    task: &serde_json::Value,
1103    client: &Client,
1104    server: &str,
1105    token: &Option<String>,
1106    worker_id: &str,
1107    processing: &Arc<Mutex<HashSet<String>>>,
1108    auto_approve: &AutoApprove,
1109    bus: &Arc<AgentBus>,
1110) {
1111    if let Some(id) = task
1112        .get("task")
1113        .and_then(|t| t["id"].as_str())
1114        .or_else(|| task["id"].as_str())
1115    {
1116        let mut proc = processing.lock().await;
1117        if !proc.contains(id) {
1118            proc.insert(id.to_string());
1119            drop(proc);
1120
1121            let task_id = id.to_string();
1122            let task = task.clone();
1123            let client = client.clone();
1124            let server = server.to_string();
1125            let token = token.clone();
1126            let worker_id = worker_id.to_string();
1127            let auto_approve = *auto_approve;
1128            let processing_clone = processing.clone();
1129            let bus = bus.clone();
1130
1131            tokio::spawn(async move {
1132                if let Err(e) = handle_task(
1133                    &client,
1134                    &server,
1135                    &token,
1136                    &worker_id,
1137                    &task,
1138                    auto_approve,
1139                    &bus,
1140                )
1141                .await
1142                {
1143                    tracing::error!("Task {} failed: {}", task_id, e);
1144                }
1145                processing_clone.lock().await.remove(&task_id);
1146            });
1147        }
1148    }
1149}
1150
1151async fn poll_pending_tasks(
1152    client: &Client,
1153    server: &str,
1154    token: &Option<String>,
1155    worker_id: &str,
1156    processing: &Arc<Mutex<HashSet<String>>>,
1157    auto_approve: &AutoApprove,
1158    bus: &Arc<AgentBus>,
1159) -> Result<()> {
1160    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
1161    if let Some(t) = token {
1162        req = req.bearer_auth(t);
1163    }
1164
1165    let res = req.send().await?;
1166    if !res.status().is_success() {
1167        return Ok(());
1168    }
1169
1170    let data: serde_json::Value = res.json().await?;
1171    let tasks = if let Some(arr) = data.as_array() {
1172        arr.clone()
1173    } else {
1174        data["tasks"].as_array().cloned().unwrap_or_default()
1175    };
1176
1177    if !tasks.is_empty() {
1178        tracing::debug!("Poll found {} pending task(s)", tasks.len());
1179    }
1180
1181    for task in &tasks {
1182        spawn_task_handler(
1183            task,
1184            client,
1185            server,
1186            token,
1187            worker_id,
1188            processing,
1189            auto_approve,
1190            bus,
1191        )
1192        .await;
1193    }
1194
1195    Ok(())
1196}
1197
1198async fn handle_task(
1199    client: &Client,
1200    server: &str,
1201    token: &Option<String>,
1202    worker_id: &str,
1203    task: &serde_json::Value,
1204    auto_approve: AutoApprove,
1205    bus: &Arc<AgentBus>,
1206) -> Result<()> {
1207    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
1208    let title = task_str(task, "title").unwrap_or("Untitled");
1209
1210    tracing::info!("Handling task: {} ({})", title, task_id);
1211
1212    // Claim the task
1213    let mut req = client
1214        .post(format!("{}/v1/worker/tasks/claim", server))
1215        .header("X-Worker-ID", worker_id);
1216    if let Some(t) = token {
1217        req = req.bearer_auth(t);
1218    }
1219
1220    let res = req
1221        .json(&serde_json::json!({ "task_id": task_id }))
1222        .send()
1223        .await?;
1224
1225    if !res.status().is_success() {
1226        let status = res.status();
1227        let text = res.text().await?;
1228        if status == reqwest::StatusCode::CONFLICT {
1229            tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1230        } else {
1231            tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1232        }
1233        return Ok(());
1234    }
1235
1236    tracing::info!("Claimed task: {}", task_id);
1237
1238    let metadata = task_metadata(task);
1239    let resume_session_id = metadata
1240        .get("resume_session_id")
1241        .and_then(|v| v.as_str())
1242        .map(|s| s.trim().to_string())
1243        .filter(|s| !s.is_empty());
1244    let complexity_hint = metadata_str(&metadata, &["complexity"]);
1245    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1246        .map(|s| s.to_ascii_lowercase())
1247        .or_else(|| {
1248            complexity_hint.as_ref().map(|complexity| {
1249                match complexity.to_ascii_lowercase().as_str() {
1250                    "quick" => "fast".to_string(),
1251                    "deep" => "heavy".to_string(),
1252                    _ => "balanced".to_string(),
1253                }
1254            })
1255        });
1256    let worker_personality = metadata_str(
1257        &metadata,
1258        &["worker_personality", "personality", "agent_personality"],
1259    );
1260    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1261    let raw_model = task_str(task, "model_ref")
1262        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1263        .or_else(|| task_str(task, "model"))
1264        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1265    let selected_model = raw_model.map(model_ref_to_provider_model);
1266
1267    // Resume existing session when requested; fall back to a fresh session if missing.
1268    let mut session = if let Some(ref sid) = resume_session_id {
1269        match Session::load(sid).await {
1270            Ok(existing) => {
1271                tracing::info!("Resuming session {} for task {}", sid, task_id);
1272                existing
1273            }
1274            Err(e) => {
1275                tracing::warn!(
1276                    "Could not load session {} for task {} ({}), starting a new session",
1277                    sid,
1278                    task_id,
1279                    e
1280                );
1281                Session::new().await?
1282            }
1283        }
1284    } else {
1285        Session::new().await?
1286    };
1287
1288    let agent_type = task_str(task, "agent_type")
1289        .or_else(|| task_str(task, "agent"))
1290        .unwrap_or("build");
1291    session.agent = agent_type.to_string();
1292
1293    if let Some(model) = selected_model.clone() {
1294        session.metadata.model = Some(model);
1295    }
1296
1297    let prompt = task_str(task, "prompt")
1298        .or_else(|| task_str(task, "description"))
1299        .unwrap_or(title);
1300
1301    tracing::info!("Executing prompt: {}", prompt);
1302
1303    // Set up output streaming to forward progress to the server
1304    let stream_client = client.clone();
1305    let stream_server = server.to_string();
1306    let stream_token = token.clone();
1307    let stream_worker_id = worker_id.to_string();
1308    let stream_task_id = task_id.to_string();
1309
1310    let output_callback = move |output: String| {
1311        let c = stream_client.clone();
1312        let s = stream_server.clone();
1313        let t = stream_token.clone();
1314        let w = stream_worker_id.clone();
1315        let tid = stream_task_id.clone();
1316        tokio::spawn(async move {
1317            let mut req = c
1318                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1319                .header("X-Worker-ID", &w);
1320            if let Some(tok) = &t {
1321                req = req.bearer_auth(tok);
1322            }
1323            let _ = req
1324                .json(&serde_json::json!({
1325                    "worker_id": w,
1326                    "output": output,
1327                }))
1328                .send()
1329                .await;
1330        });
1331    };
1332
1333    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1334    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1335        match execute_swarm_with_policy(
1336            &mut session,
1337            prompt,
1338            model_tier.as_deref(),
1339            selected_model,
1340            &metadata,
1341            complexity_hint.as_deref(),
1342            worker_personality.as_deref(),
1343            target_agent_name.as_deref(),
1344            Some(bus),
1345            Some(&output_callback),
1346        )
1347        .await
1348        {
1349            Ok((session_result, true)) => {
1350                tracing::info!("Swarm task completed successfully: {}", task_id);
1351                (
1352                    "completed",
1353                    Some(session_result.text),
1354                    None,
1355                    Some(session_result.session_id),
1356                )
1357            }
1358            Ok((session_result, false)) => {
1359                tracing::warn!("Swarm task completed with failures: {}", task_id);
1360                (
1361                    "failed",
1362                    Some(session_result.text),
1363                    Some("Swarm execution completed with failures".to_string()),
1364                    Some(session_result.session_id),
1365                )
1366            }
1367            Err(e) => {
1368                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1369                ("failed", None, Some(format!("Error: {}", e)), None)
1370            }
1371        }
1372    } else {
1373        match execute_session_with_policy(
1374            &mut session,
1375            prompt,
1376            auto_approve,
1377            model_tier.as_deref(),
1378            Some(&output_callback),
1379        )
1380        .await
1381        {
1382            Ok(session_result) => {
1383                tracing::info!("Task completed successfully: {}", task_id);
1384                (
1385                    "completed",
1386                    Some(session_result.text),
1387                    None,
1388                    Some(session_result.session_id),
1389                )
1390            }
1391            Err(e) => {
1392                tracing::error!("Task failed: {} - {}", task_id, e);
1393                ("failed", None, Some(format!("Error: {}", e)), None)
1394            }
1395        }
1396    };
1397
1398    // Release the task with full details
1399    let mut req = client
1400        .post(format!("{}/v1/worker/tasks/release", server))
1401        .header("X-Worker-ID", worker_id);
1402    if let Some(t) = token {
1403        req = req.bearer_auth(t);
1404    }
1405
1406    req.json(&serde_json::json!({
1407        "task_id": task_id,
1408        "status": status,
1409        "result": result,
1410        "error": error,
1411        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1412    }))
1413    .send()
1414    .await?;
1415
1416    tracing::info!("Task released: {} with status: {}", task_id, status);
1417    Ok(())
1418}
1419
1420async fn execute_swarm_with_policy<F>(
1421    session: &mut Session,
1422    prompt: &str,
1423    model_tier: Option<&str>,
1424    explicit_model: Option<String>,
1425    metadata: &serde_json::Map<String, serde_json::Value>,
1426    complexity_hint: Option<&str>,
1427    worker_personality: Option<&str>,
1428    target_agent_name: Option<&str>,
1429    bus: Option<&Arc<AgentBus>>,
1430    output_callback: Option<&F>,
1431) -> Result<(crate::session::SessionResult, bool)>
1432where
1433    F: Fn(String),
1434{
1435    use crate::provider::{ContentPart, Message, Role};
1436
1437    session.add_message(Message {
1438        role: Role::User,
1439        content: vec![ContentPart::Text {
1440            text: prompt.to_string(),
1441        }],
1442    });
1443
1444    if session.title.is_none() {
1445        session.generate_title().await?;
1446    }
1447
1448    let strategy = parse_swarm_strategy(metadata);
1449    let max_subagents = metadata_usize(
1450        metadata,
1451        &["swarm_max_subagents", "max_subagents", "subagents"],
1452    )
1453    .unwrap_or(10)
1454    .clamp(1, 100);
1455    let max_steps_per_subagent = metadata_usize(
1456        metadata,
1457        &[
1458            "swarm_max_steps_per_subagent",
1459            "max_steps_per_subagent",
1460            "max_steps",
1461        ],
1462    )
1463    .unwrap_or(50)
1464    .clamp(1, 200);
1465    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1466        .unwrap_or(600)
1467        .clamp(30, 3600);
1468    let parallel_enabled =
1469        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1470
1471    let model = resolve_swarm_model(explicit_model, model_tier).await;
1472    if let Some(ref selected_model) = model {
1473        session.metadata.model = Some(selected_model.clone());
1474    }
1475
1476    if let Some(cb) = output_callback {
1477        cb(format!(
1478            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1479            complexity_hint.unwrap_or("standard"),
1480            model_tier.unwrap_or("balanced"),
1481            worker_personality.unwrap_or("auto"),
1482            target_agent_name.unwrap_or("auto")
1483        ));
1484        cb(format!(
1485            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1486            strategy,
1487            max_subagents,
1488            max_steps_per_subagent,
1489            timeout_secs,
1490            model_tier.unwrap_or("balanced")
1491        ));
1492    }
1493
1494    let swarm_config = SwarmConfig {
1495        max_subagents,
1496        max_steps_per_subagent,
1497        subagent_timeout_secs: timeout_secs,
1498        parallel_enabled,
1499        model,
1500        working_dir: session
1501            .metadata
1502            .directory
1503            .as_ref()
1504            .map(|p| p.to_string_lossy().to_string()),
1505        ..Default::default()
1506    };
1507
1508    let swarm_result = if output_callback.is_some() {
1509        let (event_tx, mut event_rx) = mpsc::channel(256);
1510        let mut executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1511        if let Some(bus) = bus {
1512            executor = executor.with_bus(Arc::clone(bus));
1513        }
1514        let prompt_owned = prompt.to_string();
1515        let mut exec_handle =
1516            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1517
1518        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1519
1520        while final_result.is_none() {
1521            tokio::select! {
1522                maybe_event = event_rx.recv() => {
1523                    if let Some(event) = maybe_event {
1524                        if let Some(cb) = output_callback {
1525                            if let Some(line) = format_swarm_event_for_output(&event) {
1526                                cb(line);
1527                            }
1528                        }
1529                    }
1530                }
1531                join_result = &mut exec_handle => {
1532                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1533                    final_result = Some(joined?);
1534                }
1535            }
1536        }
1537
1538        while let Ok(event) = event_rx.try_recv() {
1539            if let Some(cb) = output_callback {
1540                if let Some(line) = format_swarm_event_for_output(&event) {
1541                    cb(line);
1542                }
1543            }
1544        }
1545
1546        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1547    } else {
1548        let mut executor = SwarmExecutor::new(swarm_config);
1549        if let Some(bus) = bus {
1550            executor = executor.with_bus(Arc::clone(bus));
1551        }
1552        executor.execute(prompt, strategy).await?
1553    };
1554
1555    let final_text = if swarm_result.result.trim().is_empty() {
1556        if swarm_result.success {
1557            "Swarm completed without textual output.".to_string()
1558        } else {
1559            "Swarm finished with failures and no textual output.".to_string()
1560        }
1561    } else {
1562        swarm_result.result.clone()
1563    };
1564
1565    session.add_message(Message {
1566        role: Role::Assistant,
1567        content: vec![ContentPart::Text {
1568            text: final_text.clone(),
1569        }],
1570    });
1571    session.save().await?;
1572
1573    Ok((
1574        crate::session::SessionResult {
1575            text: final_text,
1576            session_id: session.id.clone(),
1577        },
1578        swarm_result.success,
1579    ))
1580}
1581
1582/// Execute a session with the given auto-approve policy
1583/// Optionally streams output chunks via the callback
1584async fn execute_session_with_policy<F>(
1585    session: &mut Session,
1586    prompt: &str,
1587    auto_approve: AutoApprove,
1588    model_tier: Option<&str>,
1589    output_callback: Option<&F>,
1590) -> Result<crate::session::SessionResult>
1591where
1592    F: Fn(String),
1593{
1594    use crate::provider::{
1595        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1596    };
1597    use std::sync::Arc;
1598
1599    // Load provider registry from Vault
1600    let registry = ProviderRegistry::from_vault().await?;
1601    let providers = registry.list();
1602    tracing::info!("Available providers: {:?}", providers);
1603
1604    if providers.is_empty() {
1605        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1606    }
1607
1608    // Parse model string
1609    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1610        let (prov, model) = parse_model_string(model_str);
1611        let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
1612        if prov.is_some() {
1613            (prov.map(|s| s.to_string()), model.to_string())
1614        } else if providers.contains(&model) {
1615            (Some(model.to_string()), String::new())
1616        } else {
1617            (None, model.to_string())
1618        }
1619    } else {
1620        (None, String::new())
1621    };
1622
1623    let provider_slice = providers.as_slice();
1624    let provider_requested_but_unavailable = provider_name
1625        .as_deref()
1626        .map(|p| !providers.contains(&p))
1627        .unwrap_or(false);
1628
1629    // Determine which provider to use, preferring explicit request first, then model tier.
1630    let selected_provider = provider_name
1631        .as_deref()
1632        .filter(|p| providers.contains(p))
1633        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1634
1635    let provider = registry
1636        .get(selected_provider)
1637        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1638
1639    // Add user message
1640    session.add_message(Message {
1641        role: Role::User,
1642        content: vec![ContentPart::Text {
1643            text: prompt.to_string(),
1644        }],
1645    });
1646
1647    // Generate title
1648    if session.title.is_none() {
1649        session.generate_title().await?;
1650    }
1651
1652    // Determine model. If a specific provider was requested but not available,
1653    // ignore that model id and fall back to the tier-based default model.
1654    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1655        model_id
1656    } else {
1657        default_model_for_provider(selected_provider, model_tier)
1658    };
1659
1660    // Create tool registry with filtering based on auto-approve policy
1661    let tool_registry =
1662        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1663    let tool_definitions = tool_registry.definitions();
1664
1665    let temperature = if prefers_temperature_one(&model) {
1666        Some(1.0)
1667    } else {
1668        Some(0.7)
1669    };
1670
1671    tracing::info!(
1672        "Using model: {} via provider: {} (tier: {:?})",
1673        model,
1674        selected_provider,
1675        model_tier
1676    );
1677    tracing::info!(
1678        "Available tools: {} (auto_approve: {:?})",
1679        tool_definitions.len(),
1680        auto_approve
1681    );
1682
1683    // Build system prompt
1684    let cwd = std::env::var("PWD")
1685        .map(std::path::PathBuf::from)
1686        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1687    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1688
1689    let mut final_output = String::new();
1690    let max_steps = 50;
1691
1692    for step in 1..=max_steps {
1693        tracing::info!(step = step, "Agent step starting");
1694
1695        // Build messages with system prompt first
1696        let mut messages = vec![Message {
1697            role: Role::System,
1698            content: vec![ContentPart::Text {
1699                text: system_prompt.clone(),
1700            }],
1701        }];
1702        messages.extend(session.messages.clone());
1703
1704        let request = CompletionRequest {
1705            messages,
1706            tools: tool_definitions.clone(),
1707            model: model.clone(),
1708            temperature,
1709            top_p: None,
1710            max_tokens: Some(8192),
1711            stop: Vec::new(),
1712        };
1713
1714        let response = provider.complete(request).await?;
1715
1716        crate::telemetry::TOKEN_USAGE.record_model_usage(
1717            &model,
1718            response.usage.prompt_tokens as u64,
1719            response.usage.completion_tokens as u64,
1720        );
1721
1722        // Extract tool calls
1723        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1724            .message
1725            .content
1726            .iter()
1727            .filter_map(|part| {
1728                if let ContentPart::ToolCall {
1729                    id,
1730                    name,
1731                    arguments,
1732                } = part
1733                {
1734                    let args: serde_json::Value =
1735                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1736                    Some((id.clone(), name.clone(), args))
1737                } else {
1738                    None
1739                }
1740            })
1741            .collect();
1742
1743        // Collect text output and stream it
1744        for part in &response.message.content {
1745            if let ContentPart::Text { text } = part {
1746                if !text.is_empty() {
1747                    final_output.push_str(text);
1748                    final_output.push('\n');
1749                    if let Some(cb) = output_callback {
1750                        cb(text.clone());
1751                    }
1752                }
1753            }
1754        }
1755
1756        // If no tool calls, we're done
1757        if tool_calls.is_empty() {
1758            session.add_message(response.message.clone());
1759            break;
1760        }
1761
1762        session.add_message(response.message.clone());
1763
1764        tracing::info!(
1765            step = step,
1766            num_tools = tool_calls.len(),
1767            "Executing tool calls"
1768        );
1769
1770        // Execute each tool call
1771        for (tool_id, tool_name, tool_input) in tool_calls {
1772            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1773
1774            // Stream tool start event
1775            if let Some(cb) = output_callback {
1776                cb(format!("[tool:start:{}]", tool_name));
1777            }
1778
1779            // Check if tool is allowed based on auto-approve policy
1780            if !is_tool_allowed(&tool_name, auto_approve) {
1781                let msg = format!(
1782                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1783                    tool_name, auto_approve
1784                );
1785                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1786                session.add_message(Message {
1787                    role: Role::Tool,
1788                    content: vec![ContentPart::ToolResult {
1789                        tool_call_id: tool_id,
1790                        content: msg,
1791                    }],
1792                });
1793                continue;
1794            }
1795
1796            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1797                let exec_result: Result<crate::tool::ToolResult> =
1798                    tool.execute(tool_input.clone()).await;
1799                match exec_result {
1800                    Ok(result) => {
1801                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1802                        if let Some(cb) = output_callback {
1803                            let status = if result.success { "ok" } else { "err" };
1804                            cb(format!(
1805                                "[tool:{}:{}] {}",
1806                                tool_name,
1807                                status,
1808                                &result.output[..result.output.len().min(500)]
1809                            ));
1810                        }
1811                        result.output
1812                    }
1813                    Err(e) => {
1814                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1815                        if let Some(cb) = output_callback {
1816                            cb(format!("[tool:{}:err] {}", tool_name, e));
1817                        }
1818                        format!("Error: {}", e)
1819                    }
1820                }
1821            } else {
1822                tracing::warn!(tool = %tool_name, "Tool not found");
1823                format!("Error: Unknown tool '{}'", tool_name)
1824            };
1825
1826            session.add_message(Message {
1827                role: Role::Tool,
1828                content: vec![ContentPart::ToolResult {
1829                    tool_call_id: tool_id,
1830                    content,
1831                }],
1832            });
1833        }
1834    }
1835
1836    session.save().await?;
1837
1838    Ok(crate::session::SessionResult {
1839        text: final_output.trim().to_string(),
1840        session_id: session.id.clone(),
1841    })
1842}
1843
1844/// Check if a tool is allowed based on the auto-approve policy
1845fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1846    match auto_approve {
1847        AutoApprove::All => true,
1848        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1849    }
1850}
1851
1852/// Check if a tool is considered "safe" (read-only)
1853fn is_safe_tool(tool_name: &str) -> bool {
1854    let safe_tools = [
1855        "read",
1856        "list",
1857        "glob",
1858        "grep",
1859        "codesearch",
1860        "lsp",
1861        "webfetch",
1862        "websearch",
1863        "todo_read",
1864        "skill",
1865    ];
1866    safe_tools.contains(&tool_name)
1867}
1868
1869/// Create a filtered tool registry based on the auto-approve policy
1870fn create_filtered_registry(
1871    provider: Arc<dyn crate::provider::Provider>,
1872    model: String,
1873    auto_approve: AutoApprove,
1874) -> crate::tool::ToolRegistry {
1875    use crate::tool::*;
1876
1877    let mut registry = ToolRegistry::new();
1878
1879    // Always add safe tools
1880    registry.register(Arc::new(file::ReadTool::new()));
1881    registry.register(Arc::new(file::ListTool::new()));
1882    registry.register(Arc::new(file::GlobTool::new()));
1883    registry.register(Arc::new(search::GrepTool::new()));
1884    registry.register(Arc::new(lsp::LspTool::new()));
1885    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1886    registry.register(Arc::new(websearch::WebSearchTool::new()));
1887    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1888    registry.register(Arc::new(todo::TodoReadTool::new()));
1889    registry.register(Arc::new(skill::SkillTool::new()));
1890
1891    // Add potentially dangerous tools only if auto_approve is All
1892    if matches!(auto_approve, AutoApprove::All) {
1893        registry.register(Arc::new(file::WriteTool::new()));
1894        registry.register(Arc::new(edit::EditTool::new()));
1895        registry.register(Arc::new(bash::BashTool::new()));
1896        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1897        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1898        registry.register(Arc::new(todo::TodoWriteTool::new()));
1899        registry.register(Arc::new(task::TaskTool::new()));
1900        registry.register(Arc::new(plan::PlanEnterTool::new()));
1901        registry.register(Arc::new(plan::PlanExitTool::new()));
1902        registry.register(Arc::new(rlm::RlmTool::new(Arc::clone(&provider), model.clone())));
1903        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1904        registry.register(Arc::new(prd::PrdTool::new()));
1905        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1906        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1907        registry.register(Arc::new(undo::UndoTool));
1908        registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1909    }
1910
1911    registry.register(Arc::new(invalid::InvalidTool::new()));
1912
1913    registry
1914}
1915
1916/// Start the heartbeat background task
1917/// Returns a JoinHandle that can be used to cancel the heartbeat
1918fn start_heartbeat(
1919    client: Client,
1920    server: String,
1921    token: Option<String>,
1922    heartbeat_state: HeartbeatState,
1923    processing: Arc<Mutex<HashSet<String>>>,
1924    cognition_config: CognitionHeartbeatConfig,
1925) -> JoinHandle<()> {
1926    tokio::spawn(async move {
1927        let mut consecutive_failures = 0u32;
1928        const MAX_FAILURES: u32 = 3;
1929        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1930        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1931        let mut cognition_payload_disabled_until: Option<Instant> = None;
1932
1933        let mut interval =
1934            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1935        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1936
1937        loop {
1938            interval.tick().await;
1939
1940            // Update task count from processing set
1941            let active_count = processing.lock().await.len();
1942            heartbeat_state.set_task_count(active_count).await;
1943
1944            // Determine status based on active tasks
1945            let status = if active_count > 0 {
1946                WorkerStatus::Processing
1947            } else {
1948                WorkerStatus::Idle
1949            };
1950            heartbeat_state.set_status(status).await;
1951
1952            // Send heartbeat
1953            let url = format!(
1954                "{}/v1/opencode/workers/{}/heartbeat",
1955                server, heartbeat_state.worker_id
1956            );
1957            let mut req = client.post(&url);
1958
1959            if let Some(ref t) = token {
1960                req = req.bearer_auth(t);
1961            }
1962
1963            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1964            let base_payload = serde_json::json!({
1965                "worker_id": &heartbeat_state.worker_id,
1966                "agent_name": &heartbeat_state.agent_name,
1967                "status": status_str,
1968                "active_task_count": active_count,
1969            });
1970            let mut payload = base_payload.clone();
1971            let mut included_cognition_payload = false;
1972            let cognition_payload_allowed = cognition_payload_disabled_until
1973                .map(|until| Instant::now() >= until)
1974                .unwrap_or(true);
1975
1976            if cognition_config.enabled
1977                && cognition_payload_allowed
1978                && let Some(cognition_payload) =
1979                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1980                && let Some(obj) = payload.as_object_mut()
1981            {
1982                obj.insert("cognition".to_string(), cognition_payload);
1983                included_cognition_payload = true;
1984            }
1985
1986            match req.json(&payload).send().await {
1987                Ok(res) => {
1988                    if res.status().is_success() {
1989                        consecutive_failures = 0;
1990                        tracing::debug!(
1991                            worker_id = %heartbeat_state.worker_id,
1992                            status = status_str,
1993                            active_tasks = active_count,
1994                            "Heartbeat sent successfully"
1995                        );
1996                    } else if included_cognition_payload && res.status().is_client_error() {
1997                        tracing::warn!(
1998                            worker_id = %heartbeat_state.worker_id,
1999                            status = %res.status(),
2000                            "Heartbeat cognition payload rejected, retrying without cognition payload"
2001                        );
2002
2003                        let mut retry_req = client.post(&url);
2004                        if let Some(ref t) = token {
2005                            retry_req = retry_req.bearer_auth(t);
2006                        }
2007
2008                        match retry_req.json(&base_payload).send().await {
2009                            Ok(retry_res) if retry_res.status().is_success() => {
2010                                cognition_payload_disabled_until = Some(
2011                                    Instant::now()
2012                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
2013                                );
2014                                consecutive_failures = 0;
2015                                tracing::warn!(
2016                                    worker_id = %heartbeat_state.worker_id,
2017                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
2018                                    "Paused cognition heartbeat payload after schema rejection"
2019                                );
2020                            }
2021                            Ok(retry_res) => {
2022                                consecutive_failures += 1;
2023                                tracing::warn!(
2024                                    worker_id = %heartbeat_state.worker_id,
2025                                    status = %retry_res.status(),
2026                                    failures = consecutive_failures,
2027                                    "Heartbeat failed even after retry without cognition payload"
2028                                );
2029                            }
2030                            Err(e) => {
2031                                consecutive_failures += 1;
2032                                tracing::warn!(
2033                                    worker_id = %heartbeat_state.worker_id,
2034                                    error = %e,
2035                                    failures = consecutive_failures,
2036                                    "Heartbeat retry without cognition payload failed"
2037                                );
2038                            }
2039                        }
2040                    } else {
2041                        consecutive_failures += 1;
2042                        tracing::warn!(
2043                            worker_id = %heartbeat_state.worker_id,
2044                            status = %res.status(),
2045                            failures = consecutive_failures,
2046                            "Heartbeat failed"
2047                        );
2048                    }
2049                }
2050                Err(e) => {
2051                    consecutive_failures += 1;
2052                    tracing::warn!(
2053                        worker_id = %heartbeat_state.worker_id,
2054                        error = %e,
2055                        failures = consecutive_failures,
2056                        "Heartbeat request failed"
2057                    );
2058                }
2059            }
2060
2061            // Log error after 3 consecutive failures but do not terminate
2062            if consecutive_failures >= MAX_FAILURES {
2063                tracing::error!(
2064                    worker_id = %heartbeat_state.worker_id,
2065                    failures = consecutive_failures,
2066                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
2067                    MAX_FAILURES
2068                );
2069                // Reset counter to avoid spamming error logs
2070                consecutive_failures = 0;
2071            }
2072        }
2073    })
2074}
2075
2076async fn fetch_cognition_heartbeat_payload(
2077    client: &Client,
2078    config: &CognitionHeartbeatConfig,
2079) -> Option<serde_json::Value> {
2080    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
2081    let status_res = tokio::time::timeout(
2082        Duration::from_millis(config.request_timeout_ms),
2083        client.get(status_url).send(),
2084    )
2085    .await
2086    .ok()?
2087    .ok()?;
2088
2089    if !status_res.status().is_success() {
2090        return None;
2091    }
2092
2093    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
2094    let mut payload = serde_json::json!({
2095        "running": status.running,
2096        "last_tick_at": status.last_tick_at,
2097        "active_persona_count": status.active_persona_count,
2098        "events_buffered": status.events_buffered,
2099        "snapshots_buffered": status.snapshots_buffered,
2100        "loop_interval_ms": status.loop_interval_ms,
2101    });
2102
2103    if config.include_thought_summary {
2104        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
2105        let snapshot_res = tokio::time::timeout(
2106            Duration::from_millis(config.request_timeout_ms),
2107            client.get(snapshot_url).send(),
2108        )
2109        .await
2110        .ok()
2111        .and_then(Result::ok);
2112
2113        if let Some(snapshot_res) = snapshot_res
2114            && snapshot_res.status().is_success()
2115            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
2116            && let Some(obj) = payload.as_object_mut()
2117        {
2118            obj.insert(
2119                "latest_snapshot_at".to_string(),
2120                serde_json::Value::String(snapshot.generated_at),
2121            );
2122            obj.insert(
2123                "latest_thought".to_string(),
2124                serde_json::Value::String(trim_for_heartbeat(
2125                    &snapshot.summary,
2126                    config.summary_max_chars,
2127                )),
2128            );
2129            if let Some(model) = snapshot
2130                .metadata
2131                .get("model")
2132                .and_then(serde_json::Value::as_str)
2133            {
2134                obj.insert(
2135                    "latest_thought_model".to_string(),
2136                    serde_json::Value::String(model.to_string()),
2137                );
2138            }
2139            if let Some(source) = snapshot
2140                .metadata
2141                .get("source")
2142                .and_then(serde_json::Value::as_str)
2143            {
2144                obj.insert(
2145                    "latest_thought_source".to_string(),
2146                    serde_json::Value::String(source.to_string()),
2147                );
2148            }
2149        }
2150    }
2151
2152    Some(payload)
2153}
2154
2155fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
2156    if input.chars().count() <= max_chars {
2157        return input.trim().to_string();
2158    }
2159
2160    let mut trimmed = String::with_capacity(max_chars + 3);
2161    for ch in input.chars().take(max_chars) {
2162        trimmed.push(ch);
2163    }
2164    trimmed.push_str("...");
2165    trimmed.trim().to_string()
2166}
2167
2168fn env_bool(name: &str, default: bool) -> bool {
2169    std::env::var(name)
2170        .ok()
2171        .and_then(|v| match v.to_ascii_lowercase().as_str() {
2172            "1" | "true" | "yes" | "on" => Some(true),
2173            "0" | "false" | "no" | "off" => Some(false),
2174            _ => None,
2175        })
2176        .unwrap_or(default)
2177}
2178
2179fn env_usize(name: &str, default: usize) -> usize {
2180    std::env::var(name)
2181        .ok()
2182        .and_then(|v| v.parse::<usize>().ok())
2183        .unwrap_or(default)
2184}
2185
2186fn env_u64(name: &str, default: u64) -> u64 {
2187    std::env::var(name)
2188        .ok()
2189        .and_then(|v| v.parse::<u64>().ok())
2190        .unwrap_or(default)
2191}