1use crate::cli::A2aArgs;
4use crate::provider::ProviderRegistry;
5use crate::session::Session;
6use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
7use crate::tui::swarm_view::SwarmEvent;
8use anyhow::Result;
9use futures::StreamExt;
10use reqwest::Client;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::{Mutex, mpsc};
17use tokio::task::JoinHandle;
18use tokio::time::Instant;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum WorkerStatus {
23 Idle,
24 Processing,
25}
26
27impl WorkerStatus {
28 fn as_str(&self) -> &'static str {
29 match self {
30 WorkerStatus::Idle => "idle",
31 WorkerStatus::Processing => "processing",
32 }
33 }
34}
35
36#[derive(Clone)]
38struct HeartbeatState {
39 worker_id: String,
40 agent_name: String,
41 status: Arc<Mutex<WorkerStatus>>,
42 active_task_count: Arc<Mutex<usize>>,
43}
44
45impl HeartbeatState {
46 fn new(worker_id: String, agent_name: String) -> Self {
47 Self {
48 worker_id,
49 agent_name,
50 status: Arc::new(Mutex::new(WorkerStatus::Idle)),
51 active_task_count: Arc::new(Mutex::new(0)),
52 }
53 }
54
55 async fn set_status(&self, status: WorkerStatus) {
56 *self.status.lock().await = status;
57 }
58
59 async fn set_task_count(&self, count: usize) {
60 *self.active_task_count.lock().await = count;
61 }
62}
63
64#[derive(Clone, Debug)]
65struct CognitionHeartbeatConfig {
66 enabled: bool,
67 source_base_url: String,
68 include_thought_summary: bool,
69 summary_max_chars: usize,
70 request_timeout_ms: u64,
71}
72
73impl CognitionHeartbeatConfig {
74 fn from_env() -> Self {
75 let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
76 .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
77 .trim_end_matches('/')
78 .to_string();
79
80 Self {
81 enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
82 source_base_url,
83 include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
84 summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
85 .max(120),
86 request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
87 }
88 }
89}
90
91#[derive(Debug, Deserialize)]
92struct CognitionStatusSnapshot {
93 running: bool,
94 #[serde(default)]
95 last_tick_at: Option<String>,
96 #[serde(default)]
97 active_persona_count: usize,
98 #[serde(default)]
99 events_buffered: usize,
100 #[serde(default)]
101 snapshots_buffered: usize,
102 #[serde(default)]
103 loop_interval_ms: u64,
104}
105
106#[derive(Debug, Deserialize)]
107struct CognitionLatestSnapshot {
108 generated_at: String,
109 summary: String,
110 #[serde(default)]
111 metadata: HashMap<String, serde_json::Value>,
112}
113
114pub async fn run(args: A2aArgs) -> Result<()> {
116 let server = args.server.trim_end_matches('/');
117 let name = args
118 .name
119 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
120 let worker_id = generate_worker_id();
121
122 let codebases: Vec<String> = args
123 .codebases
124 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
125 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
126
127 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
128 tracing::info!("Server: {}", server);
129 tracing::info!("Codebases: {:?}", codebases);
130
131 let client = Client::new();
132 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
133 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
134 if cognition_heartbeat.enabled {
135 tracing::info!(
136 source = %cognition_heartbeat.source_base_url,
137 include_thoughts = cognition_heartbeat.include_thought_summary,
138 max_chars = cognition_heartbeat.summary_max_chars,
139 timeout_ms = cognition_heartbeat.request_timeout_ms,
140 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
141 );
142 } else {
143 tracing::warn!(
144 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
145 );
146 }
147
148 let auto_approve = match args.auto_approve.as_str() {
149 "all" => AutoApprove::All,
150 "safe" => AutoApprove::Safe,
151 _ => AutoApprove::None,
152 };
153
154 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
156
157 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
159
160 fetch_pending_tasks(
162 &client,
163 server,
164 &args.token,
165 &worker_id,
166 &processing,
167 &auto_approve,
168 )
169 .await?;
170
171 loop {
173 if let Err(e) =
175 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
176 {
177 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
178 }
179
180 let heartbeat_handle = start_heartbeat(
182 client.clone(),
183 server.to_string(),
184 args.token.clone(),
185 heartbeat_state.clone(),
186 processing.clone(),
187 cognition_heartbeat.clone(),
188 );
189
190 match connect_stream(
191 &client,
192 server,
193 &args.token,
194 &worker_id,
195 &name,
196 &codebases,
197 &processing,
198 &auto_approve,
199 )
200 .await
201 {
202 Ok(()) => {
203 tracing::warn!("Stream ended, reconnecting...");
204 }
205 Err(e) => {
206 tracing::error!("Stream error: {}, reconnecting...", e);
207 }
208 }
209
210 heartbeat_handle.abort();
212 tracing::debug!("Heartbeat cancelled for reconnection");
213
214 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
215 }
216}
217
218fn generate_worker_id() -> String {
219 format!(
220 "wrk_{}_{:x}",
221 chrono::Utc::now().timestamp(),
222 rand::random::<u64>()
223 )
224}
225
226#[derive(Debug, Clone, Copy)]
227enum AutoApprove {
228 All,
229 Safe,
230 None,
231}
232
233const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
235
236fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
237 task.get("task")
238 .and_then(|t| t.get(key))
239 .or_else(|| task.get(key))
240}
241
242fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
243 task_value(task, key).and_then(|v| v.as_str())
244}
245
246fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
247 task_value(task, "metadata")
248 .and_then(|m| m.as_object())
249 .cloned()
250 .unwrap_or_default()
251}
252
253fn model_ref_to_provider_model(model: &str) -> String {
254 if !model.contains('/') && model.contains(':') {
258 model.replacen(':', "/", 1)
259 } else {
260 model.to_string()
261 }
262}
263
264fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
265 match model_tier.unwrap_or("balanced") {
266 "fast" | "quick" => &[
267 "google",
268 "openai",
269 "moonshotai",
270 "zhipuai",
271 "anthropic",
272 "openrouter",
273 "novita",
274 ],
275 "heavy" | "deep" => &[
276 "anthropic",
277 "openai",
278 "google",
279 "moonshotai",
280 "zhipuai",
281 "openrouter",
282 "novita",
283 ],
284 _ => &[
285 "openai",
286 "anthropic",
287 "google",
288 "moonshotai",
289 "zhipuai",
290 "openrouter",
291 "novita",
292 ],
293 }
294}
295
296fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
297 for preferred in provider_preferences_for_tier(model_tier) {
298 if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
299 return found;
300 }
301 }
302 if let Some(found) = providers.iter().copied().find(|p| *p == "zhipuai") {
303 return found;
304 }
305 providers[0]
306}
307
308fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
309 match model_tier.unwrap_or("balanced") {
310 "fast" | "quick" => match provider {
311 "moonshotai" => "kimi-k2.5".to_string(),
312 "anthropic" => "claude-haiku-4-5".to_string(),
313 "openai" => "gpt-4o-mini".to_string(),
314 "google" => "gemini-2.5-flash".to_string(),
315 "zhipuai" => "glm-4.7".to_string(),
316 "openrouter" => "z-ai/glm-4.7".to_string(),
317 "novita" => "qwen/qwen3-coder-next".to_string(),
318 _ => "glm-4.7".to_string(),
319 },
320 "heavy" | "deep" => match provider {
321 "moonshotai" => "kimi-k2.5".to_string(),
322 "anthropic" => "claude-sonnet-4-20250514".to_string(),
323 "openai" => "o3".to_string(),
324 "google" => "gemini-2.5-pro".to_string(),
325 "zhipuai" => "glm-4.7".to_string(),
326 "openrouter" => "z-ai/glm-4.7".to_string(),
327 "novita" => "qwen/qwen3-coder-next".to_string(),
328 _ => "glm-4.7".to_string(),
329 },
330 _ => match provider {
331 "moonshotai" => "kimi-k2.5".to_string(),
332 "anthropic" => "claude-sonnet-4-20250514".to_string(),
333 "openai" => "gpt-4o".to_string(),
334 "google" => "gemini-2.5-pro".to_string(),
335 "zhipuai" => "glm-4.7".to_string(),
336 "openrouter" => "z-ai/glm-4.7".to_string(),
337 "novita" => "qwen/qwen3-coder-next".to_string(),
338 _ => "glm-4.7".to_string(),
339 },
340 }
341}
342
343fn is_swarm_agent(agent_type: &str) -> bool {
344 matches!(
345 agent_type.trim().to_ascii_lowercase().as_str(),
346 "swarm" | "parallel" | "multi-agent"
347 )
348}
349
350fn metadata_lookup<'a>(
351 metadata: &'a serde_json::Map<String, serde_json::Value>,
352 key: &str,
353) -> Option<&'a serde_json::Value> {
354 metadata
355 .get(key)
356 .or_else(|| {
357 metadata
358 .get("routing")
359 .and_then(|v| v.as_object())
360 .and_then(|obj| obj.get(key))
361 })
362 .or_else(|| {
363 metadata
364 .get("swarm")
365 .and_then(|v| v.as_object())
366 .and_then(|obj| obj.get(key))
367 })
368}
369
370fn metadata_str(
371 metadata: &serde_json::Map<String, serde_json::Value>,
372 keys: &[&str],
373) -> Option<String> {
374 for key in keys {
375 if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
376 let trimmed = value.trim();
377 if !trimmed.is_empty() {
378 return Some(trimmed.to_string());
379 }
380 }
381 }
382 None
383}
384
385fn metadata_usize(
386 metadata: &serde_json::Map<String, serde_json::Value>,
387 keys: &[&str],
388) -> Option<usize> {
389 for key in keys {
390 if let Some(value) = metadata_lookup(metadata, key) {
391 if let Some(v) = value.as_u64() {
392 return usize::try_from(v).ok();
393 }
394 if let Some(v) = value.as_i64() {
395 if v >= 0 {
396 return usize::try_from(v as u64).ok();
397 }
398 }
399 if let Some(v) = value.as_str() {
400 if let Ok(parsed) = v.trim().parse::<usize>() {
401 return Some(parsed);
402 }
403 }
404 }
405 }
406 None
407}
408
409fn metadata_u64(
410 metadata: &serde_json::Map<String, serde_json::Value>,
411 keys: &[&str],
412) -> Option<u64> {
413 for key in keys {
414 if let Some(value) = metadata_lookup(metadata, key) {
415 if let Some(v) = value.as_u64() {
416 return Some(v);
417 }
418 if let Some(v) = value.as_i64() {
419 if v >= 0 {
420 return Some(v as u64);
421 }
422 }
423 if let Some(v) = value.as_str() {
424 if let Ok(parsed) = v.trim().parse::<u64>() {
425 return Some(parsed);
426 }
427 }
428 }
429 }
430 None
431}
432
433fn metadata_bool(
434 metadata: &serde_json::Map<String, serde_json::Value>,
435 keys: &[&str],
436) -> Option<bool> {
437 for key in keys {
438 if let Some(value) = metadata_lookup(metadata, key) {
439 if let Some(v) = value.as_bool() {
440 return Some(v);
441 }
442 if let Some(v) = value.as_str() {
443 match v.trim().to_ascii_lowercase().as_str() {
444 "1" | "true" | "yes" | "on" => return Some(true),
445 "0" | "false" | "no" | "off" => return Some(false),
446 _ => {}
447 }
448 }
449 }
450 }
451 None
452}
453
454fn parse_swarm_strategy(
455 metadata: &serde_json::Map<String, serde_json::Value>,
456) -> DecompositionStrategy {
457 match metadata_str(
458 metadata,
459 &[
460 "decomposition_strategy",
461 "swarm_strategy",
462 "strategy",
463 "swarm_decomposition",
464 ],
465 )
466 .as_deref()
467 .map(|s| s.to_ascii_lowercase())
468 .as_deref()
469 {
470 Some("none") | Some("single") => DecompositionStrategy::None,
471 Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
472 Some("data") | Some("by_data") => DecompositionStrategy::ByData,
473 Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
474 _ => DecompositionStrategy::Automatic,
475 }
476}
477
478async fn resolve_swarm_model(
479 explicit_model: Option<String>,
480 model_tier: Option<&str>,
481) -> Option<String> {
482 if let Some(model) = explicit_model {
483 if !model.trim().is_empty() {
484 return Some(model);
485 }
486 }
487
488 let registry = ProviderRegistry::from_vault().await.ok()?;
489 let providers = registry.list();
490 if providers.is_empty() {
491 return None;
492 }
493 let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
494 let model = default_model_for_provider(provider, model_tier);
495 Some(format!("{}/{}", provider, model))
496}
497
498fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
499 match event {
500 SwarmEvent::Started {
501 task,
502 total_subtasks,
503 } => Some(format!(
504 "[swarm] started task={} planned_subtasks={}",
505 task, total_subtasks
506 )),
507 SwarmEvent::StageComplete {
508 stage,
509 completed,
510 failed,
511 } => Some(format!(
512 "[swarm] stage={} completed={} failed={}",
513 stage, completed, failed
514 )),
515 SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
516 "[swarm] subtask id={} status={}",
517 &id.chars().take(8).collect::<String>(),
518 format!("{status:?}").to_ascii_lowercase()
519 )),
520 SwarmEvent::AgentToolCall {
521 subtask_id,
522 tool_name,
523 } => Some(format!(
524 "[swarm] subtask id={} tool={}",
525 &subtask_id.chars().take(8).collect::<String>(),
526 tool_name
527 )),
528 SwarmEvent::AgentError { subtask_id, error } => Some(format!(
529 "[swarm] subtask id={} error={}",
530 &subtask_id.chars().take(8).collect::<String>(),
531 error
532 )),
533 SwarmEvent::Complete { success, stats } => Some(format!(
534 "[swarm] complete success={} subtasks={} speedup={:.2}",
535 success,
536 stats.subagents_completed + stats.subagents_failed,
537 stats.speedup_factor
538 )),
539 SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
540 _ => None,
541 }
542}
543
544async fn register_worker(
545 client: &Client,
546 server: &str,
547 token: &Option<String>,
548 worker_id: &str,
549 name: &str,
550 codebases: &[String],
551) -> Result<()> {
552 let models = match load_provider_models().await {
554 Ok(m) => m,
555 Err(e) => {
556 tracing::warn!(
557 "Failed to load provider models: {}, proceeding without model info",
558 e
559 );
560 HashMap::new()
561 }
562 };
563
564 let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
566
567 if let Some(t) = token {
568 req = req.bearer_auth(t);
569 }
570
571 let models_array: Vec<serde_json::Value> = models
574 .iter()
575 .flat_map(|(provider, model_ids)| {
576 model_ids.iter().map(move |model_id| {
577 serde_json::json!({
578 "id": format!("{}/{}", provider, model_id),
579 "name": model_id,
580 "provider": provider,
581 "provider_id": provider,
582 })
583 })
584 })
585 .collect();
586
587 tracing::info!(
588 "Registering worker with {} models from {} providers",
589 models_array.len(),
590 models.len()
591 );
592
593 let hostname = std::env::var("HOSTNAME")
594 .or_else(|_| std::env::var("COMPUTERNAME"))
595 .unwrap_or_else(|_| "unknown".to_string());
596
597 let res = req
598 .json(&serde_json::json!({
599 "worker_id": worker_id,
600 "name": name,
601 "capabilities": WORKER_CAPABILITIES,
602 "hostname": hostname,
603 "models": models_array,
604 "codebases": codebases,
605 }))
606 .send()
607 .await?;
608
609 if res.status().is_success() {
610 tracing::info!("Worker registered successfully");
611 } else {
612 tracing::warn!("Failed to register worker: {}", res.status());
613 }
614
615 Ok(())
616}
617
618async fn load_provider_models() -> Result<HashMap<String, Vec<String>>> {
621 let registry = match ProviderRegistry::from_vault().await {
623 Ok(r) if !r.list().is_empty() => {
624 tracing::info!("Loaded {} providers from Vault", r.list().len());
625 r
626 }
627 Ok(_) => {
628 tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
629 fallback_registry().await?
630 }
631 Err(e) => {
632 tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
633 fallback_registry().await?
634 }
635 };
636
637 let mut models_by_provider: HashMap<String, Vec<String>> = HashMap::new();
638
639 for provider_name in registry.list() {
640 if let Some(provider) = registry.get(provider_name) {
641 match provider.list_models().await {
642 Ok(models) => {
643 let model_ids: Vec<String> = models.into_iter().map(|m| m.id).collect();
644 if !model_ids.is_empty() {
645 tracing::debug!("Provider {}: {} models", provider_name, model_ids.len());
646 models_by_provider.insert(provider_name.to_string(), model_ids);
647 }
648 }
649 Err(e) => {
650 tracing::debug!("Failed to list models for {}: {}", provider_name, e);
651 }
652 }
653 }
654 }
655
656 if models_by_provider.is_empty() {
660 tracing::info!(
661 "No authenticated providers found, fetching models.dev catalog (all providers)"
662 );
663 if let Ok(catalog) = crate::provider::models::ModelCatalog::fetch().await {
664 for (provider_id, provider_info) in catalog.all_providers() {
667 let model_ids: Vec<String> = provider_info
668 .models
669 .values()
670 .map(|m| m.id.clone())
671 .collect();
672 if !model_ids.is_empty() {
673 tracing::debug!(
674 "Catalog provider {}: {} models",
675 provider_id,
676 model_ids.len()
677 );
678 models_by_provider.insert(provider_id.clone(), model_ids);
679 }
680 }
681 tracing::info!(
682 "Loaded {} providers with {} total models from catalog",
683 models_by_provider.len(),
684 models_by_provider.values().map(|v| v.len()).sum::<usize>()
685 );
686 }
687 }
688
689 Ok(models_by_provider)
690}
691
692async fn fallback_registry() -> Result<ProviderRegistry> {
694 let config = crate::config::Config::load().await.unwrap_or_default();
695 ProviderRegistry::from_config(&config).await
696}
697
698async fn fetch_pending_tasks(
699 client: &Client,
700 server: &str,
701 token: &Option<String>,
702 worker_id: &str,
703 processing: &Arc<Mutex<HashSet<String>>>,
704 auto_approve: &AutoApprove,
705) -> Result<()> {
706 tracing::info!("Checking for pending tasks...");
707
708 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
709 if let Some(t) = token {
710 req = req.bearer_auth(t);
711 }
712
713 let res = req.send().await?;
714 if !res.status().is_success() {
715 return Ok(());
716 }
717
718 let data: serde_json::Value = res.json().await?;
719 let tasks = if let Some(arr) = data.as_array() {
721 arr.clone()
722 } else {
723 data["tasks"].as_array().cloned().unwrap_or_default()
724 };
725
726 tracing::info!("Found {} pending task(s)", tasks.len());
727
728 for task in tasks {
729 if let Some(id) = task["id"].as_str() {
730 let mut proc = processing.lock().await;
731 if !proc.contains(id) {
732 proc.insert(id.to_string());
733 drop(proc);
734
735 let task_id = id.to_string();
736 let client = client.clone();
737 let server = server.to_string();
738 let token = token.clone();
739 let worker_id = worker_id.to_string();
740 let auto_approve = *auto_approve;
741 let processing = processing.clone();
742
743 tokio::spawn(async move {
744 if let Err(e) =
745 handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
746 {
747 tracing::error!("Task {} failed: {}", task_id, e);
748 }
749 processing.lock().await.remove(&task_id);
750 });
751 }
752 }
753 }
754
755 Ok(())
756}
757
758#[allow(clippy::too_many_arguments)]
759async fn connect_stream(
760 client: &Client,
761 server: &str,
762 token: &Option<String>,
763 worker_id: &str,
764 name: &str,
765 codebases: &[String],
766 processing: &Arc<Mutex<HashSet<String>>>,
767 auto_approve: &AutoApprove,
768) -> Result<()> {
769 let url = format!(
770 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
771 server,
772 urlencoding::encode(name),
773 urlencoding::encode(worker_id)
774 );
775
776 let mut req = client
777 .get(&url)
778 .header("Accept", "text/event-stream")
779 .header("X-Worker-ID", worker_id)
780 .header("X-Agent-Name", name)
781 .header("X-Codebases", codebases.join(","));
782
783 if let Some(t) = token {
784 req = req.bearer_auth(t);
785 }
786
787 let res = req.send().await?;
788 if !res.status().is_success() {
789 anyhow::bail!("Failed to connect: {}", res.status());
790 }
791
792 tracing::info!("Connected to A2A server");
793
794 let mut stream = res.bytes_stream();
795 let mut buffer = String::new();
796 let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
797 poll_interval.tick().await; loop {
800 tokio::select! {
801 chunk = stream.next() => {
802 match chunk {
803 Some(Ok(chunk)) => {
804 buffer.push_str(&String::from_utf8_lossy(&chunk));
805
806 while let Some(pos) = buffer.find("\n\n") {
808 let event_str = buffer[..pos].to_string();
809 buffer = buffer[pos + 2..].to_string();
810
811 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
812 let data = data_line.trim_start_matches("data:").trim();
813 if data == "[DONE]" || data.is_empty() {
814 continue;
815 }
816
817 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
818 spawn_task_handler(
819 &task, client, server, token, worker_id,
820 processing, auto_approve,
821 ).await;
822 }
823 }
824 }
825 }
826 Some(Err(e)) => {
827 return Err(e.into());
828 }
829 None => {
830 return Ok(());
832 }
833 }
834 }
835 _ = poll_interval.tick() => {
836 if let Err(e) = poll_pending_tasks(
838 client, server, token, worker_id, processing, auto_approve,
839 ).await {
840 tracing::warn!("Periodic task poll failed: {}", e);
841 }
842 }
843 }
844 }
845}
846
847async fn spawn_task_handler(
848 task: &serde_json::Value,
849 client: &Client,
850 server: &str,
851 token: &Option<String>,
852 worker_id: &str,
853 processing: &Arc<Mutex<HashSet<String>>>,
854 auto_approve: &AutoApprove,
855) {
856 if let Some(id) = task
857 .get("task")
858 .and_then(|t| t["id"].as_str())
859 .or_else(|| task["id"].as_str())
860 {
861 let mut proc = processing.lock().await;
862 if !proc.contains(id) {
863 proc.insert(id.to_string());
864 drop(proc);
865
866 let task_id = id.to_string();
867 let task = task.clone();
868 let client = client.clone();
869 let server = server.to_string();
870 let token = token.clone();
871 let worker_id = worker_id.to_string();
872 let auto_approve = *auto_approve;
873 let processing_clone = processing.clone();
874
875 tokio::spawn(async move {
876 if let Err(e) =
877 handle_task(&client, &server, &token, &worker_id, &task, auto_approve).await
878 {
879 tracing::error!("Task {} failed: {}", task_id, e);
880 }
881 processing_clone.lock().await.remove(&task_id);
882 });
883 }
884 }
885}
886
887async fn poll_pending_tasks(
888 client: &Client,
889 server: &str,
890 token: &Option<String>,
891 worker_id: &str,
892 processing: &Arc<Mutex<HashSet<String>>>,
893 auto_approve: &AutoApprove,
894) -> Result<()> {
895 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
896 if let Some(t) = token {
897 req = req.bearer_auth(t);
898 }
899
900 let res = req.send().await?;
901 if !res.status().is_success() {
902 return Ok(());
903 }
904
905 let data: serde_json::Value = res.json().await?;
906 let tasks = if let Some(arr) = data.as_array() {
907 arr.clone()
908 } else {
909 data["tasks"].as_array().cloned().unwrap_or_default()
910 };
911
912 if !tasks.is_empty() {
913 tracing::debug!("Poll found {} pending task(s)", tasks.len());
914 }
915
916 for task in &tasks {
917 spawn_task_handler(
918 task,
919 client,
920 server,
921 token,
922 worker_id,
923 processing,
924 auto_approve,
925 )
926 .await;
927 }
928
929 Ok(())
930}
931
932async fn handle_task(
933 client: &Client,
934 server: &str,
935 token: &Option<String>,
936 worker_id: &str,
937 task: &serde_json::Value,
938 auto_approve: AutoApprove,
939) -> Result<()> {
940 let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
941 let title = task_str(task, "title").unwrap_or("Untitled");
942
943 tracing::info!("Handling task: {} ({})", title, task_id);
944
945 let mut req = client
947 .post(format!("{}/v1/worker/tasks/claim", server))
948 .header("X-Worker-ID", worker_id);
949 if let Some(t) = token {
950 req = req.bearer_auth(t);
951 }
952
953 let res = req
954 .json(&serde_json::json!({ "task_id": task_id }))
955 .send()
956 .await?;
957
958 if !res.status().is_success() {
959 let text = res.text().await?;
960 tracing::warn!("Failed to claim task: {}", text);
961 return Ok(());
962 }
963
964 tracing::info!("Claimed task: {}", task_id);
965
966 let metadata = task_metadata(task);
967 let resume_session_id = metadata
968 .get("resume_session_id")
969 .and_then(|v| v.as_str())
970 .map(|s| s.trim().to_string())
971 .filter(|s| !s.is_empty());
972 let complexity_hint = metadata_str(&metadata, &["complexity"]);
973 let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
974 .map(|s| s.to_ascii_lowercase())
975 .or_else(|| {
976 complexity_hint.as_ref().map(|complexity| {
977 match complexity.to_ascii_lowercase().as_str() {
978 "quick" => "fast".to_string(),
979 "deep" => "heavy".to_string(),
980 _ => "balanced".to_string(),
981 }
982 })
983 });
984 let worker_personality = metadata_str(
985 &metadata,
986 &["worker_personality", "personality", "agent_personality"],
987 );
988 let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
989 let raw_model = task_str(task, "model_ref")
990 .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
991 .or_else(|| task_str(task, "model"))
992 .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
993 let selected_model = raw_model.map(model_ref_to_provider_model);
994
995 let mut session = if let Some(ref sid) = resume_session_id {
997 match Session::load(sid).await {
998 Ok(existing) => {
999 tracing::info!("Resuming session {} for task {}", sid, task_id);
1000 existing
1001 }
1002 Err(e) => {
1003 tracing::warn!(
1004 "Could not load session {} for task {} ({}), starting a new session",
1005 sid,
1006 task_id,
1007 e
1008 );
1009 Session::new().await?
1010 }
1011 }
1012 } else {
1013 Session::new().await?
1014 };
1015
1016 let agent_type = task_str(task, "agent_type")
1017 .or_else(|| task_str(task, "agent"))
1018 .unwrap_or("build");
1019 session.agent = agent_type.to_string();
1020
1021 if let Some(model) = selected_model.clone() {
1022 session.metadata.model = Some(model);
1023 }
1024
1025 let prompt = task_str(task, "prompt")
1026 .or_else(|| task_str(task, "description"))
1027 .unwrap_or(title);
1028
1029 tracing::info!("Executing prompt: {}", prompt);
1030
1031 let stream_client = client.clone();
1033 let stream_server = server.to_string();
1034 let stream_token = token.clone();
1035 let stream_worker_id = worker_id.to_string();
1036 let stream_task_id = task_id.to_string();
1037
1038 let output_callback = move |output: String| {
1039 let c = stream_client.clone();
1040 let s = stream_server.clone();
1041 let t = stream_token.clone();
1042 let w = stream_worker_id.clone();
1043 let tid = stream_task_id.clone();
1044 tokio::spawn(async move {
1045 let mut req = c
1046 .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1047 .header("X-Worker-ID", &w);
1048 if let Some(tok) = &t {
1049 req = req.bearer_auth(tok);
1050 }
1051 let _ = req
1052 .json(&serde_json::json!({
1053 "worker_id": w,
1054 "output": output,
1055 }))
1056 .send()
1057 .await;
1058 });
1059 };
1060
1061 let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1063 match execute_swarm_with_policy(
1064 &mut session,
1065 prompt,
1066 model_tier.as_deref(),
1067 selected_model,
1068 &metadata,
1069 complexity_hint.as_deref(),
1070 worker_personality.as_deref(),
1071 target_agent_name.as_deref(),
1072 Some(&output_callback),
1073 )
1074 .await
1075 {
1076 Ok((session_result, true)) => {
1077 tracing::info!("Swarm task completed successfully: {}", task_id);
1078 (
1079 "completed",
1080 Some(session_result.text),
1081 None,
1082 Some(session_result.session_id),
1083 )
1084 }
1085 Ok((session_result, false)) => {
1086 tracing::warn!("Swarm task completed with failures: {}", task_id);
1087 (
1088 "failed",
1089 Some(session_result.text),
1090 Some("Swarm execution completed with failures".to_string()),
1091 Some(session_result.session_id),
1092 )
1093 }
1094 Err(e) => {
1095 tracing::error!("Swarm task failed: {} - {}", task_id, e);
1096 ("failed", None, Some(format!("Error: {}", e)), None)
1097 }
1098 }
1099 } else {
1100 match execute_session_with_policy(
1101 &mut session,
1102 prompt,
1103 auto_approve,
1104 model_tier.as_deref(),
1105 Some(&output_callback),
1106 )
1107 .await
1108 {
1109 Ok(session_result) => {
1110 tracing::info!("Task completed successfully: {}", task_id);
1111 (
1112 "completed",
1113 Some(session_result.text),
1114 None,
1115 Some(session_result.session_id),
1116 )
1117 }
1118 Err(e) => {
1119 tracing::error!("Task failed: {} - {}", task_id, e);
1120 ("failed", None, Some(format!("Error: {}", e)), None)
1121 }
1122 }
1123 };
1124
1125 let mut req = client
1127 .post(format!("{}/v1/worker/tasks/release", server))
1128 .header("X-Worker-ID", worker_id);
1129 if let Some(t) = token {
1130 req = req.bearer_auth(t);
1131 }
1132
1133 req.json(&serde_json::json!({
1134 "task_id": task_id,
1135 "status": status,
1136 "result": result,
1137 "error": error,
1138 "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1139 }))
1140 .send()
1141 .await?;
1142
1143 tracing::info!("Task released: {} with status: {}", task_id, status);
1144 Ok(())
1145}
1146
1147async fn execute_swarm_with_policy<F>(
1148 session: &mut Session,
1149 prompt: &str,
1150 model_tier: Option<&str>,
1151 explicit_model: Option<String>,
1152 metadata: &serde_json::Map<String, serde_json::Value>,
1153 complexity_hint: Option<&str>,
1154 worker_personality: Option<&str>,
1155 target_agent_name: Option<&str>,
1156 output_callback: Option<&F>,
1157) -> Result<(crate::session::SessionResult, bool)>
1158where
1159 F: Fn(String),
1160{
1161 use crate::provider::{ContentPart, Message, Role};
1162
1163 session.add_message(Message {
1164 role: Role::User,
1165 content: vec![ContentPart::Text {
1166 text: prompt.to_string(),
1167 }],
1168 });
1169
1170 if session.title.is_none() {
1171 session.generate_title().await?;
1172 }
1173
1174 let strategy = parse_swarm_strategy(metadata);
1175 let max_subagents = metadata_usize(
1176 metadata,
1177 &["swarm_max_subagents", "max_subagents", "subagents"],
1178 )
1179 .unwrap_or(10)
1180 .clamp(1, 100);
1181 let max_steps_per_subagent = metadata_usize(
1182 metadata,
1183 &[
1184 "swarm_max_steps_per_subagent",
1185 "max_steps_per_subagent",
1186 "max_steps",
1187 ],
1188 )
1189 .unwrap_or(50)
1190 .clamp(1, 200);
1191 let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1192 .unwrap_or(600)
1193 .clamp(30, 3600);
1194 let parallel_enabled =
1195 metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1196
1197 let model = resolve_swarm_model(explicit_model, model_tier).await;
1198 if let Some(ref selected_model) = model {
1199 session.metadata.model = Some(selected_model.clone());
1200 }
1201
1202 if let Some(cb) = output_callback {
1203 cb(format!(
1204 "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1205 complexity_hint.unwrap_or("standard"),
1206 model_tier.unwrap_or("balanced"),
1207 worker_personality.unwrap_or("auto"),
1208 target_agent_name.unwrap_or("auto")
1209 ));
1210 cb(format!(
1211 "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1212 strategy,
1213 max_subagents,
1214 max_steps_per_subagent,
1215 timeout_secs,
1216 model_tier.unwrap_or("balanced")
1217 ));
1218 }
1219
1220 let swarm_config = SwarmConfig {
1221 max_subagents,
1222 max_steps_per_subagent,
1223 subagent_timeout_secs: timeout_secs,
1224 parallel_enabled,
1225 model,
1226 working_dir: session
1227 .metadata
1228 .directory
1229 .as_ref()
1230 .map(|p| p.to_string_lossy().to_string()),
1231 ..Default::default()
1232 };
1233
1234 let swarm_result = if output_callback.is_some() {
1235 let (event_tx, mut event_rx) = mpsc::channel(256);
1236 let executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1237 let prompt_owned = prompt.to_string();
1238 let mut exec_handle =
1239 tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1240
1241 let mut final_result: Option<crate::swarm::SwarmResult> = None;
1242
1243 while final_result.is_none() {
1244 tokio::select! {
1245 maybe_event = event_rx.recv() => {
1246 if let Some(event) = maybe_event {
1247 if let Some(cb) = output_callback {
1248 if let Some(line) = format_swarm_event_for_output(&event) {
1249 cb(line);
1250 }
1251 }
1252 }
1253 }
1254 join_result = &mut exec_handle => {
1255 let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1256 final_result = Some(joined?);
1257 }
1258 }
1259 }
1260
1261 while let Ok(event) = event_rx.try_recv() {
1262 if let Some(cb) = output_callback {
1263 if let Some(line) = format_swarm_event_for_output(&event) {
1264 cb(line);
1265 }
1266 }
1267 }
1268
1269 final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1270 } else {
1271 SwarmExecutor::new(swarm_config)
1272 .execute(prompt, strategy)
1273 .await?
1274 };
1275
1276 let final_text = if swarm_result.result.trim().is_empty() {
1277 if swarm_result.success {
1278 "Swarm completed without textual output.".to_string()
1279 } else {
1280 "Swarm finished with failures and no textual output.".to_string()
1281 }
1282 } else {
1283 swarm_result.result.clone()
1284 };
1285
1286 session.add_message(Message {
1287 role: Role::Assistant,
1288 content: vec![ContentPart::Text {
1289 text: final_text.clone(),
1290 }],
1291 });
1292 session.save().await?;
1293
1294 Ok((
1295 crate::session::SessionResult {
1296 text: final_text,
1297 session_id: session.id.clone(),
1298 },
1299 swarm_result.success,
1300 ))
1301}
1302
1303async fn execute_session_with_policy<F>(
1306 session: &mut Session,
1307 prompt: &str,
1308 auto_approve: AutoApprove,
1309 model_tier: Option<&str>,
1310 output_callback: Option<&F>,
1311) -> Result<crate::session::SessionResult>
1312where
1313 F: Fn(String),
1314{
1315 use crate::provider::{
1316 CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1317 };
1318 use std::sync::Arc;
1319
1320 let registry = ProviderRegistry::from_vault().await?;
1322 let providers = registry.list();
1323 tracing::info!("Available providers: {:?}", providers);
1324
1325 if providers.is_empty() {
1326 anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1327 }
1328
1329 let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1331 let (prov, model) = parse_model_string(model_str);
1332 if prov.is_some() {
1333 (prov.map(|s| s.to_string()), model.to_string())
1334 } else if providers.contains(&model) {
1335 (Some(model.to_string()), String::new())
1336 } else {
1337 (None, model.to_string())
1338 }
1339 } else {
1340 (None, String::new())
1341 };
1342
1343 let provider_slice = providers.as_slice();
1344 let provider_requested_but_unavailable = provider_name
1345 .as_deref()
1346 .map(|p| !providers.contains(&p))
1347 .unwrap_or(false);
1348
1349 let selected_provider = provider_name
1351 .as_deref()
1352 .filter(|p| providers.contains(p))
1353 .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1354
1355 let provider = registry
1356 .get(selected_provider)
1357 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1358
1359 session.add_message(Message {
1361 role: Role::User,
1362 content: vec![ContentPart::Text {
1363 text: prompt.to_string(),
1364 }],
1365 });
1366
1367 if session.title.is_none() {
1369 session.generate_title().await?;
1370 }
1371
1372 let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1375 model_id
1376 } else {
1377 default_model_for_provider(selected_provider, model_tier)
1378 };
1379
1380 let tool_registry =
1382 create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1383 let tool_definitions = tool_registry.definitions();
1384
1385 let temperature = if model.starts_with("kimi-k2") {
1386 Some(1.0)
1387 } else {
1388 Some(0.7)
1389 };
1390
1391 tracing::info!(
1392 "Using model: {} via provider: {} (tier: {:?})",
1393 model,
1394 selected_provider,
1395 model_tier
1396 );
1397 tracing::info!(
1398 "Available tools: {} (auto_approve: {:?})",
1399 tool_definitions.len(),
1400 auto_approve
1401 );
1402
1403 let cwd = std::env::var("PWD")
1405 .map(std::path::PathBuf::from)
1406 .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1407 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1408
1409 let mut final_output = String::new();
1410 let max_steps = 50;
1411
1412 for step in 1..=max_steps {
1413 tracing::info!(step = step, "Agent step starting");
1414
1415 let mut messages = vec![Message {
1417 role: Role::System,
1418 content: vec![ContentPart::Text {
1419 text: system_prompt.clone(),
1420 }],
1421 }];
1422 messages.extend(session.messages.clone());
1423
1424 let request = CompletionRequest {
1425 messages,
1426 tools: tool_definitions.clone(),
1427 model: model.clone(),
1428 temperature,
1429 top_p: None,
1430 max_tokens: Some(8192),
1431 stop: Vec::new(),
1432 };
1433
1434 let response = provider.complete(request).await?;
1435
1436 crate::telemetry::TOKEN_USAGE.record_model_usage(
1437 &model,
1438 response.usage.prompt_tokens as u64,
1439 response.usage.completion_tokens as u64,
1440 );
1441
1442 let tool_calls: Vec<(String, String, serde_json::Value)> = response
1444 .message
1445 .content
1446 .iter()
1447 .filter_map(|part| {
1448 if let ContentPart::ToolCall {
1449 id,
1450 name,
1451 arguments,
1452 } = part
1453 {
1454 let args: serde_json::Value =
1455 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1456 Some((id.clone(), name.clone(), args))
1457 } else {
1458 None
1459 }
1460 })
1461 .collect();
1462
1463 for part in &response.message.content {
1465 if let ContentPart::Text { text } = part {
1466 if !text.is_empty() {
1467 final_output.push_str(text);
1468 final_output.push('\n');
1469 if let Some(cb) = output_callback {
1470 cb(text.clone());
1471 }
1472 }
1473 }
1474 }
1475
1476 if tool_calls.is_empty() {
1478 session.add_message(response.message.clone());
1479 break;
1480 }
1481
1482 session.add_message(response.message.clone());
1483
1484 tracing::info!(
1485 step = step,
1486 num_tools = tool_calls.len(),
1487 "Executing tool calls"
1488 );
1489
1490 for (tool_id, tool_name, tool_input) in tool_calls {
1492 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1493
1494 if let Some(cb) = output_callback {
1496 cb(format!("[tool:start:{}]", tool_name));
1497 }
1498
1499 if !is_tool_allowed(&tool_name, auto_approve) {
1501 let msg = format!(
1502 "Tool '{}' requires approval but auto-approve policy is {:?}",
1503 tool_name, auto_approve
1504 );
1505 tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1506 session.add_message(Message {
1507 role: Role::Tool,
1508 content: vec![ContentPart::ToolResult {
1509 tool_call_id: tool_id,
1510 content: msg,
1511 }],
1512 });
1513 continue;
1514 }
1515
1516 let content = if let Some(tool) = tool_registry.get(&tool_name) {
1517 let exec_result: Result<crate::tool::ToolResult> =
1518 tool.execute(tool_input.clone()).await;
1519 match exec_result {
1520 Ok(result) => {
1521 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1522 if let Some(cb) = output_callback {
1523 let status = if result.success { "ok" } else { "err" };
1524 cb(format!(
1525 "[tool:{}:{}] {}",
1526 tool_name,
1527 status,
1528 &result.output[..result.output.len().min(500)]
1529 ));
1530 }
1531 result.output
1532 }
1533 Err(e) => {
1534 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1535 if let Some(cb) = output_callback {
1536 cb(format!("[tool:{}:err] {}", tool_name, e));
1537 }
1538 format!("Error: {}", e)
1539 }
1540 }
1541 } else {
1542 tracing::warn!(tool = %tool_name, "Tool not found");
1543 format!("Error: Unknown tool '{}'", tool_name)
1544 };
1545
1546 session.add_message(Message {
1547 role: Role::Tool,
1548 content: vec![ContentPart::ToolResult {
1549 tool_call_id: tool_id,
1550 content,
1551 }],
1552 });
1553 }
1554 }
1555
1556 session.save().await?;
1557
1558 Ok(crate::session::SessionResult {
1559 text: final_output.trim().to_string(),
1560 session_id: session.id.clone(),
1561 })
1562}
1563
1564fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1566 match auto_approve {
1567 AutoApprove::All => true,
1568 AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1569 }
1570}
1571
1572fn is_safe_tool(tool_name: &str) -> bool {
1574 let safe_tools = [
1575 "read",
1576 "list",
1577 "glob",
1578 "grep",
1579 "codesearch",
1580 "lsp",
1581 "webfetch",
1582 "websearch",
1583 "todo_read",
1584 "skill",
1585 ];
1586 safe_tools.contains(&tool_name)
1587}
1588
1589fn create_filtered_registry(
1591 provider: Arc<dyn crate::provider::Provider>,
1592 model: String,
1593 auto_approve: AutoApprove,
1594) -> crate::tool::ToolRegistry {
1595 use crate::tool::*;
1596
1597 let mut registry = ToolRegistry::new();
1598
1599 registry.register(Arc::new(file::ReadTool::new()));
1601 registry.register(Arc::new(file::ListTool::new()));
1602 registry.register(Arc::new(file::GlobTool::new()));
1603 registry.register(Arc::new(search::GrepTool::new()));
1604 registry.register(Arc::new(lsp::LspTool::new()));
1605 registry.register(Arc::new(webfetch::WebFetchTool::new()));
1606 registry.register(Arc::new(websearch::WebSearchTool::new()));
1607 registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1608 registry.register(Arc::new(todo::TodoReadTool::new()));
1609 registry.register(Arc::new(skill::SkillTool::new()));
1610
1611 if matches!(auto_approve, AutoApprove::All) {
1613 registry.register(Arc::new(file::WriteTool::new()));
1614 registry.register(Arc::new(edit::EditTool::new()));
1615 registry.register(Arc::new(bash::BashTool::new()));
1616 registry.register(Arc::new(multiedit::MultiEditTool::new()));
1617 registry.register(Arc::new(patch::ApplyPatchTool::new()));
1618 registry.register(Arc::new(todo::TodoWriteTool::new()));
1619 registry.register(Arc::new(task::TaskTool::new()));
1620 registry.register(Arc::new(plan::PlanEnterTool::new()));
1621 registry.register(Arc::new(plan::PlanExitTool::new()));
1622 registry.register(Arc::new(rlm::RlmTool::new()));
1623 registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1624 registry.register(Arc::new(prd::PrdTool::new()));
1625 registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1626 registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1627 registry.register(Arc::new(undo::UndoTool));
1628 }
1629
1630 registry.register(Arc::new(invalid::InvalidTool::new()));
1631
1632 registry
1633}
1634
1635fn start_heartbeat(
1638 client: Client,
1639 server: String,
1640 token: Option<String>,
1641 heartbeat_state: HeartbeatState,
1642 processing: Arc<Mutex<HashSet<String>>>,
1643 cognition_config: CognitionHeartbeatConfig,
1644) -> JoinHandle<()> {
1645 tokio::spawn(async move {
1646 let mut consecutive_failures = 0u32;
1647 const MAX_FAILURES: u32 = 3;
1648 const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1649 const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1650 let mut cognition_payload_disabled_until: Option<Instant> = None;
1651
1652 let mut interval =
1653 tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1654 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1655
1656 loop {
1657 interval.tick().await;
1658
1659 let active_count = processing.lock().await.len();
1661 heartbeat_state.set_task_count(active_count).await;
1662
1663 let status = if active_count > 0 {
1665 WorkerStatus::Processing
1666 } else {
1667 WorkerStatus::Idle
1668 };
1669 heartbeat_state.set_status(status).await;
1670
1671 let url = format!(
1673 "{}/v1/opencode/workers/{}/heartbeat",
1674 server, heartbeat_state.worker_id
1675 );
1676 let mut req = client.post(&url);
1677
1678 if let Some(ref t) = token {
1679 req = req.bearer_auth(t);
1680 }
1681
1682 let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1683 let base_payload = serde_json::json!({
1684 "worker_id": &heartbeat_state.worker_id,
1685 "agent_name": &heartbeat_state.agent_name,
1686 "status": status_str,
1687 "active_task_count": active_count,
1688 });
1689 let mut payload = base_payload.clone();
1690 let mut included_cognition_payload = false;
1691 let cognition_payload_allowed = cognition_payload_disabled_until
1692 .map(|until| Instant::now() >= until)
1693 .unwrap_or(true);
1694
1695 if cognition_config.enabled
1696 && cognition_payload_allowed
1697 && let Some(cognition_payload) =
1698 fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1699 && let Some(obj) = payload.as_object_mut()
1700 {
1701 obj.insert("cognition".to_string(), cognition_payload);
1702 included_cognition_payload = true;
1703 }
1704
1705 match req.json(&payload).send().await {
1706 Ok(res) => {
1707 if res.status().is_success() {
1708 consecutive_failures = 0;
1709 tracing::debug!(
1710 worker_id = %heartbeat_state.worker_id,
1711 status = status_str,
1712 active_tasks = active_count,
1713 "Heartbeat sent successfully"
1714 );
1715 } else if included_cognition_payload && res.status().is_client_error() {
1716 tracing::warn!(
1717 worker_id = %heartbeat_state.worker_id,
1718 status = %res.status(),
1719 "Heartbeat cognition payload rejected, retrying without cognition payload"
1720 );
1721
1722 let mut retry_req = client.post(&url);
1723 if let Some(ref t) = token {
1724 retry_req = retry_req.bearer_auth(t);
1725 }
1726
1727 match retry_req.json(&base_payload).send().await {
1728 Ok(retry_res) if retry_res.status().is_success() => {
1729 cognition_payload_disabled_until = Some(
1730 Instant::now()
1731 + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
1732 );
1733 consecutive_failures = 0;
1734 tracing::warn!(
1735 worker_id = %heartbeat_state.worker_id,
1736 retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
1737 "Paused cognition heartbeat payload after schema rejection"
1738 );
1739 }
1740 Ok(retry_res) => {
1741 consecutive_failures += 1;
1742 tracing::warn!(
1743 worker_id = %heartbeat_state.worker_id,
1744 status = %retry_res.status(),
1745 failures = consecutive_failures,
1746 "Heartbeat failed even after retry without cognition payload"
1747 );
1748 }
1749 Err(e) => {
1750 consecutive_failures += 1;
1751 tracing::warn!(
1752 worker_id = %heartbeat_state.worker_id,
1753 error = %e,
1754 failures = consecutive_failures,
1755 "Heartbeat retry without cognition payload failed"
1756 );
1757 }
1758 }
1759 } else {
1760 consecutive_failures += 1;
1761 tracing::warn!(
1762 worker_id = %heartbeat_state.worker_id,
1763 status = %res.status(),
1764 failures = consecutive_failures,
1765 "Heartbeat failed"
1766 );
1767 }
1768 }
1769 Err(e) => {
1770 consecutive_failures += 1;
1771 tracing::warn!(
1772 worker_id = %heartbeat_state.worker_id,
1773 error = %e,
1774 failures = consecutive_failures,
1775 "Heartbeat request failed"
1776 );
1777 }
1778 }
1779
1780 if consecutive_failures >= MAX_FAILURES {
1782 tracing::error!(
1783 worker_id = %heartbeat_state.worker_id,
1784 failures = consecutive_failures,
1785 "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
1786 MAX_FAILURES
1787 );
1788 consecutive_failures = 0;
1790 }
1791 }
1792 })
1793}
1794
1795async fn fetch_cognition_heartbeat_payload(
1796 client: &Client,
1797 config: &CognitionHeartbeatConfig,
1798) -> Option<serde_json::Value> {
1799 let status_url = format!("{}/v1/cognition/status", config.source_base_url);
1800 let status_res = tokio::time::timeout(
1801 Duration::from_millis(config.request_timeout_ms),
1802 client.get(status_url).send(),
1803 )
1804 .await
1805 .ok()?
1806 .ok()?;
1807
1808 if !status_res.status().is_success() {
1809 return None;
1810 }
1811
1812 let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
1813 let mut payload = serde_json::json!({
1814 "running": status.running,
1815 "last_tick_at": status.last_tick_at,
1816 "active_persona_count": status.active_persona_count,
1817 "events_buffered": status.events_buffered,
1818 "snapshots_buffered": status.snapshots_buffered,
1819 "loop_interval_ms": status.loop_interval_ms,
1820 });
1821
1822 if config.include_thought_summary {
1823 let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
1824 let snapshot_res = tokio::time::timeout(
1825 Duration::from_millis(config.request_timeout_ms),
1826 client.get(snapshot_url).send(),
1827 )
1828 .await
1829 .ok()
1830 .and_then(Result::ok);
1831
1832 if let Some(snapshot_res) = snapshot_res
1833 && snapshot_res.status().is_success()
1834 && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
1835 && let Some(obj) = payload.as_object_mut()
1836 {
1837 obj.insert(
1838 "latest_snapshot_at".to_string(),
1839 serde_json::Value::String(snapshot.generated_at),
1840 );
1841 obj.insert(
1842 "latest_thought".to_string(),
1843 serde_json::Value::String(trim_for_heartbeat(
1844 &snapshot.summary,
1845 config.summary_max_chars,
1846 )),
1847 );
1848 if let Some(model) = snapshot
1849 .metadata
1850 .get("model")
1851 .and_then(serde_json::Value::as_str)
1852 {
1853 obj.insert(
1854 "latest_thought_model".to_string(),
1855 serde_json::Value::String(model.to_string()),
1856 );
1857 }
1858 if let Some(source) = snapshot
1859 .metadata
1860 .get("source")
1861 .and_then(serde_json::Value::as_str)
1862 {
1863 obj.insert(
1864 "latest_thought_source".to_string(),
1865 serde_json::Value::String(source.to_string()),
1866 );
1867 }
1868 }
1869 }
1870
1871 Some(payload)
1872}
1873
1874fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
1875 if input.chars().count() <= max_chars {
1876 return input.trim().to_string();
1877 }
1878
1879 let mut trimmed = String::with_capacity(max_chars + 3);
1880 for ch in input.chars().take(max_chars) {
1881 trimmed.push(ch);
1882 }
1883 trimmed.push_str("...");
1884 trimmed.trim().to_string()
1885}
1886
1887fn env_bool(name: &str, default: bool) -> bool {
1888 std::env::var(name)
1889 .ok()
1890 .and_then(|v| match v.to_ascii_lowercase().as_str() {
1891 "1" | "true" | "yes" | "on" => Some(true),
1892 "0" | "false" | "no" | "off" => Some(false),
1893 _ => None,
1894 })
1895 .unwrap_or(default)
1896}
1897
1898fn env_usize(name: &str, default: usize) -> usize {
1899 std::env::var(name)
1900 .ok()
1901 .and_then(|v| v.parse::<usize>().ok())
1902 .unwrap_or(default)
1903}
1904
1905fn env_u64(name: &str, default: u64) -> u64 {
1906 std::env::var(name)
1907 .ok()
1908 .and_then(|v| v.parse::<u64>().ok())
1909 .unwrap_or(default)
1910}