Skip to main content

codetether_agent/a2a/
worker.rs

1//! A2A Worker - connects to an A2A server to process tasks
2
3use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{Mutex, mpsc};
17use tokio::task::JoinHandle;
18use tokio::time::Instant;
19
20/// Worker status for heartbeat
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum WorkerStatus {
23    Idle,
24    Processing,
25}
26
27impl WorkerStatus {
28    fn as_str(&self) -> &'static str {
29        match self {
30            WorkerStatus::Idle => "idle",
31            WorkerStatus::Processing => "processing",
32        }
33    }
34}
35
36/// Heartbeat state shared between the heartbeat task and the main worker
37#[derive(Clone)]
38struct HeartbeatState {
39    worker_id: String,
40    agent_name: String,
41    status: Arc<Mutex<WorkerStatus>>,
42    active_task_count: Arc<Mutex<usize>>,
43}
44
45impl HeartbeatState {
46    fn new(worker_id: String, agent_name: String) -> Self {
47        Self {
48            worker_id,
49            agent_name,
50            status: Arc::new(Mutex::new(WorkerStatus::Idle)),
51            active_task_count: Arc::new(Mutex::new(0)),
52        }
53    }
54
55    async fn set_status(&self, status: WorkerStatus) {
56        *self.status.lock().await = status;
57    }
58
59    async fn set_task_count(&self, count: usize) {
60        *self.active_task_count.lock().await = count;
61    }
62}
63
64#[derive(Clone, Debug)]
65struct CognitionHeartbeatConfig {
66    enabled: bool,
67    source_base_url: String,
68    include_thought_summary: bool,
69    summary_max_chars: usize,
70    request_timeout_ms: u64,
71}
72
73impl CognitionHeartbeatConfig {
74    fn from_env() -> Self {
75        let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
76            .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
77            .trim_end_matches('/')
78            .to_string();
79
80        Self {
81            enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
82            source_base_url,
83            include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
84            summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
85                .max(120),
86            request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
87        }
88    }
89}
90
91#[derive(Debug, Deserialize)]
92struct CognitionStatusSnapshot {
93    running: bool,
94    #[serde(default)]
95    last_tick_at: Option<String>,
96    #[serde(default)]
97    active_persona_count: usize,
98    #[serde(default)]
99    events_buffered: usize,
100    #[serde(default)]
101    snapshots_buffered: usize,
102    #[serde(default)]
103    loop_interval_ms: u64,
104}
105
106#[derive(Debug, Deserialize)]
107struct CognitionLatestSnapshot {
108    generated_at: String,
109    summary: String,
110    #[serde(default)]
111    metadata: HashMap<String, serde_json::Value>,
112}
113
114/// Run the A2A worker
115pub async fn run(args: A2aArgs) -> Result<()> {
116    let server = args.server.trim_end_matches('/');
117    let name = args
118        .name
119        .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
120    let worker_id = generate_worker_id();
121
122    let codebases: Vec<String> = args
123        .codebases
124        .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
125        .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
126
127    tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
128    tracing::info!("Server: {}", server);
129    tracing::info!("Codebases: {:?}", codebases);
130
131    let client = Client::new();
132    let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
133    let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
134    if cognition_heartbeat.enabled {
135        tracing::info!(
136            source = %cognition_heartbeat.source_base_url,
137            include_thoughts = cognition_heartbeat.include_thought_summary,
138            max_chars = cognition_heartbeat.summary_max_chars,
139            timeout_ms = cognition_heartbeat.request_timeout_ms,
140            "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
141        );
142    } else {
143        tracing::warn!(
144            "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
145        );
146    }
147
148    let auto_approve = match args.auto_approve.as_str() {
149        "all" => AutoApprove::All,
150        "safe" => AutoApprove::Safe,
151        _ => AutoApprove::None,
152    };
153
154    // Create heartbeat state
155    let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
156
157    // Register worker
158    register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
159
160    // Fetch pending tasks
161    fetch_pending_tasks(
162        &client,
163        server,
164        &args.token,
165        &worker_id,
166        &processing,
167        &auto_approve,
168    )
169    .await?;
170
171    // Connect to SSE stream
172    loop {
173        // Re-register worker on each reconnection to report updated models/capabilities
174        if let Err(e) =
175            register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
176        {
177            tracing::warn!("Failed to re-register worker on reconnection: {}", e);
178        }
179
180        // Start heartbeat task for this connection
181        let heartbeat_handle = start_heartbeat(
182            client.clone(),
183            server.to_string(),
184            args.token.clone(),
185            heartbeat_state.clone(),
186            processing.clone(),
187            cognition_heartbeat.clone(),
188        );
189
190        match connect_stream(
191            &client,
192            server,
193            &args.token,
194            &worker_id,
195            &name,
196            &codebases,
197            &processing,
198            &auto_approve,
199        )
200        .await
201        {
202            Ok(()) => {
203                tracing::warn!("Stream ended, reconnecting...");
204            }
205            Err(e) => {
206                tracing::error!("Stream error: {}, reconnecting...", e);
207            }
208        }
209
210        // Cancel heartbeat on disconnection
211        heartbeat_handle.abort();
212        tracing::debug!("Heartbeat cancelled for reconnection");
213
214        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
215    }
216}
217
218fn generate_worker_id() -> String {
219    format!(
220        "wrk_{}_{:x}",
221        chrono::Utc::now().timestamp(),
222        rand::random::<u64>()
223    )
224}
225
226#[derive(Debug, Clone, Copy)]
227enum AutoApprove {
228    All,
229    Safe,
230    None,
231}
232
233/// Capabilities of the codetether-agent worker
234const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
235
236fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
237    task.get("task")
238        .and_then(|t| t.get(key))
239        .or_else(|| task.get(key))
240}
241
242fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
243    task_value(task, key).and_then(|v| v.as_str())
244}
245
246fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
247    task_value(task, "metadata")
248        .and_then(|m| m.as_object())
249        .cloned()
250        .unwrap_or_default()
251}
252
253fn model_ref_to_provider_model(model: &str) -> String {
254    // Convert "provider:model" to "provider/model" format, but only if
255    // there is no '/' already present. Model IDs like "amazon.nova-micro-v1:0"
256    // contain colons as version separators and must NOT be converted.
257    if !model.contains('/') && model.contains(':') {
258        model.replacen(':', "/", 1)
259    } else {
260        model.to_string()
261    }
262}
263
264fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
265    match model_tier.unwrap_or("balanced") {
266        "fast" | "quick" => &[
267            "google",
268            "openai",
269            "moonshotai",
270            "zhipuai",
271            "anthropic",
272            "openrouter",
273            "novita",
274        ],
275        "heavy" | "deep" => &[
276            "anthropic",
277            "openai",
278            "google",
279            "moonshotai",
280            "zhipuai",
281            "openrouter",
282            "novita",
283        ],
284        _ => &[
285            "openai",
286            "anthropic",
287            "google",
288            "moonshotai",
289            "zhipuai",
290            "openrouter",
291            "novita",
292        ],
293    }
294}
295
296fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
297    for preferred in provider_preferences_for_tier(model_tier) {
298        if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
299            return found;
300        }
301    }
302    if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
303        return found;
304    }
305    providers[0]
306}
307
308fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
309    match model_tier.unwrap_or("balanced") {
310        "fast" | "quick" => match provider {
311            "moonshotai" => "kimi-k2.5".to_string(),
312            "anthropic" => "claude-haiku-4-5".to_string(),
313            "openai" => "gpt-4o-mini".to_string(),
314            "google" => "gemini-2.5-flash".to_string(),
315            "zhipuai" => "glm-4.7".to_string(),
316            "openrouter" => "z-ai/glm-4.7".to_string(),
317            "novita" => "qwen/qwen3-coder-next".to_string(),
318            _ => "glm-4.7".to_string(),
319        },
320        "heavy" | "deep" => match provider {
321            "moonshotai" => "kimi-k2.5".to_string(),
322            "anthropic" => "claude-sonnet-4-20250514".to_string(),
323            "openai" => "o3".to_string(),
324            "google" => "gemini-2.5-pro".to_string(),
325            "zhipuai" => "glm-4.7".to_string(),
326            "openrouter" => "z-ai/glm-4.7".to_string(),
327            "novita" => "qwen/qwen3-coder-next".to_string(),
328            _ => "glm-4.7".to_string(),
329        },
330        _ => match provider {
331            "moonshotai" => "kimi-k2.5".to_string(),
332            "anthropic" => "claude-sonnet-4-20250514".to_string(),
333            "openai" => "gpt-4o".to_string(),
334            "google" => "gemini-2.5-pro".to_string(),
335            "zhipuai" => "glm-4.7".to_string(),
336            "openrouter" => "z-ai/glm-4.7".to_string(),
337            "novita" => "qwen/qwen3-coder-next".to_string(),
338            _ => "glm-4.7".to_string(),
339        },
340    }
341}
342
343fn is_swarm_agent(agent_type: &str) -> bool {
344    matches!(
345        agent_type.trim().to_ascii_lowercase().as_str(),
346        "swarm" | "parallel" | "multi-agent"
347    )
348}
349
350fn metadata_lookup<'a>(
351    metadata: &'a serde_json::Map<String, serde_json::Value>,
352    key: &str,
353) -> Option<&'a serde_json::Value> {
354    metadata
355        .get(key)
356        .or_else(|| {
357            metadata
358                .get("routing")
359                .and_then(|v| v.as_object())
360                .and_then(|obj| obj.get(key))
361        })
362        .or_else(|| {
363            metadata
364                .get("swarm")
365                .and_then(|v| v.as_object())
366                .and_then(|obj| obj.get(key))
367        })
368}
369
370fn metadata_str(
371    metadata: &serde_json::Map<String, serde_json::Value>,
372    keys: &[&str],
373) -> Option<String> {
374    for key in keys {
375        if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
376            let trimmed = value.trim();
377            if !trimmed.is_empty() {
378                return Some(trimmed.to_string());
379            }
380        }
381    }
382    None
383}
384
385fn metadata_usize(
386    metadata: &serde_json::Map<String, serde_json::Value>,
387    keys: &[&str],
388) -> Option<usize> {
389    for key in keys {
390        if let Some(value) = metadata_lookup(metadata, key) {
391            if let Some(v) = value.as_u64() {
392                return usize::try_from(v).ok();
393            }
394            if let Some(v) = value.as_i64() {
395                if v >= 0 {
396                    return usize::try_from(v as u64).ok();
397                }
398            }
399            if let Some(v) = value.as_str() {
400                if let Ok(parsed) = v.trim().parse::<usize>() {
401                    return Some(parsed);
402                }
403            }
404        }
405    }
406    None
407}
408
409fn metadata_u64(
410    metadata: &serde_json::Map<String, serde_json::Value>,
411    keys: &[&str],
412) -> Option<u64> {
413    for key in keys {
414        if let Some(value) = metadata_lookup(metadata, key) {
415            if let Some(v) = value.as_u64() {
416                return Some(v);
417            }
418            if let Some(v) = value.as_i64() {
419                if v >= 0 {
420                    return Some(v as u64);
421                }
422            }
423            if let Some(v) = value.as_str() {
424                if let Ok(parsed) = v.trim().parse::<u64>() {
425                    return Some(parsed);
426                }
427            }
428        }
429    }
430    None
431}
432
433fn metadata_bool(
434    metadata: &serde_json::Map<String, serde_json::Value>,
435    keys: &[&str],
436) -> Option<bool> {
437    for key in keys {
438        if let Some(value) = metadata_lookup(metadata, key) {
439            if let Some(v) = value.as_bool() {
440                return Some(v);
441            }
442            if let Some(v) = value.as_str() {
443                match v.trim().to_ascii_lowercase().as_str() {
444                    "1" | "true" | "yes" | "on" => return Some(true),
445                    "0" | "false" | "no" | "off" => return Some(false),
446                    _ => {}
447                }
448            }
449        }
450    }
451    None
452}
453
454fn parse_swarm_strategy(
455    metadata: &serde_json::Map<String, serde_json::Value>,
456) -> DecompositionStrategy {
457    match metadata_str(
458        metadata,
459        &[
460            "decomposition_strategy",
461            "swarm_strategy",
462            "strategy",
463            "swarm_decomposition",
464        ],
465    )
466    .as_deref()
467    .map(|s| s.to_ascii_lowercase())
468    .as_deref()
469    {
470        Some("none") | Some("single") => DecompositionStrategy::None,
471        Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
472        Some("data") | Some("by_data") => DecompositionStrategy::ByData,
473        Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
474        _ => DecompositionStrategy::Automatic,
475    }
476}
477
478async fn resolve_swarm_model(
479    explicit_model: Option<String>,
480    model_tier: Option<&str>,
481) -> Option<String> {
482    if let Some(model) = explicit_model {
483        if !model.trim().is_empty() {
484            return Some(model);
485        }
486    }
487
488    let registry = ProviderRegistry::from_vault().await.ok()?;
489    let providers = registry.list();
490    if providers.is_empty() {
491        return None;
492    }
493    let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
494    let model = default_model_for_provider(provider, model_tier);
495    Some(format!("{}/{}", provider, model))
496}
497
498fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
499    match event {
500        SwarmEvent::Started {
501            task,
502            total_subtasks,
503        } => Some(format!(
504            "[swarm] started task={} planned_subtasks={}",
505            task, total_subtasks
506        )),
507        SwarmEvent::StageComplete {
508            stage,
509            completed,
510            failed,
511        } => Some(format!(
512            "[swarm] stage={} completed={} failed={}",
513            stage, completed, failed
514        )),
515        SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
516            "[swarm] subtask id={} status={}",
517            &id.chars().take(8).collect::<String>(),
518            format!("{status:?}").to_ascii_lowercase()
519        )),
520        SwarmEvent::AgentToolCall {
521            subtask_id,
522            tool_name,
523        } => Some(format!(
524            "[swarm] subtask id={} tool={}",
525            &subtask_id.chars().take(8).collect::<String>(),
526            tool_name
527        )),
528        SwarmEvent::AgentError { subtask_id, error } => Some(format!(
529            "[swarm] subtask id={} error={}",
530            &subtask_id.chars().take(8).collect::<String>(),
531            error
532        )),
533        SwarmEvent::Complete { success, stats } => Some(format!(
534            "[swarm] complete success={} subtasks={} speedup={:.2}",
535            success,
536            stats.subagents_completed + stats.subagents_failed,
537            stats.speedup_factor
538        )),
539        SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
540        _ => None,
541    }
542}
543
544async fn register_worker(
545    client: &Client,
546    server: &str,
547    token: &Option<String>,
548    worker_id: &str,
549    name: &str,
550    codebases: &[String],
551) -> Result<()> {
552    // Load ProviderRegistry and collect available models
553    let models = match load_provider_models().await {
554        Ok(m) => m,
555        Err(e) => {
556            tracing::warn!(
557                "Failed to load provider models: {}, proceeding without model info",
558                e
559            );
560            HashMap::new()
561        }
562    };
563
564    // Register via the workers/register endpoint
565    let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
566
567    if let Some(t) = token {
568        req = req.bearer_auth(t);
569    }
570
571    // Flatten models HashMap into array of {id, name, provider, provider_id} objects
572    // matching the format expected by the A2A server's /models and /workers endpoints
573    let models_array: Vec<serde_json::Value> = models
574        .iter()
575        .flat_map(|(provider, model_ids)| {
576            model_ids.iter().map(move |model_id| {
577                serde_json::json!({
578                    "id": format!("{}/{}", provider, model_id),
579                    "name": model_id,
580                    "provider": provider,
581                    "provider_id": provider,
582                })
583            })
584        })
585        .collect();
586
587    tracing::info!(
588        "Registering worker with {} models from {} providers",
589        models_array.len(),
590        models.len()
591    );
592
593    let hostname = std::env::var("HOSTNAME")
594        .or_else(|_| std::env::var("COMPUTERNAME"))
595        .unwrap_or_else(|_| "unknown".to_string());
596
597    let res = req
598        .json(&serde_json::json!({
599            "worker_id": worker_id,
600            "name": name,
601            "capabilities": WORKER_CAPABILITIES,
602            "hostname": hostname,
603            "models": models_array,
604            "codebases": codebases,
605        }))
606        .send()
607        .await?;
608
609    if res.status().is_success() {
610        tracing::info!("Worker registered successfully");
611    } else {
612        tracing::warn!("Failed to register worker: {}", res.status());
613    }
614
615    Ok(())
616}
617
618/// Load ProviderRegistry and collect all available models grouped by provider.
619/// Tries Vault first, then falls back to config/env vars if Vault is unreachable.
620async fn load_provider_models() -> Result<HashMap<String, Vec<String>>> {
621    // Try Vault first
622    let registry = match ProviderRegistry::from_vault().await {
623        Ok(r) if !r.list().is_empty() => {
624            tracing::info!("Loaded {} providers from Vault", r.list().len());
625            r
626        }
627        Ok(_) => {
628            tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
629            fallback_registry().await?
630        }
631        Err(e) => {
632            tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
633            fallback_registry().await?
634        }
635    };
636
637    let mut models_by_provider: HashMap<String, Vec<String>> = HashMap::new();
638
639    for provider_name in registry.list() {
640        if let Some(provider) = registry.get(provider_name) {
641            match provider.list_models().await {
642                Ok(models) => {
643                    let model_ids: Vec<String> = models.into_iter().map(|m| m.id).collect();
644                    if !model_ids.is_empty() {
645                        tracing::debug!("Provider {}: {} models", provider_name, model_ids.len());
646                        models_by_provider.insert(provider_name.to_string(), model_ids);
647                    }
648                }
649                Err(e) => {
650                    tracing::debug!("Failed to list models for {}: {}", provider_name, e);
651                }
652            }
653        }
654    }
655
656    // If we still have 0, try the models.dev catalog as last resort
657    // NOTE: We list ALL models from the catalog (not just ones with verified API keys)
658    // because the worker should advertise what it can handle. The server handles routing.
659    if models_by_provider.is_empty() {
660        tracing::info!(
661            "No authenticated providers found, fetching models.dev catalog (all providers)"
662        );
663        if let Ok(catalog) = crate::provider::models::ModelCatalog::fetch().await {
664            // Use all_providers_with_models() to get every provider+model from catalog
665            // regardless of API key availability (Vault may be down)
666            for (provider_id, provider_info) in catalog.all_providers() {
667                let model_ids: Vec<String> = provider_info
668                    .models
669                    .values()
670                    .map(|m| m.id.clone())
671                    .collect();
672                if !model_ids.is_empty() {
673                    tracing::debug!(
674                        "Catalog provider {}: {} models",
675                        provider_id,
676                        model_ids.len()
677                    );
678                    models_by_provider.insert(provider_id.clone(), model_ids);
679                }
680            }
681            tracing::info!(
682                "Loaded {} providers with {} total models from catalog",
683                models_by_provider.len(),
684                models_by_provider.values().map(|v| v.len()).sum::<usize>()
685            );
686        }
687    }
688
689    Ok(models_by_provider)
690}
691
692/// Fallback: build a ProviderRegistry from config file + environment variables
693async fn fallback_registry() -> Result<ProviderRegistry> {
694    let config = crate::config::Config::load().await.unwrap_or_default();
695    ProviderRegistry::from_config(&config).await
696}
697
698async fn fetch_pending_tasks(
699    client: &Client,
700    server: &str,
701    token: &Option<String>,
702    worker_id: &str,
703    processing: &Arc<Mutex<HashSet<String>>>,
704    auto_approve: &AutoApprove,
705) -> Result<()> {
706    tracing::info!("Checking for pending tasks...");
707
708    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
709    if let Some(t) = token {
710        req = req.bearer_auth(t);
711    }
712
713    let res = req.send().await?;
714    if !res.status().is_success() {
715        return Ok(());
716    }
717
718    let data: serde_json::Value = res.json().await?;
719    // Handle both plain array response and {tasks: [...]} wrapper
720    let tasks = if let Some(arr) = data.as_array() {
721        arr.clone()
722    } else {
723        data["tasks"].as_array().cloned().unwrap_or_default()
724    };
725
726    tracing::info!("Found {} pending task(s)", tasks.len());
727
728    for task in tasks {
729        if let Some(id) = task["id"].as_str() {
730            let mut proc = processing.lock().await;
731            if !proc.contains(id) {
732                proc.insert(id.to_string());
733                drop(proc);
734
735                let task_id = id.to_string();
736                let client = client.clone();
737                let server = server.to_string();
738                let token = token.clone();
739                let worker_id = worker_id.to_string();
740                let auto_approve = *auto_approve;
741                let processing = processing.clone();
742
743                tokio::spawn(async move {
744                    if let Err(e) =
745                        handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
746                    {
747                        tracing::error!("Task {} failed: {}", task_id, e);
748                    }
749                    processing.lock().await.remove(&task_id);
750                });
751            }
752        }
753    }
754
755    Ok(())
756}
757
758#[allow(clippy::too_many_arguments)]
759async fn connect_stream(
760    client: &Client,
761    server: &str,
762    token: &Option<String>,
763    worker_id: &str,
764    name: &str,
765    codebases: &[String],
766    processing: &Arc<Mutex<HashSet<String>>>,
767    auto_approve: &AutoApprove,
768) -> Result<()> {
769    let url = format!(
770        "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
771        server,
772        urlencoding::encode(name),
773        urlencoding::encode(worker_id)
774    );
775
776    let mut req = client
777        .get(&url)
778        .header("Accept", "text/event-stream")
779        .header("X-Worker-ID", worker_id)
780        .header("X-Agent-Name", name)
781        .header("X-Codebases", codebases.join(","));
782
783    if let Some(t) = token {
784        req = req.bearer_auth(t);
785    }
786
787    let res = req.send().await?;
788    if !res.status().is_success() {
789        anyhow::bail!("Failed to connect: {}", res.status());
790    }
791
792    tracing::info!("Connected to A2A server");
793
794    let mut stream = res.bytes_stream();
795    let mut buffer = String::new();
796    let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
797    poll_interval.tick().await; // consume the initial immediate tick
798
799    loop {
800        tokio::select! {
801            chunk = stream.next() => {
802                match chunk {
803                    Some(Ok(chunk)) => {
804                        buffer.push_str(&String::from_utf8_lossy(&chunk));
805
806                        // Process SSE events
807                        while let Some(pos) = buffer.find("\n\n") {
808                            let event_str = buffer[..pos].to_string();
809                            buffer = buffer[pos + 2..].to_string();
810
811                            if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
812                                let data = data_line.trim_start_matches("data:").trim();
813                                if data == "[DONE]" || data.is_empty() {
814                                    continue;
815                                }
816
817                                if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
818                                    spawn_task_handler(
819                                        &task, client, server, token, worker_id,
820                                        processing, auto_approve,
821                                    ).await;
822                                }
823                            }
824                        }
825                    }
826                    Some(Err(e)) => {
827                        return Err(e.into());
828                    }
829                    None => {
830                        // Stream ended
831                        return Ok(());
832                    }
833                }
834            }
835            _ = poll_interval.tick() => {
836                // Periodic poll for pending tasks the SSE stream may have missed
837                if let Err(e) = poll_pending_tasks(
838                    client, server, token, worker_id, processing, auto_approve,
839                ).await {
840                    tracing::warn!("Periodic task poll failed: {}", e);
841                }
842            }
843        }
844    }
845}
846
847async fn spawn_task_handler(
848    task: &serde_json::Value,
849    client: &Client,
850    server: &str,
851    token: &Option<String>,
852    worker_id: &str,
853    processing: &Arc<Mutex<HashSet<String>>>,
854    auto_approve: &AutoApprove,
855) {
856    if let Some(id) = task
857        .get("task")
858        .and_then(|t| t["id"].as_str())
859        .or_else(|| task["id"].as_str())
860    {
861        let mut proc = processing.lock().await;
862        if !proc.contains(id) {
863            proc.insert(id.to_string());
864            drop(proc);
865
866            let task_id = id.to_string();
867            let task = task.clone();
868            let client = client.clone();
869            let server = server.to_string();
870            let token = token.clone();
871            let worker_id = worker_id.to_string();
872            let auto_approve = *auto_approve;
873            let processing_clone = processing.clone();
874
875            tokio::spawn(async move {
876                if let Err(e) =
877                    handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
878                {
879                    tracing::error!("Task {} failed: {}", task_id, e);
880                }
881                processing_clone.lock().await.remove(&task_id);
882            });
883        }
884    }
885}
886
887async fn poll_pending_tasks(
888    client: &Client,
889    server: &str,
890    token: &Option<String>,
891    worker_id: &str,
892    processing: &Arc<Mutex<HashSet<String>>>,
893    auto_approve: &AutoApprove,
894) -> Result<()> {
895    let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
896    if let Some(t) = token {
897        req = req.bearer_auth(t);
898    }
899
900    let res = req.send().await?;
901    if !res.status().is_success() {
902        return Ok(());
903    }
904
905    let data: serde_json::Value = res.json().await?;
906    let tasks = if let Some(arr) = data.as_array() {
907        arr.clone()
908    } else {
909        data["tasks"].as_array().cloned().unwrap_or_default()
910    };
911
912    if !tasks.is_empty() {
913        tracing::debug!("Poll found {} pending task(s)", tasks.len());
914    }
915
916    for task in &tasks {
917        spawn_task_handler(
918            task,
919            client,
920            server,
921            token,
922            worker_id,
923            processing,
924            auto_approve,
925        )
926        .await;
927    }
928
929    Ok(())
930}
931
932async fn handle_task(
933    client: &Client,
934    server: &str,
935    token: &Option<String>,
936    worker_id: &str,
937    task: &serde_json::Value,
938    auto_approve: AutoApprove,
939) -> Result<()> {
940    let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
941    let title = task_str(task, "title").unwrap_or("Untitled");
942
943    tracing::info!("Handling task: {} ({})", title, task_id);
944
945    // Claim the task
946    let mut req = client
947        .post(format!("{}/v1/worker/tasks/claim", server))
948        .header("X-Worker-ID", worker_id);
949    if let Some(t) = token {
950        req = req.bearer_auth(t);
951    }
952
953    let res = req
954        .json(&serde_json::json!({ "task_id": task_id }))
955        .send()
956        .await?;
957
958    if !res.status().is_success() {
959        let text = res.text().await?;
960        tracing::warn!("Failed to claim task: {}", text);
961        return Ok(());
962    }
963
964    tracing::info!("Claimed task: {}", task_id);
965
966    let metadata = task_metadata(task);
967    let resume_session_id = metadata
968        .get("resume_session_id")
969        .and_then(|v| v.as_str())
970        .map(|s| s.trim().to_string())
971        .filter(|s| !s.is_empty());
972    let complexity_hint = metadata_str(&metadata, &["complexity"]);
973    let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
974        .map(|s| s.to_ascii_lowercase())
975        .or_else(|| {
976            complexity_hint.as_ref().map(|complexity| {
977                match complexity.to_ascii_lowercase().as_str() {
978                    "quick" => "fast".to_string(),
979                    "deep" => "heavy".to_string(),
980                    _ => "balanced".to_string(),
981                }
982            })
983        });
984    let worker_personality = metadata_str(
985        &metadata,
986        &["worker_personality", "personality", "agent_personality"],
987    );
988    let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
989    let raw_model = task_str(task, "model_ref")
990        .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
991        .or_else(|| task_str(task, "model"))
992        .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
993    let selected_model = raw_model.map(model_ref_to_provider_model);
994
995    // Resume existing session when requested; fall back to a fresh session if missing.
996    let mut session = if let Some(ref sid) = resume_session_id {
997        match Session::load(sid).await {
998            Ok(existing) => {
999                tracing::info!("Resuming session {} for task {}", sid, task_id);
1000                existing
1001            }
1002            Err(e) => {
1003                tracing::warn!(
1004                    "Could not load session {} for task {} ({}), starting a new session",
1005                    sid,
1006                    task_id,
1007                    e
1008                );
1009                Session::new().await?
1010            }
1011        }
1012    } else {
1013        Session::new().await?
1014    };
1015
1016    let agent_type = task_str(task, "agent_type")
1017        .or_else(|| task_str(task, "agent"))
1018        .unwrap_or("build");
1019    session.agent = agent_type.to_string();
1020
1021    if let Some(model) = selected_model.clone() {
1022        session.metadata.model = Some(model);
1023    }
1024
1025    let prompt = task_str(task, "prompt")
1026        .or_else(|| task_str(task, "description"))
1027        .unwrap_or(title);
1028
1029    tracing::info!("Executing prompt: {}", prompt);
1030
1031    // Set up output streaming to forward progress to the server
1032    let stream_client = client.clone();
1033    let stream_server = server.to_string();
1034    let stream_token = token.clone();
1035    let stream_worker_id = worker_id.to_string();
1036    let stream_task_id = task_id.to_string();
1037
1038    let output_callback = move |output: String| {
1039        let c = stream_client.clone();
1040        let s = stream_server.clone();
1041        let t = stream_token.clone();
1042        let w = stream_worker_id.clone();
1043        let tid = stream_task_id.clone();
1044        tokio::spawn(async move {
1045            let mut req = c
1046                .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1047                .header("X-Worker-ID", &w);
1048            if let Some(tok) = &t {
1049                req = req.bearer_auth(tok);
1050            }
1051            let _ = req
1052                .json(&serde_json::json!({
1053                    "worker_id": w,
1054                    "output": output,
1055                }))
1056                .send()
1057                .await;
1058        });
1059    };
1060
1061    // Execute swarm tasks via SwarmExecutor; all other agents use the standard session loop.
1062    let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1063        match execute_swarm_with_policy(
1064            &mut session,
1065            prompt,
1066            model_tier.as_deref(),
1067            selected_model,
1068            &metadata,
1069            complexity_hint.as_deref(),
1070            worker_personality.as_deref(),
1071            target_agent_name.as_deref(),
1072            Some(&output_callback),
1073        )
1074        .await
1075        {
1076            Ok((session_result, true)) => {
1077                tracing::info!("Swarm task completed successfully: {}", task_id);
1078                (
1079                    "completed",
1080                    Some(session_result.text),
1081                    None,
1082                    Some(session_result.session_id),
1083                )
1084            }
1085            Ok((session_result, false)) => {
1086                tracing::warn!("Swarm task completed with failures: {}", task_id);
1087                (
1088                    "failed",
1089                    Some(session_result.text),
1090                    Some("Swarm execution completed with failures".to_string()),
1091                    Some(session_result.session_id),
1092                )
1093            }
1094            Err(e) => {
1095                tracing::error!("Swarm task failed: {} - {}", task_id, e);
1096                ("failed", None, Some(format!("Error: {}", e)), None)
1097            }
1098        }
1099    } else {
1100        match execute_session_with_policy(
1101            &mut session,
1102            prompt,
1103            auto_approve,
1104            model_tier.as_deref(),
1105            Some(&output_callback),
1106        )
1107        .await
1108        {
1109            Ok(session_result) => {
1110                tracing::info!("Task completed successfully: {}", task_id);
1111                (
1112                    "completed",
1113                    Some(session_result.text),
1114                    None,
1115                    Some(session_result.session_id),
1116                )
1117            }
1118            Err(e) => {
1119                tracing::error!("Task failed: {} - {}", task_id, e);
1120                ("failed", None, Some(format!("Error: {}", e)), None)
1121            }
1122        }
1123    };
1124
1125    // Release the task with full details
1126    let mut req = client
1127        .post(format!("{}/v1/worker/tasks/release", server))
1128        .header("X-Worker-ID", worker_id);
1129    if let Some(t) = token {
1130        req = req.bearer_auth(t);
1131    }
1132
1133    req.json(&serde_json::json!({
1134        "task_id": task_id,
1135        "status": status,
1136        "result": result,
1137        "error": error,
1138        "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1139    }))
1140    .send()
1141    .await?;
1142
1143    tracing::info!("Task released: {} with status: {}", task_id, status);
1144    Ok(())
1145}
1146
1147async fn execute_swarm_with_policy<F>(
1148    session: &mut Session,
1149    prompt: &str,
1150    model_tier: Option<&str>,
1151    explicit_model: Option<String>,
1152    metadata: &serde_json::Map<String, serde_json::Value>,
1153    complexity_hint: Option<&str>,
1154    worker_personality: Option<&str>,
1155    target_agent_name: Option<&str>,
1156    output_callback: Option<&F>,
1157) -> Result<(crate::session::SessionResult, bool)>
1158where
1159    F: Fn(String),
1160{
1161    use crate::provider::{ContentPart, Message, Role};
1162
1163    session.add_message(Message {
1164        role: Role::User,
1165        content: vec![ContentPart::Text {
1166            text: prompt.to_string(),
1167        }],
1168    });
1169
1170    if session.title.is_none() {
1171        session.generate_title().await?;
1172    }
1173
1174    let strategy = parse_swarm_strategy(metadata);
1175    let max_subagents = metadata_usize(
1176        metadata,
1177        &["swarm_max_subagents", "max_subagents", "subagents"],
1178    )
1179    .unwrap_or(10)
1180    .clamp(1, 100);
1181    let max_steps_per_subagent = metadata_usize(
1182        metadata,
1183        &[
1184            "swarm_max_steps_per_subagent",
1185            "max_steps_per_subagent",
1186            "max_steps",
1187        ],
1188    )
1189    .unwrap_or(50)
1190    .clamp(1, 200);
1191    let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1192        .unwrap_or(600)
1193        .clamp(30, 3600);
1194    let parallel_enabled =
1195        metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1196
1197    let model = resolve_swarm_model(explicit_model, model_tier).await;
1198    if let Some(ref selected_model) = model {
1199        session.metadata.model = Some(selected_model.clone());
1200    }
1201
1202    if let Some(cb) = output_callback {
1203        cb(format!(
1204            "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1205            complexity_hint.unwrap_or("standard"),
1206            model_tier.unwrap_or("balanced"),
1207            worker_personality.unwrap_or("auto"),
1208            target_agent_name.unwrap_or("auto")
1209        ));
1210        cb(format!(
1211            "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1212            strategy,
1213            max_subagents,
1214            max_steps_per_subagent,
1215            timeout_secs,
1216            model_tier.unwrap_or("balanced")
1217        ));
1218    }
1219
1220    let swarm_config = SwarmConfig {
1221        max_subagents,
1222        max_steps_per_subagent,
1223        subagent_timeout_secs: timeout_secs,
1224        parallel_enabled,
1225        model,
1226        working_dir: session
1227            .metadata
1228            .directory
1229            .as_ref()
1230            .map(|p| p.to_string_lossy().to_string()),
1231        ..Default::default()
1232    };
1233
1234    let swarm_result = if output_callback.is_some() {
1235        let (event_tx, mut event_rx) = mpsc::channel(256);
1236        let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1237        let prompt_owned = prompt.to_string();
1238        let mut exec_handle =
1239            tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1240
1241        let mut final_result: Option<crate::swarm::SwarmResult> = None;
1242
1243        while final_result.is_none() {
1244            tokio::select! {
1245                maybe_event = event_rx.recv() => {
1246                    if let Some(event) = maybe_event {
1247                        if let Some(cb) = output_callback {
1248                            if let Some(line) = format_swarm_event_for_output(&event) {
1249                                cb(line);
1250                            }
1251                        }
1252                    }
1253                }
1254                join_result = &mut exec_handle => {
1255                    let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1256                    final_result = Some(joined?);
1257                }
1258            }
1259        }
1260
1261        while let Ok(event) = event_rx.try_recv() {
1262            if let Some(cb) = output_callback {
1263                if let Some(line) = format_swarm_event_for_output(&event) {
1264                    cb(line);
1265                }
1266            }
1267        }
1268
1269        final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1270    } else {
1271        SwarmExecutor::new(swarm_config)
1272            .execute(prompt, strategy)
1273            .await?
1274    };
1275
1276    let final_text = if swarm_result.result.trim().is_empty() {
1277        if swarm_result.success {
1278            "Swarm completed without textual output.".to_string()
1279        } else {
1280            "Swarm finished with failures and no textual output.".to_string()
1281        }
1282    } else {
1283        swarm_result.result.clone()
1284    };
1285
1286    session.add_message(Message {
1287        role: Role::Assistant,
1288        content: vec![ContentPart::Text {
1289            text: final_text.clone(),
1290        }],
1291    });
1292    session.save().await?;
1293
1294    Ok((
1295        crate::session::SessionResult {
1296            text: final_text,
1297            session_id: session.id.clone(),
1298        },
1299        swarm_result.success,
1300    ))
1301}
1302
1303/// Execute a session with the given auto-approve policy
1304/// Optionally streams output chunks via the callback
1305async fn execute_session_with_policy<F>(
1306    session: &mut Session,
1307    prompt: &str,
1308    auto_approve: AutoApprove,
1309    model_tier: Option<&str>,
1310    output_callback: Option<&F>,
1311) -> Result<crate::session::SessionResult>
1312where
1313    F: Fn(String),
1314{
1315    use crate::provider::{
1316        CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1317    };
1318    use std::sync::Arc;
1319
1320    // Load provider registry from Vault
1321    let registry = ProviderRegistry::from_vault().await?;
1322    let providers = registry.list();
1323    tracing::info!("Available providers: {:?}", providers);
1324
1325    if providers.is_empty() {
1326        anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1327    }
1328
1329    // Parse model string
1330    let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1331        let (prov, model) = parse_model_string(model_str);
1332        if prov.is_some() {
1333            (prov.map(|s| s.to_string()), model.to_string())
1334        } else if providers.contains(&model) {
1335            (Some(model.to_string()), String::new())
1336        } else {
1337            (None, model.to_string())
1338        }
1339    } else {
1340        (None, String::new())
1341    };
1342
1343    let provider_slice = providers.as_slice();
1344    let provider_requested_but_unavailable = provider_name
1345        .as_deref()
1346        .map(|p| !providers.contains(&p))
1347        .unwrap_or(false);
1348
1349    // Determine which provider to use, preferring explicit request first, then model tier.
1350    let selected_provider = provider_name
1351        .as_deref()
1352        .filter(|p| providers.contains(p))
1353        .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1354
1355    let provider = registry
1356        .get(selected_provider)
1357        .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1358
1359    // Add user message
1360    session.add_message(Message {
1361        role: Role::User,
1362        content: vec![ContentPart::Text {
1363            text: prompt.to_string(),
1364        }],
1365    });
1366
1367    // Generate title
1368    if session.title.is_none() {
1369        session.generate_title().await?;
1370    }
1371
1372    // Determine model. If a specific provider was requested but not available,
1373    // ignore that model id and fall back to the tier-based default model.
1374    let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1375        model_id
1376    } else {
1377        default_model_for_provider(selected_provider, model_tier)
1378    };
1379
1380    // Create tool registry with filtering based on auto-approve policy
1381    let tool_registry =
1382        create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1383    let tool_definitions = tool_registry.definitions();
1384
1385    let temperature = if model.starts_with("kimi-k2") {
1386        Some(1.0)
1387    } else {
1388        Some(0.7)
1389    };
1390
1391    tracing::info!(
1392        "Using model: {} via provider: {} (tier: {:?})",
1393        model,
1394        selected_provider,
1395        model_tier
1396    );
1397    tracing::info!(
1398        "Available tools: {} (auto_approve: {:?})",
1399        tool_definitions.len(),
1400        auto_approve
1401    );
1402
1403    // Build system prompt
1404    let cwd = std::env::var("PWD")
1405        .map(std::path::PathBuf::from)
1406        .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1407    let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1408
1409    let mut final_output = String::new();
1410    let max_steps = 50;
1411
1412    for step in 1..=max_steps {
1413        tracing::info!(step = step, "Agent step starting");
1414
1415        // Build messages with system prompt first
1416        let mut messages = vec![Message {
1417            role: Role::System,
1418            content: vec![ContentPart::Text {
1419                text: system_prompt.clone(),
1420            }],
1421        }];
1422        messages.extend(session.messages.clone());
1423
1424        let request = CompletionRequest {
1425            messages,
1426            tools: tool_definitions.clone(),
1427            model: model.clone(),
1428            temperature,
1429            top_p: None,
1430            max_tokens: Some(8192),
1431            stop: Vec::new(),
1432        };
1433
1434        let response = provider.complete(request).await?;
1435
1436        crate::telemetry::TOKEN_USAGE.record_model_usage(
1437            &model,
1438            response.usage.prompt_tokens as u64,
1439            response.usage.completion_tokens as u64,
1440        );
1441
1442        // Extract tool calls
1443        let tool_calls: Vec<(String, String, serde_json::Value)> = response
1444            .message
1445            .content
1446            .iter()
1447            .filter_map(|part| {
1448                if let ContentPart::ToolCall {
1449                    id,
1450                    name,
1451                    arguments,
1452                } = part
1453                {
1454                    let args: serde_json::Value =
1455                        serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1456                    Some((id.clone(), name.clone(), args))
1457                } else {
1458                    None
1459                }
1460            })
1461            .collect();
1462
1463        // Collect text output and stream it
1464        for part in &response.message.content {
1465            if let ContentPart::Text { text } = part {
1466                if !text.is_empty() {
1467                    final_output.push_str(text);
1468                    final_output.push('\n');
1469                    if let Some(cb) = output_callback {
1470                        cb(text.clone());
1471                    }
1472                }
1473            }
1474        }
1475
1476        // If no tool calls, we're done
1477        if tool_calls.is_empty() {
1478            session.add_message(response.message.clone());
1479            break;
1480        }
1481
1482        session.add_message(response.message.clone());
1483
1484        tracing::info!(
1485            step = step,
1486            num_tools = tool_calls.len(),
1487            "Executing tool calls"
1488        );
1489
1490        // Execute each tool call
1491        for (tool_id, tool_name, tool_input) in tool_calls {
1492            tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1493
1494            // Stream tool start event
1495            if let Some(cb) = output_callback {
1496                cb(format!("[tool:start:{}]", tool_name));
1497            }
1498
1499            // Check if tool is allowed based on auto-approve policy
1500            if !is_tool_allowed(&tool_name, auto_approve) {
1501                let msg = format!(
1502                    "Tool '{}' requires approval but auto-approve policy is {:?}",
1503                    tool_name, auto_approve
1504                );
1505                tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1506                session.add_message(Message {
1507                    role: Role::Tool,
1508                    content: vec![ContentPart::ToolResult {
1509                        tool_call_id: tool_id,
1510                        content: msg,
1511                    }],
1512                });
1513                continue;
1514            }
1515
1516            let content = if let Some(tool) = tool_registry.get(&tool_name) {
1517                let exec_result: Result<crate::tool::ToolResult> =
1518                    tool.execute(tool_input.clone()).await;
1519                match exec_result {
1520                    Ok(result) => {
1521                        tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1522                        if let Some(cb) = output_callback {
1523                            let status = if result.success { "ok" } else { "err" };
1524                            cb(format!(
1525                                "[tool:{}:{}] {}",
1526                                tool_name,
1527                                status,
1528                                &result.output[..result.output.len().min(500)]
1529                            ));
1530                        }
1531                        result.output
1532                    }
1533                    Err(e) => {
1534                        tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1535                        if let Some(cb) = output_callback {
1536                            cb(format!("[tool:{}:err] {}", tool_name, e));
1537                        }
1538                        format!("Error: {}", e)
1539                    }
1540                }
1541            } else {
1542                tracing::warn!(tool = %tool_name, "Tool not found");
1543                format!("Error: Unknown tool '{}'", tool_name)
1544            };
1545
1546            session.add_message(Message {
1547                role: Role::Tool,
1548                content: vec![ContentPart::ToolResult {
1549                    tool_call_id: tool_id,
1550                    content,
1551                }],
1552            });
1553        }
1554    }
1555
1556    session.save().await?;
1557
1558    Ok(crate::session::SessionResult {
1559        text: final_output.trim().to_string(),
1560        session_id: session.id.clone(),
1561    })
1562}
1563
1564/// Check if a tool is allowed based on the auto-approve policy
1565fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1566    match auto_approve {
1567        AutoApprove::All => true,
1568        AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1569    }
1570}
1571
1572/// Check if a tool is considered "safe" (read-only)
1573fn is_safe_tool(tool_name: &str) -> bool {
1574    let safe_tools = [
1575        "read",
1576        "list",
1577        "glob",
1578        "grep",
1579        "codesearch",
1580        "lsp",
1581        "webfetch",
1582        "websearch",
1583        "todo_read",
1584        "skill",
1585    ];
1586    safe_tools.contains(&tool_name)
1587}
1588
1589/// Create a filtered tool registry based on the auto-approve policy
1590fn create_filtered_registry(
1591    provider: Arc<dyn crate::provider::Provider>,
1592    model: String,
1593    auto_approve: AutoApprove,
1594) -> crate::tool::ToolRegistry {
1595    use crate::tool::*;
1596
1597    let mut registry = ToolRegistry::new();
1598
1599    // Always add safe tools
1600    registry.register(Arc::new(file::ReadTool::new()));
1601    registry.register(Arc::new(file::ListTool::new()));
1602    registry.register(Arc::new(file::GlobTool::new()));
1603    registry.register(Arc::new(search::GrepTool::new()));
1604    registry.register(Arc::new(lsp::LspTool::new()));
1605    registry.register(Arc::new(webfetch::WebFetchTool::new()));
1606    registry.register(Arc::new(websearch::WebSearchTool::new()));
1607    registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1608    registry.register(Arc::new(todo::TodoReadTool::new()));
1609    registry.register(Arc::new(skill::SkillTool::new()));
1610
1611    // Add potentially dangerous tools only if auto_approve is All
1612    if matches!(auto_approve, AutoApprove::All) {
1613        registry.register(Arc::new(file::WriteTool::new()));
1614        registry.register(Arc::new(edit::EditTool::new()));
1615        registry.register(Arc::new(bash::BashTool::new()));
1616        registry.register(Arc::new(multiedit::MultiEditTool::new()));
1617        registry.register(Arc::new(patch::ApplyPatchTool::new()));
1618        registry.register(Arc::new(todo::TodoWriteTool::new()));
1619        registry.register(Arc::new(task::TaskTool::new()));
1620        registry.register(Arc::new(plan::PlanEnterTool::new()));
1621        registry.register(Arc::new(plan::PlanExitTool::new()));
1622        registry.register(Arc::new(rlm::RlmTool::new()));
1623        registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1624        registry.register(Arc::new(prd::PrdTool::new()));
1625        registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1626        registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1627        registry.register(Arc::new(undo::UndoTool));
1628    }
1629
1630    registry.register(Arc::new(invalid::InvalidTool::new()));
1631
1632    registry
1633}
1634
1635/// Start the heartbeat background task
1636/// Returns a JoinHandle that can be used to cancel the heartbeat
1637fn start_heartbeat(
1638    client: Client,
1639    server: String,
1640    token: Option<String>,
1641    heartbeat_state: HeartbeatState,
1642    processing: Arc<Mutex<HashSet<String>>>,
1643    cognition_config: CognitionHeartbeatConfig,
1644) -> JoinHandle<()> {
1645    tokio::spawn(async move {
1646        let mut consecutive_failures = 0u32;
1647        const MAX_FAILURES: u32 = 3;
1648        const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1649        const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1650        let mut cognition_payload_disabled_until: Option<Instant> = None;
1651
1652        let mut interval =
1653            tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1654        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1655
1656        loop {
1657            interval.tick().await;
1658
1659            // Update task count from processing set
1660            let active_count = processing.lock().await.len();
1661            heartbeat_state.set_task_count(active_count).await;
1662
1663            // Determine status based on active tasks
1664            let status = if active_count > 0 {
1665                WorkerStatus::Processing
1666            } else {
1667                WorkerStatus::Idle
1668            };
1669            heartbeat_state.set_status(status).await;
1670
1671            // Send heartbeat
1672            let url = format!(
1673                "{}/v1/opencode/workers/{}/heartbeat",
1674                server, heartbeat_state.worker_id
1675            );
1676            let mut req = client.post(&url);
1677
1678            if let Some(ref t) = token {
1679                req = req.bearer_auth(t);
1680            }
1681
1682            let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1683            let base_payload = serde_json::json!({
1684                "worker_id": &heartbeat_state.worker_id,
1685                "agent_name": &heartbeat_state.agent_name,
1686                "status": status_str,
1687                "active_task_count": active_count,
1688            });
1689            let mut payload = base_payload.clone();
1690            let mut included_cognition_payload = false;
1691            let cognition_payload_allowed = cognition_payload_disabled_until
1692                .map(|until| Instant::now() >= until)
1693                .unwrap_or(true);
1694
1695            if cognition_config.enabled
1696                && cognition_payload_allowed
1697                && let Some(cognition_payload) =
1698                    fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1699                && let Some(obj) = payload.as_object_mut()
1700            {
1701                obj.insert("cognition".to_string(), cognition_payload);
1702                included_cognition_payload = true;
1703            }
1704
1705            match req.json(&payload).send().await {
1706                Ok(res) => {
1707                    if res.status().is_success() {
1708                        consecutive_failures = 0;
1709                        tracing::debug!(
1710                            worker_id = %heartbeat_state.worker_id,
1711                            status = status_str,
1712                            active_tasks = active_count,
1713                            "Heartbeat sent successfully"
1714                        );
1715                    } else if included_cognition_payload && res.status().is_client_error() {
1716                        tracing::warn!(
1717                            worker_id = %heartbeat_state.worker_id,
1718                            status = %res.status(),
1719                            "Heartbeat cognition payload rejected, retrying without cognition payload"
1720                        );
1721
1722                        let mut retry_req = client.post(&url);
1723                        if let Some(ref t) = token {
1724                            retry_req = retry_req.bearer_auth(t);
1725                        }
1726
1727                        match retry_req.json(&base_payload).send().await {
1728                            Ok(retry_res) if retry_res.status().is_success() => {
1729                                cognition_payload_disabled_until = Some(
1730                                    Instant::now()
1731                                        + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1732                                );
1733                                consecutive_failures = 0;
1734                                tracing::warn!(
1735                                    worker_id = %heartbeat_state.worker_id,
1736                                    retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1737                                    "Paused cognition heartbeat payload after schema rejection"
1738                                );
1739                            }
1740                            Ok(retry_res) => {
1741                                consecutive_failures += 1;
1742                                tracing::warn!(
1743                                    worker_id = %heartbeat_state.worker_id,
1744                                    status = %retry_res.status(),
1745                                    failures = consecutive_failures,
1746                                    "Heartbeat failed even after retry without cognition payload"
1747                                );
1748                            }
1749                            Err(e) => {
1750                                consecutive_failures += 1;
1751                                tracing::warn!(
1752                                    worker_id = %heartbeat_state.worker_id,
1753                                    error = %e,
1754                                    failures = consecutive_failures,
1755                                    "Heartbeat retry without cognition payload failed"
1756                                );
1757                            }
1758                        }
1759                    } else {
1760                        consecutive_failures += 1;
1761                        tracing::warn!(
1762                            worker_id = %heartbeat_state.worker_id,
1763                            status = %res.status(),
1764                            failures = consecutive_failures,
1765                            "Heartbeat failed"
1766                        );
1767                    }
1768                }
1769                Err(e) => {
1770                    consecutive_failures += 1;
1771                    tracing::warn!(
1772                        worker_id = %heartbeat_state.worker_id,
1773                        error = %e,
1774                        failures = consecutive_failures,
1775                        "Heartbeat request failed"
1776                    );
1777                }
1778            }
1779
1780            // Log error after 3 consecutive failures but do not terminate
1781            if consecutive_failures >= MAX_FAILURES {
1782                tracing::error!(
1783                    worker_id = %heartbeat_state.worker_id,
1784                    failures = consecutive_failures,
1785                    "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1786                    MAX_FAILURES
1787                );
1788                // Reset counter to avoid spamming error logs
1789                consecutive_failures = 0;
1790            }
1791        }
1792    })
1793}
1794
1795async fn fetch_cognition_heartbeat_payload(
1796    client: &Client,
1797    config: &CognitionHeartbeatConfig,
1798) -> Option<serde_json::Value> {
1799    let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1800    let status_res = tokio::time::timeout(
1801        Duration::from_millis(config.request_timeout_ms),
1802        client.get(status_url).send(),
1803    )
1804    .await
1805    .ok()?
1806    .ok()?;
1807
1808    if !status_res.status().is_success() {
1809        return None;
1810    }
1811
1812    let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1813    let mut payload = serde_json::json!({
1814        "running": status.running,
1815        "last_tick_at": status.last_tick_at,
1816        "active_persona_count": status.active_persona_count,
1817        "events_buffered": status.events_buffered,
1818        "snapshots_buffered": status.snapshots_buffered,
1819        "loop_interval_ms": status.loop_interval_ms,
1820    });
1821
1822    if config.include_thought_summary {
1823        let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1824        let snapshot_res = tokio::time::timeout(
1825            Duration::from_millis(config.request_timeout_ms),
1826            client.get(snapshot_url).send(),
1827        )
1828        .await
1829        .ok()
1830        .and_then(Result::ok);
1831
1832        if let Some(snapshot_res) = snapshot_res
1833            && snapshot_res.status().is_success()
1834            && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1835            && let Some(obj) = payload.as_object_mut()
1836        {
1837            obj.insert(
1838                "latest_snapshot_at".to_string(),
1839                serde_json::Value::String(snapshot.generated_at),
1840            );
1841            obj.insert(
1842                "latest_thought".to_string(),
1843                serde_json::Value::String(trim_for_heartbeat(
1844                    &snapshot.summary,
1845                    config.summary_max_chars,
1846                )),
1847            );
1848            if let Some(model) = snapshot
1849                .metadata
1850                .get("model")
1851                .and_then(serde_json::Value::as_str)
1852            {
1853                obj.insert(
1854                    "latest_thought_model".to_string(),
1855                    serde_json::Value::String(model.to_string()),
1856                );
1857            }
1858            if let Some(source) = snapshot
1859                .metadata
1860                .get("source")
1861                .and_then(serde_json::Value::as_str)
1862            {
1863                obj.insert(
1864                    "latest_thought_source".to_string(),
1865                    serde_json::Value::String(source.to_string()),
1866                );
1867            }
1868        }
1869    }
1870
1871    Some(payload)
1872}
1873
1874fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1875    if input.chars().count() <= max_chars {
1876        return input.trim().to_string();
1877    }
1878
1879    let mut trimmed = String::with_capacity(max_chars + 3);
1880    for ch in input.chars().take(max_chars) {
1881        trimmed.push(ch);
1882    }
1883    trimmed.push_str("...");
1884    trimmed.trim().to_string()
1885}
1886
1887fn env_bool(name: &str, default: bool) -> bool {
1888    std::env::var(name)
1889        .ok()
1890        .and_then(|v| match v.to_ascii_lowercase().as_str() {
1891            "1" | "true" | "yes" | "on" => Some(true),
1892            "0" | "false" | "no" | "off" => Some(false),
1893            _ => None,
1894        })
1895        .unwrap_or(default)
1896}
1897
1898fn env_usize(name: &str, default: usize) -> usize {
1899    std::env::var(name)
1900        .ok()
1901        .and_then(|v| v.parse::<usize>().ok())
1902        .unwrap_or(default)
1903}
1904
1905fn env_u64(name: &str, default: u64) -> u64 {
1906    std::env::var(name)
1907        .ok()
1908        .and_then(|v| v.parse::<u64>().ok())
1909        .unwrap_or(default)
1910}