1use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{Mutex, mpsc};
17use tokio::task::JoinHandle;
18use tokio::time::Instant;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum WorkerStatus {
23 Idle,
24 Processing,
25}
26
27impl WorkerStatus {
28 fn as_str(&self) -> &'static str {
29 match self {
30 WorkerStatus::Idle => "idle",
31 WorkerStatus::Processing => "processing",
32 }
33 }
34}
35
36#[derive(Clone)]
38struct HeartbeatState {
39 worker_id: String,
40 agent_name: String,
41 status: Arc<Mutex<WorkerStatus>>,
42 active_task_count: Arc<Mutex<usize>>,
43}
44
45impl HeartbeatState {
46 fn new(worker_id: String, agent_name: String) -> Self {
47 Self {
48 worker_id,
49 agent_name,
50 status: Arc::new(Mutex::new(WorkerStatus::Idle)),
51 active_task_count: Arc::new(Mutex::new(0)),
52 }
53 }
54
55 async fn set_status(&self, status: WorkerStatus) {
56 *self.status.lock().await = status;
57 }
58
59 async fn set_task_count(&self, count: usize) {
60 *self.active_task_count.lock().await = count;
61 }
62}
63
64#[derive(Clone, Debug)]
65struct CognitionHeartbeatConfig {
66 enabled: bool,
67 source_base_url: String,
68 include_thought_summary: bool,
69 summary_max_chars: usize,
70 request_timeout_ms: u64,
71}
72
73impl CognitionHeartbeatConfig {
74 fn from_env() -> Self {
75 let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
76 .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
77 .trim_end_matches('/')
78 .to_string();
79
80 Self {
81 enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
82 source_base_url,
83 include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
84 summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
85 .max(120),
86 request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
87 }
88 }
89}
90
91#[derive(Debug, Deserialize)]
92struct CognitionStatusSnapshot {
93 running: bool,
94 #[serde(default)]
95 last_tick_at: Option<String>,
96 #[serde(default)]
97 active_persona_count: usize,
98 #[serde(default)]
99 events_buffered: usize,
100 #[serde(default)]
101 snapshots_buffered: usize,
102 #[serde(default)]
103 loop_interval_ms: u64,
104}
105
106#[derive(Debug, Deserialize)]
107struct CognitionLatestSnapshot {
108 generated_at: String,
109 summary: String,
110 #[serde(default)]
111 metadata: HashMap<String, serde_json::Value>,
112}
113
114pub async fn run(args: A2aArgs) -> Result<()> {
116 let server = args.server.trim_end_matches('/');
117 let name = args
118 .name
119 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
120 let worker_id = generate_worker_id();
121
122 let codebases: Vec<String> = args
123 .codebases
124 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
125 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
126
127 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
128 tracing::info!("Server: {}", server);
129 tracing::info!("Codebases: {:?}", codebases);
130
131 let client = Client::new();
132 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
133 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
134 if cognition_heartbeat.enabled {
135 tracing::info!(
136 source = %cognition_heartbeat.source_base_url,
137 include_thoughts = cognition_heartbeat.include_thought_summary,
138 max_chars = cognition_heartbeat.summary_max_chars,
139 timeout_ms = cognition_heartbeat.request_timeout_ms,
140 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
141 );
142 } else {
143 tracing::warn!(
144 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
145 );
146 }
147
148 let auto_approve = match args.auto_approve.as_str() {
149 "all" => AutoApprove::All,
150 "safe" => AutoApprove::Safe,
151 _ => AutoApprove::None,
152 };
153
154 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
156
157 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
159
160 fetch_pending_tasks(
162 &client,
163 server,
164 &args.token,
165 &worker_id,
166 &processing,
167 &auto_approve,
168 )
169 .await?;
170
171 loop {
173 if let Err(e) =
175 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
176 {
177 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
178 }
179
180 let heartbeat_handle = start_heartbeat(
182 client.clone(),
183 server.to_string(),
184 args.token.clone(),
185 heartbeat_state.clone(),
186 processing.clone(),
187 cognition_heartbeat.clone(),
188 );
189
190 match connect_stream(
191 &client,
192 server,
193 &args.token,
194 &worker_id,
195 &name,
196 &codebases,
197 &processing,
198 &auto_approve,
199 )
200 .await
201 {
202 Ok(()) => {
203 tracing::warn!("Stream ended, reconnecting...");
204 }
205 Err(e) => {
206 tracing::error!("Stream error: {}, reconnecting...", e);
207 }
208 }
209
210 heartbeat_handle.abort();
212 tracing::debug!("Heartbeat cancelled for reconnection");
213
214 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
215 }
216}
217
218fn generate_worker_id() -> String {
219 format!(
220 "wrk_{}_{:x}",
221 chrono::Utc::now().timestamp(),
222 rand::random::<u64>()
223 )
224}
225
226#[derive(Debug, Clone, Copy)]
227enum AutoApprove {
228 All,
229 Safe,
230 None,
231}
232
233pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
235
236const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
238
239fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
240 task.get("task")
241 .and_then(|t| t.get(key))
242 .or_else(|| task.get(key))
243}
244
245fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
246 task_value(task, key).and_then(|v| v.as_str())
247}
248
249fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
250 task_value(task, "metadata")
251 .and_then(|m| m.as_object())
252 .cloned()
253 .unwrap_or_default()
254}
255
256fn model_ref_to_provider_model(model: &str) -> String {
257 if !model.contains('/') && model.contains(':') {
261 model.replacen(':', "/", 1)
262 } else {
263 model.to_string()
264 }
265}
266
267fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
268 match model_tier.unwrap_or("balanced") {
269 "fast" | "quick" => &[
270 "openai",
271 "github-copilot",
272 "moonshotai",
273 "zhipuai",
274 "openrouter",
275 "novita",
276 "google",
277 "anthropic",
278 ],
279 "heavy" | "deep" => &[
280 "anthropic",
281 "openai",
282 "github-copilot",
283 "moonshotai",
284 "zhipuai",
285 "openrouter",
286 "novita",
287 "google",
288 ],
289 _ => &[
290 "openai",
291 "github-copilot",
292 "anthropic",
293 "moonshotai",
294 "zhipuai",
295 "openrouter",
296 "novita",
297 "google",
298 ],
299 }
300}
301
302fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
303 for preferred in provider_preferences_for_tier(model_tier) {
304 if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
305 return found;
306 }
307 }
308 if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
309 return found;
310 }
311 providers[0]
312}
313
314fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
315 match model_tier.unwrap_or("balanced") {
316 "fast" | "quick" => match provider {
317 "moonshotai" => "kimi-k2.5".to_string(),
318 "anthropic" => "claude-haiku-4-5".to_string(),
319 "openai" => "gpt-4o-mini".to_string(),
320 "google" => "gemini-2.5-flash".to_string(),
321 "zhipuai" => "glm-4.7".to_string(),
322 "openrouter" => "z-ai/glm-4.7".to_string(),
323 "novita" => "qwen/qwen3-coder-next".to_string(),
324 _ => "glm-4.7".to_string(),
325 },
326 "heavy" | "deep" => match provider {
327 "moonshotai" => "kimi-k2.5".to_string(),
328 "anthropic" => "claude-sonnet-4-20250514".to_string(),
329 "openai" => "o3".to_string(),
330 "google" => "gemini-2.5-pro".to_string(),
331 "zhipuai" => "glm-4.7".to_string(),
332 "openrouter" => "z-ai/glm-4.7".to_string(),
333 "novita" => "qwen/qwen3-coder-next".to_string(),
334 _ => "glm-4.7".to_string(),
335 },
336 _ => match provider {
337 "moonshotai" => "kimi-k2.5".to_string(),
338 "anthropic" => "claude-sonnet-4-20250514".to_string(),
339 "openai" => "gpt-4o".to_string(),
340 "google" => "gemini-2.5-pro".to_string(),
341 "zhipuai" => "glm-4.7".to_string(),
342 "openrouter" => "z-ai/glm-4.7".to_string(),
343 "novita" => "qwen/qwen3-coder-next".to_string(),
344 _ => "glm-4.7".to_string(),
345 },
346 }
347}
348
349fn is_swarm_agent(agent_type: &str) -> bool {
350 matches!(
351 agent_type.trim().to_ascii_lowercase().as_str(),
352 "swarm" | "parallel" | "multi-agent"
353 )
354}
355
356fn metadata_lookup<'a>(
357 metadata: &'a serde_json::Map<String, serde_json::Value>,
358 key: &str,
359) -> Option<&'a serde_json::Value> {
360 metadata
361 .get(key)
362 .or_else(|| {
363 metadata
364 .get("routing")
365 .and_then(|v| v.as_object())
366 .and_then(|obj| obj.get(key))
367 })
368 .or_else(|| {
369 metadata
370 .get("swarm")
371 .and_then(|v| v.as_object())
372 .and_then(|obj| obj.get(key))
373 })
374}
375
376fn metadata_str(
377 metadata: &serde_json::Map<String, serde_json::Value>,
378 keys: &[&str],
379) -> Option<String> {
380 for key in keys {
381 if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
382 let trimmed = value.trim();
383 if !trimmed.is_empty() {
384 return Some(trimmed.to_string());
385 }
386 }
387 }
388 None
389}
390
391fn metadata_usize(
392 metadata: &serde_json::Map<String, serde_json::Value>,
393 keys: &[&str],
394) -> Option<usize> {
395 for key in keys {
396 if let Some(value) = metadata_lookup(metadata, key) {
397 if let Some(v) = value.as_u64() {
398 return usize::try_from(v).ok();
399 }
400 if let Some(v) = value.as_i64() {
401 if v >= 0 {
402 return usize::try_from(v as u64).ok();
403 }
404 }
405 if let Some(v) = value.as_str() {
406 if let Ok(parsed) = v.trim().parse::<usize>() {
407 return Some(parsed);
408 }
409 }
410 }
411 }
412 None
413}
414
415fn metadata_u64(
416 metadata: &serde_json::Map<String, serde_json::Value>,
417 keys: &[&str],
418) -> Option<u64> {
419 for key in keys {
420 if let Some(value) = metadata_lookup(metadata, key) {
421 if let Some(v) = value.as_u64() {
422 return Some(v);
423 }
424 if let Some(v) = value.as_i64() {
425 if v >= 0 {
426 return Some(v as u64);
427 }
428 }
429 if let Some(v) = value.as_str() {
430 if let Ok(parsed) = v.trim().parse::<u64>() {
431 return Some(parsed);
432 }
433 }
434 }
435 }
436 None
437}
438
439fn metadata_bool(
440 metadata: &serde_json::Map<String, serde_json::Value>,
441 keys: &[&str],
442) -> Option<bool> {
443 for key in keys {
444 if let Some(value) = metadata_lookup(metadata, key) {
445 if let Some(v) = value.as_bool() {
446 return Some(v);
447 }
448 if let Some(v) = value.as_str() {
449 match v.trim().to_ascii_lowercase().as_str() {
450 "1" | "true" | "yes" | "on" => return Some(true),
451 "0" | "false" | "no" | "off" => return Some(false),
452 _ => {}
453 }
454 }
455 }
456 }
457 None
458}
459
460fn parse_swarm_strategy(
461 metadata: &serde_json::Map<String, serde_json::Value>,
462) -> DecompositionStrategy {
463 match metadata_str(
464 metadata,
465 &[
466 "decomposition_strategy",
467 "swarm_strategy",
468 "strategy",
469 "swarm_decomposition",
470 ],
471 )
472 .as_deref()
473 .map(|s| s.to_ascii_lowercase())
474 .as_deref()
475 {
476 Some("none") | Some("single") => DecompositionStrategy::None,
477 Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
478 Some("data") | Some("by_data") => DecompositionStrategy::ByData,
479 Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
480 _ => DecompositionStrategy::Automatic,
481 }
482}
483
484async fn resolve_swarm_model(
485 explicit_model: Option<String>,
486 model_tier: Option<&str>,
487) -> Option<String> {
488 if let Some(model) = explicit_model {
489 if !model.trim().is_empty() {
490 return Some(model);
491 }
492 }
493
494 let registry = ProviderRegistry::from_vault().await.ok()?;
495 let providers = registry.list();
496 if providers.is_empty() {
497 return None;
498 }
499 let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
500 let model = default_model_for_provider(provider, model_tier);
501 Some(format!("{}/{}", provider, model))
502}
503
504fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
505 match event {
506 SwarmEvent::Started {
507 task,
508 total_subtasks,
509 } => Some(format!(
510 "[swarm] started task={} planned_subtasks={}",
511 task, total_subtasks
512 )),
513 SwarmEvent::StageComplete {
514 stage,
515 completed,
516 failed,
517 } => Some(format!(
518 "[swarm] stage={} completed={} failed={}",
519 stage, completed, failed
520 )),
521 SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
522 "[swarm] subtask id={} status={}",
523 &id.chars().take(8).collect::<String>(),
524 format!("{status:?}").to_ascii_lowercase()
525 )),
526 SwarmEvent::AgentToolCall {
527 subtask_id,
528 tool_name,
529 } => Some(format!(
530 "[swarm] subtask id={} tool={}",
531 &subtask_id.chars().take(8).collect::<String>(),
532 tool_name
533 )),
534 SwarmEvent::AgentError { subtask_id, error } => Some(format!(
535 "[swarm] subtask id={} error={}",
536 &subtask_id.chars().take(8).collect::<String>(),
537 error
538 )),
539 SwarmEvent::Complete { success, stats } => Some(format!(
540 "[swarm] complete success={} subtasks={} speedup={:.2}",
541 success,
542 stats.subagents_completed + stats.subagents_failed,
543 stats.speedup_factor
544 )),
545 SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
546 _ => None,
547 }
548}
549
550async fn register_worker(
551 client: &Client,
552 server: &str,
553 token: &Option<String>,
554 worker_id: &str,
555 name: &str,
556 codebases: &[String],
557) -> Result<()> {
558 let models = match load_provider_models().await {
560 Ok(m) => m,
561 Err(e) => {
562 tracing::warn!(
563 "Failed to load provider models: {}, proceeding without model info",
564 e
565 );
566 HashMap::new()
567 }
568 };
569
570 let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
572
573 if let Some(t) = token {
574 req = req.bearer_auth(t);
575 }
576
577 let models_array: Vec<serde_json::Value> = models
580 .iter()
581 .flat_map(|(provider, model_infos)| {
582 model_infos.iter().map(move |m| {
583 let mut obj = serde_json::json!({
584 "id": format!("{}/{}", provider, m.id),
585 "name": &m.id,
586 "provider": provider,
587 "provider_id": provider,
588 });
589 if let Some(input_cost) = m.input_cost_per_million {
590 obj["input_cost_per_million"] = serde_json::json!(input_cost);
591 }
592 if let Some(output_cost) = m.output_cost_per_million {
593 obj["output_cost_per_million"] = serde_json::json!(output_cost);
594 }
595 obj
596 })
597 })
598 .collect();
599
600 tracing::info!(
601 "Registering worker with {} models from {} providers",
602 models_array.len(),
603 models.len()
604 );
605
606 let hostname = std::env::var("HOSTNAME")
607 .or_else(|_| std::env::var("COMPUTERNAME"))
608 .unwrap_or_else(|_| "unknown".to_string());
609
610 let res = req
611 .json(&serde_json::json!({
612 "worker_id": worker_id,
613 "name": name,
614 "capabilities": WORKER_CAPABILITIES,
615 "hostname": hostname,
616 "models": models_array,
617 "codebases": codebases,
618 }))
619 .send()
620 .await?;
621
622 if res.status().is_success() {
623 tracing::info!("Worker registered successfully");
624 } else {
625 tracing::warn!("Failed to register worker: {}", res.status());
626 }
627
628 Ok(())
629}
630
631async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
635 let registry = match ProviderRegistry::from_vault().await {
637 Ok(r) if !r.list().is_empty() => {
638 tracing::info!("Loaded {} providers from Vault", r.list().len());
639 r
640 }
641 Ok(_) => {
642 tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
643 fallback_registry().await?
644 }
645 Err(e) => {
646 tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
647 fallback_registry().await?
648 }
649 };
650
651 let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
653
654 let catalog_alias = |pid: &str| -> String {
656 match pid {
657 "bedrock" => "amazon-bedrock".to_string(),
658 "novita" => "novita-ai".to_string(),
659 _ => pid.to_string(),
660 }
661 };
662
663 let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
664
665 for provider_name in registry.list() {
666 if let Some(provider) = registry.get(provider_name) {
667 match provider.list_models().await {
668 Ok(models) => {
669 let enriched: Vec<crate::provider::ModelInfo> = models
670 .into_iter()
671 .map(|mut m| {
672 if m.input_cost_per_million.is_none()
674 || m.output_cost_per_million.is_none()
675 {
676 if let Some(ref cat) = catalog {
677 let cat_pid = catalog_alias(provider_name);
678 if let Some(prov_info) = cat.get_provider(&cat_pid) {
679 let model_info =
681 prov_info.models.get(&m.id).or_else(|| {
682 m.id.strip_prefix("us.").and_then(|stripped| {
683 prov_info.models.get(stripped)
684 })
685 });
686 if let Some(model_info) = model_info {
687 if let Some(ref cost) = model_info.cost {
688 if m.input_cost_per_million.is_none() {
689 m.input_cost_per_million = Some(cost.input);
690 }
691 if m.output_cost_per_million.is_none() {
692 m.output_cost_per_million = Some(cost.output);
693 }
694 }
695 }
696 }
697 }
698 }
699 m
700 })
701 .collect();
702 if !enriched.is_empty() {
703 tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
704 models_by_provider.insert(provider_name.to_string(), enriched);
705 }
706 }
707 Err(e) => {
708 tracing::debug!("Failed to list models for {}: {}", provider_name, e);
709 }
710 }
711 }
712 }
713
714 if models_by_provider.is_empty() {
718 tracing::info!(
719 "No authenticated providers found, fetching models.dev catalog (all providers)"
720 );
721 if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
722 for (provider_id, provider_info) in cat.all_providers() {
725 let model_infos: Vec<crate::provider::ModelInfo> = provider_info
726 .models
727 .values()
728 .map(|m| cat.to_model_info(m, provider_id))
729 .collect();
730 if !model_infos.is_empty() {
731 tracing::debug!(
732 "Catalog provider {}: {} models",
733 provider_id,
734 model_infos.len()
735 );
736 models_by_provider.insert(provider_id.clone(), model_infos);
737 }
738 }
739 tracing::info!(
740 "Loaded {} providers with {} total models from catalog",
741 models_by_provider.len(),
742 models_by_provider.values().map(|v| v.len()).sum::<usize>()
743 );
744 }
745 }
746
747 Ok(models_by_provider)
748}
749
750async fn fallback_registry() -> Result<ProviderRegistry> {
752 let config = crate::config::Config::load().await.unwrap_or_default();
753 ProviderRegistry::from_config(&config).await
754}
755
756async fn fetch_pending_tasks(
757 client: &Client,
758 server: &str,
759 token: &Option<String>,
760 worker_id: &str,
761 processing: &Arc<Mutex<HashSet<String>>>,
762 auto_approve: &AutoApprove,
763) -> Result<()> {
764 tracing::info!("Checking for pending tasks...");
765
766 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
767 if let Some(t) = token {
768 req = req.bearer_auth(t);
769 }
770
771 let res = req.send().await?;
772 if !res.status().is_success() {
773 return Ok(());
774 }
775
776 let data: serde_json::Value = res.json().await?;
777 let tasks = if let Some(arr) = data.as_array() {
779 arr.clone()
780 } else {
781 data["tasks"].as_array().cloned().unwrap_or_default()
782 };
783
784 tracing::info!("Found {} pending task(s)", tasks.len());
785
786 for task in tasks {
787 if let Some(id) = task["id"].as_str() {
788 let mut proc = processing.lock().await;
789 if !proc.contains(id) {
790 proc.insert(id.to_string());
791 drop(proc);
792
793 let task_id = id.to_string();
794 let client = client.clone();
795 let server = server.to_string();
796 let token = token.clone();
797 let worker_id = worker_id.to_string();
798 let auto_approve = *auto_approve;
799 let processing = processing.clone();
800
801 tokio::spawn(async move {
802 if let Err(e) =
803 handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
804 {
805 tracing::error!("Task {} failed: {}", task_id, e);
806 }
807 processing.lock().await.remove(&task_id);
808 });
809 }
810 }
811 }
812
813 Ok(())
814}
815
816#[allow(clippy::too_many_arguments)]
817async fn connect_stream(
818 client: &Client,
819 server: &str,
820 token: &Option<String>,
821 worker_id: &str,
822 name: &str,
823 codebases: &[String],
824 processing: &Arc<Mutex<HashSet<String>>>,
825 auto_approve: &AutoApprove,
826) -> Result<()> {
827 let url = format!(
828 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
829 server,
830 urlencoding::encode(name),
831 urlencoding::encode(worker_id)
832 );
833
834 let mut req = client
835 .get(&url)
836 .header("Accept", "text/event-stream")
837 .header("X-Worker-ID", worker_id)
838 .header("X-Agent-Name", name)
839 .header("X-Codebases", codebases.join(","));
840
841 if let Some(t) = token {
842 req = req.bearer_auth(t);
843 }
844
845 let res = req.send().await?;
846 if !res.status().is_success() {
847 anyhow::bail!("Failed to connect: {}", res.status());
848 }
849
850 tracing::info!("Connected to A2A server");
851
852 let mut stream = res.bytes_stream();
853 let mut buffer = String::new();
854 let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
855 poll_interval.tick().await; loop {
858 tokio::select! {
859 chunk = stream.next() => {
860 match chunk {
861 Some(Ok(chunk)) => {
862 buffer.push_str(&String::from_utf8_lossy(&chunk));
863
864 while let Some(pos) = buffer.find("\n\n") {
866 let event_str = buffer[..pos].to_string();
867 buffer = buffer[pos + 2..].to_string();
868
869 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
870 let data = data_line.trim_start_matches("data:").trim();
871 if data == "[DONE]" || data.is_empty() {
872 continue;
873 }
874
875 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
876 spawn_task_handler(
877 &task, client, server, token, worker_id,
878 processing, auto_approve,
879 ).await;
880 }
881 }
882 }
883 }
884 Some(Err(e)) => {
885 return Err(e.into());
886 }
887 None => {
888 return Ok(());
890 }
891 }
892 }
893 _ = poll_interval.tick() => {
894 if let Err(e) = poll_pending_tasks(
896 client, server, token, worker_id, processing, auto_approve,
897 ).await {
898 tracing::warn!("Periodic task poll failed: {}", e);
899 }
900 }
901 }
902 }
903}
904
905async fn spawn_task_handler(
906 task: &serde_json::Value,
907 client: &Client,
908 server: &str,
909 token: &Option<String>,
910 worker_id: &str,
911 processing: &Arc<Mutex<HashSet<String>>>,
912 auto_approve: &AutoApprove,
913) {
914 if let Some(id) = task
915 .get("task")
916 .and_then(|t| t["id"].as_str())
917 .or_else(|| task["id"].as_str())
918 {
919 let mut proc = processing.lock().await;
920 if !proc.contains(id) {
921 proc.insert(id.to_string());
922 drop(proc);
923
924 let task_id = id.to_string();
925 let task = task.clone();
926 let client = client.clone();
927 let server = server.to_string();
928 let token = token.clone();
929 let worker_id = worker_id.to_string();
930 let auto_approve = *auto_approve;
931 let processing_clone = processing.clone();
932
933 tokio::spawn(async move {
934 if let Err(e) =
935 handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
936 {
937 tracing::error!("Task {} failed: {}", task_id, e);
938 }
939 processing_clone.lock().await.remove(&task_id);
940 });
941 }
942 }
943}
944
945async fn poll_pending_tasks(
946 client: &Client,
947 server: &str,
948 token: &Option<String>,
949 worker_id: &str,
950 processing: &Arc<Mutex<HashSet<String>>>,
951 auto_approve: &AutoApprove,
952) -> Result<()> {
953 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
954 if let Some(t) = token {
955 req = req.bearer_auth(t);
956 }
957
958 let res = req.send().await?;
959 if !res.status().is_success() {
960 return Ok(());
961 }
962
963 let data: serde_json::Value = res.json().await?;
964 let tasks = if let Some(arr) = data.as_array() {
965 arr.clone()
966 } else {
967 data["tasks"].as_array().cloned().unwrap_or_default()
968 };
969
970 if !tasks.is_empty() {
971 tracing::debug!("Poll found {} pending task(s)", tasks.len());
972 }
973
974 for task in &tasks {
975 spawn_task_handler(
976 task,
977 client,
978 server,
979 token,
980 worker_id,
981 processing,
982 auto_approve,
983 )
984 .await;
985 }
986
987 Ok(())
988}
989
990async fn handle_task(
991 client: &Client,
992 server: &str,
993 token: &Option<String>,
994 worker_id: &str,
995 task: &serde_json::Value,
996 auto_approve: AutoApprove,
997) -> Result<()> {
998 let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
999 let title = task_str(task, "title").unwrap_or("Untitled");
1000
1001 tracing::info!("Handling task: {} ({})", title, task_id);
1002
1003 let mut req = client
1005 .post(format!("{}/v1/worker/tasks/claim", server))
1006 .header("X-Worker-ID", worker_id);
1007 if let Some(t) = token {
1008 req = req.bearer_auth(t);
1009 }
1010
1011 let res = req
1012 .json(&serde_json::json!({ "task_id": task_id }))
1013 .send()
1014 .await?;
1015
1016 if !res.status().is_success() {
1017 let status = res.status();
1018 let text = res.text().await?;
1019 if status == reqwest::StatusCode::CONFLICT {
1020 tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1021 } else {
1022 tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1023 }
1024 return Ok(());
1025 }
1026
1027 tracing::info!("Claimed task: {}", task_id);
1028
1029 let metadata = task_metadata(task);
1030 let resume_session_id = metadata
1031 .get("resume_session_id")
1032 .and_then(|v| v.as_str())
1033 .map(|s| s.trim().to_string())
1034 .filter(|s| !s.is_empty());
1035 let complexity_hint = metadata_str(&metadata, &["complexity"]);
1036 let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1037 .map(|s| s.to_ascii_lowercase())
1038 .or_else(|| {
1039 complexity_hint.as_ref().map(|complexity| {
1040 match complexity.to_ascii_lowercase().as_str() {
1041 "quick" => "fast".to_string(),
1042 "deep" => "heavy".to_string(),
1043 _ => "balanced".to_string(),
1044 }
1045 })
1046 });
1047 let worker_personality = metadata_str(
1048 &metadata,
1049 &["worker_personality", "personality", "agent_personality"],
1050 );
1051 let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1052 let raw_model = task_str(task, "model_ref")
1053 .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1054 .or_else(|| task_str(task, "model"))
1055 .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1056 let selected_model = raw_model.map(model_ref_to_provider_model);
1057
1058 let mut session = if let Some(ref sid) = resume_session_id {
1060 match Session::load(sid).await {
1061 Ok(existing) => {
1062 tracing::info!("Resuming session {} for task {}", sid, task_id);
1063 existing
1064 }
1065 Err(e) => {
1066 tracing::warn!(
1067 "Could not load session {} for task {} ({}), starting a new session",
1068 sid,
1069 task_id,
1070 e
1071 );
1072 Session::new().await?
1073 }
1074 }
1075 } else {
1076 Session::new().await?
1077 };
1078
1079 let agent_type = task_str(task, "agent_type")
1080 .or_else(|| task_str(task, "agent"))
1081 .unwrap_or("build");
1082 session.agent = agent_type.to_string();
1083
1084 if let Some(model) = selected_model.clone() {
1085 session.metadata.model = Some(model);
1086 }
1087
1088 let prompt = task_str(task, "prompt")
1089 .or_else(|| task_str(task, "description"))
1090 .unwrap_or(title);
1091
1092 tracing::info!("Executing prompt: {}", prompt);
1093
1094 let stream_client = client.clone();
1096 let stream_server = server.to_string();
1097 let stream_token = token.clone();
1098 let stream_worker_id = worker_id.to_string();
1099 let stream_task_id = task_id.to_string();
1100
1101 let output_callback = move |output: String| {
1102 let c = stream_client.clone();
1103 let s = stream_server.clone();
1104 let t = stream_token.clone();
1105 let w = stream_worker_id.clone();
1106 let tid = stream_task_id.clone();
1107 tokio::spawn(async move {
1108 let mut req = c
1109 .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1110 .header("X-Worker-ID", &w);
1111 if let Some(tok) = &t {
1112 req = req.bearer_auth(tok);
1113 }
1114 let _ = req
1115 .json(&serde_json::json!({
1116 "worker_id": w,
1117 "output": output,
1118 }))
1119 .send()
1120 .await;
1121 });
1122 };
1123
1124 let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1126 match execute_swarm_with_policy(
1127 &mut session,
1128 prompt,
1129 model_tier.as_deref(),
1130 selected_model,
1131 &metadata,
1132 complexity_hint.as_deref(),
1133 worker_personality.as_deref(),
1134 target_agent_name.as_deref(),
1135 Some(&output_callback),
1136 )
1137 .await
1138 {
1139 Ok((session_result, true)) => {
1140 tracing::info!("Swarm task completed successfully: {}", task_id);
1141 (
1142 "completed",
1143 Some(session_result.text),
1144 None,
1145 Some(session_result.session_id),
1146 )
1147 }
1148 Ok((session_result, false)) => {
1149 tracing::warn!("Swarm task completed with failures: {}", task_id);
1150 (
1151 "failed",
1152 Some(session_result.text),
1153 Some("Swarm execution completed with failures".to_string()),
1154 Some(session_result.session_id),
1155 )
1156 }
1157 Err(e) => {
1158 tracing::error!("Swarm task failed: {} - {}", task_id, e);
1159 ("failed", None, Some(format!("Error: {}", e)), None)
1160 }
1161 }
1162 } else {
1163 match execute_session_with_policy(
1164 &mut session,
1165 prompt,
1166 auto_approve,
1167 model_tier.as_deref(),
1168 Some(&output_callback),
1169 )
1170 .await
1171 {
1172 Ok(session_result) => {
1173 tracing::info!("Task completed successfully: {}", task_id);
1174 (
1175 "completed",
1176 Some(session_result.text),
1177 None,
1178 Some(session_result.session_id),
1179 )
1180 }
1181 Err(e) => {
1182 tracing::error!("Task failed: {} - {}", task_id, e);
1183 ("failed", None, Some(format!("Error: {}", e)), None)
1184 }
1185 }
1186 };
1187
1188 let mut req = client
1190 .post(format!("{}/v1/worker/tasks/release", server))
1191 .header("X-Worker-ID", worker_id);
1192 if let Some(t) = token {
1193 req = req.bearer_auth(t);
1194 }
1195
1196 req.json(&serde_json::json!({
1197 "task_id": task_id,
1198 "status": status,
1199 "result": result,
1200 "error": error,
1201 "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1202 }))
1203 .send()
1204 .await?;
1205
1206 tracing::info!("Task released: {} with status: {}", task_id, status);
1207 Ok(())
1208}
1209
1210async fn execute_swarm_with_policy<F>(
1211 session: &mut Session,
1212 prompt: &str,
1213 model_tier: Option<&str>,
1214 explicit_model: Option<String>,
1215 metadata: &serde_json::Map<String, serde_json::Value>,
1216 complexity_hint: Option<&str>,
1217 worker_personality: Option<&str>,
1218 target_agent_name: Option<&str>,
1219 output_callback: Option<&F>,
1220) -> Result<(crate::session::SessionResult, bool)>
1221where
1222 F: Fn(String),
1223{
1224 use crate::provider::{ContentPart, Message, Role};
1225
1226 session.add_message(Message {
1227 role: Role::User,
1228 content: vec![ContentPart::Text {
1229 text: prompt.to_string(),
1230 }],
1231 });
1232
1233 if session.title.is_none() {
1234 session.generate_title().await?;
1235 }
1236
1237 let strategy = parse_swarm_strategy(metadata);
1238 let max_subagents = metadata_usize(
1239 metadata,
1240 &["swarm_max_subagents", "max_subagents", "subagents"],
1241 )
1242 .unwrap_or(10)
1243 .clamp(1, 100);
1244 let max_steps_per_subagent = metadata_usize(
1245 metadata,
1246 &[
1247 "swarm_max_steps_per_subagent",
1248 "max_steps_per_subagent",
1249 "max_steps",
1250 ],
1251 )
1252 .unwrap_or(50)
1253 .clamp(1, 200);
1254 let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1255 .unwrap_or(600)
1256 .clamp(30, 3600);
1257 let parallel_enabled =
1258 metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1259
1260 let model = resolve_swarm_model(explicit_model, model_tier).await;
1261 if let Some(ref selected_model) = model {
1262 session.metadata.model = Some(selected_model.clone());
1263 }
1264
1265 if let Some(cb) = output_callback {
1266 cb(format!(
1267 "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1268 complexity_hint.unwrap_or("standard"),
1269 model_tier.unwrap_or("balanced"),
1270 worker_personality.unwrap_or("auto"),
1271 target_agent_name.unwrap_or("auto")
1272 ));
1273 cb(format!(
1274 "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1275 strategy,
1276 max_subagents,
1277 max_steps_per_subagent,
1278 timeout_secs,
1279 model_tier.unwrap_or("balanced")
1280 ));
1281 }
1282
1283 let swarm_config = SwarmConfig {
1284 max_subagents,
1285 max_steps_per_subagent,
1286 subagent_timeout_secs: timeout_secs,
1287 parallel_enabled,
1288 model,
1289 working_dir: session
1290 .metadata
1291 .directory
1292 .as_ref()
1293 .map(|p| p.to_string_lossy().to_string()),
1294 ..Default::default()
1295 };
1296
1297 let swarm_result = if output_callback.is_some() {
1298 let (event_tx, mut event_rx) = mpsc::channel(256);
1299 let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1300 let prompt_owned = prompt.to_string();
1301 let mut exec_handle =
1302 tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1303
1304 let mut final_result: Option<crate::swarm::SwarmResult> = None;
1305
1306 while final_result.is_none() {
1307 tokio::select! {
1308 maybe_event = event_rx.recv() => {
1309 if let Some(event) = maybe_event {
1310 if let Some(cb) = output_callback {
1311 if let Some(line) = format_swarm_event_for_output(&event) {
1312 cb(line);
1313 }
1314 }
1315 }
1316 }
1317 join_result = &mut exec_handle => {
1318 let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1319 final_result = Some(joined?);
1320 }
1321 }
1322 }
1323
1324 while let Ok(event) = event_rx.try_recv() {
1325 if let Some(cb) = output_callback {
1326 if let Some(line) = format_swarm_event_for_output(&event) {
1327 cb(line);
1328 }
1329 }
1330 }
1331
1332 final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1333 } else {
1334 SwarmExecutor::new(swarm_config)
1335 .execute(prompt, strategy)
1336 .await?
1337 };
1338
1339 let final_text = if swarm_result.result.trim().is_empty() {
1340 if swarm_result.success {
1341 "Swarm completed without textual output.".to_string()
1342 } else {
1343 "Swarm finished with failures and no textual output.".to_string()
1344 }
1345 } else {
1346 swarm_result.result.clone()
1347 };
1348
1349 session.add_message(Message {
1350 role: Role::Assistant,
1351 content: vec![ContentPart::Text {
1352 text: final_text.clone(),
1353 }],
1354 });
1355 session.save().await?;
1356
1357 Ok((
1358 crate::session::SessionResult {
1359 text: final_text,
1360 session_id: session.id.clone(),
1361 },
1362 swarm_result.success,
1363 ))
1364}
1365
1366async fn execute_session_with_policy<F>(
1369 session: &mut Session,
1370 prompt: &str,
1371 auto_approve: AutoApprove,
1372 model_tier: Option<&str>,
1373 output_callback: Option<&F>,
1374) -> Result<crate::session::SessionResult>
1375where
1376 F: Fn(String),
1377{
1378 use crate::provider::{
1379 CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1380 };
1381 use std::sync::Arc;
1382
1383 let registry = ProviderRegistry::from_vault().await?;
1385 let providers = registry.list();
1386 tracing::info!("Available providers: {:?}", providers);
1387
1388 if providers.is_empty() {
1389 anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1390 }
1391
1392 let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1394 let (prov, model) = parse_model_string(model_str);
1395 if prov.is_some() {
1396 (prov.map(|s| s.to_string()), model.to_string())
1397 } else if providers.contains(&model) {
1398 (Some(model.to_string()), String::new())
1399 } else {
1400 (None, model.to_string())
1401 }
1402 } else {
1403 (None, String::new())
1404 };
1405
1406 let provider_slice = providers.as_slice();
1407 let provider_requested_but_unavailable = provider_name
1408 .as_deref()
1409 .map(|p| !providers.contains(&p))
1410 .unwrap_or(false);
1411
1412 let selected_provider = provider_name
1414 .as_deref()
1415 .filter(|p| providers.contains(p))
1416 .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1417
1418 let provider = registry
1419 .get(selected_provider)
1420 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1421
1422 session.add_message(Message {
1424 role: Role::User,
1425 content: vec![ContentPart::Text {
1426 text: prompt.to_string(),
1427 }],
1428 });
1429
1430 if session.title.is_none() {
1432 session.generate_title().await?;
1433 }
1434
1435 let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1438 model_id
1439 } else {
1440 default_model_for_provider(selected_provider, model_tier)
1441 };
1442
1443 let tool_registry =
1445 create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1446 let tool_definitions = tool_registry.definitions();
1447
1448 let temperature = if model.starts_with("kimi-k2") {
1449 Some(1.0)
1450 } else {
1451 Some(0.7)
1452 };
1453
1454 tracing::info!(
1455 "Using model: {} via provider: {} (tier: {:?})",
1456 model,
1457 selected_provider,
1458 model_tier
1459 );
1460 tracing::info!(
1461 "Available tools: {} (auto_approve: {:?})",
1462 tool_definitions.len(),
1463 auto_approve
1464 );
1465
1466 let cwd = std::env::var("PWD")
1468 .map(std::path::PathBuf::from)
1469 .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1470 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1471
1472 let mut final_output = String::new();
1473 let max_steps = 50;
1474
1475 for step in 1..=max_steps {
1476 tracing::info!(step = step, "Agent step starting");
1477
1478 let mut messages = vec![Message {
1480 role: Role::System,
1481 content: vec![ContentPart::Text {
1482 text: system_prompt.clone(),
1483 }],
1484 }];
1485 messages.extend(session.messages.clone());
1486
1487 let request = CompletionRequest {
1488 messages,
1489 tools: tool_definitions.clone(),
1490 model: model.clone(),
1491 temperature,
1492 top_p: None,
1493 max_tokens: Some(8192),
1494 stop: Vec::new(),
1495 };
1496
1497 let response = provider.complete(request).await?;
1498
1499 crate::telemetry::TOKEN_USAGE.record_model_usage(
1500 &model,
1501 response.usage.prompt_tokens as u64,
1502 response.usage.completion_tokens as u64,
1503 );
1504
1505 let tool_calls: Vec<(String, String, serde_json::Value)> = response
1507 .message
1508 .content
1509 .iter()
1510 .filter_map(|part| {
1511 if let ContentPart::ToolCall {
1512 id,
1513 name,
1514 arguments,
1515 } = part
1516 {
1517 let args: serde_json::Value =
1518 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1519 Some((id.clone(), name.clone(), args))
1520 } else {
1521 None
1522 }
1523 })
1524 .collect();
1525
1526 for part in &response.message.content {
1528 if let ContentPart::Text { text } = part {
1529 if !text.is_empty() {
1530 final_output.push_str(text);
1531 final_output.push('\n');
1532 if let Some(cb) = output_callback {
1533 cb(text.clone());
1534 }
1535 }
1536 }
1537 }
1538
1539 if tool_calls.is_empty() {
1541 session.add_message(response.message.clone());
1542 break;
1543 }
1544
1545 session.add_message(response.message.clone());
1546
1547 tracing::info!(
1548 step = step,
1549 num_tools = tool_calls.len(),
1550 "Executing tool calls"
1551 );
1552
1553 for (tool_id, tool_name, tool_input) in tool_calls {
1555 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1556
1557 if let Some(cb) = output_callback {
1559 cb(format!("[tool:start:{}]", tool_name));
1560 }
1561
1562 if !is_tool_allowed(&tool_name, auto_approve) {
1564 let msg = format!(
1565 "Tool '{}' requires approval but auto-approve policy is {:?}",
1566 tool_name, auto_approve
1567 );
1568 tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1569 session.add_message(Message {
1570 role: Role::Tool,
1571 content: vec![ContentPart::ToolResult {
1572 tool_call_id: tool_id,
1573 content: msg,
1574 }],
1575 });
1576 continue;
1577 }
1578
1579 let content = if let Some(tool) = tool_registry.get(&tool_name) {
1580 let exec_result: Result<crate::tool::ToolResult> =
1581 tool.execute(tool_input.clone()).await;
1582 match exec_result {
1583 Ok(result) => {
1584 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1585 if let Some(cb) = output_callback {
1586 let status = if result.success { "ok" } else { "err" };
1587 cb(format!(
1588 "[tool:{}:{}] {}",
1589 tool_name,
1590 status,
1591 &result.output[..result.output.len().min(500)]
1592 ));
1593 }
1594 result.output
1595 }
1596 Err(e) => {
1597 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1598 if let Some(cb) = output_callback {
1599 cb(format!("[tool:{}:err] {}", tool_name, e));
1600 }
1601 format!("Error: {}", e)
1602 }
1603 }
1604 } else {
1605 tracing::warn!(tool = %tool_name, "Tool not found");
1606 format!("Error: Unknown tool '{}'", tool_name)
1607 };
1608
1609 session.add_message(Message {
1610 role: Role::Tool,
1611 content: vec![ContentPart::ToolResult {
1612 tool_call_id: tool_id,
1613 content,
1614 }],
1615 });
1616 }
1617 }
1618
1619 session.save().await?;
1620
1621 Ok(crate::session::SessionResult {
1622 text: final_output.trim().to_string(),
1623 session_id: session.id.clone(),
1624 })
1625}
1626
1627fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1629 match auto_approve {
1630 AutoApprove::All => true,
1631 AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1632 }
1633}
1634
1635fn is_safe_tool(tool_name: &str) -> bool {
1637 let safe_tools = [
1638 "read",
1639 "list",
1640 "glob",
1641 "grep",
1642 "codesearch",
1643 "lsp",
1644 "webfetch",
1645 "websearch",
1646 "todo_read",
1647 "skill",
1648 ];
1649 safe_tools.contains(&tool_name)
1650}
1651
1652fn create_filtered_registry(
1654 provider: Arc<dyn crate::provider::Provider>,
1655 model: String,
1656 auto_approve: AutoApprove,
1657) -> crate::tool::ToolRegistry {
1658 use crate::tool::*;
1659
1660 let mut registry = ToolRegistry::new();
1661
1662 registry.register(Arc::new(file::ReadTool::new()));
1664 registry.register(Arc::new(file::ListTool::new()));
1665 registry.register(Arc::new(file::GlobTool::new()));
1666 registry.register(Arc::new(search::GrepTool::new()));
1667 registry.register(Arc::new(lsp::LspTool::new()));
1668 registry.register(Arc::new(webfetch::WebFetchTool::new()));
1669 registry.register(Arc::new(websearch::WebSearchTool::new()));
1670 registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1671 registry.register(Arc::new(todo::TodoReadTool::new()));
1672 registry.register(Arc::new(skill::SkillTool::new()));
1673
1674 if matches!(auto_approve, AutoApprove::All) {
1676 registry.register(Arc::new(file::WriteTool::new()));
1677 registry.register(Arc::new(edit::EditTool::new()));
1678 registry.register(Arc::new(bash::BashTool::new()));
1679 registry.register(Arc::new(multiedit::MultiEditTool::new()));
1680 registry.register(Arc::new(patch::ApplyPatchTool::new()));
1681 registry.register(Arc::new(todo::TodoWriteTool::new()));
1682 registry.register(Arc::new(task::TaskTool::new()));
1683 registry.register(Arc::new(plan::PlanEnterTool::new()));
1684 registry.register(Arc::new(plan::PlanExitTool::new()));
1685 registry.register(Arc::new(rlm::RlmTool::new()));
1686 registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1687 registry.register(Arc::new(prd::PrdTool::new()));
1688 registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1689 registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1690 registry.register(Arc::new(undo::UndoTool));
1691 registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1692 }
1693
1694 registry.register(Arc::new(invalid::InvalidTool::new()));
1695
1696 registry
1697}
1698
1699fn start_heartbeat(
1702 client: Client,
1703 server: String,
1704 token: Option<String>,
1705 heartbeat_state: HeartbeatState,
1706 processing: Arc<Mutex<HashSet<String>>>,
1707 cognition_config: CognitionHeartbeatConfig,
1708) -> JoinHandle<()> {
1709 tokio::spawn(async move {
1710 let mut consecutive_failures = 0u32;
1711 const MAX_FAILURES: u32 = 3;
1712 const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1713 const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1714 let mut cognition_payload_disabled_until: Option<Instant> = None;
1715
1716 let mut interval =
1717 tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1718 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1719
1720 loop {
1721 interval.tick().await;
1722
1723 let active_count = processing.lock().await.len();
1725 heartbeat_state.set_task_count(active_count).await;
1726
1727 let status = if active_count > 0 {
1729 WorkerStatus::Processing
1730 } else {
1731 WorkerStatus::Idle
1732 };
1733 heartbeat_state.set_status(status).await;
1734
1735 let url = format!(
1737 "{}/v1/opencode/workers/{}/heartbeat",
1738 server, heartbeat_state.worker_id
1739 );
1740 let mut req = client.post(&url);
1741
1742 if let Some(ref t) = token {
1743 req = req.bearer_auth(t);
1744 }
1745
1746 let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1747 let base_payload = serde_json::json!({
1748 "worker_id": &heartbeat_state.worker_id,
1749 "agent_name": &heartbeat_state.agent_name,
1750 "status": status_str,
1751 "active_task_count": active_count,
1752 });
1753 let mut payload = base_payload.clone();
1754 let mut included_cognition_payload = false;
1755 let cognition_payload_allowed = cognition_payload_disabled_until
1756 .map(|until| Instant::now() >= until)
1757 .unwrap_or(true);
1758
1759 if cognition_config.enabled
1760 && cognition_payload_allowed
1761 && let Some(cognition_payload) =
1762 fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1763 && let Some(obj) = payload.as_object_mut()
1764 {
1765 obj.insert("cognition".to_string(), cognition_payload);
1766 included_cognition_payload = true;
1767 }
1768
1769 match req.json(&payload).send().await {
1770 Ok(res) => {
1771 if res.status().is_success() {
1772 consecutive_failures = 0;
1773 tracing::debug!(
1774 worker_id = %heartbeat_state.worker_id,
1775 status = status_str,
1776 active_tasks = active_count,
1777 "Heartbeat sent successfully"
1778 );
1779 } else if included_cognition_payload && res.status().is_client_error() {
1780 tracing::warn!(
1781 worker_id = %heartbeat_state.worker_id,
1782 status = %res.status(),
1783 "Heartbeat cognition payload rejected, retrying without cognition payload"
1784 );
1785
1786 let mut retry_req = client.post(&url);
1787 if let Some(ref t) = token {
1788 retry_req = retry_req.bearer_auth(t);
1789 }
1790
1791 match retry_req.json(&base_payload).send().await {
1792 Ok(retry_res) if retry_res.status().is_success() => {
1793 cognition_payload_disabled_until = Some(
1794 Instant::now()
1795 + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1796 );
1797 consecutive_failures = 0;
1798 tracing::warn!(
1799 worker_id = %heartbeat_state.worker_id,
1800 retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1801 "Paused cognition heartbeat payload after schema rejection"
1802 );
1803 }
1804 Ok(retry_res) => {
1805 consecutive_failures += 1;
1806 tracing::warn!(
1807 worker_id = %heartbeat_state.worker_id,
1808 status = %retry_res.status(),
1809 failures = consecutive_failures,
1810 "Heartbeat failed even after retry without cognition payload"
1811 );
1812 }
1813 Err(e) => {
1814 consecutive_failures += 1;
1815 tracing::warn!(
1816 worker_id = %heartbeat_state.worker_id,
1817 error = %e,
1818 failures = consecutive_failures,
1819 "Heartbeat retry without cognition payload failed"
1820 );
1821 }
1822 }
1823 } else {
1824 consecutive_failures += 1;
1825 tracing::warn!(
1826 worker_id = %heartbeat_state.worker_id,
1827 status = %res.status(),
1828 failures = consecutive_failures,
1829 "Heartbeat failed"
1830 );
1831 }
1832 }
1833 Err(e) => {
1834 consecutive_failures += 1;
1835 tracing::warn!(
1836 worker_id = %heartbeat_state.worker_id,
1837 error = %e,
1838 failures = consecutive_failures,
1839 "Heartbeat request failed"
1840 );
1841 }
1842 }
1843
1844 if consecutive_failures >= MAX_FAILURES {
1846 tracing::error!(
1847 worker_id = %heartbeat_state.worker_id,
1848 failures = consecutive_failures,
1849 "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1850 MAX_FAILURES
1851 );
1852 consecutive_failures = 0;
1854 }
1855 }
1856 })
1857}
1858
1859async fn fetch_cognition_heartbeat_payload(
1860 client: &Client,
1861 config: &CognitionHeartbeatConfig,
1862) -> Option<serde_json::Value> {
1863 let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1864 let status_res = tokio::time::timeout(
1865 Duration::from_millis(config.request_timeout_ms),
1866 client.get(status_url).send(),
1867 )
1868 .await
1869 .ok()?
1870 .ok()?;
1871
1872 if !status_res.status().is_success() {
1873 return None;
1874 }
1875
1876 let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1877 let mut payload = serde_json::json!({
1878 "running": status.running,
1879 "last_tick_at": status.last_tick_at,
1880 "active_persona_count": status.active_persona_count,
1881 "events_buffered": status.events_buffered,
1882 "snapshots_buffered": status.snapshots_buffered,
1883 "loop_interval_ms": status.loop_interval_ms,
1884 });
1885
1886 if config.include_thought_summary {
1887 let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1888 let snapshot_res = tokio::time::timeout(
1889 Duration::from_millis(config.request_timeout_ms),
1890 client.get(snapshot_url).send(),
1891 )
1892 .await
1893 .ok()
1894 .and_then(Result::ok);
1895
1896 if let Some(snapshot_res) = snapshot_res
1897 && snapshot_res.status().is_success()
1898 && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1899 && let Some(obj) = payload.as_object_mut()
1900 {
1901 obj.insert(
1902 "latest_snapshot_at".to_string(),
1903 serde_json::Value::String(snapshot.generated_at),
1904 );
1905 obj.insert(
1906 "latest_thought".to_string(),
1907 serde_json::Value::String(trim_for_heartbeat(
1908 &snapshot.summary,
1909 config.summary_max_chars,
1910 )),
1911 );
1912 if let Some(model) = snapshot
1913 .metadata
1914 .get("model")
1915 .and_then(serde_json::Value::as_str)
1916 {
1917 obj.insert(
1918 "latest_thought_model".to_string(),
1919 serde_json::Value::String(model.to_string()),
1920 );
1921 }
1922 if let Some(source) = snapshot
1923 .metadata
1924 .get("source")
1925 .and_then(serde_json::Value::as_str)
1926 {
1927 obj.insert(
1928 "latest_thought_source".to_string(),
1929 serde_json::Value::String(source.to_string()),
1930 );
1931 }
1932 }
1933 }
1934
1935 Some(payload)
1936}
1937
1938fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1939 if input.chars().count() <= max_chars {
1940 return input.trim().to_string();
1941 }
1942
1943 let mut trimmed = String::with_capacity(max_chars + 3);
1944 for ch in input.chars().take(max_chars) {
1945 trimmed.push(ch);
1946 }
1947 trimmed.push_str("...");
1948 trimmed.trim().to_string()
1949}
1950
1951fn env_bool(name: &str, default: bool) -> bool {
1952 std::env::var(name)
1953 .ok()
1954 .and_then(|v| match v.to_ascii_lowercase().as_str() {
1955 "1" | "true" | "yes" | "on" => Some(true),
1956 "0" | "false" | "no" | "off" => Some(false),
1957 _ => None,
1958 })
1959 .unwrap_or(default)
1960}
1961
1962fn env_usize(name: &str, default: usize) -> usize {
1963 std::env::var(name)
1964 .ok()
1965 .and_then(|v| v.parse::<usize>().ok())
1966 .unwrap_or(default)
1967}
1968
1969fn env_u64(name: &str, default: u64) -> u64 {
1970 std::env::var(name)
1971 .ok()
1972 .and_then(|v| v.parse::<u64>().ok())
1973 .unwrap_or(default)
1974}