1use crate::bus::AgentBus;
4use crate::cli::A2aArgs;
5use crate::provider::ProviderRegistry;
6use crate::session::Session;
7use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
8use crate::tui::swarm_view::SwarmEvent;
9use anyhow::Result;
10use futures::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{Mutex, mpsc};
18use tokio::task::JoinHandle;
19use tokio::time::Instant;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum WorkerStatus {
24 Idle,
25 Processing,
26}
27
28impl WorkerStatus {
29 fn as_str(&self) -> &'static str {
30 match self {
31 WorkerStatus::Idle => "idle",
32 WorkerStatus::Processing => "processing",
33 }
34 }
35}
36
37#[derive(Clone)]
39struct HeartbeatState {
40 worker_id: String,
41 agent_name: String,
42 status: Arc<Mutex<WorkerStatus>>,
43 active_task_count: Arc<Mutex<usize>>,
44}
45
46impl HeartbeatState {
47 fn new(worker_id: String, agent_name: String) -> Self {
48 Self {
49 worker_id,
50 agent_name,
51 status: Arc::new(Mutex::new(WorkerStatus::Idle)),
52 active_task_count: Arc::new(Mutex::new(0)),
53 }
54 }
55
56 async fn set_status(&self, status: WorkerStatus) {
57 *self.status.lock().await = status;
58 }
59
60 async fn set_task_count(&self, count: usize) {
61 *self.active_task_count.lock().await = count;
62 }
63}
64
65#[derive(Clone, Debug)]
66struct CognitionHeartbeatConfig {
67 enabled: bool,
68 source_base_url: String,
69 include_thought_summary: bool,
70 summary_max_chars: usize,
71 request_timeout_ms: u64,
72}
73
74impl CognitionHeartbeatConfig {
75 fn from_env() -> Self {
76 let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
77 .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
78 .trim_end_matches('/')
79 .to_string();
80
81 Self {
82 enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
83 source_base_url,
84 include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
85 summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
86 .max(120),
87 request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
88 }
89 }
90}
91
92#[derive(Debug, Deserialize)]
93struct CognitionStatusSnapshot {
94 running: bool,
95 #[serde(default)]
96 last_tick_at: Option<String>,
97 #[serde(default)]
98 active_persona_count: usize,
99 #[serde(default)]
100 events_buffered: usize,
101 #[serde(default)]
102 snapshots_buffered: usize,
103 #[serde(default)]
104 loop_interval_ms: u64,
105}
106
107#[derive(Debug, Deserialize)]
108struct CognitionLatestSnapshot {
109 generated_at: String,
110 summary: String,
111 #[serde(default)]
112 metadata: HashMap<String, serde_json::Value>,
113}
114
115pub async fn run(args: A2aArgs) -> Result<()> {
117 let server = args.server.trim_end_matches('/');
118 let name = args
119 .name
120 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
121 let worker_id = generate_worker_id();
122
123 let codebases: Vec<String> = args
124 .codebases
125 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
126 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
127
128 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
129 tracing::info!("Server: {}", server);
130 tracing::info!("Codebases: {:?}", codebases);
131
132 let client = Client::new();
133 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
134 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
135 if cognition_heartbeat.enabled {
136 tracing::info!(
137 source = %cognition_heartbeat.source_base_url,
138 include_thoughts = cognition_heartbeat.include_thought_summary,
139 max_chars = cognition_heartbeat.summary_max_chars,
140 timeout_ms = cognition_heartbeat.request_timeout_ms,
141 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
142 );
143 } else {
144 tracing::warn!(
145 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
146 );
147 }
148
149 let auto_approve = match args.auto_approve.as_str() {
150 "all" => AutoApprove::All,
151 "safe" => AutoApprove::Safe,
152 _ => AutoApprove::None,
153 };
154
155 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
157
158 let bus = AgentBus::new().into_arc();
160 {
161 let handle = bus.handle(&worker_id);
162 handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
163 }
164
165 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
167
168 fetch_pending_tasks(
170 &client,
171 server,
172 &args.token,
173 &worker_id,
174 &processing,
175 &auto_approve,
176 &bus,
177 )
178 .await?;
179
180 loop {
182 if let Err(e) =
184 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
185 {
186 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
187 }
188
189 let heartbeat_handle = start_heartbeat(
191 client.clone(),
192 server.to_string(),
193 args.token.clone(),
194 heartbeat_state.clone(),
195 processing.clone(),
196 cognition_heartbeat.clone(),
197 );
198
199 match connect_stream(
200 &client,
201 server,
202 &args.token,
203 &worker_id,
204 &name,
205 &codebases,
206 &processing,
207 &auto_approve,
208 &bus,
209 )
210 .await
211 {
212 Ok(()) => {
213 tracing::warn!("Stream ended, reconnecting...");
214 }
215 Err(e) => {
216 tracing::error!("Stream error: {}, reconnecting...", e);
217 }
218 }
219
220 heartbeat_handle.abort();
222 tracing::debug!("Heartbeat cancelled for reconnection");
223
224 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
225 }
226}
227
228fn generate_worker_id() -> String {
229 format!(
230 "wrk_{}_{:x}",
231 chrono::Utc::now().timestamp(),
232 rand::random::<u64>()
233 )
234}
235
236#[derive(Debug, Clone, Copy)]
237enum AutoApprove {
238 All,
239 Safe,
240 None,
241}
242
243pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
245
246const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
248
249fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
250 task.get("task")
251 .and_then(|t| t.get(key))
252 .or_else(|| task.get(key))
253}
254
255fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
256 task_value(task, key).and_then(|v| v.as_str())
257}
258
259fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
260 task_value(task, "metadata")
261 .and_then(|m| m.as_object())
262 .cloned()
263 .unwrap_or_default()
264}
265
266fn model_ref_to_provider_model(model: &str) -> String {
267 if !model.contains('/') && model.contains(':') {
271 model.replacen(':', "/", 1)
272 } else {
273 model.to_string()
274 }
275}
276
277fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
278 match model_tier.unwrap_or("balanced") {
279 "fast" | "quick" => &[
280 "zai",
281 "openai",
282 "github-copilot",
283 "moonshotai",
284 "openrouter",
285 "novita",
286 "google",
287 "anthropic",
288 ],
289 "heavy" | "deep" => &[
290 "zai",
291 "anthropic",
292 "openai",
293 "github-copilot",
294 "moonshotai",
295 "openrouter",
296 "novita",
297 "google",
298 ],
299 _ => &[
300 "zai",
301 "openai",
302 "github-copilot",
303 "anthropic",
304 "moonshotai",
305 "openrouter",
306 "novita",
307 "google",
308 ],
309 }
310}
311
312fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
313 for preferred in provider_preferences_for_tier(model_tier) {
314 if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
315 return found;
316 }
317 }
318 if let Some(found) = providers.iter().copied().find(|p| *p == "zai") {
319 return found;
320 }
321 providers[0]
322}
323
324fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
325 match model_tier.unwrap_or("balanced") {
326 "fast" | "quick" => match provider {
327 "moonshotai" => "kimi-k2.5".to_string(),
328 "anthropic" => "claude-haiku-4-5".to_string(),
329 "openai" => "gpt-4o-mini".to_string(),
330 "google" => "gemini-2.5-flash".to_string(),
331 "zhipuai" | "zai" => "glm-5".to_string(),
332 "openrouter" => "z-ai/glm-5".to_string(),
333 "novita" => "qwen/qwen3-coder-next".to_string(),
334 "bedrock" => "amazon.nova-lite-v1:0".to_string(),
335 _ => "glm-5".to_string(),
336 },
337 "heavy" | "deep" => match provider {
338 "moonshotai" => "kimi-k2.5".to_string(),
339 "anthropic" => "claude-sonnet-4-20250514".to_string(),
340 "openai" => "o3".to_string(),
341 "google" => "gemini-2.5-pro".to_string(),
342 "zhipuai" | "zai" => "glm-5".to_string(),
343 "openrouter" => "z-ai/glm-5".to_string(),
344 "novita" => "qwen/qwen3-coder-next".to_string(),
345 "bedrock" => "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
346 _ => "glm-5".to_string(),
347 },
348 _ => match provider {
349 "moonshotai" => "kimi-k2.5".to_string(),
350 "anthropic" => "claude-sonnet-4-20250514".to_string(),
351 "openai" => "gpt-4o".to_string(),
352 "google" => "gemini-2.5-pro".to_string(),
353 "zhipuai" | "zai" => "glm-5".to_string(),
354 "openrouter" => "z-ai/glm-5".to_string(),
355 "novita" => "qwen/qwen3-coder-next".to_string(),
356 "bedrock" => "amazon.nova-lite-v1:0".to_string(),
357 _ => "glm-5".to_string(),
358 },
359 }
360}
361
362fn prefers_temperature_one(model: &str) -> bool {
363 let normalized = model.to_ascii_lowercase();
364 normalized.contains("kimi-k2") || normalized.contains("glm-") || normalized.contains("minimax")
365}
366
367fn is_swarm_agent(agent_type: &str) -> bool {
368 matches!(
369 agent_type.trim().to_ascii_lowercase().as_str(),
370 "swarm" | "parallel" | "multi-agent"
371 )
372}
373
374fn metadata_lookup<'a>(
375 metadata: &'a serde_json::Map<String, serde_json::Value>,
376 key: &str,
377) -> Option<&'a serde_json::Value> {
378 metadata
379 .get(key)
380 .or_else(|| {
381 metadata
382 .get("routing")
383 .and_then(|v| v.as_object())
384 .and_then(|obj| obj.get(key))
385 })
386 .or_else(|| {
387 metadata
388 .get("swarm")
389 .and_then(|v| v.as_object())
390 .and_then(|obj| obj.get(key))
391 })
392}
393
394fn metadata_str(
395 metadata: &serde_json::Map<String, serde_json::Value>,
396 keys: &[&str],
397) -> Option<String> {
398 for key in keys {
399 if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
400 let trimmed = value.trim();
401 if !trimmed.is_empty() {
402 return Some(trimmed.to_string());
403 }
404 }
405 }
406 None
407}
408
409fn metadata_usize(
410 metadata: &serde_json::Map<String, serde_json::Value>,
411 keys: &[&str],
412) -> Option<usize> {
413 for key in keys {
414 if let Some(value) = metadata_lookup(metadata, key) {
415 if let Some(v) = value.as_u64() {
416 return usize::try_from(v).ok();
417 }
418 if let Some(v) = value.as_i64() {
419 if v >= 0 {
420 return usize::try_from(v as u64).ok();
421 }
422 }
423 if let Some(v) = value.as_str() {
424 if let Ok(parsed) = v.trim().parse::<usize>() {
425 return Some(parsed);
426 }
427 }
428 }
429 }
430 None
431}
432
433fn metadata_u64(
434 metadata: &serde_json::Map<String, serde_json::Value>,
435 keys: &[&str],
436) -> Option<u64> {
437 for key in keys {
438 if let Some(value) = metadata_lookup(metadata, key) {
439 if let Some(v) = value.as_u64() {
440 return Some(v);
441 }
442 if let Some(v) = value.as_i64() {
443 if v >= 0 {
444 return Some(v as u64);
445 }
446 }
447 if let Some(v) = value.as_str() {
448 if let Ok(parsed) = v.trim().parse::<u64>() {
449 return Some(parsed);
450 }
451 }
452 }
453 }
454 None
455}
456
457fn metadata_bool(
458 metadata: &serde_json::Map<String, serde_json::Value>,
459 keys: &[&str],
460) -> Option<bool> {
461 for key in keys {
462 if let Some(value) = metadata_lookup(metadata, key) {
463 if let Some(v) = value.as_bool() {
464 return Some(v);
465 }
466 if let Some(v) = value.as_str() {
467 match v.trim().to_ascii_lowercase().as_str() {
468 "1" | "true" | "yes" | "on" => return Some(true),
469 "0" | "false" | "no" | "off" => return Some(false),
470 _ => {}
471 }
472 }
473 }
474 }
475 None
476}
477
478fn parse_swarm_strategy(
479 metadata: &serde_json::Map<String, serde_json::Value>,
480) -> DecompositionStrategy {
481 match metadata_str(
482 metadata,
483 &[
484 "decomposition_strategy",
485 "swarm_strategy",
486 "strategy",
487 "swarm_decomposition",
488 ],
489 )
490 .as_deref()
491 .map(|s| s.to_ascii_lowercase())
492 .as_deref()
493 {
494 Some("none") | Some("single") => DecompositionStrategy::None,
495 Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
496 Some("data") | Some("by_data") => DecompositionStrategy::ByData,
497 Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
498 _ => DecompositionStrategy::Automatic,
499 }
500}
501
502async fn resolve_swarm_model(
503 explicit_model: Option<String>,
504 model_tier: Option<&str>,
505) -> Option<String> {
506 if let Some(model) = explicit_model {
507 if !model.trim().is_empty() {
508 return Some(model);
509 }
510 }
511
512 let registry = ProviderRegistry::from_vault().await.ok()?;
513 let providers = registry.list();
514 if providers.is_empty() {
515 return None;
516 }
517 let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
518 let model = default_model_for_provider(provider, model_tier);
519 Some(format!("{}/{}", provider, model))
520}
521
522fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
523 match event {
524 SwarmEvent::Started {
525 task,
526 total_subtasks,
527 } => Some(format!(
528 "[swarm] started task={} planned_subtasks={}",
529 task, total_subtasks
530 )),
531 SwarmEvent::StageComplete {
532 stage,
533 completed,
534 failed,
535 } => Some(format!(
536 "[swarm] stage={} completed={} failed={}",
537 stage, completed, failed
538 )),
539 SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
540 "[swarm] subtask id={} status={}",
541 &id.chars().take(8).collect::<String>(),
542 format!("{status:?}").to_ascii_lowercase()
543 )),
544 SwarmEvent::AgentToolCall {
545 subtask_id,
546 tool_name,
547 } => Some(format!(
548 "[swarm] subtask id={} tool={}",
549 &subtask_id.chars().take(8).collect::<String>(),
550 tool_name
551 )),
552 SwarmEvent::AgentError { subtask_id, error } => Some(format!(
553 "[swarm] subtask id={} error={}",
554 &subtask_id.chars().take(8).collect::<String>(),
555 error
556 )),
557 SwarmEvent::Complete { success, stats } => Some(format!(
558 "[swarm] complete success={} subtasks={} speedup={:.2}",
559 success,
560 stats.subagents_completed + stats.subagents_failed,
561 stats.speedup_factor
562 )),
563 SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
564 _ => None,
565 }
566}
567
568async fn register_worker(
569 client: &Client,
570 server: &str,
571 token: &Option<String>,
572 worker_id: &str,
573 name: &str,
574 codebases: &[String],
575) -> Result<()> {
576 let models = match load_provider_models().await {
578 Ok(m) => m,
579 Err(e) => {
580 tracing::warn!(
581 "Failed to load provider models: {}, proceeding without model info",
582 e
583 );
584 HashMap::new()
585 }
586 };
587
588 let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
590
591 if let Some(t) = token {
592 req = req.bearer_auth(t);
593 }
594
595 let models_array: Vec<serde_json::Value> = models
598 .iter()
599 .flat_map(|(provider, model_infos)| {
600 model_infos.iter().map(move |m| {
601 let mut obj = serde_json::json!({
602 "id": format!("{}/{}", provider, m.id),
603 "name": &m.id,
604 "provider": provider,
605 "provider_id": provider,
606 });
607 if let Some(input_cost) = m.input_cost_per_million {
608 obj["input_cost_per_million"] = serde_json::json!(input_cost);
609 }
610 if let Some(output_cost) = m.output_cost_per_million {
611 obj["output_cost_per_million"] = serde_json::json!(output_cost);
612 }
613 obj
614 })
615 })
616 .collect();
617
618 tracing::info!(
619 "Registering worker with {} models from {} providers",
620 models_array.len(),
621 models.len()
622 );
623
624 let hostname = std::env::var("HOSTNAME")
625 .or_else(|_| std::env::var("COMPUTERNAME"))
626 .unwrap_or_else(|_| "unknown".to_string());
627
628 let res = req
629 .json(&serde_json::json!({
630 "worker_id": worker_id,
631 "name": name,
632 "capabilities": WORKER_CAPABILITIES,
633 "hostname": hostname,
634 "models": models_array,
635 "codebases": codebases,
636 }))
637 .send()
638 .await?;
639
640 if res.status().is_success() {
641 tracing::info!("Worker registered successfully");
642 } else {
643 tracing::warn!("Failed to register worker: {}", res.status());
644 }
645
646 Ok(())
647}
648
649async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
653 let registry = match ProviderRegistry::from_vault().await {
655 Ok(r) if !r.list().is_empty() => {
656 tracing::info!("Loaded {} providers from Vault", r.list().len());
657 r
658 }
659 Ok(_) => {
660 tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
661 fallback_registry().await?
662 }
663 Err(e) => {
664 tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
665 fallback_registry().await?
666 }
667 };
668
669 let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
671
672 let catalog_alias = |pid: &str| -> String {
674 match pid {
675 "bedrock" => "amazon-bedrock".to_string(),
676 "novita" => "novita-ai".to_string(),
677 _ => pid.to_string(),
678 }
679 };
680
681 let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
682
683 for provider_name in registry.list() {
684 if let Some(provider) = registry.get(provider_name) {
685 match provider.list_models().await {
686 Ok(models) => {
687 let enriched: Vec<crate::provider::ModelInfo> = models
688 .into_iter()
689 .map(|mut m| {
690 if m.input_cost_per_million.is_none()
692 || m.output_cost_per_million.is_none()
693 {
694 if let Some(ref cat) = catalog {
695 let cat_pid = catalog_alias(provider_name);
696 if let Some(prov_info) = cat.get_provider(&cat_pid) {
697 let model_info =
699 prov_info.models.get(&m.id).or_else(|| {
700 m.id.strip_prefix("us.").and_then(|stripped| {
701 prov_info.models.get(stripped)
702 })
703 });
704 if let Some(model_info) = model_info {
705 if let Some(ref cost) = model_info.cost {
706 if m.input_cost_per_million.is_none() {
707 m.input_cost_per_million = Some(cost.input);
708 }
709 if m.output_cost_per_million.is_none() {
710 m.output_cost_per_million = Some(cost.output);
711 }
712 }
713 }
714 }
715 }
716 }
717 m
718 })
719 .collect();
720 if !enriched.is_empty() {
721 tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
722 models_by_provider.insert(provider_name.to_string(), enriched);
723 }
724 }
725 Err(e) => {
726 tracing::debug!("Failed to list models for {}: {}", provider_name, e);
727 }
728 }
729 }
730 }
731
732 if models_by_provider.is_empty() {
736 tracing::info!(
737 "No authenticated providers found, fetching models.dev catalog (all providers)"
738 );
739 if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
740 for (provider_id, provider_info) in cat.all_providers() {
743 let model_infos: Vec<crate::provider::ModelInfo> = provider_info
744 .models
745 .values()
746 .map(|m| cat.to_model_info(m, provider_id))
747 .collect();
748 if !model_infos.is_empty() {
749 tracing::debug!(
750 "Catalog provider {}: {} models",
751 provider_id,
752 model_infos.len()
753 );
754 models_by_provider.insert(provider_id.clone(), model_infos);
755 }
756 }
757 tracing::info!(
758 "Loaded {} providers with {} total models from catalog",
759 models_by_provider.len(),
760 models_by_provider.values().map(|v| v.len()).sum::<usize>()
761 );
762 }
763 }
764
765 Ok(models_by_provider)
766}
767
768async fn fallback_registry() -> Result<ProviderRegistry> {
770 let config = crate::config::Config::load().await.unwrap_or_default();
771 ProviderRegistry::from_config(&config).await
772}
773
774async fn fetch_pending_tasks(
775 client: &Client,
776 server: &str,
777 token: &Option<String>,
778 worker_id: &str,
779 processing: &Arc<Mutex<HashSet<String>>>,
780 auto_approve: &AutoApprove,
781 bus: &Arc<AgentBus>,
782) -> Result<()> {
783 tracing::info!("Checking for pending tasks...");
784
785 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
786 if let Some(t) = token {
787 req = req.bearer_auth(t);
788 }
789
790 let res = req.send().await?;
791 if !res.status().is_success() {
792 return Ok(());
793 }
794
795 let data: serde_json::Value = res.json().await?;
796 let tasks = if let Some(arr) = data.as_array() {
798 arr.clone()
799 } else {
800 data["tasks"].as_array().cloned().unwrap_or_default()
801 };
802
803 tracing::info!("Found {} pending task(s)", tasks.len());
804
805 for task in tasks {
806 if let Some(id) = task["id"].as_str() {
807 let mut proc = processing.lock().await;
808 if !proc.contains(id) {
809 proc.insert(id.to_string());
810 drop(proc);
811
812 let task_id = id.to_string();
813 let client = client.clone();
814 let server = server.to_string();
815 let token = token.clone();
816 let worker_id = worker_id.to_string();
817 let auto_approve = *auto_approve;
818 let processing = processing.clone();
819 let bus = bus.clone();
820
821 tokio::spawn(async move {
822 if let Err(e) = handle_task(
823 &client,
824 &server,
825 &token,
826 &worker_id,
827 &task,
828 auto_approve,
829 &bus,
830 )
831 .await
832 {
833 tracing::error!("Task {} failed: {}", task_id, e);
834 }
835 processing.lock().await.remove(&task_id);
836 });
837 }
838 }
839 }
840
841 Ok(())
842}
843
844#[allow(clippy::too_many_arguments)]
845async fn connect_stream(
846 client: &Client,
847 server: &str,
848 token: &Option<String>,
849 worker_id: &str,
850 name: &str,
851 codebases: &[String],
852 processing: &Arc<Mutex<HashSet<String>>>,
853 auto_approve: &AutoApprove,
854 bus: &Arc<AgentBus>,
855) -> Result<()> {
856 let url = format!(
857 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
858 server,
859 urlencoding::encode(name),
860 urlencoding::encode(worker_id)
861 );
862
863 let mut req = client
864 .get(&url)
865 .header("Accept", "text/event-stream")
866 .header("X-Worker-ID", worker_id)
867 .header("X-Agent-Name", name)
868 .header("X-Codebases", codebases.join(","));
869
870 if let Some(t) = token {
871 req = req.bearer_auth(t);
872 }
873
874 let res = req.send().await?;
875 if !res.status().is_success() {
876 anyhow::bail!("Failed to connect: {}", res.status());
877 }
878
879 tracing::info!("Connected to A2A server");
880
881 let mut stream = res.bytes_stream();
882 let mut buffer = String::new();
883 let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
884 poll_interval.tick().await; loop {
887 tokio::select! {
888 chunk = stream.next() => {
889 match chunk {
890 Some(Ok(chunk)) => {
891 buffer.push_str(&String::from_utf8_lossy(&chunk));
892
893 while let Some(pos) = buffer.find("\n\n") {
895 let event_str = buffer[..pos].to_string();
896 buffer = buffer[pos + 2..].to_string();
897
898 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
899 let data = data_line.trim_start_matches("data:").trim();
900 if data == "[DONE]" || data.is_empty() {
901 continue;
902 }
903
904 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
905 spawn_task_handler(
906 &task, client, server, token, worker_id,
907 processing, auto_approve, bus,
908 ).await;
909 }
910 }
911 }
912 }
913 Some(Err(e)) => {
914 return Err(e.into());
915 }
916 None => {
917 return Ok(());
919 }
920 }
921 }
922 _ = poll_interval.tick() => {
923 if let Err(e) = poll_pending_tasks(
925 client, server, token, worker_id, processing, auto_approve, bus,
926 ).await {
927 tracing::warn!("Periodic task poll failed: {}", e);
928 }
929 }
930 }
931 }
932}
933
934async fn spawn_task_handler(
935 task: &serde_json::Value,
936 client: &Client,
937 server: &str,
938 token: &Option<String>,
939 worker_id: &str,
940 processing: &Arc<Mutex<HashSet<String>>>,
941 auto_approve: &AutoApprove,
942 bus: &Arc<AgentBus>,
943) {
944 if let Some(id) = task
945 .get("task")
946 .and_then(|t| t["id"].as_str())
947 .or_else(|| task["id"].as_str())
948 {
949 let mut proc = processing.lock().await;
950 if !proc.contains(id) {
951 proc.insert(id.to_string());
952 drop(proc);
953
954 let task_id = id.to_string();
955 let task = task.clone();
956 let client = client.clone();
957 let server = server.to_string();
958 let token = token.clone();
959 let worker_id = worker_id.to_string();
960 let auto_approve = *auto_approve;
961 let processing_clone = processing.clone();
962 let bus = bus.clone();
963
964 tokio::spawn(async move {
965 if let Err(e) = handle_task(
966 &client,
967 &server,
968 &token,
969 &worker_id,
970 &task,
971 auto_approve,
972 &bus,
973 )
974 .await
975 {
976 tracing::error!("Task {} failed: {}", task_id, e);
977 }
978 processing_clone.lock().await.remove(&task_id);
979 });
980 }
981 }
982}
983
984async fn poll_pending_tasks(
985 client: &Client,
986 server: &str,
987 token: &Option<String>,
988 worker_id: &str,
989 processing: &Arc<Mutex<HashSet<String>>>,
990 auto_approve: &AutoApprove,
991 bus: &Arc<AgentBus>,
992) -> Result<()> {
993 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
994 if let Some(t) = token {
995 req = req.bearer_auth(t);
996 }
997
998 let res = req.send().await?;
999 if !res.status().is_success() {
1000 return Ok(());
1001 }
1002
1003 let data: serde_json::Value = res.json().await?;
1004 let tasks = if let Some(arr) = data.as_array() {
1005 arr.clone()
1006 } else {
1007 data["tasks"].as_array().cloned().unwrap_or_default()
1008 };
1009
1010 if !tasks.is_empty() {
1011 tracing::debug!("Poll found {} pending task(s)", tasks.len());
1012 }
1013
1014 for task in &tasks {
1015 spawn_task_handler(
1016 task,
1017 client,
1018 server,
1019 token,
1020 worker_id,
1021 processing,
1022 auto_approve,
1023 bus,
1024 )
1025 .await;
1026 }
1027
1028 Ok(())
1029}
1030
1031async fn handle_task(
1032 client: &Client,
1033 server: &str,
1034 token: &Option<String>,
1035 worker_id: &str,
1036 task: &serde_json::Value,
1037 auto_approve: AutoApprove,
1038 bus: &Arc<AgentBus>,
1039) -> Result<()> {
1040 let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
1041 let title = task_str(task, "title").unwrap_or("Untitled");
1042
1043 tracing::info!("Handling task: {} ({})", title, task_id);
1044
1045 let mut req = client
1047 .post(format!("{}/v1/worker/tasks/claim", server))
1048 .header("X-Worker-ID", worker_id);
1049 if let Some(t) = token {
1050 req = req.bearer_auth(t);
1051 }
1052
1053 let res = req
1054 .json(&serde_json::json!({ "task_id": task_id }))
1055 .send()
1056 .await?;
1057
1058 if !res.status().is_success() {
1059 let status = res.status();
1060 let text = res.text().await?;
1061 if status == reqwest::StatusCode::CONFLICT {
1062 tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1063 } else {
1064 tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1065 }
1066 return Ok(());
1067 }
1068
1069 tracing::info!("Claimed task: {}", task_id);
1070
1071 let metadata = task_metadata(task);
1072 let resume_session_id = metadata
1073 .get("resume_session_id")
1074 .and_then(|v| v.as_str())
1075 .map(|s| s.trim().to_string())
1076 .filter(|s| !s.is_empty());
1077 let complexity_hint = metadata_str(&metadata, &["complexity"]);
1078 let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1079 .map(|s| s.to_ascii_lowercase())
1080 .or_else(|| {
1081 complexity_hint.as_ref().map(|complexity| {
1082 match complexity.to_ascii_lowercase().as_str() {
1083 "quick" => "fast".to_string(),
1084 "deep" => "heavy".to_string(),
1085 _ => "balanced".to_string(),
1086 }
1087 })
1088 });
1089 let worker_personality = metadata_str(
1090 &metadata,
1091 &["worker_personality", "personality", "agent_personality"],
1092 );
1093 let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1094 let raw_model = task_str(task, "model_ref")
1095 .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1096 .or_else(|| task_str(task, "model"))
1097 .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1098 let selected_model = raw_model.map(model_ref_to_provider_model);
1099
1100 let mut session = if let Some(ref sid) = resume_session_id {
1102 match Session::load(sid).await {
1103 Ok(existing) => {
1104 tracing::info!("Resuming session {} for task {}", sid, task_id);
1105 existing
1106 }
1107 Err(e) => {
1108 tracing::warn!(
1109 "Could not load session {} for task {} ({}), starting a new session",
1110 sid,
1111 task_id,
1112 e
1113 );
1114 Session::new().await?
1115 }
1116 }
1117 } else {
1118 Session::new().await?
1119 };
1120
1121 let agent_type = task_str(task, "agent_type")
1122 .or_else(|| task_str(task, "agent"))
1123 .unwrap_or("build");
1124 session.agent = agent_type.to_string();
1125
1126 if let Some(model) = selected_model.clone() {
1127 session.metadata.model = Some(model);
1128 }
1129
1130 let prompt = task_str(task, "prompt")
1131 .or_else(|| task_str(task, "description"))
1132 .unwrap_or(title);
1133
1134 tracing::info!("Executing prompt: {}", prompt);
1135
1136 let stream_client = client.clone();
1138 let stream_server = server.to_string();
1139 let stream_token = token.clone();
1140 let stream_worker_id = worker_id.to_string();
1141 let stream_task_id = task_id.to_string();
1142
1143 let output_callback = move |output: String| {
1144 let c = stream_client.clone();
1145 let s = stream_server.clone();
1146 let t = stream_token.clone();
1147 let w = stream_worker_id.clone();
1148 let tid = stream_task_id.clone();
1149 tokio::spawn(async move {
1150 let mut req = c
1151 .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1152 .header("X-Worker-ID", &w);
1153 if let Some(tok) = &t {
1154 req = req.bearer_auth(tok);
1155 }
1156 let _ = req
1157 .json(&serde_json::json!({
1158 "worker_id": w,
1159 "output": output,
1160 }))
1161 .send()
1162 .await;
1163 });
1164 };
1165
1166 let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1168 match execute_swarm_with_policy(
1169 &mut session,
1170 prompt,
1171 model_tier.as_deref(),
1172 selected_model,
1173 &metadata,
1174 complexity_hint.as_deref(),
1175 worker_personality.as_deref(),
1176 target_agent_name.as_deref(),
1177 Some(bus),
1178 Some(&output_callback),
1179 )
1180 .await
1181 {
1182 Ok((session_result, true)) => {
1183 tracing::info!("Swarm task completed successfully: {}", task_id);
1184 (
1185 "completed",
1186 Some(session_result.text),
1187 None,
1188 Some(session_result.session_id),
1189 )
1190 }
1191 Ok((session_result, false)) => {
1192 tracing::warn!("Swarm task completed with failures: {}", task_id);
1193 (
1194 "failed",
1195 Some(session_result.text),
1196 Some("Swarm execution completed with failures".to_string()),
1197 Some(session_result.session_id),
1198 )
1199 }
1200 Err(e) => {
1201 tracing::error!("Swarm task failed: {} - {}", task_id, e);
1202 ("failed", None, Some(format!("Error: {}", e)), None)
1203 }
1204 }
1205 } else {
1206 match execute_session_with_policy(
1207 &mut session,
1208 prompt,
1209 auto_approve,
1210 model_tier.as_deref(),
1211 Some(&output_callback),
1212 )
1213 .await
1214 {
1215 Ok(session_result) => {
1216 tracing::info!("Task completed successfully: {}", task_id);
1217 (
1218 "completed",
1219 Some(session_result.text),
1220 None,
1221 Some(session_result.session_id),
1222 )
1223 }
1224 Err(e) => {
1225 tracing::error!("Task failed: {} - {}", task_id, e);
1226 ("failed", None, Some(format!("Error: {}", e)), None)
1227 }
1228 }
1229 };
1230
1231 let mut req = client
1233 .post(format!("{}/v1/worker/tasks/release", server))
1234 .header("X-Worker-ID", worker_id);
1235 if let Some(t) = token {
1236 req = req.bearer_auth(t);
1237 }
1238
1239 req.json(&serde_json::json!({
1240 "task_id": task_id,
1241 "status": status,
1242 "result": result,
1243 "error": error,
1244 "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1245 }))
1246 .send()
1247 .await?;
1248
1249 tracing::info!("Task released: {} with status: {}", task_id, status);
1250 Ok(())
1251}
1252
1253async fn execute_swarm_with_policy<F>(
1254 session: &mut Session,
1255 prompt: &str,
1256 model_tier: Option<&str>,
1257 explicit_model: Option<String>,
1258 metadata: &serde_json::Map<String, serde_json::Value>,
1259 complexity_hint: Option<&str>,
1260 worker_personality: Option<&str>,
1261 target_agent_name: Option<&str>,
1262 bus: Option<&Arc<AgentBus>>,
1263 output_callback: Option<&F>,
1264) -> Result<(crate::session::SessionResult, bool)>
1265where
1266 F: Fn(String),
1267{
1268 use crate::provider::{ContentPart, Message, Role};
1269
1270 session.add_message(Message {
1271 role: Role::User,
1272 content: vec![ContentPart::Text {
1273 text: prompt.to_string(),
1274 }],
1275 });
1276
1277 if session.title.is_none() {
1278 session.generate_title().await?;
1279 }
1280
1281 let strategy = parse_swarm_strategy(metadata);
1282 let max_subagents = metadata_usize(
1283 metadata,
1284 &["swarm_max_subagents", "max_subagents", "subagents"],
1285 )
1286 .unwrap_or(10)
1287 .clamp(1, 100);
1288 let max_steps_per_subagent = metadata_usize(
1289 metadata,
1290 &[
1291 "swarm_max_steps_per_subagent",
1292 "max_steps_per_subagent",
1293 "max_steps",
1294 ],
1295 )
1296 .unwrap_or(50)
1297 .clamp(1, 200);
1298 let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1299 .unwrap_or(600)
1300 .clamp(30, 3600);
1301 let parallel_enabled =
1302 metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1303
1304 let model = resolve_swarm_model(explicit_model, model_tier).await;
1305 if let Some(ref selected_model) = model {
1306 session.metadata.model = Some(selected_model.clone());
1307 }
1308
1309 if let Some(cb) = output_callback {
1310 cb(format!(
1311 "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1312 complexity_hint.unwrap_or("standard"),
1313 model_tier.unwrap_or("balanced"),
1314 worker_personality.unwrap_or("auto"),
1315 target_agent_name.unwrap_or("auto")
1316 ));
1317 cb(format!(
1318 "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1319 strategy,
1320 max_subagents,
1321 max_steps_per_subagent,
1322 timeout_secs,
1323 model_tier.unwrap_or("balanced")
1324 ));
1325 }
1326
1327 let swarm_config = SwarmConfig {
1328 max_subagents,
1329 max_steps_per_subagent,
1330 subagent_timeout_secs: timeout_secs,
1331 parallel_enabled,
1332 model,
1333 working_dir: session
1334 .metadata
1335 .directory
1336 .as_ref()
1337 .map(|p| p.to_string_lossy().to_string()),
1338 ..Default::default()
1339 };
1340
1341 let swarm_result = if output_callback.is_some() {
1342 let (event_tx, mut event_rx) = mpsc::channel(256);
1343 let mut executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1344 if let Some(bus) = bus {
1345 executor = executor.with_bus(Arc::clone(bus));
1346 }
1347 let prompt_owned = prompt.to_string();
1348 let mut exec_handle =
1349 tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1350
1351 let mut final_result: Option<crate::swarm::SwarmResult> = None;
1352
1353 while final_result.is_none() {
1354 tokio::select! {
1355 maybe_event = event_rx.recv() => {
1356 if let Some(event) = maybe_event {
1357 if let Some(cb) = output_callback {
1358 if let Some(line) = format_swarm_event_for_output(&event) {
1359 cb(line);
1360 }
1361 }
1362 }
1363 }
1364 join_result = &mut exec_handle => {
1365 let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1366 final_result = Some(joined?);
1367 }
1368 }
1369 }
1370
1371 while let Ok(event) = event_rx.try_recv() {
1372 if let Some(cb) = output_callback {
1373 if let Some(line) = format_swarm_event_for_output(&event) {
1374 cb(line);
1375 }
1376 }
1377 }
1378
1379 final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1380 } else {
1381 let mut executor = SwarmExecutor::new(swarm_config);
1382 if let Some(bus) = bus {
1383 executor = executor.with_bus(Arc::clone(bus));
1384 }
1385 executor.execute(prompt, strategy).await?
1386 };
1387
1388 let final_text = if swarm_result.result.trim().is_empty() {
1389 if swarm_result.success {
1390 "Swarm completed without textual output.".to_string()
1391 } else {
1392 "Swarm finished with failures and no textual output.".to_string()
1393 }
1394 } else {
1395 swarm_result.result.clone()
1396 };
1397
1398 session.add_message(Message {
1399 role: Role::Assistant,
1400 content: vec![ContentPart::Text {
1401 text: final_text.clone(),
1402 }],
1403 });
1404 session.save().await?;
1405
1406 Ok((
1407 crate::session::SessionResult {
1408 text: final_text,
1409 session_id: session.id.clone(),
1410 },
1411 swarm_result.success,
1412 ))
1413}
1414
1415async fn execute_session_with_policy<F>(
1418 session: &mut Session,
1419 prompt: &str,
1420 auto_approve: AutoApprove,
1421 model_tier: Option<&str>,
1422 output_callback: Option<&F>,
1423) -> Result<crate::session::SessionResult>
1424where
1425 F: Fn(String),
1426{
1427 use crate::provider::{
1428 CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1429 };
1430 use std::sync::Arc;
1431
1432 let registry = ProviderRegistry::from_vault().await?;
1434 let providers = registry.list();
1435 tracing::info!("Available providers: {:?}", providers);
1436
1437 if providers.is_empty() {
1438 anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1439 }
1440
1441 let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1443 let (prov, model) = parse_model_string(model_str);
1444 let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
1445 if prov.is_some() {
1446 (prov.map(|s| s.to_string()), model.to_string())
1447 } else if providers.contains(&model) {
1448 (Some(model.to_string()), String::new())
1449 } else {
1450 (None, model.to_string())
1451 }
1452 } else {
1453 (None, String::new())
1454 };
1455
1456 let provider_slice = providers.as_slice();
1457 let provider_requested_but_unavailable = provider_name
1458 .as_deref()
1459 .map(|p| !providers.contains(&p))
1460 .unwrap_or(false);
1461
1462 let selected_provider = provider_name
1464 .as_deref()
1465 .filter(|p| providers.contains(p))
1466 .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1467
1468 let provider = registry
1469 .get(selected_provider)
1470 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1471
1472 session.add_message(Message {
1474 role: Role::User,
1475 content: vec![ContentPart::Text {
1476 text: prompt.to_string(),
1477 }],
1478 });
1479
1480 if session.title.is_none() {
1482 session.generate_title().await?;
1483 }
1484
1485 let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1488 model_id
1489 } else {
1490 default_model_for_provider(selected_provider, model_tier)
1491 };
1492
1493 let tool_registry =
1495 create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1496 let tool_definitions = tool_registry.definitions();
1497
1498 let temperature = if prefers_temperature_one(&model) {
1499 Some(1.0)
1500 } else {
1501 Some(0.7)
1502 };
1503
1504 tracing::info!(
1505 "Using model: {} via provider: {} (tier: {:?})",
1506 model,
1507 selected_provider,
1508 model_tier
1509 );
1510 tracing::info!(
1511 "Available tools: {} (auto_approve: {:?})",
1512 tool_definitions.len(),
1513 auto_approve
1514 );
1515
1516 let cwd = std::env::var("PWD")
1518 .map(std::path::PathBuf::from)
1519 .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1520 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1521
1522 let mut final_output = String::new();
1523 let max_steps = 50;
1524
1525 for step in 1..=max_steps {
1526 tracing::info!(step = step, "Agent step starting");
1527
1528 let mut messages = vec![Message {
1530 role: Role::System,
1531 content: vec![ContentPart::Text {
1532 text: system_prompt.clone(),
1533 }],
1534 }];
1535 messages.extend(session.messages.clone());
1536
1537 let request = CompletionRequest {
1538 messages,
1539 tools: tool_definitions.clone(),
1540 model: model.clone(),
1541 temperature,
1542 top_p: None,
1543 max_tokens: Some(8192),
1544 stop: Vec::new(),
1545 };
1546
1547 let response = provider.complete(request).await?;
1548
1549 crate::telemetry::TOKEN_USAGE.record_model_usage(
1550 &model,
1551 response.usage.prompt_tokens as u64,
1552 response.usage.completion_tokens as u64,
1553 );
1554
1555 let tool_calls: Vec<(String, String, serde_json::Value)> = response
1557 .message
1558 .content
1559 .iter()
1560 .filter_map(|part| {
1561 if let ContentPart::ToolCall {
1562 id,
1563 name,
1564 arguments,
1565 } = part
1566 {
1567 let args: serde_json::Value =
1568 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1569 Some((id.clone(), name.clone(), args))
1570 } else {
1571 None
1572 }
1573 })
1574 .collect();
1575
1576 for part in &response.message.content {
1578 if let ContentPart::Text { text } = part {
1579 if !text.is_empty() {
1580 final_output.push_str(text);
1581 final_output.push('\n');
1582 if let Some(cb) = output_callback {
1583 cb(text.clone());
1584 }
1585 }
1586 }
1587 }
1588
1589 if tool_calls.is_empty() {
1591 session.add_message(response.message.clone());
1592 break;
1593 }
1594
1595 session.add_message(response.message.clone());
1596
1597 tracing::info!(
1598 step = step,
1599 num_tools = tool_calls.len(),
1600 "Executing tool calls"
1601 );
1602
1603 for (tool_id, tool_name, tool_input) in tool_calls {
1605 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1606
1607 if let Some(cb) = output_callback {
1609 cb(format!("[tool:start:{}]", tool_name));
1610 }
1611
1612 if !is_tool_allowed(&tool_name, auto_approve) {
1614 let msg = format!(
1615 "Tool '{}' requires approval but auto-approve policy is {:?}",
1616 tool_name, auto_approve
1617 );
1618 tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1619 session.add_message(Message {
1620 role: Role::Tool,
1621 content: vec![ContentPart::ToolResult {
1622 tool_call_id: tool_id,
1623 content: msg,
1624 }],
1625 });
1626 continue;
1627 }
1628
1629 let content = if let Some(tool) = tool_registry.get(&tool_name) {
1630 let exec_result: Result<crate::tool::ToolResult> =
1631 tool.execute(tool_input.clone()).await;
1632 match exec_result {
1633 Ok(result) => {
1634 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1635 if let Some(cb) = output_callback {
1636 let status = if result.success { "ok" } else { "err" };
1637 cb(format!(
1638 "[tool:{}:{}] {}",
1639 tool_name,
1640 status,
1641 &result.output[..result.output.len().min(500)]
1642 ));
1643 }
1644 result.output
1645 }
1646 Err(e) => {
1647 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1648 if let Some(cb) = output_callback {
1649 cb(format!("[tool:{}:err] {}", tool_name, e));
1650 }
1651 format!("Error: {}", e)
1652 }
1653 }
1654 } else {
1655 tracing::warn!(tool = %tool_name, "Tool not found");
1656 format!("Error: Unknown tool '{}'", tool_name)
1657 };
1658
1659 session.add_message(Message {
1660 role: Role::Tool,
1661 content: vec![ContentPart::ToolResult {
1662 tool_call_id: tool_id,
1663 content,
1664 }],
1665 });
1666 }
1667 }
1668
1669 session.save().await?;
1670
1671 Ok(crate::session::SessionResult {
1672 text: final_output.trim().to_string(),
1673 session_id: session.id.clone(),
1674 })
1675}
1676
1677fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1679 match auto_approve {
1680 AutoApprove::All => true,
1681 AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1682 }
1683}
1684
1685fn is_safe_tool(tool_name: &str) -> bool {
1687 let safe_tools = [
1688 "read",
1689 "list",
1690 "glob",
1691 "grep",
1692 "codesearch",
1693 "lsp",
1694 "webfetch",
1695 "websearch",
1696 "todo_read",
1697 "skill",
1698 ];
1699 safe_tools.contains(&tool_name)
1700}
1701
1702fn create_filtered_registry(
1704 provider: Arc<dyn crate::provider::Provider>,
1705 model: String,
1706 auto_approve: AutoApprove,
1707) -> crate::tool::ToolRegistry {
1708 use crate::tool::*;
1709
1710 let mut registry = ToolRegistry::new();
1711
1712 registry.register(Arc::new(file::ReadTool::new()));
1714 registry.register(Arc::new(file::ListTool::new()));
1715 registry.register(Arc::new(file::GlobTool::new()));
1716 registry.register(Arc::new(search::GrepTool::new()));
1717 registry.register(Arc::new(lsp::LspTool::new()));
1718 registry.register(Arc::new(webfetch::WebFetchTool::new()));
1719 registry.register(Arc::new(websearch::WebSearchTool::new()));
1720 registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1721 registry.register(Arc::new(todo::TodoReadTool::new()));
1722 registry.register(Arc::new(skill::SkillTool::new()));
1723
1724 if matches!(auto_approve, AutoApprove::All) {
1726 registry.register(Arc::new(file::WriteTool::new()));
1727 registry.register(Arc::new(edit::EditTool::new()));
1728 registry.register(Arc::new(bash::BashTool::new()));
1729 registry.register(Arc::new(multiedit::MultiEditTool::new()));
1730 registry.register(Arc::new(patch::ApplyPatchTool::new()));
1731 registry.register(Arc::new(todo::TodoWriteTool::new()));
1732 registry.register(Arc::new(task::TaskTool::new()));
1733 registry.register(Arc::new(plan::PlanEnterTool::new()));
1734 registry.register(Arc::new(plan::PlanExitTool::new()));
1735 registry.register(Arc::new(rlm::RlmTool::new()));
1736 registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1737 registry.register(Arc::new(prd::PrdTool::new()));
1738 registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1739 registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1740 registry.register(Arc::new(undo::UndoTool));
1741 registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1742 }
1743
1744 registry.register(Arc::new(invalid::InvalidTool::new()));
1745
1746 registry
1747}
1748
1749fn start_heartbeat(
1752 client: Client,
1753 server: String,
1754 token: Option<String>,
1755 heartbeat_state: HeartbeatState,
1756 processing: Arc<Mutex<HashSet<String>>>,
1757 cognition_config: CognitionHeartbeatConfig,
1758) -> JoinHandle<()> {
1759 tokio::spawn(async move {
1760 let mut consecutive_failures = 0u32;
1761 const MAX_FAILURES: u32 = 3;
1762 const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1763 const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1764 let mut cognition_payload_disabled_until: Option<Instant> = None;
1765
1766 let mut interval =
1767 tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1768 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1769
1770 loop {
1771 interval.tick().await;
1772
1773 let active_count = processing.lock().await.len();
1775 heartbeat_state.set_task_count(active_count).await;
1776
1777 let status = if active_count > 0 {
1779 WorkerStatus::Processing
1780 } else {
1781 WorkerStatus::Idle
1782 };
1783 heartbeat_state.set_status(status).await;
1784
1785 let url = format!(
1787 "{}/v1/opencode/workers/{}/heartbeat",
1788 server, heartbeat_state.worker_id
1789 );
1790 let mut req = client.post(&url);
1791
1792 if let Some(ref t) = token {
1793 req = req.bearer_auth(t);
1794 }
1795
1796 let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1797 let base_payload = serde_json::json!({
1798 "worker_id": &heartbeat_state.worker_id,
1799 "agent_name": &heartbeat_state.agent_name,
1800 "status": status_str,
1801 "active_task_count": active_count,
1802 });
1803 let mut payload = base_payload.clone();
1804 let mut included_cognition_payload = false;
1805 let cognition_payload_allowed = cognition_payload_disabled_until
1806 .map(|until| Instant::now() >= until)
1807 .unwrap_or(true);
1808
1809 if cognition_config.enabled
1810 && cognition_payload_allowed
1811 && let Some(cognition_payload) =
1812 fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1813 && let Some(obj) = payload.as_object_mut()
1814 {
1815 obj.insert("cognition".to_string(), cognition_payload);
1816 included_cognition_payload = true;
1817 }
1818
1819 match req.json(&payload).send().await {
1820 Ok(res) => {
1821 if res.status().is_success() {
1822 consecutive_failures = 0;
1823 tracing::debug!(
1824 worker_id = %heartbeat_state.worker_id,
1825 status = status_str,
1826 active_tasks = active_count,
1827 "Heartbeat sent successfully"
1828 );
1829 } else if included_cognition_payload && res.status().is_client_error() {
1830 tracing::warn!(
1831 worker_id = %heartbeat_state.worker_id,
1832 status = %res.status(),
1833 "Heartbeat cognition payload rejected, retrying without cognition payload"
1834 );
1835
1836 let mut retry_req = client.post(&url);
1837 if let Some(ref t) = token {
1838 retry_req = retry_req.bearer_auth(t);
1839 }
1840
1841 match retry_req.json(&base_payload).send().await {
1842 Ok(retry_res) if retry_res.status().is_success() => {
1843 cognition_payload_disabled_until = Some(
1844 Instant::now()
1845 + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1846 );
1847 consecutive_failures = 0;
1848 tracing::warn!(
1849 worker_id = %heartbeat_state.worker_id,
1850 retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1851 "Paused cognition heartbeat payload after schema rejection"
1852 );
1853 }
1854 Ok(retry_res) => {
1855 consecutive_failures += 1;
1856 tracing::warn!(
1857 worker_id = %heartbeat_state.worker_id,
1858 status = %retry_res.status(),
1859 failures = consecutive_failures,
1860 "Heartbeat failed even after retry without cognition payload"
1861 );
1862 }
1863 Err(e) => {
1864 consecutive_failures += 1;
1865 tracing::warn!(
1866 worker_id = %heartbeat_state.worker_id,
1867 error = %e,
1868 failures = consecutive_failures,
1869 "Heartbeat retry without cognition payload failed"
1870 );
1871 }
1872 }
1873 } else {
1874 consecutive_failures += 1;
1875 tracing::warn!(
1876 worker_id = %heartbeat_state.worker_id,
1877 status = %res.status(),
1878 failures = consecutive_failures,
1879 "Heartbeat failed"
1880 );
1881 }
1882 }
1883 Err(e) => {
1884 consecutive_failures += 1;
1885 tracing::warn!(
1886 worker_id = %heartbeat_state.worker_id,
1887 error = %e,
1888 failures = consecutive_failures,
1889 "Heartbeat request failed"
1890 );
1891 }
1892 }
1893
1894 if consecutive_failures >= MAX_FAILURES {
1896 tracing::error!(
1897 worker_id = %heartbeat_state.worker_id,
1898 failures = consecutive_failures,
1899 "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1900 MAX_FAILURES
1901 );
1902 consecutive_failures = 0;
1904 }
1905 }
1906 })
1907}
1908
1909async fn fetch_cognition_heartbeat_payload(
1910 client: &Client,
1911 config: &CognitionHeartbeatConfig,
1912) -> Option<serde_json::Value> {
1913 let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1914 let status_res = tokio::time::timeout(
1915 Duration::from_millis(config.request_timeout_ms),
1916 client.get(status_url).send(),
1917 )
1918 .await
1919 .ok()?
1920 .ok()?;
1921
1922 if !status_res.status().is_success() {
1923 return None;
1924 }
1925
1926 let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1927 let mut payload = serde_json::json!({
1928 "running": status.running,
1929 "last_tick_at": status.last_tick_at,
1930 "active_persona_count": status.active_persona_count,
1931 "events_buffered": status.events_buffered,
1932 "snapshots_buffered": status.snapshots_buffered,
1933 "loop_interval_ms": status.loop_interval_ms,
1934 });
1935
1936 if config.include_thought_summary {
1937 let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1938 let snapshot_res = tokio::time::timeout(
1939 Duration::from_millis(config.request_timeout_ms),
1940 client.get(snapshot_url).send(),
1941 )
1942 .await
1943 .ok()
1944 .and_then(Result::ok);
1945
1946 if let Some(snapshot_res) = snapshot_res
1947 && snapshot_res.status().is_success()
1948 && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1949 && let Some(obj) = payload.as_object_mut()
1950 {
1951 obj.insert(
1952 "latest_snapshot_at".to_string(),
1953 serde_json::Value::String(snapshot.generated_at),
1954 );
1955 obj.insert(
1956 "latest_thought".to_string(),
1957 serde_json::Value::String(trim_for_heartbeat(
1958 &snapshot.summary,
1959 config.summary_max_chars,
1960 )),
1961 );
1962 if let Some(model) = snapshot
1963 .metadata
1964 .get("model")
1965 .and_then(serde_json::Value::as_str)
1966 {
1967 obj.insert(
1968 "latest_thought_model".to_string(),
1969 serde_json::Value::String(model.to_string()),
1970 );
1971 }
1972 if let Some(source) = snapshot
1973 .metadata
1974 .get("source")
1975 .and_then(serde_json::Value::as_str)
1976 {
1977 obj.insert(
1978 "latest_thought_source".to_string(),
1979 serde_json::Value::String(source.to_string()),
1980 );
1981 }
1982 }
1983 }
1984
1985 Some(payload)
1986}
1987
1988fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1989 if input.chars().count() <= max_chars {
1990 return input.trim().to_string();
1991 }
1992
1993 let mut trimmed = String::with_capacity(max_chars + 3);
1994 for ch in input.chars().take(max_chars) {
1995 trimmed.push(ch);
1996 }
1997 trimmed.push_str("...");
1998 trimmed.trim().to_string()
1999}
2000
2001fn env_bool(name: &str, default: bool) -> bool {
2002 std::env::var(name)
2003 .ok()
2004 .and_then(|v| match v.to_ascii_lowercase().as_str() {
2005 "1" | "true" | "yes" | "on" => Some(true),
2006 "0" | "false" | "no" | "off" => Some(false),
2007 _ => None,
2008 })
2009 .unwrap_or(default)
2010}
2011
2012fn env_usize(name: &str, default: usize) -> usize {
2013 std::env::var(name)
2014 .ok()
2015 .and_then(|v| v.parse::<usize>().ok())
2016 .unwrap_or(default)
2017}
2018
2019fn env_u64(name: &str, default: u64) -> u64 {
2020 std::env::var(name)
2021 .ok()
2022 .and_then(|v| v.parse::<u64>().ok())
2023 .unwrap_or(default)
2024}