Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{Mutex, mpsc};
17use tokio::task::JoinHandle;
18use tokio::time::Instant;
19
20/// Worker status for heartbeat
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum WorkerStatus {
23    Idle,
24    Processing,
25}
26
27impl WorkerStatus {
28    fn as_str(&self) -> &'static str {
29        match self {
30            WorkerStatus::Idle => "idle",
31            WorkerStatus::Processing => "processing",
32        }
33    }
34}
35
36/// Heartbeat state shared between the heartbeat task and the main worker
37#[derive(Clone)]
38struct HeartbeatState {
39    worker_id: String,
40    agent_name: String,
41    status: Arc<Mutex<WorkerStatus>>,
42    active_task_count: Arc<Mutex<usize>>,
43}
44
45impl HeartbeatState {
46    fn new(worker_id: String, agent_name: String) -> Self {
47        Self {
48            worker_id,
49            agent_name,
50            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
51            active_task_count: Arc::new(Mutex::new(0)),
52        }
53    }
54
55    async fn set_status(&self, status: WorkerStatus) {
56        *self.status.lock().await = status;
57    }
58
59    async fn set_task_count(&self, count: usize) {
60        *self.active_task_count.lock().await = count;
61    }
62}
63
64#[derive(Clone, Debug)]
65struct CognitionHeartbeatConfig {
66    enabled: bool,
67    source_base_url: String,
68    include_thought_summary: bool,
69    summary_max_chars: usize,
70    request_timeout_ms: u64,
71}
72
73impl CognitionHeartbeatConfig {
74    fn from_env() -> Self {
75        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
76            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
77            .trim_end_matches('/')
78            .to_string();
79
80        Self {
81            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
82            source_base_url,
83            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
84            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
85                .max(120),
86            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
87        }
88    }
89}
90
91#[derive(Debug, Deserialize)]
92struct CognitionStatusSnapshot {
93    running: bool,
94    #[serde(default)]
95    last_tick_at: Option<String>,
96    #[serde(default)]
97    active_persona_count: usize,
98    #[serde(default)]
99    events_buffered: usize,
100    #[serde(default)]
101    snapshots_buffered: usize,
102    #[serde(default)]
103    loop_interval_ms: u64,
104}
105
106#[derive(Debug, Deserialize)]
107struct CognitionLatestSnapshot {
108    generated_at: String,
109    summary: String,
110    #[serde(default)]
111    metadata: HashMap<String, serde_json::Value>,
112}
113
114/// Run the A2A worker
115pub async fn run(args: A2aArgs) -> Result<()> {
116    let server = args.server.trim_end_matches('/');
117    let name = args
118        .name
119        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
120    let worker_id = generate_worker_id();
121
122    let codebases: Vec<String> = args
123        .codebases
124        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
125        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
126
127    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
128    tracing::info!("Server: {}", server);
129    tracing::info!("Codebases: {:?}", codebases);
130
131    let client = Client::new();
132    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
133    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
134    if cognition_heartbeat.enabled {
135        tracing::info!(
136            source = %cognition_heartbeat.source_base_url,
137            include_thoughts = cognition_heartbeat.include_thought_summary,
138            max_chars = cognition_heartbeat.summary_max_chars,
139            timeout_ms = cognition_heartbeat.request_timeout_ms,
140            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
141        );
142    } else {
143        tracing::warn!(
144            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
145        );
146    }
147
148    let auto_approve = match args.auto_approve.as_str() {
149        "all" => AutoApprove::All,
150        "safe" => AutoApprove::Safe,
151        _ => AutoApprove::None,
152    };
153
154    // Create heartbeat state
155    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
156
157    // Register worker
158    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
159
160    // Fetch pending tasks
161    fetch_pending_tasks(
162        &client,
163        server,
164        &args.token,
165        &worker_id,
166        &processing,
167        &auto_approve,
168    )
169    .await?;
170
171    // Connect to SSE stream
172    loop {
173        // Re-register worker on each reconnection to report updated models/capabilities
174        if let Err(e) =
175            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
176        {
177            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
178        }
179
180        // Start heartbeat task for this connection
181        let heartbeat_handle = start_heartbeat(
182            client.clone(),
183            server.to_string(),
184            args.token.clone(),
185            heartbeat_state.clone(),
186            processing.clone(),
187            cognition_heartbeat.clone(),
188        );
189
190        match connect_stream(
191            &client,
192            server,
193            &args.token,
194            &worker_id,
195            &name,
196            &codebases,
197            &processing,
198            &auto_approve,
199        )
200        .await
201        {
202            Ok(()) => {
203                tracing::warn!("Stream ended, reconnecting...");
204            }
205            Err(e) => {
206                tracing::error!("Stream error: {}, reconnecting...", e);
207            }
208        }
209
210        // Cancel heartbeat on disconnection
211        heartbeat_handle.abort();
212        tracing::debug!("Heartbeat cancelled for reconnection");
213
214        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
215    }
216}
217
218fn generate_worker_id() -> String {
219    format!(
220        "wrk_{}_{:x}",
221        chrono::Utc::now().timestamp(),
222        rand::random::<u64>()
223    )
224}
225
226#[derive(Debug, Clone, Copy)]
227enum AutoApprove {
228    All,
229    Safe,
230    None,
231}
232
233/// Capabilities of the codetether-agent worker
234const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
235
236fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
237    task.get("task")
238        .and_then(|t| t.get(key))
239        .or_else(|| task.get(key))
240}
241
242fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
243    task_value(task, key).and_then(|v| v.as_str())
244}
245
246fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
247    task_value(task, "metadata")
248        .and_then(|m| m.as_object())
249        .cloned()
250        .unwrap_or_default()
251}
252
253fn model_ref_to_provider_model(model: &str) -> String {
254    // Convert "provider:model" to "provider/model" format, but only if
255    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
256    // contain colons as version separators and must NOT be converted.
257    if !model.contains('/') && model.contains(':') {
258        model.replacen(':', "/", 1)
259    } else {
260        model.to_string()
261    }
262}
263
264fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
265    match model_tier.unwrap_or("balanced") {
266        "fast" | "quick" => &[
267            "openai",
268            "github-copilot",
269            "moonshotai",
270            "zhipuai",
271            "openrouter",
272            "novita",
273            "google",
274            "anthropic",
275        ],
276        "heavy" | "deep" => &[
277            "anthropic",
278            "openai",
279            "github-copilot",
280            "moonshotai",
281            "zhipuai",
282            "openrouter",
283            "novita",
284            "google",
285        ],
286        _ => &[
287            "openai",
288            "github-copilot",
289            "anthropic",
290            "moonshotai",
291            "zhipuai",
292            "openrouter",
293            "novita",
294            "google",
295        ],
296    }
297}
298
299fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
300    for preferred in provider_preferences_for_tier(model_tier) {
301        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
302            return found;
303        }
304    }
305    if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
306        return found;
307    }
308    providers[0]
309}
310
311fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
312    match model_tier.unwrap_or("balanced") {
313        "fast" | "quick" => match provider {
314            "moonshotai" => "kimi-k2.5".to_string(),
315            "anthropic" => "claude-haiku-4-5".to_string(),
316            "openai" => "gpt-4o-mini".to_string(),
317            "google" => "gemini-2.5-flash".to_string(),
318            "zhipuai" => "glm-4.7".to_string(),
319            "openrouter" => "z-ai/glm-4.7".to_string(),
320            "novita" => "qwen/qwen3-coder-next".to_string(),
321            _ => "glm-4.7".to_string(),
322        },
323        "heavy" | "deep" => match provider {
324            "moonshotai" => "kimi-k2.5".to_string(),
325            "anthropic" => "claude-sonnet-4-20250514".to_string(),
326            "openai" => "o3".to_string(),
327            "google" => "gemini-2.5-pro".to_string(),
328            "zhipuai" => "glm-4.7".to_string(),
329            "openrouter" => "z-ai/glm-4.7".to_string(),
330            "novita" => "qwen/qwen3-coder-next".to_string(),
331            _ => "glm-4.7".to_string(),
332        },
333        _ => match provider {
334            "moonshotai" => "kimi-k2.5".to_string(),
335            "anthropic" => "claude-sonnet-4-20250514".to_string(),
336            "openai" => "gpt-4o".to_string(),
337            "google" => "gemini-2.5-pro".to_string(),
338            "zhipuai" => "glm-4.7".to_string(),
339            "openrouter" => "z-ai/glm-4.7".to_string(),
340            "novita" => "qwen/qwen3-coder-next".to_string(),
341            _ => "glm-4.7".to_string(),
342        },
343    }
344}
345
346fn is_swarm_agent(agent_type: &str) -> bool {
347    matches!(
348        agent_type.trim().to_ascii_lowercase().as_str(),
349        "swarm" | "parallel" | "multi-agent"
350    )
351}
352
353fn metadata_lookup<'a>(
354    metadata: &'a serde_json::Map<String, serde_json::Value>,
355    key: &str,
356) -> Option<&'a serde_json::Value> {
357    metadata
358        .get(key)
359        .or_else(|| {
360            metadata
361                .get("routing")
362                .and_then(|v| v.as_object())
363                .and_then(|obj| obj.get(key))
364        })
365        .or_else(|| {
366            metadata
367                .get("swarm")
368                .and_then(|v| v.as_object())
369                .and_then(|obj| obj.get(key))
370        })
371}
372
373fn metadata_str(
374    metadata: &serde_json::Map<String, serde_json::Value>,
375    keys: &[&str],
376) -> Option<String> {
377    for key in keys {
378        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
379            let trimmed = value.trim();
380            if !trimmed.is_empty() {
381                return Some(trimmed.to_string());
382            }
383        }
384    }
385    None
386}
387
388fn metadata_usize(
389    metadata: &serde_json::Map<String, serde_json::Value>,
390    keys: &[&str],
391) -> Option<usize> {
392    for key in keys {
393        if let Some(value) = metadata_lookup(metadata, key) {
394            if let Some(v) = value.as_u64() {
395                return usize::try_from(v).ok();
396            }
397            if let Some(v) = value.as_i64() {
398                if v >= 0 {
399                    return usize::try_from(v as u64).ok();
400                }
401            }
402            if let Some(v) = value.as_str() {
403                if let Ok(parsed) = v.trim().parse::<usize>() {
404                    return Some(parsed);
405                }
406            }
407        }
408    }
409    None
410}
411
412fn metadata_u64(
413    metadata: &serde_json::Map<String, serde_json::Value>,
414    keys: &[&str],
415) -> Option<u64> {
416    for key in keys {
417        if let Some(value) = metadata_lookup(metadata, key) {
418            if let Some(v) = value.as_u64() {
419                return Some(v);
420            }
421            if let Some(v) = value.as_i64() {
422                if v >= 0 {
423                    return Some(v as u64);
424                }
425            }
426            if let Some(v) = value.as_str() {
427                if let Ok(parsed) = v.trim().parse::<u64>() {
428                    return Some(parsed);
429                }
430            }
431        }
432    }
433    None
434}
435
436fn metadata_bool(
437    metadata: &serde_json::Map<String, serde_json::Value>,
438    keys: &[&str],
439) -> Option<bool> {
440    for key in keys {
441        if let Some(value) = metadata_lookup(metadata, key) {
442            if let Some(v) = value.as_bool() {
443                return Some(v);
444            }
445            if let Some(v) = value.as_str() {
446                match v.trim().to_ascii_lowercase().as_str() {
447                    "1" | "true" | "yes" | "on" => return Some(true),
448                    "0" | "false" | "no" | "off" => return Some(false),
449                    _ => {}
450                }
451            }
452        }
453    }
454    None
455}
456
457fn parse_swarm_strategy(
458    metadata: &serde_json::Map<String, serde_json::Value>,
459) -> DecompositionStrategy {
460    match metadata_str(
461        metadata,
462        &[
463            "decomposition_strategy",
464            "swarm_strategy",
465            "strategy",
466            "swarm_decomposition",
467        ],
468    )
469    .as_deref()
470    .map(|s| s.to_ascii_lowercase())
471    .as_deref()
472    {
473        Some("none") | Some("single") => DecompositionStrategy::None,
474        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
475        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
476        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
477        _ => DecompositionStrategy::Automatic,
478    }
479}
480
481async fn resolve_swarm_model(
482    explicit_model: Option<String>,
483    model_tier: Option<&str>,
484) -> Option<String> {
485    if let Some(model) = explicit_model {
486        if !model.trim().is_empty() {
487            return Some(model);
488        }
489    }
490
491    let registry = ProviderRegistry::from_vault().await.ok()?;
492    let providers = registry.list();
493    if providers.is_empty() {
494        return None;
495    }
496    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
497    let model = default_model_for_provider(provider, model_tier);
498    Some(format!("{}/{}", provider, model))
499}
500
501fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
502    match event {
503        SwarmEvent::Started {
504            task,
505            total_subtasks,
506        } => Some(format!(
507            "[swarm] started task={} planned_subtasks={}",
508            task, total_subtasks
509        )),
510        SwarmEvent::StageComplete {
511            stage,
512            completed,
513            failed,
514        } => Some(format!(
515            "[swarm] stage={} completed={} failed={}",
516            stage, completed, failed
517        )),
518        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
519            "[swarm] subtask id={} status={}",
520            &id.chars().take(8).collect::<String>(),
521            format!("{status:?}").to_ascii_lowercase()
522        )),
523        SwarmEvent::AgentToolCall {
524            subtask_id,
525            tool_name,
526        } => Some(format!(
527            "[swarm] subtask id={} tool={}",
528            &subtask_id.chars().take(8).collect::<String>(),
529            tool_name
530        )),
531        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
532            "[swarm] subtask id={} error={}",
533            &subtask_id.chars().take(8).collect::<String>(),
534            error
535        )),
536        SwarmEvent::Complete { success, stats } => Some(format!(
537            "[swarm] complete success={} subtasks={} speedup={:.2}",
538            success,
539            stats.subagents_completed + stats.subagents_failed,
540            stats.speedup_factor
541        )),
542        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
543        _ => None,
544    }
545}
546
547async fn register_worker(
548    client: &Client,
549    server: &str,
550    token: &Option<String>,
551    worker_id: &str,
552    name: &str,
553    codebases: &[String],
554) -> Result<()> {
555    // Load ProviderRegistry and collect available models
556    let models = match load_provider_models().await {
557        Ok(m) => m,
558        Err(e) => {
559            tracing::warn!(
560                "Failed to load provider models: {}, proceeding without model info",
561                e
562            );
563            HashMap::new()
564        }
565    };
566
567    // Register via the workers/register endpoint
568    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
569
570    if let Some(t) = token {
571        req = req.bearer_auth(t);
572    }
573
574    // Flatten models HashMap into array of {id, name, provider, provider_id} objects
575    // matching the format expected by the A2A server's /models and /workers endpoints
576    let models_array: Vec<serde_json::Value> = models
577        .iter()
578        .flat_map(|(provider, model_ids)| {
579            model_ids.iter().map(move |model_id| {
580                serde_json::json!({
581                    "id": format!("{}/{}", provider, model_id),
582                    "name": model_id,
583                    "provider": provider,
584                    "provider_id": provider,
585                })
586            })
587        })
588        .collect();
589
590    tracing::info!(
591        "Registering worker with {} models from {} providers",
592        models_array.len(),
593        models.len()
594    );
595
596    let hostname = std::env::var("HOSTNAME")
597        .or_else(|_| std::env::var("COMPUTERNAME"))
598        .unwrap_or_else(|_| "unknown".to_string());
599
600    let res = req
601        .json(&serde_json::json!({
602            "worker_id": worker_id,
603            "name": name,
604            "capabilities": WORKER_CAPABILITIES,
605            "hostname": hostname,
606            "models": models_array,
607            "codebases": codebases,
608        }))
609        .send()
610        .await?;
611
612    if res.status().is_success() {
613        tracing::info!("Worker registered successfully");
614    } else {
615        tracing::warn!("Failed to register worker: {}", res.status());
616    }
617
618    Ok(())
619}
620
621/// Load ProviderRegistry and collect all available models grouped by provider.
622/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
623async fn load_provider_models() -> Result<HashMap<String, Vec<String>>> {
624    // Try Vault first
625    let registry = match ProviderRegistry::from_vault().await {
626        Ok(r) if !r.list().is_empty() => {
627            tracing::info!("Loaded {} providers from Vault", r.list().len());
628            r
629        }
630        Ok(_) => {
631            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
632            fallback_registry().await?
633        }
634        Err(e) => {
635            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
636            fallback_registry().await?
637        }
638    };
639
640    let mut models_by_provider: HashMap<String, Vec<String>> = HashMap::new();
641
642    for provider_name in registry.list() {
643        if let Some(provider) = registry.get(provider_name) {
644            match provider.list_models().await {
645                Ok(models) => {
646                    let model_ids: Vec<String> = models.into_iter().map(|m| m.id).collect();
647                    if !model_ids.is_empty() {
648                        tracing::debug!("Provider {}: {} models", provider_name, model_ids.len());
649                        models_by_provider.insert(provider_name.to_string(), model_ids);
650                    }
651                }
652                Err(e) => {
653                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
654                }
655            }
656        }
657    }
658
659    // If we still have 0, try the models.dev catalog as last resort
660    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
661    // because the worker should advertise what it can handle. The server handles routing.
662    if models_by_provider.is_empty() {
663        tracing::info!(
664            "No authenticated providers found, fetching models.dev catalog (all providers)"
665        );
666        if let Ok(catalog) = crate::provider::models::ModelCatalog::fetch().await {
667            // Use all_providers_with_models() to get every provider+model from catalog
668            // regardless of API key availability (Vault may be down)
669            for (provider_id, provider_info) in catalog.all_providers() {
670                let model_ids: Vec<String> = provider_info
671                    .models
672                    .values()
673                    .map(|m| m.id.clone())
674                    .collect();
675                if !model_ids.is_empty() {
676                    tracing::debug!(
677                        "Catalog provider {}: {} models",
678                        provider_id,
679                        model_ids.len()
680                    );
681                    models_by_provider.insert(provider_id.clone(), model_ids);
682                }
683            }
684            tracing::info!(
685                "Loaded {} providers with {} total models from catalog",
686                models_by_provider.len(),
687                models_by_provider.values().map(|v| v.len()).sum::<usize>()
688            );
689        }
690    }
691
692    Ok(models_by_provider)
693}
694
695/// Fallback: build a ProviderRegistry from config file + environment variables
696async fn fallback_registry() -> Result<ProviderRegistry> {
697    let config = crate::config::Config::load().await.unwrap_or_default();
698    ProviderRegistry::from_config(&config).await
699}
700
701async fn fetch_pending_tasks(
702    client: &Client,
703    server: &str,
704    token: &Option<String>,
705    worker_id: &str,
706    processing: &Arc<Mutex<HashSet<String>>>,
707    auto_approve: &AutoApprove,
708) -> Result<()> {
709    tracing::info!("Checking for pending tasks...");
710
711    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
712    if let Some(t) = token {
713        req = req.bearer_auth(t);
714    }
715
716    let res = req.send().await?;
717    if !res.status().is_success() {
718        return Ok(());
719    }
720
721    let data: serde_json::Value = res.json().await?;
722    // Handle both plain array response and {tasks: [...]} wrapper
723    let tasks = if let Some(arr) = data.as_array() {
724        arr.clone()
725    } else {
726        data["tasks"].as_array().cloned().unwrap_or_default()
727    };
728
729    tracing::info!("Found {} pending task(s)", tasks.len());
730
731    for task in tasks {
732        if let Some(id) = task["id"].as_str() {
733            let mut proc = processing.lock().await;
734            if !proc.contains(id) {
735                proc.insert(id.to_string());
736                drop(proc);
737
738                let task_id = id.to_string();
739                let client = client.clone();
740                let server = server.to_string();
741                let token = token.clone();
742                let worker_id = worker_id.to_string();
743                let auto_approve = *auto_approve;
744                let processing = processing.clone();
745
746                tokio::spawn(async move {
747                    if let Err(e) =
748                        handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
749                    {
750                        tracing::error!("Task {} failed: {}", task_id, e);
751                    }
752                    processing.lock().await.remove(&task_id);
753                });
754            }
755        }
756    }
757
758    Ok(())
759}
760
761#[allow(clippy::too_many_arguments)]
762async fn connect_stream(
763    client: &Client,
764    server: &str,
765    token: &Option<String>,
766    worker_id: &str,
767    name: &str,
768    codebases: &[String],
769    processing: &Arc<Mutex<HashSet<String>>>,
770    auto_approve: &AutoApprove,
771) -> Result<()> {
772    let url = format!(
773        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
774        server,
775        urlencoding::encode(name),
776        urlencoding::encode(worker_id)
777    );
778
779    let mut req = client
780        .get(&url)
781        .header("Accept", "text/event-stream")
782        .header("X-Worker-ID", worker_id)
783        .header("X-Agent-Name", name)
784        .header("X-Codebases", codebases.join(","));
785
786    if let Some(t) = token {
787        req = req.bearer_auth(t);
788    }
789
790    let res = req.send().await?;
791    if !res.status().is_success() {
792        anyhow::bail!("Failed to connect: {}", res.status());
793    }
794
795    tracing::info!("Connected to A2A server");
796
797    let mut stream = res.bytes_stream();
798    let mut buffer = String::new();
799    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
800    poll_interval.tick().await; // consume the initial immediate tick
801
802    loop {
803        tokio::select! {
804            chunk = stream.next() => {
805                match chunk {
806                    Some(Ok(chunk)) => {
807                        buffer.push_str(&String::from_utf8_lossy(&chunk));
808
809                        // Process SSE events
810                        while let Some(pos) = buffer.find("\n\n") {
811                            let event_str = buffer[..pos].to_string();
812                            buffer = buffer[pos + 2..].to_string();
813
814                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
815                                let data = data_line.trim_start_matches("data:").trim();
816                                if data == "[DONE]" || data.is_empty() {
817                                    continue;
818                                }
819
820                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
821                                    spawn_task_handler(
822                                        &task, client, server, token, worker_id,
823                                        processing, auto_approve,
824                                    ).await;
825                                }
826                            }
827                        }
828                    }
829                    Some(Err(e)) => {
830                        return Err(e.into());
831                    }
832                    None => {
833                        // Stream ended
834                        return Ok(());
835                    }
836                }
837            }
838            _ = poll_interval.tick() => {
839                // Periodic poll for pending tasks the SSE stream may have missed
840                if let Err(e) = poll_pending_tasks(
841                    client, server, token, worker_id, processing, auto_approve,
842                ).await {
843                    tracing::warn!("Periodic task poll failed: {}", e);
844                }
845            }
846        }
847    }
848}
849
850async fn spawn_task_handler(
851    task: &serde_json::Value,
852    client: &Client,
853    server: &str,
854    token: &Option<String>,
855    worker_id: &str,
856    processing: &Arc<Mutex<HashSet<String>>>,
857    auto_approve: &AutoApprove,
858) {
859    if let Some(id) = task
860        .get("task")
861        .and_then(|t| t["id"].as_str())
862        .or_else(|| task["id"].as_str())
863    {
864        let mut proc = processing.lock().await;
865        if !proc.contains(id) {
866            proc.insert(id.to_string());
867            drop(proc);
868
869            let task_id = id.to_string();
870            let task = task.clone();
871            let client = client.clone();
872            let server = server.to_string();
873            let token = token.clone();
874            let worker_id = worker_id.to_string();
875            let auto_approve = *auto_approve;
876            let processing_clone = processing.clone();
877
878            tokio::spawn(async move {
879                if let Err(e) =
880                    handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
881                {
882                    tracing::error!("Task {} failed: {}", task_id, e);
883                }
884                processing_clone.lock().await.remove(&task_id);
885            });
886        }
887    }
888}
889
890async fn poll_pending_tasks(
891    client: &Client,
892    server: &str,
893    token: &Option<String>,
894    worker_id: &str,
895    processing: &Arc<Mutex<HashSet<String>>>,
896    auto_approve: &AutoApprove,
897) -> Result<()> {
898    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
899    if let Some(t) = token {
900        req = req.bearer_auth(t);
901    }
902
903    let res = req.send().await?;
904    if !res.status().is_success() {
905        return Ok(());
906    }
907
908    let data: serde_json::Value = res.json().await?;
909    let tasks = if let Some(arr) = data.as_array() {
910        arr.clone()
911    } else {
912        data["tasks"].as_array().cloned().unwrap_or_default()
913    };
914
915    if !tasks.is_empty() {
916        tracing::debug!("Poll found {} pending task(s)", tasks.len());
917    }
918
919    for task in &tasks {
920        spawn_task_handler(
921            task,
922            client,
923            server,
924            token,
925            worker_id,
926            processing,
927            auto_approve,
928        )
929        .await;
930    }
931
932    Ok(())
933}
934
935async fn handle_task(
936    client: &Client,
937    server: &str,
938    token: &Option<String>,
939    worker_id: &str,
940    task: &serde_json::Value,
941    auto_approve: AutoApprove,
942) -> Result<()> {
943    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
944    let title = task_str(task, "title").unwrap_or("Untitled");
945
946    tracing::info!("Handling task: {} ({})", title, task_id);
947
948    // Claim the task
949    let mut req = client
950        .post(format!("{}/v1/worker/tasks/claim", server))
951        .header("X-Worker-ID", worker_id);
952    if let Some(t) = token {
953        req = req.bearer_auth(t);
954    }
955
956    let res = req
957        .json(&serde_json::json!({ "task_id": task_id }))
958        .send()
959        .await?;
960
961    if !res.status().is_success() {
962        let text = res.text().await?;
963        tracing::warn!("Failed to claim task: {}", text);
964        return Ok(());
965    }
966
967    tracing::info!("Claimed task: {}", task_id);
968
969    let metadata = task_metadata(task);
970    let resume_session_id = metadata
971        .get("resume_session_id")
972        .and_then(|v| v.as_str())
973        .map(|s| s.trim().to_string())
974        .filter(|s| !s.is_empty());
975    let complexity_hint = metadata_str(&metadata, &["complexity"]);
976    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
977        .map(|s| s.to_ascii_lowercase())
978        .or_else(|| {
979            complexity_hint.as_ref().map(|complexity| {
980                match complexity.to_ascii_lowercase().as_str() {
981                    "quick" => "fast".to_string(),
982                    "deep" => "heavy".to_string(),
983                    _ => "balanced".to_string(),
984                }
985            })
986        });
987    let worker_personality = metadata_str(
988        &metadata,
989        &["worker_personality", "personality", "agent_personality"],
990    );
991    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
992    let raw_model = task_str(task, "model_ref")
993        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
994        .or_else(|| task_str(task, "model"))
995        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
996    let selected_model = raw_model.map(model_ref_to_provider_model);
997
998    // Resume existing session when requested; fall back to a fresh session if missing.
999    let mut session = if let Some(ref sid) = resume_session_id {
1000        match Session::load(sid).await {
1001            Ok(existing) => {
1002                tracing::info!("Resuming session {} for task {}", sid, task_id);
1003                existing
1004            }
1005            Err(e) => {
1006                tracing::warn!(
1007                    "Could not load session {} for task {} ({}), starting a new session",
1008                    sid,
1009                    task_id,
1010                    e
1011                );
1012                Session::new().await?
1013            }
1014        }
1015    } else {
1016        Session::new().await?
1017    };
1018
1019    let agent_type = task_str(task, "agent_type")
1020        .or_else(|| task_str(task, "agent"))
1021        .unwrap_or("build");
1022    session.agent = agent_type.to_string();
1023
1024    if let Some(model) = selected_model.clone() {
1025        session.metadata.model = Some(model);
1026    }
1027
1028    let prompt = task_str(task, "prompt")
1029        .or_else(|| task_str(task, "description"))
1030        .unwrap_or(title);
1031
1032    tracing::info!("Executing prompt: {}", prompt);
1033
1034    // Set up output streaming to forward progress to the server
1035    let stream_client = client.clone();
1036    let stream_server = server.to_string();
1037    let stream_token = token.clone();
1038    let stream_worker_id = worker_id.to_string();
1039    let stream_task_id = task_id.to_string();
1040
1041    let output_callback = move |output: String| {
1042        let c = stream_client.clone();
1043        let s = stream_server.clone();
1044        let t = stream_token.clone();
1045        let w = stream_worker_id.clone();
1046        let tid = stream_task_id.clone();
1047        tokio::spawn(async move {
1048            let mut req = c
1049                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1050                .header("X-Worker-ID", &w);
1051            if let Some(tok) = &t {
1052                req = req.bearer_auth(tok);
1053            }
1054            let _ = req
1055                .json(&serde_json::json!({
1056                    "worker_id": w,
1057                    "output": output,
1058                }))
1059                .send()
1060                .await;
1061        });
1062    };
1063
1064    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1065    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1066        match execute_swarm_with_policy(
1067            &mut session,
1068            prompt,
1069            model_tier.as_deref(),
1070            selected_model,
1071            &metadata,
1072            complexity_hint.as_deref(),
1073            worker_personality.as_deref(),
1074            target_agent_name.as_deref(),
1075            Some(&output_callback),
1076        )
1077        .await
1078        {
1079            Ok((session_result, true)) => {
1080                tracing::info!("Swarm task completed successfully: {}", task_id);
1081                (
1082                    "completed",
1083                    Some(session_result.text),
1084                    None,
1085                    Some(session_result.session_id),
1086                )
1087            }
1088            Ok((session_result, false)) => {
1089                tracing::warn!("Swarm task completed with failures: {}", task_id);
1090                (
1091                    "failed",
1092                    Some(session_result.text),
1093                    Some("Swarm execution completed with failures".to_string()),
1094                    Some(session_result.session_id),
1095                )
1096            }
1097            Err(e) => {
1098                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1099                ("failed", None, Some(format!("Error: {}", e)), None)
1100            }
1101        }
1102    } else {
1103        match execute_session_with_policy(
1104            &mut session,
1105            prompt,
1106            auto_approve,
1107            model_tier.as_deref(),
1108            Some(&output_callback),
1109        )
1110        .await
1111        {
1112            Ok(session_result) => {
1113                tracing::info!("Task completed successfully: {}", task_id);
1114                (
1115                    "completed",
1116                    Some(session_result.text),
1117                    None,
1118                    Some(session_result.session_id),
1119                )
1120            }
1121            Err(e) => {
1122                tracing::error!("Task failed: {} - {}", task_id, e);
1123                ("failed", None, Some(format!("Error: {}", e)), None)
1124            }
1125        }
1126    };
1127
1128    // Release the task with full details
1129    let mut req = client
1130        .post(format!("{}/v1/worker/tasks/release", server))
1131        .header("X-Worker-ID", worker_id);
1132    if let Some(t) = token {
1133        req = req.bearer_auth(t);
1134    }
1135
1136    req.json(&serde_json::json!({
1137        "task_id": task_id,
1138        "status": status,
1139        "result": result,
1140        "error": error,
1141        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1142    }))
1143    .send()
1144    .await?;
1145
1146    tracing::info!("Task released: {} with status: {}", task_id, status);
1147    Ok(())
1148}
1149
1150async fn execute_swarm_with_policy<F>(
1151    session: &mut Session,
1152    prompt: &str,
1153    model_tier: Option<&str>,
1154    explicit_model: Option<String>,
1155    metadata: &serde_json::Map<String, serde_json::Value>,
1156    complexity_hint: Option<&str>,
1157    worker_personality: Option<&str>,
1158    target_agent_name: Option<&str>,
1159    output_callback: Option<&F>,
1160) -> Result<(crate::session::SessionResult, bool)>
1161where
1162    F: Fn(String),
1163{
1164    use crate::provider::{ContentPart, Message, Role};
1165
1166    session.add_message(Message {
1167        role: Role::User,
1168        content: vec![ContentPart::Text {
1169            text: prompt.to_string(),
1170        }],
1171    });
1172
1173    if session.title.is_none() {
1174        session.generate_title().await?;
1175    }
1176
1177    let strategy = parse_swarm_strategy(metadata);
1178    let max_subagents = metadata_usize(
1179        metadata,
1180        &["swarm_max_subagents", "max_subagents", "subagents"],
1181    )
1182    .unwrap_or(10)
1183    .clamp(1, 100);
1184    let max_steps_per_subagent = metadata_usize(
1185        metadata,
1186        &[
1187            "swarm_max_steps_per_subagent",
1188            "max_steps_per_subagent",
1189            "max_steps",
1190        ],
1191    )
1192    .unwrap_or(50)
1193    .clamp(1, 200);
1194    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1195        .unwrap_or(600)
1196        .clamp(30, 3600);
1197    let parallel_enabled =
1198        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1199
1200    let model = resolve_swarm_model(explicit_model, model_tier).await;
1201    if let Some(ref selected_model) = model {
1202        session.metadata.model = Some(selected_model.clone());
1203    }
1204
1205    if let Some(cb) = output_callback {
1206        cb(format!(
1207            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1208            complexity_hint.unwrap_or("standard"),
1209            model_tier.unwrap_or("balanced"),
1210            worker_personality.unwrap_or("auto"),
1211            target_agent_name.unwrap_or("auto")
1212        ));
1213        cb(format!(
1214            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1215            strategy,
1216            max_subagents,
1217            max_steps_per_subagent,
1218            timeout_secs,
1219            model_tier.unwrap_or("balanced")
1220        ));
1221    }
1222
1223    let swarm_config = SwarmConfig {
1224        max_subagents,
1225        max_steps_per_subagent,
1226        subagent_timeout_secs: timeout_secs,
1227        parallel_enabled,
1228        model,
1229        working_dir: session
1230            .metadata
1231            .directory
1232            .as_ref()
1233            .map(|p| p.to_string_lossy().to_string()),
1234        ..Default::default()
1235    };
1236
1237    let swarm_result = if output_callback.is_some() {
1238        let (event_tx, mut event_rx) = mpsc::channel(256);
1239        let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1240        let prompt_owned = prompt.to_string();
1241        let mut exec_handle =
1242            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1243
1244        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1245
1246        while final_result.is_none() {
1247            tokio::select! {
1248                maybe_event = event_rx.recv() => {
1249                    if let Some(event) = maybe_event {
1250                        if let Some(cb) = output_callback {
1251                            if let Some(line) = format_swarm_event_for_output(&event) {
1252                                cb(line);
1253                            }
1254                        }
1255                    }
1256                }
1257                join_result = &mut exec_handle => {
1258                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1259                    final_result = Some(joined?);
1260                }
1261            }
1262        }
1263
1264        while let Ok(event) = event_rx.try_recv() {
1265            if let Some(cb) = output_callback {
1266                if let Some(line) = format_swarm_event_for_output(&event) {
1267                    cb(line);
1268                }
1269            }
1270        }
1271
1272        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1273    } else {
1274        SwarmExecutor::new(swarm_config)
1275            .execute(prompt, strategy)
1276            .await?
1277    };
1278
1279    let final_text = if swarm_result.result.trim().is_empty() {
1280        if swarm_result.success {
1281            "Swarm completed without textual output.".to_string()
1282        } else {
1283            "Swarm finished with failures and no textual output.".to_string()
1284        }
1285    } else {
1286        swarm_result.result.clone()
1287    };
1288
1289    session.add_message(Message {
1290        role: Role::Assistant,
1291        content: vec![ContentPart::Text {
1292            text: final_text.clone(),
1293        }],
1294    });
1295    session.save().await?;
1296
1297    Ok((
1298        crate::session::SessionResult {
1299            text: final_text,
1300            session_id: session.id.clone(),
1301        },
1302        swarm_result.success,
1303    ))
1304}
1305
1306/// Execute a session with the given auto-approve policy
1307/// Optionally streams output chunks via the callback
1308async fn execute_session_with_policy<F>(
1309    session: &mut Session,
1310    prompt: &str,
1311    auto_approve: AutoApprove,
1312    model_tier: Option<&str>,
1313    output_callback: Option<&F>,
1314) -> Result<crate::session::SessionResult>
1315where
1316    F: Fn(String),
1317{
1318    use crate::provider::{
1319        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1320    };
1321    use std::sync::Arc;
1322
1323    // Load provider registry from Vault
1324    let registry = ProviderRegistry::from_vault().await?;
1325    let providers = registry.list();
1326    tracing::info!("Available providers: {:?}", providers);
1327
1328    if providers.is_empty() {
1329        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1330    }
1331
1332    // Parse model string
1333    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1334        let (prov, model) = parse_model_string(model_str);
1335        if prov.is_some() {
1336            (prov.map(|s| s.to_string()), model.to_string())
1337        } else if providers.contains(&model) {
1338            (Some(model.to_string()), String::new())
1339        } else {
1340            (None, model.to_string())
1341        }
1342    } else {
1343        (None, String::new())
1344    };
1345
1346    let provider_slice = providers.as_slice();
1347    let provider_requested_but_unavailable = provider_name
1348        .as_deref()
1349        .map(|p| !providers.contains(&p))
1350        .unwrap_or(false);
1351
1352    // Determine which provider to use, preferring explicit request first, then model tier.
1353    let selected_provider = provider_name
1354        .as_deref()
1355        .filter(|p| providers.contains(p))
1356        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1357
1358    let provider = registry
1359        .get(selected_provider)
1360        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1361
1362    // Add user message
1363    session.add_message(Message {
1364        role: Role::User,
1365        content: vec![ContentPart::Text {
1366            text: prompt.to_string(),
1367        }],
1368    });
1369
1370    // Generate title
1371    if session.title.is_none() {
1372        session.generate_title().await?;
1373    }
1374
1375    // Determine model. If a specific provider was requested but not available,
1376    // ignore that model id and fall back to the tier-based default model.
1377    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1378        model_id
1379    } else {
1380        default_model_for_provider(selected_provider, model_tier)
1381    };
1382
1383    // Create tool registry with filtering based on auto-approve policy
1384    let tool_registry =
1385        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1386    let tool_definitions = tool_registry.definitions();
1387
1388    let temperature = if model.starts_with("kimi-k2") {
1389        Some(1.0)
1390    } else {
1391        Some(0.7)
1392    };
1393
1394    tracing::info!(
1395        "Using model: {} via provider: {} (tier: {:?})",
1396        model,
1397        selected_provider,
1398        model_tier
1399    );
1400    tracing::info!(
1401        "Available tools: {} (auto_approve: {:?})",
1402        tool_definitions.len(),
1403        auto_approve
1404    );
1405
1406    // Build system prompt
1407    let cwd = std::env::var("PWD")
1408        .map(std::path::PathBuf::from)
1409        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1410    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1411
1412    let mut final_output = String::new();
1413    let max_steps = 50;
1414
1415    for step in 1..=max_steps {
1416        tracing::info!(step = step, "Agent step starting");
1417
1418        // Build messages with system prompt first
1419        let mut messages = vec![Message {
1420            role: Role::System,
1421            content: vec![ContentPart::Text {
1422                text: system_prompt.clone(),
1423            }],
1424        }];
1425        messages.extend(session.messages.clone());
1426
1427        let request = CompletionRequest {
1428            messages,
1429            tools: tool_definitions.clone(),
1430            model: model.clone(),
1431            temperature,
1432            top_p: None,
1433            max_tokens: Some(8192),
1434            stop: Vec::new(),
1435        };
1436
1437        let response = provider.complete(request).await?;
1438
1439        crate::telemetry::TOKEN_USAGE.record_model_usage(
1440            &model,
1441            response.usage.prompt_tokens as u64,
1442            response.usage.completion_tokens as u64,
1443        );
1444
1445        // Extract tool calls
1446        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1447            .message
1448            .content
1449            .iter()
1450            .filter_map(|part| {
1451                if let ContentPart::ToolCall {
1452                    id,
1453                    name,
1454                    arguments,
1455                } = part
1456                {
1457                    let args: serde_json::Value =
1458                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1459                    Some((id.clone(), name.clone(), args))
1460                } else {
1461                    None
1462                }
1463            })
1464            .collect();
1465
1466        // Collect text output and stream it
1467        for part in &response.message.content {
1468            if let ContentPart::Text { text } = part {
1469                if !text.is_empty() {
1470                    final_output.push_str(text);
1471                    final_output.push('\n');
1472                    if let Some(cb) = output_callback {
1473                        cb(text.clone());
1474                    }
1475                }
1476            }
1477        }
1478
1479        // If no tool calls, we're done
1480        if tool_calls.is_empty() {
1481            session.add_message(response.message.clone());
1482            break;
1483        }
1484
1485        session.add_message(response.message.clone());
1486
1487        tracing::info!(
1488            step = step,
1489            num_tools = tool_calls.len(),
1490            "Executing tool calls"
1491        );
1492
1493        // Execute each tool call
1494        for (tool_id, tool_name, tool_input) in tool_calls {
1495            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1496
1497            // Stream tool start event
1498            if let Some(cb) = output_callback {
1499                cb(format!("[tool:start:{}]", tool_name));
1500            }
1501
1502            // Check if tool is allowed based on auto-approve policy
1503            if !is_tool_allowed(&tool_name, auto_approve) {
1504                let msg = format!(
1505                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1506                    tool_name, auto_approve
1507                );
1508                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1509                session.add_message(Message {
1510                    role: Role::Tool,
1511                    content: vec![ContentPart::ToolResult {
1512                        tool_call_id: tool_id,
1513                        content: msg,
1514                    }],
1515                });
1516                continue;
1517            }
1518
1519            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1520                let exec_result: Result<crate::tool::ToolResult> =
1521                    tool.execute(tool_input.clone()).await;
1522                match exec_result {
1523                    Ok(result) => {
1524                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1525                        if let Some(cb) = output_callback {
1526                            let status = if result.success { "ok" } else { "err" };
1527                            cb(format!(
1528                                "[tool:{}:{}] {}",
1529                                tool_name,
1530                                status,
1531                                &result.output[..result.output.len().min(500)]
1532                            ));
1533                        }
1534                        result.output
1535                    }
1536                    Err(e) => {
1537                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1538                        if let Some(cb) = output_callback {
1539                            cb(format!("[tool:{}:err] {}", tool_name, e));
1540                        }
1541                        format!("Error: {}", e)
1542                    }
1543                }
1544            } else {
1545                tracing::warn!(tool = %tool_name, "Tool not found");
1546                format!("Error: Unknown tool '{}'", tool_name)
1547            };
1548
1549            session.add_message(Message {
1550                role: Role::Tool,
1551                content: vec![ContentPart::ToolResult {
1552                    tool_call_id: tool_id,
1553                    content,
1554                }],
1555            });
1556        }
1557    }
1558
1559    session.save().await?;
1560
1561    Ok(crate::session::SessionResult {
1562        text: final_output.trim().to_string(),
1563        session_id: session.id.clone(),
1564    })
1565}
1566
1567/// Check if a tool is allowed based on the auto-approve policy
1568fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1569    match auto_approve {
1570        AutoApprove::All => true,
1571        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1572    }
1573}
1574
1575/// Check if a tool is considered "safe" (read-only)
1576fn is_safe_tool(tool_name: &str) -> bool {
1577    let safe_tools = [
1578        "read",
1579        "list",
1580        "glob",
1581        "grep",
1582        "codesearch",
1583        "lsp",
1584        "webfetch",
1585        "websearch",
1586        "todo_read",
1587        "skill",
1588    ];
1589    safe_tools.contains(&tool_name)
1590}
1591
1592/// Create a filtered tool registry based on the auto-approve policy
1593fn create_filtered_registry(
1594    provider: Arc<dyn crate::provider::Provider>,
1595    model: String,
1596    auto_approve: AutoApprove,
1597) -> crate::tool::ToolRegistry {
1598    use crate::tool::*;
1599
1600    let mut registry = ToolRegistry::new();
1601
1602    // Always add safe tools
1603    registry.register(Arc::new(file::ReadTool::new()));
1604    registry.register(Arc::new(file::ListTool::new()));
1605    registry.register(Arc::new(file::GlobTool::new()));
1606    registry.register(Arc::new(search::GrepTool::new()));
1607    registry.register(Arc::new(lsp::LspTool::new()));
1608    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1609    registry.register(Arc::new(websearch::WebSearchTool::new()));
1610    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1611    registry.register(Arc::new(todo::TodoReadTool::new()));
1612    registry.register(Arc::new(skill::SkillTool::new()));
1613
1614    // Add potentially dangerous tools only if auto_approve is All
1615    if matches!(auto_approve, AutoApprove::All) {
1616        registry.register(Arc::new(file::WriteTool::new()));
1617        registry.register(Arc::new(edit::EditTool::new()));
1618        registry.register(Arc::new(bash::BashTool::new()));
1619        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1620        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1621        registry.register(Arc::new(todo::TodoWriteTool::new()));
1622        registry.register(Arc::new(task::TaskTool::new()));
1623        registry.register(Arc::new(plan::PlanEnterTool::new()));
1624        registry.register(Arc::new(plan::PlanExitTool::new()));
1625        registry.register(Arc::new(rlm::RlmTool::new()));
1626        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1627        registry.register(Arc::new(prd::PrdTool::new()));
1628        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1629        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1630        registry.register(Arc::new(undo::UndoTool));
1631    }
1632
1633    registry.register(Arc::new(invalid::InvalidTool::new()));
1634
1635    registry
1636}
1637
1638/// Start the heartbeat background task
1639/// Returns a JoinHandle that can be used to cancel the heartbeat
1640fn start_heartbeat(
1641    client: Client,
1642    server: String,
1643    token: Option<String>,
1644    heartbeat_state: HeartbeatState,
1645    processing: Arc<Mutex<HashSet<String>>>,
1646    cognition_config: CognitionHeartbeatConfig,
1647) -> JoinHandle<()> {
1648    tokio::spawn(async move {
1649        let mut consecutive_failures = 0u32;
1650        const MAX_FAILURES: u32 = 3;
1651        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1652        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1653        let mut cognition_payload_disabled_until: Option<Instant> = None;
1654
1655        let mut interval =
1656            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1657        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1658
1659        loop {
1660            interval.tick().await;
1661
1662            // Update task count from processing set
1663            let active_count = processing.lock().await.len();
1664            heartbeat_state.set_task_count(active_count).await;
1665
1666            // Determine status based on active tasks
1667            let status = if active_count > 0 {
1668                WorkerStatus::Processing
1669            } else {
1670                WorkerStatus::Idle
1671            };
1672            heartbeat_state.set_status(status).await;
1673
1674            // Send heartbeat
1675            let url = format!(
1676                "{}/v1/opencode/workers/{}/heartbeat",
1677                server, heartbeat_state.worker_id
1678            );
1679            let mut req = client.post(&url);
1680
1681            if let Some(ref t) = token {
1682                req = req.bearer_auth(t);
1683            }
1684
1685            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1686            let base_payload = serde_json::json!({
1687                "worker_id": &heartbeat_state.worker_id,
1688                "agent_name": &heartbeat_state.agent_name,
1689                "status": status_str,
1690                "active_task_count": active_count,
1691            });
1692            let mut payload = base_payload.clone();
1693            let mut included_cognition_payload = false;
1694            let cognition_payload_allowed = cognition_payload_disabled_until
1695                .map(|until| Instant::now() >= until)
1696                .unwrap_or(true);
1697
1698            if cognition_config.enabled
1699                && cognition_payload_allowed
1700                && let Some(cognition_payload) =
1701                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1702                && let Some(obj) = payload.as_object_mut()
1703            {
1704                obj.insert("cognition".to_string(), cognition_payload);
1705                included_cognition_payload = true;
1706            }
1707
1708            match req.json(&payload).send().await {
1709                Ok(res) => {
1710                    if res.status().is_success() {
1711                        consecutive_failures = 0;
1712                        tracing::debug!(
1713                            worker_id = %heartbeat_state.worker_id,
1714                            status = status_str,
1715                            active_tasks = active_count,
1716                            "Heartbeat sent successfully"
1717                        );
1718                    } else if included_cognition_payload && res.status().is_client_error() {
1719                        tracing::warn!(
1720                            worker_id = %heartbeat_state.worker_id,
1721                            status = %res.status(),
1722                            "Heartbeat cognition payload rejected, retrying without cognition payload"
1723                        );
1724
1725                        let mut retry_req = client.post(&url);
1726                        if let Some(ref t) = token {
1727                            retry_req = retry_req.bearer_auth(t);
1728                        }
1729
1730                        match retry_req.json(&base_payload).send().await {
1731                            Ok(retry_res) if retry_res.status().is_success() => {
1732                                cognition_payload_disabled_until = Some(
1733                                    Instant::now()
1734                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1735                                );
1736                                consecutive_failures = 0;
1737                                tracing::warn!(
1738                                    worker_id = %heartbeat_state.worker_id,
1739                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1740                                    "Paused cognition heartbeat payload after schema rejection"
1741                                );
1742                            }
1743                            Ok(retry_res) => {
1744                                consecutive_failures += 1;
1745                                tracing::warn!(
1746                                    worker_id = %heartbeat_state.worker_id,
1747                                    status = %retry_res.status(),
1748                                    failures = consecutive_failures,
1749                                    "Heartbeat failed even after retry without cognition payload"
1750                                );
1751                            }
1752                            Err(e) => {
1753                                consecutive_failures += 1;
1754                                tracing::warn!(
1755                                    worker_id = %heartbeat_state.worker_id,
1756                                    error = %e,
1757                                    failures = consecutive_failures,
1758                                    "Heartbeat retry without cognition payload failed"
1759                                );
1760                            }
1761                        }
1762                    } else {
1763                        consecutive_failures += 1;
1764                        tracing::warn!(
1765                            worker_id = %heartbeat_state.worker_id,
1766                            status = %res.status(),
1767                            failures = consecutive_failures,
1768                            "Heartbeat failed"
1769                        );
1770                    }
1771                }
1772                Err(e) => {
1773                    consecutive_failures += 1;
1774                    tracing::warn!(
1775                        worker_id = %heartbeat_state.worker_id,
1776                        error = %e,
1777                        failures = consecutive_failures,
1778                        "Heartbeat request failed"
1779                    );
1780                }
1781            }
1782
1783            // Log error after 3 consecutive failures but do not terminate
1784            if consecutive_failures >= MAX_FAILURES {
1785                tracing::error!(
1786                    worker_id = %heartbeat_state.worker_id,
1787                    failures = consecutive_failures,
1788                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1789                    MAX_FAILURES
1790                );
1791                // Reset counter to avoid spamming error logs
1792                consecutive_failures = 0;
1793            }
1794        }
1795    })
1796}
1797
1798async fn fetch_cognition_heartbeat_payload(
1799    client: &Client,
1800    config: &CognitionHeartbeatConfig,
1801) -> Option<serde_json::Value> {
1802    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1803    let status_res = tokio::time::timeout(
1804        Duration::from_millis(config.request_timeout_ms),
1805        client.get(status_url).send(),
1806    )
1807    .await
1808    .ok()?
1809    .ok()?;
1810
1811    if !status_res.status().is_success() {
1812        return None;
1813    }
1814
1815    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1816    let mut payload = serde_json::json!({
1817        "running": status.running,
1818        "last_tick_at": status.last_tick_at,
1819        "active_persona_count": status.active_persona_count,
1820        "events_buffered": status.events_buffered,
1821        "snapshots_buffered": status.snapshots_buffered,
1822        "loop_interval_ms": status.loop_interval_ms,
1823    });
1824
1825    if config.include_thought_summary {
1826        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1827        let snapshot_res = tokio::time::timeout(
1828            Duration::from_millis(config.request_timeout_ms),
1829            client.get(snapshot_url).send(),
1830        )
1831        .await
1832        .ok()
1833        .and_then(Result::ok);
1834
1835        if let Some(snapshot_res) = snapshot_res
1836            && snapshot_res.status().is_success()
1837            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1838            && let Some(obj) = payload.as_object_mut()
1839        {
1840            obj.insert(
1841                "latest_snapshot_at".to_string(),
1842                serde_json::Value::String(snapshot.generated_at),
1843            );
1844            obj.insert(
1845                "latest_thought".to_string(),
1846                serde_json::Value::String(trim_for_heartbeat(
1847                    &snapshot.summary,
1848                    config.summary_max_chars,
1849                )),
1850            );
1851            if let Some(model) = snapshot
1852                .metadata
1853                .get("model")
1854                .and_then(serde_json::Value::as_str)
1855            {
1856                obj.insert(
1857                    "latest_thought_model".to_string(),
1858                    serde_json::Value::String(model.to_string()),
1859                );
1860            }
1861            if let Some(source) = snapshot
1862                .metadata
1863                .get("source")
1864                .and_then(serde_json::Value::as_str)
1865            {
1866                obj.insert(
1867                    "latest_thought_source".to_string(),
1868                    serde_json::Value::String(source.to_string()),
1869                );
1870            }
1871        }
1872    }
1873
1874    Some(payload)
1875}
1876
1877fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1878    if input.chars().count() <= max_chars {
1879        return input.trim().to_string();
1880    }
1881
1882    let mut trimmed = String::with_capacity(max_chars + 3);
1883    for ch in input.chars().take(max_chars) {
1884        trimmed.push(ch);
1885    }
1886    trimmed.push_str("...");
1887    trimmed.trim().to_string()
1888}
1889
1890fn env_bool(name: &str, default: bool) -> bool {
1891    std::env::var(name)
1892        .ok()
1893        .and_then(|v| match v.to_ascii_lowercase().as_str() {
1894            "1" | "true" | "yes" | "on" => Some(true),
1895            "0" | "false" | "no" | "off" => Some(false),
1896            _ => None,
1897        })
1898        .unwrap_or(default)
1899}
1900
1901fn env_usize(name: &str, default: usize) -> usize {
1902    std::env::var(name)
1903        .ok()
1904        .and_then(|v| v.parse::<usize>().ok())
1905        .unwrap_or(default)
1906}
1907
1908fn env_u64(name: &str, default: u64) -> u64 {
1909    std::env::var(name)
1910        .ok()
1911        .and_then(|v| v.parse::<u64>().ok())
1912        .unwrap_or(default)
1913}