1use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{Mutex, mpsc};
17use tokio::task::JoinHandle;
18use tokio::time::Instant;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum WorkerStatus {
23 Idle,
24 Processing,
25}
26
27impl WorkerStatus {
28 fn as_str(&self) -> &'static str {
29 match self {
30 WorkerStatus::Idle => "idle",
31 WorkerStatus::Processing => "processing",
32 }
33 }
34}
35
36#[derive(Clone)]
38struct HeartbeatState {
39 worker_id: String,
40 agent_name: String,
41 status: Arc<Mutex<WorkerStatus>>,
42 active_task_count: Arc<Mutex<usize>>,
43}
44
45impl HeartbeatState {
46 fn new(worker_id: String, agent_name: String) -> Self {
47 Self {
48 worker_id,
49 agent_name,
50 status: Arc::new(Mutex::new(WorkerStatus::Idle)),
51 active_task_count: Arc::new(Mutex::new(0)),
52 }
53 }
54
55 async fn set_status(&self, status: WorkerStatus) {
56 *self.status.lock().await = status;
57 }
58
59 async fn set_task_count(&self, count: usize) {
60 *self.active_task_count.lock().await = count;
61 }
62}
63
64#[derive(Clone, Debug)]
65struct CognitionHeartbeatConfig {
66 enabled: bool,
67 source_base_url: String,
68 include_thought_summary: bool,
69 summary_max_chars: usize,
70 request_timeout_ms: u64,
71}
72
73impl CognitionHeartbeatConfig {
74 fn from_env() -> Self {
75 let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
76 .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
77 .trim_end_matches('/')
78 .to_string();
79
80 Self {
81 enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
82 source_base_url,
83 include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
84 summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
85 .max(120),
86 request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
87 }
88 }
89}
90
91#[derive(Debug, Deserialize)]
92struct CognitionStatusSnapshot {
93 running: bool,
94 #[serde(default)]
95 last_tick_at: Option<String>,
96 #[serde(default)]
97 active_persona_count: usize,
98 #[serde(default)]
99 events_buffered: usize,
100 #[serde(default)]
101 snapshots_buffered: usize,
102 #[serde(default)]
103 loop_interval_ms: u64,
104}
105
106#[derive(Debug, Deserialize)]
107struct CognitionLatestSnapshot {
108 generated_at: String,
109 summary: String,
110 #[serde(default)]
111 metadata: HashMap<String, serde_json::Value>,
112}
113
114pub async fn run(args: A2aArgs) -> Result<()> {
116 let server = args.server.trim_end_matches('/');
117 let name = args
118 .name
119 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
120 let worker_id = generate_worker_id();
121
122 let codebases: Vec<String> = args
123 .codebases
124 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
125 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
126
127 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
128 tracing::info!("Server: {}", server);
129 tracing::info!("Codebases: {:?}", codebases);
130
131 let client = Client::new();
132 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
133 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
134 if cognition_heartbeat.enabled {
135 tracing::info!(
136 source = %cognition_heartbeat.source_base_url,
137 include_thoughts = cognition_heartbeat.include_thought_summary,
138 max_chars = cognition_heartbeat.summary_max_chars,
139 timeout_ms = cognition_heartbeat.request_timeout_ms,
140 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
141 );
142 } else {
143 tracing::warn!(
144 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
145 );
146 }
147
148 let auto_approve = match args.auto_approve.as_str() {
149 "all" => AutoApprove::All,
150 "safe" => AutoApprove::Safe,
151 _ => AutoApprove::None,
152 };
153
154 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
156
157 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
159
160 fetch_pending_tasks(
162 &client,
163 server,
164 &args.token,
165 &worker_id,
166 &processing,
167 &auto_approve,
168 )
169 .await?;
170
171 loop {
173 if let Err(e) =
175 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
176 {
177 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
178 }
179
180 let heartbeat_handle = start_heartbeat(
182 client.clone(),
183 server.to_string(),
184 args.token.clone(),
185 heartbeat_state.clone(),
186 processing.clone(),
187 cognition_heartbeat.clone(),
188 );
189
190 match connect_stream(
191 &client,
192 server,
193 &args.token,
194 &worker_id,
195 &name,
196 &codebases,
197 &processing,
198 &auto_approve,
199 )
200 .await
201 {
202 Ok(()) => {
203 tracing::warn!("Stream ended, reconnecting...");
204 }
205 Err(e) => {
206 tracing::error!("Stream error: {}, reconnecting...", e);
207 }
208 }
209
210 heartbeat_handle.abort();
212 tracing::debug!("Heartbeat cancelled for reconnection");
213
214 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
215 }
216}
217
218fn generate_worker_id() -> String {
219 format!(
220 "wrk_{}_{:x}",
221 chrono::Utc::now().timestamp(),
222 rand::random::<u64>()
223 )
224}
225
226#[derive(Debug, Clone, Copy)]
227enum AutoApprove {
228 All,
229 Safe,
230 None,
231}
232
233const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
235
236fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
237 task.get("task")
238 .and_then(|t| t.get(key))
239 .or_else(|| task.get(key))
240}
241
242fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
243 task_value(task, key).and_then(|v| v.as_str())
244}
245
246fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
247 task_value(task, "metadata")
248 .and_then(|m| m.as_object())
249 .cloned()
250 .unwrap_or_default()
251}
252
253fn model_ref_to_provider_model(model: &str) -> String {
254 if !model.contains('/') && model.contains(':') {
258 model.replacen(':', "/", 1)
259 } else {
260 model.to_string()
261 }
262}
263
264fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
265 match model_tier.unwrap_or("balanced") {
266 "fast" | "quick" => &[
267 "openai",
268 "github-copilot",
269 "moonshotai",
270 "zhipuai",
271 "openrouter",
272 "novita",
273 "google",
274 "anthropic",
275 ],
276 "heavy" | "deep" => &[
277 "anthropic",
278 "openai",
279 "github-copilot",
280 "moonshotai",
281 "zhipuai",
282 "openrouter",
283 "novita",
284 "google",
285 ],
286 _ => &[
287 "openai",
288 "github-copilot",
289 "anthropic",
290 "moonshotai",
291 "zhipuai",
292 "openrouter",
293 "novita",
294 "google",
295 ],
296 }
297}
298
299fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
300 for preferred in provider_preferences_for_tier(model_tier) {
301 if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
302 return found;
303 }
304 }
305 if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
306 return found;
307 }
308 providers[0]
309}
310
311fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
312 match model_tier.unwrap_or("balanced") {
313 "fast" | "quick" => match provider {
314 "moonshotai" => "kimi-k2.5".to_string(),
315 "anthropic" => "claude-haiku-4-5".to_string(),
316 "openai" => "gpt-4o-mini".to_string(),
317 "google" => "gemini-2.5-flash".to_string(),
318 "zhipuai" => "glm-4.7".to_string(),
319 "openrouter" => "z-ai/glm-4.7".to_string(),
320 "novita" => "qwen/qwen3-coder-next".to_string(),
321 _ => "glm-4.7".to_string(),
322 },
323 "heavy" | "deep" => match provider {
324 "moonshotai" => "kimi-k2.5".to_string(),
325 "anthropic" => "claude-sonnet-4-20250514".to_string(),
326 "openai" => "o3".to_string(),
327 "google" => "gemini-2.5-pro".to_string(),
328 "zhipuai" => "glm-4.7".to_string(),
329 "openrouter" => "z-ai/glm-4.7".to_string(),
330 "novita" => "qwen/qwen3-coder-next".to_string(),
331 _ => "glm-4.7".to_string(),
332 },
333 _ => match provider {
334 "moonshotai" => "kimi-k2.5".to_string(),
335 "anthropic" => "claude-sonnet-4-20250514".to_string(),
336 "openai" => "gpt-4o".to_string(),
337 "google" => "gemini-2.5-pro".to_string(),
338 "zhipuai" => "glm-4.7".to_string(),
339 "openrouter" => "z-ai/glm-4.7".to_string(),
340 "novita" => "qwen/qwen3-coder-next".to_string(),
341 _ => "glm-4.7".to_string(),
342 },
343 }
344}
345
346fn is_swarm_agent(agent_type: &str) -> bool {
347 matches!(
348 agent_type.trim().to_ascii_lowercase().as_str(),
349 "swarm" | "parallel" | "multi-agent"
350 )
351}
352
353fn metadata_lookup<'a>(
354 metadata: &'a serde_json::Map<String, serde_json::Value>,
355 key: &str,
356) -> Option<&'a serde_json::Value> {
357 metadata
358 .get(key)
359 .or_else(|| {
360 metadata
361 .get("routing")
362 .and_then(|v| v.as_object())
363 .and_then(|obj| obj.get(key))
364 })
365 .or_else(|| {
366 metadata
367 .get("swarm")
368 .and_then(|v| v.as_object())
369 .and_then(|obj| obj.get(key))
370 })
371}
372
373fn metadata_str(
374 metadata: &serde_json::Map<String, serde_json::Value>,
375 keys: &[&str],
376) -> Option<String> {
377 for key in keys {
378 if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
379 let trimmed = value.trim();
380 if !trimmed.is_empty() {
381 return Some(trimmed.to_string());
382 }
383 }
384 }
385 None
386}
387
388fn metadata_usize(
389 metadata: &serde_json::Map<String, serde_json::Value>,
390 keys: &[&str],
391) -> Option<usize> {
392 for key in keys {
393 if let Some(value) = metadata_lookup(metadata, key) {
394 if let Some(v) = value.as_u64() {
395 return usize::try_from(v).ok();
396 }
397 if let Some(v) = value.as_i64() {
398 if v >= 0 {
399 return usize::try_from(v as u64).ok();
400 }
401 }
402 if let Some(v) = value.as_str() {
403 if let Ok(parsed) = v.trim().parse::<usize>() {
404 return Some(parsed);
405 }
406 }
407 }
408 }
409 None
410}
411
412fn metadata_u64(
413 metadata: &serde_json::Map<String, serde_json::Value>,
414 keys: &[&str],
415) -> Option<u64> {
416 for key in keys {
417 if let Some(value) = metadata_lookup(metadata, key) {
418 if let Some(v) = value.as_u64() {
419 return Some(v);
420 }
421 if let Some(v) = value.as_i64() {
422 if v >= 0 {
423 return Some(v as u64);
424 }
425 }
426 if let Some(v) = value.as_str() {
427 if let Ok(parsed) = v.trim().parse::<u64>() {
428 return Some(parsed);
429 }
430 }
431 }
432 }
433 None
434}
435
436fn metadata_bool(
437 metadata: &serde_json::Map<String, serde_json::Value>,
438 keys: &[&str],
439) -> Option<bool> {
440 for key in keys {
441 if let Some(value) = metadata_lookup(metadata, key) {
442 if let Some(v) = value.as_bool() {
443 return Some(v);
444 }
445 if let Some(v) = value.as_str() {
446 match v.trim().to_ascii_lowercase().as_str() {
447 "1" | "true" | "yes" | "on" => return Some(true),
448 "0" | "false" | "no" | "off" => return Some(false),
449 _ => {}
450 }
451 }
452 }
453 }
454 None
455}
456
457fn parse_swarm_strategy(
458 metadata: &serde_json::Map<String, serde_json::Value>,
459) -> DecompositionStrategy {
460 match metadata_str(
461 metadata,
462 &[
463 "decomposition_strategy",
464 "swarm_strategy",
465 "strategy",
466 "swarm_decomposition",
467 ],
468 )
469 .as_deref()
470 .map(|s| s.to_ascii_lowercase())
471 .as_deref()
472 {
473 Some("none") | Some("single") => DecompositionStrategy::None,
474 Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
475 Some("data") | Some("by_data") => DecompositionStrategy::ByData,
476 Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
477 _ => DecompositionStrategy::Automatic,
478 }
479}
480
481async fn resolve_swarm_model(
482 explicit_model: Option<String>,
483 model_tier: Option<&str>,
484) -> Option<String> {
485 if let Some(model) = explicit_model {
486 if !model.trim().is_empty() {
487 return Some(model);
488 }
489 }
490
491 let registry = ProviderRegistry::from_vault().await.ok()?;
492 let providers = registry.list();
493 if providers.is_empty() {
494 return None;
495 }
496 let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
497 let model = default_model_for_provider(provider, model_tier);
498 Some(format!("{}/{}", provider, model))
499}
500
501fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
502 match event {
503 SwarmEvent::Started {
504 task,
505 total_subtasks,
506 } => Some(format!(
507 "[swarm] started task={} planned_subtasks={}",
508 task, total_subtasks
509 )),
510 SwarmEvent::StageComplete {
511 stage,
512 completed,
513 failed,
514 } => Some(format!(
515 "[swarm] stage={} completed={} failed={}",
516 stage, completed, failed
517 )),
518 SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
519 "[swarm] subtask id={} status={}",
520 &id.chars().take(8).collect::<String>(),
521 format!("{status:?}").to_ascii_lowercase()
522 )),
523 SwarmEvent::AgentToolCall {
524 subtask_id,
525 tool_name,
526 } => Some(format!(
527 "[swarm] subtask id={} tool={}",
528 &subtask_id.chars().take(8).collect::<String>(),
529 tool_name
530 )),
531 SwarmEvent::AgentError { subtask_id, error } => Some(format!(
532 "[swarm] subtask id={} error={}",
533 &subtask_id.chars().take(8).collect::<String>(),
534 error
535 )),
536 SwarmEvent::Complete { success, stats } => Some(format!(
537 "[swarm] complete success={} subtasks={} speedup={:.2}",
538 success,
539 stats.subagents_completed + stats.subagents_failed,
540 stats.speedup_factor
541 )),
542 SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
543 _ => None,
544 }
545}
546
547async fn register_worker(
548 client: &Client,
549 server: &str,
550 token: &Option<String>,
551 worker_id: &str,
552 name: &str,
553 codebases: &[String],
554) -> Result<()> {
555 let models = match load_provider_models().await {
557 Ok(m) => m,
558 Err(e) => {
559 tracing::warn!(
560 "Failed to load provider models: {}, proceeding without model info",
561 e
562 );
563 HashMap::new()
564 }
565 };
566
567 let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
569
570 if let Some(t) = token {
571 req = req.bearer_auth(t);
572 }
573
574 let models_array: Vec<serde_json::Value> = models
577 .iter()
578 .flat_map(|(provider, model_infos)| {
579 model_infos.iter().map(move |m| {
580 let mut obj = serde_json::json!({
581 "id": format!("{}/{}", provider, m.id),
582 "name": &m.id,
583 "provider": provider,
584 "provider_id": provider,
585 });
586 if let Some(input_cost) = m.input_cost_per_million {
587 obj["input_cost_per_million"] = serde_json::json!(input_cost);
588 }
589 if let Some(output_cost) = m.output_cost_per_million {
590 obj["output_cost_per_million"] = serde_json::json!(output_cost);
591 }
592 obj
593 })
594 })
595 .collect();
596
597 tracing::info!(
598 "Registering worker with {} models from {} providers",
599 models_array.len(),
600 models.len()
601 );
602
603 let hostname = std::env::var("HOSTNAME")
604 .or_else(|_| std::env::var("COMPUTERNAME"))
605 .unwrap_or_else(|_| "unknown".to_string());
606
607 let res = req
608 .json(&serde_json::json!({
609 "worker_id": worker_id,
610 "name": name,
611 "capabilities": WORKER_CAPABILITIES,
612 "hostname": hostname,
613 "models": models_array,
614 "codebases": codebases,
615 }))
616 .send()
617 .await?;
618
619 if res.status().is_success() {
620 tracing::info!("Worker registered successfully");
621 } else {
622 tracing::warn!("Failed to register worker: {}", res.status());
623 }
624
625 Ok(())
626}
627
628async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
632 let registry = match ProviderRegistry::from_vault().await {
634 Ok(r) if !r.list().is_empty() => {
635 tracing::info!("Loaded {} providers from Vault", r.list().len());
636 r
637 }
638 Ok(_) => {
639 tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
640 fallback_registry().await?
641 }
642 Err(e) => {
643 tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
644 fallback_registry().await?
645 }
646 };
647
648 let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
650
651 let catalog_alias = |pid: &str| -> String {
653 match pid {
654 "bedrock" => "amazon-bedrock".to_string(),
655 "novita" => "novita-ai".to_string(),
656 _ => pid.to_string(),
657 }
658 };
659
660 let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
661
662 for provider_name in registry.list() {
663 if let Some(provider) = registry.get(provider_name) {
664 match provider.list_models().await {
665 Ok(models) => {
666 let enriched: Vec<crate::provider::ModelInfo> = models
667 .into_iter()
668 .map(|mut m| {
669 if m.input_cost_per_million.is_none()
671 || m.output_cost_per_million.is_none()
672 {
673 if let Some(ref cat) = catalog {
674 let cat_pid = catalog_alias(provider_name);
675 if let Some(prov_info) = cat.get_provider(&cat_pid) {
676 let model_info =
678 prov_info.models.get(&m.id).or_else(|| {
679 m.id.strip_prefix("us.").and_then(|stripped| {
680 prov_info.models.get(stripped)
681 })
682 });
683 if let Some(model_info) = model_info {
684 if let Some(ref cost) = model_info.cost {
685 if m.input_cost_per_million.is_none() {
686 m.input_cost_per_million = Some(cost.input);
687 }
688 if m.output_cost_per_million.is_none() {
689 m.output_cost_per_million = Some(cost.output);
690 }
691 }
692 }
693 }
694 }
695 }
696 m
697 })
698 .collect();
699 if !enriched.is_empty() {
700 tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
701 models_by_provider.insert(provider_name.to_string(), enriched);
702 }
703 }
704 Err(e) => {
705 tracing::debug!("Failed to list models for {}: {}", provider_name, e);
706 }
707 }
708 }
709 }
710
711 if models_by_provider.is_empty() {
715 tracing::info!(
716 "No authenticated providers found, fetching models.dev catalog (all providers)"
717 );
718 if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
719 for (provider_id, provider_info) in cat.all_providers() {
722 let model_infos: Vec<crate::provider::ModelInfo> = provider_info
723 .models
724 .values()
725 .map(|m| cat.to_model_info(m, provider_id))
726 .collect();
727 if !model_infos.is_empty() {
728 tracing::debug!(
729 "Catalog provider {}: {} models",
730 provider_id,
731 model_infos.len()
732 );
733 models_by_provider.insert(provider_id.clone(), model_infos);
734 }
735 }
736 tracing::info!(
737 "Loaded {} providers with {} total models from catalog",
738 models_by_provider.len(),
739 models_by_provider.values().map(|v| v.len()).sum::<usize>()
740 );
741 }
742 }
743
744 Ok(models_by_provider)
745}
746
747async fn fallback_registry() -> Result<ProviderRegistry> {
749 let config = crate::config::Config::load().await.unwrap_or_default();
750 ProviderRegistry::from_config(&config).await
751}
752
753async fn fetch_pending_tasks(
754 client: &Client,
755 server: &str,
756 token: &Option<String>,
757 worker_id: &str,
758 processing: &Arc<Mutex<HashSet<String>>>,
759 auto_approve: &AutoApprove,
760) -> Result<()> {
761 tracing::info!("Checking for pending tasks...");
762
763 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
764 if let Some(t) = token {
765 req = req.bearer_auth(t);
766 }
767
768 let res = req.send().await?;
769 if !res.status().is_success() {
770 return Ok(());
771 }
772
773 let data: serde_json::Value = res.json().await?;
774 let tasks = if let Some(arr) = data.as_array() {
776 arr.clone()
777 } else {
778 data["tasks"].as_array().cloned().unwrap_or_default()
779 };
780
781 tracing::info!("Found {} pending task(s)", tasks.len());
782
783 for task in tasks {
784 if let Some(id) = task["id"].as_str() {
785 let mut proc = processing.lock().await;
786 if !proc.contains(id) {
787 proc.insert(id.to_string());
788 drop(proc);
789
790 let task_id = id.to_string();
791 let client = client.clone();
792 let server = server.to_string();
793 let token = token.clone();
794 let worker_id = worker_id.to_string();
795 let auto_approve = *auto_approve;
796 let processing = processing.clone();
797
798 tokio::spawn(async move {
799 if let Err(e) =
800 handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
801 {
802 tracing::error!("Task {} failed: {}", task_id, e);
803 }
804 processing.lock().await.remove(&task_id);
805 });
806 }
807 }
808 }
809
810 Ok(())
811}
812
813#[allow(clippy::too_many_arguments)]
814async fn connect_stream(
815 client: &Client,
816 server: &str,
817 token: &Option<String>,
818 worker_id: &str,
819 name: &str,
820 codebases: &[String],
821 processing: &Arc<Mutex<HashSet<String>>>,
822 auto_approve: &AutoApprove,
823) -> Result<()> {
824 let url = format!(
825 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
826 server,
827 urlencoding::encode(name),
828 urlencoding::encode(worker_id)
829 );
830
831 let mut req = client
832 .get(&url)
833 .header("Accept", "text/event-stream")
834 .header("X-Worker-ID", worker_id)
835 .header("X-Agent-Name", name)
836 .header("X-Codebases", codebases.join(","));
837
838 if let Some(t) = token {
839 req = req.bearer_auth(t);
840 }
841
842 let res = req.send().await?;
843 if !res.status().is_success() {
844 anyhow::bail!("Failed to connect: {}", res.status());
845 }
846
847 tracing::info!("Connected to A2A server");
848
849 let mut stream = res.bytes_stream();
850 let mut buffer = String::new();
851 let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
852 poll_interval.tick().await; loop {
855 tokio::select! {
856 chunk = stream.next() => {
857 match chunk {
858 Some(Ok(chunk)) => {
859 buffer.push_str(&String::from_utf8_lossy(&chunk));
860
861 while let Some(pos) = buffer.find("\n\n") {
863 let event_str = buffer[..pos].to_string();
864 buffer = buffer[pos + 2..].to_string();
865
866 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
867 let data = data_line.trim_start_matches("data:").trim();
868 if data == "[DONE]" || data.is_empty() {
869 continue;
870 }
871
872 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
873 spawn_task_handler(
874 &task, client, server, token, worker_id,
875 processing, auto_approve,
876 ).await;
877 }
878 }
879 }
880 }
881 Some(Err(e)) => {
882 return Err(e.into());
883 }
884 None => {
885 return Ok(());
887 }
888 }
889 }
890 _ = poll_interval.tick() => {
891 if let Err(e) = poll_pending_tasks(
893 client, server, token, worker_id, processing, auto_approve,
894 ).await {
895 tracing::warn!("Periodic task poll failed: {}", e);
896 }
897 }
898 }
899 }
900}
901
902async fn spawn_task_handler(
903 task: &serde_json::Value,
904 client: &Client,
905 server: &str,
906 token: &Option<String>,
907 worker_id: &str,
908 processing: &Arc<Mutex<HashSet<String>>>,
909 auto_approve: &AutoApprove,
910) {
911 if let Some(id) = task
912 .get("task")
913 .and_then(|t| t["id"].as_str())
914 .or_else(|| task["id"].as_str())
915 {
916 let mut proc = processing.lock().await;
917 if !proc.contains(id) {
918 proc.insert(id.to_string());
919 drop(proc);
920
921 let task_id = id.to_string();
922 let task = task.clone();
923 let client = client.clone();
924 let server = server.to_string();
925 let token = token.clone();
926 let worker_id = worker_id.to_string();
927 let auto_approve = *auto_approve;
928 let processing_clone = processing.clone();
929
930 tokio::spawn(async move {
931 if let Err(e) =
932 handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
933 {
934 tracing::error!("Task {} failed: {}", task_id, e);
935 }
936 processing_clone.lock().await.remove(&task_id);
937 });
938 }
939 }
940}
941
942async fn poll_pending_tasks(
943 client: &Client,
944 server: &str,
945 token: &Option<String>,
946 worker_id: &str,
947 processing: &Arc<Mutex<HashSet<String>>>,
948 auto_approve: &AutoApprove,
949) -> Result<()> {
950 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
951 if let Some(t) = token {
952 req = req.bearer_auth(t);
953 }
954
955 let res = req.send().await?;
956 if !res.status().is_success() {
957 return Ok(());
958 }
959
960 let data: serde_json::Value = res.json().await?;
961 let tasks = if let Some(arr) = data.as_array() {
962 arr.clone()
963 } else {
964 data["tasks"].as_array().cloned().unwrap_or_default()
965 };
966
967 if !tasks.is_empty() {
968 tracing::debug!("Poll found {} pending task(s)", tasks.len());
969 }
970
971 for task in &tasks {
972 spawn_task_handler(
973 task,
974 client,
975 server,
976 token,
977 worker_id,
978 processing,
979 auto_approve,
980 )
981 .await;
982 }
983
984 Ok(())
985}
986
987async fn handle_task(
988 client: &Client,
989 server: &str,
990 token: &Option<String>,
991 worker_id: &str,
992 task: &serde_json::Value,
993 auto_approve: AutoApprove,
994) -> Result<()> {
995 let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
996 let title = task_str(task, "title").unwrap_or("Untitled");
997
998 tracing::info!("Handling task: {} ({})", title, task_id);
999
1000 let mut req = client
1002 .post(format!("{}/v1/worker/tasks/claim", server))
1003 .header("X-Worker-ID", worker_id);
1004 if let Some(t) = token {
1005 req = req.bearer_auth(t);
1006 }
1007
1008 let res = req
1009 .json(&serde_json::json!({ "task_id": task_id }))
1010 .send()
1011 .await?;
1012
1013 if !res.status().is_success() {
1014 let status = res.status();
1015 let text = res.text().await?;
1016 if status == reqwest::StatusCode::CONFLICT {
1017 tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1018 } else {
1019 tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1020 }
1021 return Ok(());
1022 }
1023
1024 tracing::info!("Claimed task: {}", task_id);
1025
1026 let metadata = task_metadata(task);
1027 let resume_session_id = metadata
1028 .get("resume_session_id")
1029 .and_then(|v| v.as_str())
1030 .map(|s| s.trim().to_string())
1031 .filter(|s| !s.is_empty());
1032 let complexity_hint = metadata_str(&metadata, &["complexity"]);
1033 let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1034 .map(|s| s.to_ascii_lowercase())
1035 .or_else(|| {
1036 complexity_hint.as_ref().map(|complexity| {
1037 match complexity.to_ascii_lowercase().as_str() {
1038 "quick" => "fast".to_string(),
1039 "deep" => "heavy".to_string(),
1040 _ => "balanced".to_string(),
1041 }
1042 })
1043 });
1044 let worker_personality = metadata_str(
1045 &metadata,
1046 &["worker_personality", "personality", "agent_personality"],
1047 );
1048 let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1049 let raw_model = task_str(task, "model_ref")
1050 .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1051 .or_else(|| task_str(task, "model"))
1052 .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1053 let selected_model = raw_model.map(model_ref_to_provider_model);
1054
1055 let mut session = if let Some(ref sid) = resume_session_id {
1057 match Session::load(sid).await {
1058 Ok(existing) => {
1059 tracing::info!("Resuming session {} for task {}", sid, task_id);
1060 existing
1061 }
1062 Err(e) => {
1063 tracing::warn!(
1064 "Could not load session {} for task {} ({}), starting a new session",
1065 sid,
1066 task_id,
1067 e
1068 );
1069 Session::new().await?
1070 }
1071 }
1072 } else {
1073 Session::new().await?
1074 };
1075
1076 let agent_type = task_str(task, "agent_type")
1077 .or_else(|| task_str(task, "agent"))
1078 .unwrap_or("build");
1079 session.agent = agent_type.to_string();
1080
1081 if let Some(model) = selected_model.clone() {
1082 session.metadata.model = Some(model);
1083 }
1084
1085 let prompt = task_str(task, "prompt")
1086 .or_else(|| task_str(task, "description"))
1087 .unwrap_or(title);
1088
1089 tracing::info!("Executing prompt: {}", prompt);
1090
1091 let stream_client = client.clone();
1093 let stream_server = server.to_string();
1094 let stream_token = token.clone();
1095 let stream_worker_id = worker_id.to_string();
1096 let stream_task_id = task_id.to_string();
1097
1098 let output_callback = move |output: String| {
1099 let c = stream_client.clone();
1100 let s = stream_server.clone();
1101 let t = stream_token.clone();
1102 let w = stream_worker_id.clone();
1103 let tid = stream_task_id.clone();
1104 tokio::spawn(async move {
1105 let mut req = c
1106 .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1107 .header("X-Worker-ID", &w);
1108 if let Some(tok) = &t {
1109 req = req.bearer_auth(tok);
1110 }
1111 let _ = req
1112 .json(&serde_json::json!({
1113 "worker_id": w,
1114 "output": output,
1115 }))
1116 .send()
1117 .await;
1118 });
1119 };
1120
1121 let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1123 match execute_swarm_with_policy(
1124 &mut session,
1125 prompt,
1126 model_tier.as_deref(),
1127 selected_model,
1128 &metadata,
1129 complexity_hint.as_deref(),
1130 worker_personality.as_deref(),
1131 target_agent_name.as_deref(),
1132 Some(&output_callback),
1133 )
1134 .await
1135 {
1136 Ok((session_result, true)) => {
1137 tracing::info!("Swarm task completed successfully: {}", task_id);
1138 (
1139 "completed",
1140 Some(session_result.text),
1141 None,
1142 Some(session_result.session_id),
1143 )
1144 }
1145 Ok((session_result, false)) => {
1146 tracing::warn!("Swarm task completed with failures: {}", task_id);
1147 (
1148 "failed",
1149 Some(session_result.text),
1150 Some("Swarm execution completed with failures".to_string()),
1151 Some(session_result.session_id),
1152 )
1153 }
1154 Err(e) => {
1155 tracing::error!("Swarm task failed: {} - {}", task_id, e);
1156 ("failed", None, Some(format!("Error: {}", e)), None)
1157 }
1158 }
1159 } else {
1160 match execute_session_with_policy(
1161 &mut session,
1162 prompt,
1163 auto_approve,
1164 model_tier.as_deref(),
1165 Some(&output_callback),
1166 )
1167 .await
1168 {
1169 Ok(session_result) => {
1170 tracing::info!("Task completed successfully: {}", task_id);
1171 (
1172 "completed",
1173 Some(session_result.text),
1174 None,
1175 Some(session_result.session_id),
1176 )
1177 }
1178 Err(e) => {
1179 tracing::error!("Task failed: {} - {}", task_id, e);
1180 ("failed", None, Some(format!("Error: {}", e)), None)
1181 }
1182 }
1183 };
1184
1185 let mut req = client
1187 .post(format!("{}/v1/worker/tasks/release", server))
1188 .header("X-Worker-ID", worker_id);
1189 if let Some(t) = token {
1190 req = req.bearer_auth(t);
1191 }
1192
1193 req.json(&serde_json::json!({
1194 "task_id": task_id,
1195 "status": status,
1196 "result": result,
1197 "error": error,
1198 "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1199 }))
1200 .send()
1201 .await?;
1202
1203 tracing::info!("Task released: {} with status: {}", task_id, status);
1204 Ok(())
1205}
1206
1207async fn execute_swarm_with_policy<F>(
1208 session: &mut Session,
1209 prompt: &str,
1210 model_tier: Option<&str>,
1211 explicit_model: Option<String>,
1212 metadata: &serde_json::Map<String, serde_json::Value>,
1213 complexity_hint: Option<&str>,
1214 worker_personality: Option<&str>,
1215 target_agent_name: Option<&str>,
1216 output_callback: Option<&F>,
1217) -> Result<(crate::session::SessionResult, bool)>
1218where
1219 F: Fn(String),
1220{
1221 use crate::provider::{ContentPart, Message, Role};
1222
1223 session.add_message(Message {
1224 role: Role::User,
1225 content: vec![ContentPart::Text {
1226 text: prompt.to_string(),
1227 }],
1228 });
1229
1230 if session.title.is_none() {
1231 session.generate_title().await?;
1232 }
1233
1234 let strategy = parse_swarm_strategy(metadata);
1235 let max_subagents = metadata_usize(
1236 metadata,
1237 &["swarm_max_subagents", "max_subagents", "subagents"],
1238 )
1239 .unwrap_or(10)
1240 .clamp(1, 100);
1241 let max_steps_per_subagent = metadata_usize(
1242 metadata,
1243 &[
1244 "swarm_max_steps_per_subagent",
1245 "max_steps_per_subagent",
1246 "max_steps",
1247 ],
1248 )
1249 .unwrap_or(50)
1250 .clamp(1, 200);
1251 let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1252 .unwrap_or(600)
1253 .clamp(30, 3600);
1254 let parallel_enabled =
1255 metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1256
1257 let model = resolve_swarm_model(explicit_model, model_tier).await;
1258 if let Some(ref selected_model) = model {
1259 session.metadata.model = Some(selected_model.clone());
1260 }
1261
1262 if let Some(cb) = output_callback {
1263 cb(format!(
1264 "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1265 complexity_hint.unwrap_or("standard"),
1266 model_tier.unwrap_or("balanced"),
1267 worker_personality.unwrap_or("auto"),
1268 target_agent_name.unwrap_or("auto")
1269 ));
1270 cb(format!(
1271 "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1272 strategy,
1273 max_subagents,
1274 max_steps_per_subagent,
1275 timeout_secs,
1276 model_tier.unwrap_or("balanced")
1277 ));
1278 }
1279
1280 let swarm_config = SwarmConfig {
1281 max_subagents,
1282 max_steps_per_subagent,
1283 subagent_timeout_secs: timeout_secs,
1284 parallel_enabled,
1285 model,
1286 working_dir: session
1287 .metadata
1288 .directory
1289 .as_ref()
1290 .map(|p| p.to_string_lossy().to_string()),
1291 ..Default::default()
1292 };
1293
1294 let swarm_result = if output_callback.is_some() {
1295 let (event_tx, mut event_rx) = mpsc::channel(256);
1296 let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1297 let prompt_owned = prompt.to_string();
1298 let mut exec_handle =
1299 tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1300
1301 let mut final_result: Option<crate::swarm::SwarmResult> = None;
1302
1303 while final_result.is_none() {
1304 tokio::select! {
1305 maybe_event = event_rx.recv() => {
1306 if let Some(event) = maybe_event {
1307 if let Some(cb) = output_callback {
1308 if let Some(line) = format_swarm_event_for_output(&event) {
1309 cb(line);
1310 }
1311 }
1312 }
1313 }
1314 join_result = &mut exec_handle => {
1315 let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1316 final_result = Some(joined?);
1317 }
1318 }
1319 }
1320
1321 while let Ok(event) = event_rx.try_recv() {
1322 if let Some(cb) = output_callback {
1323 if let Some(line) = format_swarm_event_for_output(&event) {
1324 cb(line);
1325 }
1326 }
1327 }
1328
1329 final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1330 } else {
1331 SwarmExecutor::new(swarm_config)
1332 .execute(prompt, strategy)
1333 .await?
1334 };
1335
1336 let final_text = if swarm_result.result.trim().is_empty() {
1337 if swarm_result.success {
1338 "Swarm completed without textual output.".to_string()
1339 } else {
1340 "Swarm finished with failures and no textual output.".to_string()
1341 }
1342 } else {
1343 swarm_result.result.clone()
1344 };
1345
1346 session.add_message(Message {
1347 role: Role::Assistant,
1348 content: vec![ContentPart::Text {
1349 text: final_text.clone(),
1350 }],
1351 });
1352 session.save().await?;
1353
1354 Ok((
1355 crate::session::SessionResult {
1356 text: final_text,
1357 session_id: session.id.clone(),
1358 },
1359 swarm_result.success,
1360 ))
1361}
1362
1363async fn execute_session_with_policy<F>(
1366 session: &mut Session,
1367 prompt: &str,
1368 auto_approve: AutoApprove,
1369 model_tier: Option<&str>,
1370 output_callback: Option<&F>,
1371) -> Result<crate::session::SessionResult>
1372where
1373 F: Fn(String),
1374{
1375 use crate::provider::{
1376 CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1377 };
1378 use std::sync::Arc;
1379
1380 let registry = ProviderRegistry::from_vault().await?;
1382 let providers = registry.list();
1383 tracing::info!("Available providers: {:?}", providers);
1384
1385 if providers.is_empty() {
1386 anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1387 }
1388
1389 let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1391 let (prov, model) = parse_model_string(model_str);
1392 if prov.is_some() {
1393 (prov.map(|s| s.to_string()), model.to_string())
1394 } else if providers.contains(&model) {
1395 (Some(model.to_string()), String::new())
1396 } else {
1397 (None, model.to_string())
1398 }
1399 } else {
1400 (None, String::new())
1401 };
1402
1403 let provider_slice = providers.as_slice();
1404 let provider_requested_but_unavailable = provider_name
1405 .as_deref()
1406 .map(|p| !providers.contains(&p))
1407 .unwrap_or(false);
1408
1409 let selected_provider = provider_name
1411 .as_deref()
1412 .filter(|p| providers.contains(p))
1413 .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1414
1415 let provider = registry
1416 .get(selected_provider)
1417 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1418
1419 session.add_message(Message {
1421 role: Role::User,
1422 content: vec![ContentPart::Text {
1423 text: prompt.to_string(),
1424 }],
1425 });
1426
1427 if session.title.is_none() {
1429 session.generate_title().await?;
1430 }
1431
1432 let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1435 model_id
1436 } else {
1437 default_model_for_provider(selected_provider, model_tier)
1438 };
1439
1440 let tool_registry =
1442 create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1443 let tool_definitions = tool_registry.definitions();
1444
1445 let temperature = if model.starts_with("kimi-k2") {
1446 Some(1.0)
1447 } else {
1448 Some(0.7)
1449 };
1450
1451 tracing::info!(
1452 "Using model: {} via provider: {} (tier: {:?})",
1453 model,
1454 selected_provider,
1455 model_tier
1456 );
1457 tracing::info!(
1458 "Available tools: {} (auto_approve: {:?})",
1459 tool_definitions.len(),
1460 auto_approve
1461 );
1462
1463 let cwd = std::env::var("PWD")
1465 .map(std::path::PathBuf::from)
1466 .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1467 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1468
1469 let mut final_output = String::new();
1470 let max_steps = 50;
1471
1472 for step in 1..=max_steps {
1473 tracing::info!(step = step, "Agent step starting");
1474
1475 let mut messages = vec![Message {
1477 role: Role::System,
1478 content: vec![ContentPart::Text {
1479 text: system_prompt.clone(),
1480 }],
1481 }];
1482 messages.extend(session.messages.clone());
1483
1484 let request = CompletionRequest {
1485 messages,
1486 tools: tool_definitions.clone(),
1487 model: model.clone(),
1488 temperature,
1489 top_p: None,
1490 max_tokens: Some(8192),
1491 stop: Vec::new(),
1492 };
1493
1494 let response = provider.complete(request).await?;
1495
1496 crate::telemetry::TOKEN_USAGE.record_model_usage(
1497 &model,
1498 response.usage.prompt_tokens as u64,
1499 response.usage.completion_tokens as u64,
1500 );
1501
1502 let tool_calls: Vec<(String, String, serde_json::Value)> = response
1504 .message
1505 .content
1506 .iter()
1507 .filter_map(|part| {
1508 if let ContentPart::ToolCall {
1509 id,
1510 name,
1511 arguments,
1512 } = part
1513 {
1514 let args: serde_json::Value =
1515 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1516 Some((id.clone(), name.clone(), args))
1517 } else {
1518 None
1519 }
1520 })
1521 .collect();
1522
1523 for part in &response.message.content {
1525 if let ContentPart::Text { text } = part {
1526 if !text.is_empty() {
1527 final_output.push_str(text);
1528 final_output.push('\n');
1529 if let Some(cb) = output_callback {
1530 cb(text.clone());
1531 }
1532 }
1533 }
1534 }
1535
1536 if tool_calls.is_empty() {
1538 session.add_message(response.message.clone());
1539 break;
1540 }
1541
1542 session.add_message(response.message.clone());
1543
1544 tracing::info!(
1545 step = step,
1546 num_tools = tool_calls.len(),
1547 "Executing tool calls"
1548 );
1549
1550 for (tool_id, tool_name, tool_input) in tool_calls {
1552 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1553
1554 if let Some(cb) = output_callback {
1556 cb(format!("[tool:start:{}]", tool_name));
1557 }
1558
1559 if !is_tool_allowed(&tool_name, auto_approve) {
1561 let msg = format!(
1562 "Tool '{}' requires approval but auto-approve policy is {:?}",
1563 tool_name, auto_approve
1564 );
1565 tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1566 session.add_message(Message {
1567 role: Role::Tool,
1568 content: vec![ContentPart::ToolResult {
1569 tool_call_id: tool_id,
1570 content: msg,
1571 }],
1572 });
1573 continue;
1574 }
1575
1576 let content = if let Some(tool) = tool_registry.get(&tool_name) {
1577 let exec_result: Result<crate::tool::ToolResult> =
1578 tool.execute(tool_input.clone()).await;
1579 match exec_result {
1580 Ok(result) => {
1581 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1582 if let Some(cb) = output_callback {
1583 let status = if result.success { "ok" } else { "err" };
1584 cb(format!(
1585 "[tool:{}:{}] {}",
1586 tool_name,
1587 status,
1588 &result.output[..result.output.len().min(500)]
1589 ));
1590 }
1591 result.output
1592 }
1593 Err(e) => {
1594 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1595 if let Some(cb) = output_callback {
1596 cb(format!("[tool:{}:err] {}", tool_name, e));
1597 }
1598 format!("Error: {}", e)
1599 }
1600 }
1601 } else {
1602 tracing::warn!(tool = %tool_name, "Tool not found");
1603 format!("Error: Unknown tool '{}'", tool_name)
1604 };
1605
1606 session.add_message(Message {
1607 role: Role::Tool,
1608 content: vec![ContentPart::ToolResult {
1609 tool_call_id: tool_id,
1610 content,
1611 }],
1612 });
1613 }
1614 }
1615
1616 session.save().await?;
1617
1618 Ok(crate::session::SessionResult {
1619 text: final_output.trim().to_string(),
1620 session_id: session.id.clone(),
1621 })
1622}
1623
1624fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1626 match auto_approve {
1627 AutoApprove::All => true,
1628 AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1629 }
1630}
1631
1632fn is_safe_tool(tool_name: &str) -> bool {
1634 let safe_tools = [
1635 "read",
1636 "list",
1637 "glob",
1638 "grep",
1639 "codesearch",
1640 "lsp",
1641 "webfetch",
1642 "websearch",
1643 "todo_read",
1644 "skill",
1645 ];
1646 safe_tools.contains(&tool_name)
1647}
1648
1649fn create_filtered_registry(
1651 provider: Arc<dyn crate::provider::Provider>,
1652 model: String,
1653 auto_approve: AutoApprove,
1654) -> crate::tool::ToolRegistry {
1655 use crate::tool::*;
1656
1657 let mut registry = ToolRegistry::new();
1658
1659 registry.register(Arc::new(file::ReadTool::new()));
1661 registry.register(Arc::new(file::ListTool::new()));
1662 registry.register(Arc::new(file::GlobTool::new()));
1663 registry.register(Arc::new(search::GrepTool::new()));
1664 registry.register(Arc::new(lsp::LspTool::new()));
1665 registry.register(Arc::new(webfetch::WebFetchTool::new()));
1666 registry.register(Arc::new(websearch::WebSearchTool::new()));
1667 registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1668 registry.register(Arc::new(todo::TodoReadTool::new()));
1669 registry.register(Arc::new(skill::SkillTool::new()));
1670
1671 if matches!(auto_approve, AutoApprove::All) {
1673 registry.register(Arc::new(file::WriteTool::new()));
1674 registry.register(Arc::new(edit::EditTool::new()));
1675 registry.register(Arc::new(bash::BashTool::new()));
1676 registry.register(Arc::new(multiedit::MultiEditTool::new()));
1677 registry.register(Arc::new(patch::ApplyPatchTool::new()));
1678 registry.register(Arc::new(todo::TodoWriteTool::new()));
1679 registry.register(Arc::new(task::TaskTool::new()));
1680 registry.register(Arc::new(plan::PlanEnterTool::new()));
1681 registry.register(Arc::new(plan::PlanExitTool::new()));
1682 registry.register(Arc::new(rlm::RlmTool::new()));
1683 registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1684 registry.register(Arc::new(prd::PrdTool::new()));
1685 registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1686 registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1687 registry.register(Arc::new(undo::UndoTool));
1688 }
1689
1690 registry.register(Arc::new(invalid::InvalidTool::new()));
1691
1692 registry
1693}
1694
1695fn start_heartbeat(
1698 client: Client,
1699 server: String,
1700 token: Option<String>,
1701 heartbeat_state: HeartbeatState,
1702 processing: Arc<Mutex<HashSet<String>>>,
1703 cognition_config: CognitionHeartbeatConfig,
1704) -> JoinHandle<()> {
1705 tokio::spawn(async move {
1706 let mut consecutive_failures = 0u32;
1707 const MAX_FAILURES: u32 = 3;
1708 const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1709 const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1710 let mut cognition_payload_disabled_until: Option<Instant> = None;
1711
1712 let mut interval =
1713 tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1714 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1715
1716 loop {
1717 interval.tick().await;
1718
1719 let active_count = processing.lock().await.len();
1721 heartbeat_state.set_task_count(active_count).await;
1722
1723 let status = if active_count > 0 {
1725 WorkerStatus::Processing
1726 } else {
1727 WorkerStatus::Idle
1728 };
1729 heartbeat_state.set_status(status).await;
1730
1731 let url = format!(
1733 "{}/v1/opencode/workers/{}/heartbeat",
1734 server, heartbeat_state.worker_id
1735 );
1736 let mut req = client.post(&url);
1737
1738 if let Some(ref t) = token {
1739 req = req.bearer_auth(t);
1740 }
1741
1742 let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1743 let base_payload = serde_json::json!({
1744 "worker_id": &heartbeat_state.worker_id,
1745 "agent_name": &heartbeat_state.agent_name,
1746 "status": status_str,
1747 "active_task_count": active_count,
1748 });
1749 let mut payload = base_payload.clone();
1750 let mut included_cognition_payload = false;
1751 let cognition_payload_allowed = cognition_payload_disabled_until
1752 .map(|until| Instant::now() >= until)
1753 .unwrap_or(true);
1754
1755 if cognition_config.enabled
1756 && cognition_payload_allowed
1757 && let Some(cognition_payload) =
1758 fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1759 && let Some(obj) = payload.as_object_mut()
1760 {
1761 obj.insert("cognition".to_string(), cognition_payload);
1762 included_cognition_payload = true;
1763 }
1764
1765 match req.json(&payload).send().await {
1766 Ok(res) => {
1767 if res.status().is_success() {
1768 consecutive_failures = 0;
1769 tracing::debug!(
1770 worker_id = %heartbeat_state.worker_id,
1771 status = status_str,
1772 active_tasks = active_count,
1773 "Heartbeat sent successfully"
1774 );
1775 } else if included_cognition_payload && res.status().is_client_error() {
1776 tracing::warn!(
1777 worker_id = %heartbeat_state.worker_id,
1778 status = %res.status(),
1779 "Heartbeat cognition payload rejected, retrying without cognition payload"
1780 );
1781
1782 let mut retry_req = client.post(&url);
1783 if let Some(ref t) = token {
1784 retry_req = retry_req.bearer_auth(t);
1785 }
1786
1787 match retry_req.json(&base_payload).send().await {
1788 Ok(retry_res) if retry_res.status().is_success() => {
1789 cognition_payload_disabled_until = Some(
1790 Instant::now()
1791 + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1792 );
1793 consecutive_failures = 0;
1794 tracing::warn!(
1795 worker_id = %heartbeat_state.worker_id,
1796 retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1797 "Paused cognition heartbeat payload after schema rejection"
1798 );
1799 }
1800 Ok(retry_res) => {
1801 consecutive_failures += 1;
1802 tracing::warn!(
1803 worker_id = %heartbeat_state.worker_id,
1804 status = %retry_res.status(),
1805 failures = consecutive_failures,
1806 "Heartbeat failed even after retry without cognition payload"
1807 );
1808 }
1809 Err(e) => {
1810 consecutive_failures += 1;
1811 tracing::warn!(
1812 worker_id = %heartbeat_state.worker_id,
1813 error = %e,
1814 failures = consecutive_failures,
1815 "Heartbeat retry without cognition payload failed"
1816 );
1817 }
1818 }
1819 } else {
1820 consecutive_failures += 1;
1821 tracing::warn!(
1822 worker_id = %heartbeat_state.worker_id,
1823 status = %res.status(),
1824 failures = consecutive_failures,
1825 "Heartbeat failed"
1826 );
1827 }
1828 }
1829 Err(e) => {
1830 consecutive_failures += 1;
1831 tracing::warn!(
1832 worker_id = %heartbeat_state.worker_id,
1833 error = %e,
1834 failures = consecutive_failures,
1835 "Heartbeat request failed"
1836 );
1837 }
1838 }
1839
1840 if consecutive_failures >= MAX_FAILURES {
1842 tracing::error!(
1843 worker_id = %heartbeat_state.worker_id,
1844 failures = consecutive_failures,
1845 "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1846 MAX_FAILURES
1847 );
1848 consecutive_failures = 0;
1850 }
1851 }
1852 })
1853}
1854
1855async fn fetch_cognition_heartbeat_payload(
1856 client: &Client,
1857 config: &CognitionHeartbeatConfig,
1858) -> Option<serde_json::Value> {
1859 let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1860 let status_res = tokio::time::timeout(
1861 Duration::from_millis(config.request_timeout_ms),
1862 client.get(status_url).send(),
1863 )
1864 .await
1865 .ok()?
1866 .ok()?;
1867
1868 if !status_res.status().is_success() {
1869 return None;
1870 }
1871
1872 let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1873 let mut payload = serde_json::json!({
1874 "running": status.running,
1875 "last_tick_at": status.last_tick_at,
1876 "active_persona_count": status.active_persona_count,
1877 "events_buffered": status.events_buffered,
1878 "snapshots_buffered": status.snapshots_buffered,
1879 "loop_interval_ms": status.loop_interval_ms,
1880 });
1881
1882 if config.include_thought_summary {
1883 let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1884 let snapshot_res = tokio::time::timeout(
1885 Duration::from_millis(config.request_timeout_ms),
1886 client.get(snapshot_url).send(),
1887 )
1888 .await
1889 .ok()
1890 .and_then(Result::ok);
1891
1892 if let Some(snapshot_res) = snapshot_res
1893 && snapshot_res.status().is_success()
1894 && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1895 && let Some(obj) = payload.as_object_mut()
1896 {
1897 obj.insert(
1898 "latest_snapshot_at".to_string(),
1899 serde_json::Value::String(snapshot.generated_at),
1900 );
1901 obj.insert(
1902 "latest_thought".to_string(),
1903 serde_json::Value::String(trim_for_heartbeat(
1904 &snapshot.summary,
1905 config.summary_max_chars,
1906 )),
1907 );
1908 if let Some(model) = snapshot
1909 .metadata
1910 .get("model")
1911 .and_then(serde_json::Value::as_str)
1912 {
1913 obj.insert(
1914 "latest_thought_model".to_string(),
1915 serde_json::Value::String(model.to_string()),
1916 );
1917 }
1918 if let Some(source) = snapshot
1919 .metadata
1920 .get("source")
1921 .and_then(serde_json::Value::as_str)
1922 {
1923 obj.insert(
1924 "latest_thought_source".to_string(),
1925 serde_json::Value::String(source.to_string()),
1926 );
1927 }
1928 }
1929 }
1930
1931 Some(payload)
1932}
1933
1934fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1935 if input.chars().count() <= max_chars {
1936 return input.trim().to_string();
1937 }
1938
1939 let mut trimmed = String::with_capacity(max_chars + 3);
1940 for ch in input.chars().take(max_chars) {
1941 trimmed.push(ch);
1942 }
1943 trimmed.push_str("...");
1944 trimmed.trim().to_string()
1945}
1946
1947fn env_bool(name: &str, default: bool) -> bool {
1948 std::env::var(name)
1949 .ok()
1950 .and_then(|v| match v.to_ascii_lowercase().as_str() {
1951 "1" | "true" | "yes" | "on" => Some(true),
1952 "0" | "false" | "no" | "off" => Some(false),
1953 _ => None,
1954 })
1955 .unwrap_or(default)
1956}
1957
1958fn env_usize(name: &str, default: usize) -> usize {
1959 std::env::var(name)
1960 .ok()
1961 .and_then(|v| v.parse::<usize>().ok())
1962 .unwrap_or(default)
1963}
1964
1965fn env_u64(name: &str, default: u64) -> u64 {
1966 std::env::var(name)
1967 .ok()
1968 .and_then(|v| v.parse::<u64>().ok())
1969 .unwrap_or(default)
1970}