1use crate::bus::AgentBus;
4use crate::cli::A2aArgs;
5use crate::provider::ProviderRegistry;
6use crate::session::Session;
7use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
8use crate::tui::swarm_view::SwarmEvent;
9use anyhow::Result;
10use futures::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{Mutex, mpsc};
18use tokio::task::JoinHandle;
19use tokio::time::Instant;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum WorkerStatus {
24 Idle,
25 Processing,
26}
27
28impl WorkerStatus {
29 pub fn as_str(&self) -> &'static str {
30 match self {
31 WorkerStatus::Idle => "idle",
32 WorkerStatus::Processing => "processing",
33 }
34 }
35}
36
37#[derive(Clone)]
39pub struct HeartbeatState {
40 worker_id: String,
41 pub agent_name: String,
42 pub status: Arc<Mutex<WorkerStatus>>,
43 pub active_task_count: Arc<Mutex<usize>>,
44}
45
46impl HeartbeatState {
47 pub fn new(worker_id: String, agent_name: String) -> Self {
48 Self {
49 worker_id,
50 agent_name,
51 status: Arc::new(Mutex::new(WorkerStatus::Idle)),
52 active_task_count: Arc::new(Mutex::new(0)),
53 }
54 }
55
56 pub async fn set_status(&self, status: WorkerStatus) {
57 *self.status.lock().await = status;
58 }
59
60 pub async fn set_task_count(&self, count: usize) {
61 *self.active_task_count.lock().await = count;
62 }
63}
64
65#[derive(Clone, Debug)]
66pub struct CognitionHeartbeatConfig {
67 pub enabled: bool,
68 pub source_base_url: String,
69 pub token: Option<String>,
70 pub provider_name: String,
71 pub interval_secs: u64,
72 pub include_thought_summary: bool,
73 pub summary_max_chars: usize,
74 pub request_timeout_ms: u64,
75}
76
77impl CognitionHeartbeatConfig {
78 pub fn from_env() -> Self {
79 let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
80 .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
81 .trim_end_matches('/')
82 .to_string();
83
84 Self {
85 enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
86 source_base_url,
87 include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
88 summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480),
89 request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
90 interval_secs: env_u64("CODETETHER_WORKER_COGNITION_INTERVAL_SECS", 30).max(5),
91 provider_name: std::env::var("CODETETHER_WORKER_COGNITION_PROVIDER")
92 .unwrap_or_else(|_| "cognition".to_string()),
93 token: std::env::var("CODETETHER_WORKER_COGNITION_TOKEN").ok(),
94 }
95 }
96}
97
98#[derive(Debug, Deserialize)]
99struct CognitionStatusSnapshot {
100 running: bool,
101 #[serde(default)]
102 last_tick_at: Option<String>,
103 #[serde(default)]
104 active_persona_count: usize,
105 #[serde(default)]
106 events_buffered: usize,
107 #[serde(default)]
108 snapshots_buffered: usize,
109 #[serde(default)]
110 loop_interval_ms: u64,
111}
112
113#[derive(Debug, Deserialize)]
114struct CognitionLatestSnapshot {
115 generated_at: String,
116 summary: String,
117 #[serde(default)]
118 metadata: HashMap<String, serde_json::Value>,
119}
120
121pub async fn run(args: A2aArgs) -> Result<()> {
123 let server = args.server.trim_end_matches('/');
124 let name = args
125 .name
126 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
127 let worker_id = generate_worker_id();
128
129 let codebases: Vec<String> = args
130 .workspaces
131 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
132 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
133
134 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
135 tracing::info!("Server: {}", server);
136 tracing::info!("Workspaces: {:?}", codebases);
137
138 let shared_codebases: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(codebases));
140
141 let client = Client::new();
142 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
143 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
144 if cognition_heartbeat.enabled {
145 tracing::info!(
146 source = %cognition_heartbeat.source_base_url,
147 include_thoughts = cognition_heartbeat.include_thought_summary,
148 max_chars = cognition_heartbeat.summary_max_chars,
149 timeout_ms = cognition_heartbeat.request_timeout_ms,
150 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
151 );
152 } else {
153 tracing::warn!(
154 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
155 );
156 }
157
158 let auto_approve = match args.auto_approve.as_str() {
159 "all" => AutoApprove::All,
160 "safe" => AutoApprove::Safe,
161 _ => AutoApprove::None,
162 };
163
164 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
166
167 let bus = AgentBus::new().into_arc();
169
170 crate::bus::s3_sink::spawn_bus_s3_sink(bus.clone());
172
173 {
174 let handle = bus.handle(&worker_id);
175 handle.announce_ready(worker_capabilities());
176 }
177
178 {
180 let codebases = shared_codebases.lock().await.clone();
181 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
182 }
183
184 fetch_pending_tasks(
186 &client,
187 server,
188 &args.token,
189 &worker_id,
190 &processing,
191 &auto_approve,
192 &bus,
193 )
194 .await?;
195
196 let _workspace_sync_handle = start_workspace_sync(
198 client.clone(),
199 server.to_string(),
200 args.token.clone(),
201 shared_codebases.clone(),
202 );
203
204 loop {
206 let codebases = shared_codebases.lock().await.clone();
208
209 if let Err(e) =
211 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
212 {
213 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
214 }
215
216 let heartbeat_handle = start_heartbeat(
218 client.clone(),
219 server.to_string(),
220 args.token.clone(),
221 heartbeat_state.clone(),
222 processing.clone(),
223 cognition_heartbeat.clone(),
224 );
225
226 match connect_stream(
227 &client,
228 server,
229 &args.token,
230 &worker_id,
231 &name,
232 &codebases,
233 &processing,
234 &auto_approve,
235 &bus,
236 None, )
238 .await
239 {
240 Ok(()) => {
241 tracing::warn!("Stream ended, reconnecting...");
242 }
243 Err(e) => {
244 tracing::error!("Stream error: {}, reconnecting...", e);
245 }
246 }
247
248 heartbeat_handle.abort();
250 tracing::debug!("Heartbeat cancelled for reconnection");
251
252 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
253 }
254}
255
256pub async fn run_with_state(
259 args: A2aArgs,
260 server_state: crate::worker_server::WorkerServerState,
261) -> Result<()> {
262 let server = args.server.trim_end_matches('/');
263 let name = args
264 .name
265 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
266 let worker_id = generate_worker_id();
267
268 server_state.set_worker_id(worker_id.clone()).await;
270
271 let codebases: Vec<String> = args
272 .workspaces
273 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
274 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
275
276 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
277 tracing::info!("Server: {}", server);
278 tracing::info!("Workspaces: {:?}", codebases);
279
280 let shared_codebases: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(codebases));
282
283 let client = Client::new();
284 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
285 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
286 if cognition_heartbeat.enabled {
287 tracing::info!(
288 source = %cognition_heartbeat.source_base_url,
289 include_thoughts = cognition_heartbeat.include_thought_summary,
290 max_chars = cognition_heartbeat.summary_max_chars,
291 timeout_ms = cognition_heartbeat.request_timeout_ms,
292 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
293 );
294 } else {
295 tracing::warn!(
296 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
297 );
298 }
299
300 let auto_approve = match args.auto_approve.as_str() {
301 "all" => AutoApprove::All,
302 "safe" => AutoApprove::Safe,
303 _ => AutoApprove::None,
304 };
305
306 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
308
309 server_state
311 .set_heartbeat_state(Arc::new(heartbeat_state.clone()))
312 .await;
313
314 let bus = AgentBus::new().into_arc();
316
317 crate::bus::s3_sink::spawn_bus_s3_sink(bus.clone());
319
320 {
321 let handle = bus.handle(&worker_id);
322 handle.announce_ready(worker_capabilities());
323 }
324
325 {
327 let codebases = shared_codebases.lock().await.clone();
328 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
329 }
330
331 server_state.set_connected(true).await;
333
334 fetch_pending_tasks(
336 &client,
337 server,
338 &args.token,
339 &worker_id,
340 &processing,
341 &auto_approve,
342 &bus,
343 )
344 .await?;
345
346 let _workspace_sync_handle = start_workspace_sync(
348 client.clone(),
349 server.to_string(),
350 args.token.clone(),
351 shared_codebases.clone(),
352 );
353
354 loop {
356 let codebases = shared_codebases.lock().await.clone();
358
359 let (task_notify_tx, task_notify_rx) = mpsc::channel::<String>(32);
362 server_state
363 .set_task_notification_channel(task_notify_tx)
364 .await;
365
366 server_state.set_connected(true).await;
368
369 if let Err(e) =
371 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
372 {
373 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
374 }
375
376 let heartbeat_handle = start_heartbeat(
378 client.clone(),
379 server.to_string(),
380 args.token.clone(),
381 heartbeat_state.clone(),
382 processing.clone(),
383 cognition_heartbeat.clone(),
384 );
385
386 match connect_stream(
387 &client,
388 server,
389 &args.token,
390 &worker_id,
391 &name,
392 &codebases,
393 &processing,
394 &auto_approve,
395 &bus,
396 Some(task_notify_rx),
397 )
398 .await
399 {
400 Ok(()) => {
401 tracing::warn!("Stream ended, reconnecting...");
402 }
403 Err(e) => {
404 tracing::error!("Stream error: {}, reconnecting...", e);
405 }
406 }
407
408 server_state.set_connected(false).await;
410
411 heartbeat_handle.abort();
413 tracing::debug!("Heartbeat cancelled for reconnection");
414
415 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
416 }
417}
418
419pub fn generate_worker_id() -> String {
420 format!(
421 "wrk_{}_{:x}",
422 chrono::Utc::now().timestamp(),
423 rand::random::<u64>()
424 )
425}
426
427#[derive(Debug, Clone, Copy)]
428enum AutoApprove {
429 All,
430 Safe,
431 None,
432}
433
434pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
436
437const BASE_WORKER_CAPABILITIES: &[&str] = &[
439 "ralph", "swarm", "rlm", "a2a", "mcp", "grpc", "grpc-web", "jsonrpc",
440];
441
442fn worker_capabilities() -> Vec<String> {
443 let mut capabilities: Vec<String> = BASE_WORKER_CAPABILITIES
444 .iter()
445 .map(|capability| capability.to_string())
446 .collect();
447
448 let is_knative = std::env::var("KNATIVE_SERVICE")
449 .map(|value| {
450 let normalized = value.trim().to_lowercase();
451 normalized == "1" || normalized == "true" || normalized == "yes"
452 })
453 .unwrap_or(false);
454 if is_knative {
455 capabilities.push("knative".to_string());
456 }
457
458 capabilities
459}
460
461fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
462 task.get("task")
463 .and_then(|t| t.get(key))
464 .or_else(|| task.get(key))
465}
466
467fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
468 task_value(task, key).and_then(|v| v.as_str())
469}
470
471fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
472 task_value(task, "metadata")
473 .and_then(|m| m.as_object())
474 .cloned()
475 .unwrap_or_default()
476}
477
478fn model_ref_to_provider_model(model: &str) -> String {
479 if !model.contains('/') && model.contains(':') {
483 model.replacen(':', "/", 1)
484 } else {
485 model.to_string()
486 }
487}
488
489fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
490 match model_tier.unwrap_or("balanced") {
491 "fast" | "quick" => &[
492 "zai",
493 "openai",
494 "github-copilot",
495 "moonshotai",
496 "openrouter",
497 "novita",
498 "google",
499 "anthropic",
500 ],
501 "heavy" | "deep" => &[
502 "zai",
503 "anthropic",
504 "openai",
505 "github-copilot",
506 "moonshotai",
507 "openrouter",
508 "novita",
509 "google",
510 ],
511 _ => &[
512 "zai",
513 "openai",
514 "github-copilot",
515 "anthropic",
516 "moonshotai",
517 "openrouter",
518 "novita",
519 "google",
520 ],
521 }
522}
523
524fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
525 for preferred in provider_preferences_for_tier(model_tier) {
526 if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
527 return found;
528 }
529 }
530 if let Some(found) = providers.iter().copied().find(|p| *p == "zai") {
531 return found;
532 }
533 providers[0]
534}
535
536fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
537 match model_tier.unwrap_or("balanced") {
538 "fast" | "quick" => match provider {
539 "moonshotai" => "kimi-k2.5".to_string(),
540 "anthropic" => "claude-haiku-4-5".to_string(),
541 "openai" => "gpt-4o-mini".to_string(),
542 "google" => "gemini-2.5-flash".to_string(),
543 "zhipuai" | "zai" => "glm-5".to_string(),
544 "openrouter" => "z-ai/glm-5".to_string(),
545 "novita" => "qwen/qwen3-coder-next".to_string(),
546 "bedrock" => "amazon.nova-lite-v1:0".to_string(),
547 _ => "glm-5".to_string(),
548 },
549 "heavy" | "deep" => match provider {
550 "moonshotai" => "kimi-k2.5".to_string(),
551 "anthropic" => "claude-sonnet-4-20250514".to_string(),
552 "openai" => "o3".to_string(),
553 "google" => "gemini-2.5-pro".to_string(),
554 "zhipuai" | "zai" => "glm-5".to_string(),
555 "openrouter" => "z-ai/glm-5".to_string(),
556 "novita" => "qwen/qwen3-coder-next".to_string(),
557 "bedrock" => "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
558 _ => "glm-5".to_string(),
559 },
560 _ => match provider {
561 "moonshotai" => "kimi-k2.5".to_string(),
562 "anthropic" => "claude-sonnet-4-20250514".to_string(),
563 "openai" => "gpt-4o".to_string(),
564 "google" => "gemini-2.5-pro".to_string(),
565 "zhipuai" | "zai" => "glm-5".to_string(),
566 "openrouter" => "z-ai/glm-5".to_string(),
567 "novita" => "qwen/qwen3-coder-next".to_string(),
568 "bedrock" => "amazon.nova-lite-v1:0".to_string(),
569 _ => "glm-5".to_string(),
570 },
571 }
572}
573
574fn prefers_temperature_one(model: &str) -> bool {
575 let normalized = model.to_ascii_lowercase();
576 normalized.contains("kimi-k2") || normalized.contains("glm-") || normalized.contains("minimax")
577}
578
579fn is_swarm_agent(agent_type: &str) -> bool {
580 matches!(
581 agent_type.trim().to_ascii_lowercase().as_str(),
582 "swarm" | "parallel" | "multi-agent"
583 )
584}
585
586fn metadata_lookup<'a>(
587 metadata: &'a serde_json::Map<String, serde_json::Value>,
588 key: &str,
589) -> Option<&'a serde_json::Value> {
590 metadata
591 .get(key)
592 .or_else(|| {
593 metadata
594 .get("routing")
595 .and_then(|v| v.as_object())
596 .and_then(|obj| obj.get(key))
597 })
598 .or_else(|| {
599 metadata
600 .get("swarm")
601 .and_then(|v| v.as_object())
602 .and_then(|obj| obj.get(key))
603 })
604}
605
606fn metadata_str(
607 metadata: &serde_json::Map<String, serde_json::Value>,
608 keys: &[&str],
609) -> Option<String> {
610 for key in keys {
611 if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
612 let trimmed = value.trim();
613 if !trimmed.is_empty() {
614 return Some(trimmed.to_string());
615 }
616 }
617 }
618 None
619}
620
621fn metadata_usize(
622 metadata: &serde_json::Map<String, serde_json::Value>,
623 keys: &[&str],
624) -> Option<usize> {
625 for key in keys {
626 if let Some(value) = metadata_lookup(metadata, key) {
627 if let Some(v) = value.as_u64() {
628 return usize::try_from(v).ok();
629 }
630 if let Some(v) = value.as_i64() {
631 if v >= 0 {
632 return usize::try_from(v as u64).ok();
633 }
634 }
635 if let Some(v) = value.as_str() {
636 if let Ok(parsed) = v.trim().parse::<usize>() {
637 return Some(parsed);
638 }
639 }
640 }
641 }
642 None
643}
644
645fn metadata_u64(
646 metadata: &serde_json::Map<String, serde_json::Value>,
647 keys: &[&str],
648) -> Option<u64> {
649 for key in keys {
650 if let Some(value) = metadata_lookup(metadata, key) {
651 if let Some(v) = value.as_u64() {
652 return Some(v);
653 }
654 if let Some(v) = value.as_i64() {
655 if v >= 0 {
656 return Some(v as u64);
657 }
658 }
659 if let Some(v) = value.as_str() {
660 if let Ok(parsed) = v.trim().parse::<u64>() {
661 return Some(parsed);
662 }
663 }
664 }
665 }
666 None
667}
668
669fn metadata_bool(
670 metadata: &serde_json::Map<String, serde_json::Value>,
671 keys: &[&str],
672) -> Option<bool> {
673 for key in keys {
674 if let Some(value) = metadata_lookup(metadata, key) {
675 if let Some(v) = value.as_bool() {
676 return Some(v);
677 }
678 if let Some(v) = value.as_str() {
679 match v.trim().to_ascii_lowercase().as_str() {
680 "1" | "true" | "yes" | "on" => return Some(true),
681 "0" | "false" | "no" | "off" => return Some(false),
682 _ => {}
683 }
684 }
685 }
686 }
687 None
688}
689
690fn parse_swarm_strategy(
691 metadata: &serde_json::Map<String, serde_json::Value>,
692) -> DecompositionStrategy {
693 match metadata_str(
694 metadata,
695 &[
696 "decomposition_strategy",
697 "swarm_strategy",
698 "strategy",
699 "swarm_decomposition",
700 ],
701 )
702 .as_deref()
703 .map(|s| s.to_ascii_lowercase())
704 .as_deref()
705 {
706 Some("none") | Some("single") => DecompositionStrategy::None,
707 Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
708 Some("data") | Some("by_data") => DecompositionStrategy::ByData,
709 Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
710 _ => DecompositionStrategy::Automatic,
711 }
712}
713
714async fn resolve_swarm_model(
715 explicit_model: Option<String>,
716 model_tier: Option<&str>,
717) -> Option<String> {
718 if let Some(model) = explicit_model {
719 if !model.trim().is_empty() {
720 return Some(model);
721 }
722 }
723
724 let registry = ProviderRegistry::from_vault().await.ok()?;
725 let providers = registry.list();
726 if providers.is_empty() {
727 return None;
728 }
729 let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
730 let model = default_model_for_provider(provider, model_tier);
731 Some(format!("{}/{}", provider, model))
732}
733
734fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
735 match event {
736 SwarmEvent::Started {
737 task,
738 total_subtasks,
739 } => Some(format!(
740 "[swarm] started task={} planned_subtasks={}",
741 task, total_subtasks
742 )),
743 SwarmEvent::StageComplete {
744 stage,
745 completed,
746 failed,
747 } => Some(format!(
748 "[swarm] stage={} completed={} failed={}",
749 stage, completed, failed
750 )),
751 SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
752 "[swarm] subtask id={} status={}",
753 &id.chars().take(8).collect::<String>(),
754 format!("{status:?}").to_ascii_lowercase()
755 )),
756 SwarmEvent::AgentToolCall {
757 subtask_id,
758 tool_name,
759 } => Some(format!(
760 "[swarm] subtask id={} tool={}",
761 &subtask_id.chars().take(8).collect::<String>(),
762 tool_name
763 )),
764 SwarmEvent::AgentError { subtask_id, error } => Some(format!(
765 "[swarm] subtask id={} error={}",
766 &subtask_id.chars().take(8).collect::<String>(),
767 error
768 )),
769 SwarmEvent::Complete { success, stats } => Some(format!(
770 "[swarm] complete success={} subtasks={} speedup={:.2}",
771 success,
772 stats.subagents_completed + stats.subagents_failed,
773 stats.speedup_factor
774 )),
775 SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
776 _ => None,
777 }
778}
779
780pub async fn register_worker(
781 client: &Client,
782 server: &str,
783 token: &Option<String>,
784 worker_id: &str,
785 name: &str,
786 codebases: &[String],
787) -> Result<()> {
788 let models = match load_provider_models().await {
790 Ok(m) => m,
791 Err(e) => {
792 tracing::warn!(
793 "Failed to load provider models: {}, proceeding without model info",
794 e
795 );
796 HashMap::new()
797 }
798 };
799
800 let mut req = client.post(format!("{}/v1/agent/workers/register", server));
802
803 if let Some(t) = token {
804 req = req.bearer_auth(t);
805 }
806
807 let models_array: Vec<serde_json::Value> = models
810 .iter()
811 .flat_map(|(provider, model_infos)| {
812 model_infos.iter().map(move |m| {
813 let mut obj = serde_json::json!({
814 "id": format!("{}/{}", provider, m.id),
815 "name": &m.id,
816 "provider": provider,
817 "provider_id": provider,
818 });
819 if let Some(input_cost) = m.input_cost_per_million {
820 obj["input_cost_per_million"] = serde_json::json!(input_cost);
821 }
822 if let Some(output_cost) = m.output_cost_per_million {
823 obj["output_cost_per_million"] = serde_json::json!(output_cost);
824 }
825 obj
826 })
827 })
828 .collect();
829
830 tracing::info!(
831 "Registering worker with {} models from {} providers",
832 models_array.len(),
833 models.len()
834 );
835
836 let hostname = std::env::var("HOSTNAME")
837 .or_else(|_| std::env::var("COMPUTERNAME"))
838 .unwrap_or_else(|_| "unknown".to_string());
839 let k8s_node_name = std::env::var("K8S_NODE_NAME")
840 .ok()
841 .map(|value| value.trim().to_string())
842 .filter(|value| !value.is_empty());
843
844 let registry = crate::agent::AgentRegistry::with_builtins();
846 let agent_defs: Vec<serde_json::Value> = registry
847 .list()
848 .iter()
849 .map(|info| {
850 serde_json::json!({
851 "name": info.name,
852 "description": info.description,
853 "mode": format!("{:?}", info.mode).to_lowercase(),
854 "native": info.native,
855 "hidden": info.hidden,
856 "model": info.model,
857 "temperature": info.temperature,
858 "top_p": info.top_p,
859 "max_steps": info.max_steps,
860 })
861 })
862 .collect();
863
864 let res = req
865 .json(&serde_json::json!({
866 "worker_id": worker_id,
867 "name": name,
868 "capabilities": worker_capabilities(),
869 "hostname": hostname,
870 "k8s_node_name": k8s_node_name,
871 "models": models_array,
872 "workspaces": codebases,
873 "agents": agent_defs,
874 }))
875 .send()
876 .await?;
877
878 if res.status().is_success() {
879 tracing::info!("Worker registered successfully");
880 } else {
881 tracing::warn!("Failed to register worker: {}", res.status());
882 }
883
884 Ok(())
885}
886
887async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
891 let registry = match ProviderRegistry::from_vault().await {
893 Ok(r) if !r.list().is_empty() => {
894 tracing::info!("Loaded {} providers from Vault", r.list().len());
895 r
896 }
897 Ok(_) => {
898 tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
899 fallback_registry().await?
900 }
901 Err(e) => {
902 tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
903 fallback_registry().await?
904 }
905 };
906
907 let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
908
909 for provider_name in registry.list() {
910 if let Some(provider) = registry.get(provider_name) {
911 match provider.list_models().await {
912 Ok(models) => {
913 if !models.is_empty() {
914 tracing::debug!("Provider {}: {} models", provider_name, models.len());
915 models_by_provider.insert(provider_name.to_string(), models);
916 }
917 }
918 Err(e) => {
919 tracing::debug!("Failed to list models for {}: {}", provider_name, e);
920 }
921 }
922 }
923 }
924
925 Ok(models_by_provider)
926}
927
928async fn fallback_registry() -> Result<ProviderRegistry> {
930 let config = crate::config::Config::load().await.unwrap_or_default();
931 ProviderRegistry::from_config(&config).await
932}
933
934async fn fetch_pending_tasks(
935 client: &Client,
936 server: &str,
937 token: &Option<String>,
938 worker_id: &str,
939 processing: &Arc<Mutex<HashSet<String>>>,
940 auto_approve: &AutoApprove,
941 bus: &Arc<AgentBus>,
942) -> Result<()> {
943 tracing::info!("Checking for pending tasks...");
944
945 let mut req = client.get(format!("{}/v1/agent/tasks?status=pending", server));
946 if let Some(t) = token {
947 req = req.bearer_auth(t);
948 }
949
950 let res = req.send().await?;
951 if !res.status().is_success() {
952 return Ok(());
953 }
954
955 let data: serde_json::Value = res.json().await?;
956 let tasks = if let Some(arr) = data.as_array() {
958 arr.clone()
959 } else {
960 data["tasks"].as_array().cloned().unwrap_or_default()
961 };
962
963 tracing::info!("Found {} pending task(s)", tasks.len());
964
965 for task in tasks {
966 if let Some(id) = task["id"].as_str() {
967 let mut proc = processing.lock().await;
968 if !proc.contains(id) {
969 proc.insert(id.to_string());
970 drop(proc);
971
972 let task_id = id.to_string();
973 let client = client.clone();
974 let server = server.to_string();
975 let token = token.clone();
976 let worker_id = worker_id.to_string();
977 let auto_approve = *auto_approve;
978 let processing = processing.clone();
979 let bus = bus.clone();
980
981 tokio::spawn(async move {
982 if let Err(e) = handle_task(
983 &client,
984 &server,
985 &token,
986 &worker_id,
987 &task,
988 auto_approve,
989 &bus,
990 )
991 .await
992 {
993 tracing::error!("Task {} failed: {}", task_id, e);
994 }
995 processing.lock().await.remove(&task_id);
996 });
997 }
998 }
999 }
1000
1001 Ok(())
1002}
1003
1004#[allow(clippy::too_many_arguments)]
1005async fn connect_stream(
1006 client: &Client,
1007 server: &str,
1008 token: &Option<String>,
1009 worker_id: &str,
1010 name: &str,
1011 codebases: &[String],
1012 processing: &Arc<Mutex<HashSet<String>>>,
1013 auto_approve: &AutoApprove,
1014 bus: &Arc<AgentBus>,
1015 task_notify_rx: Option<mpsc::Receiver<String>>,
1016) -> Result<()> {
1017 let url = format!(
1018 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
1019 server,
1020 urlencoding::encode(name),
1021 urlencoding::encode(worker_id)
1022 );
1023
1024 let mut req = client
1025 .get(&url)
1026 .header("Accept", "text/event-stream")
1027 .header("X-Worker-ID", worker_id)
1028 .header("X-Agent-Name", name)
1029 .header("X-Workspaces", codebases.join(","));
1030
1031 if let Some(t) = token {
1032 req = req.bearer_auth(t);
1033 }
1034
1035 let res = req.send().await?;
1036 if !res.status().is_success() {
1037 anyhow::bail!("Failed to connect: {}", res.status());
1038 }
1039
1040 tracing::info!("Connected to A2A server");
1041
1042 let mut stream = res.bytes_stream();
1043 let mut buffer = String::new();
1044 let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
1045 poll_interval.tick().await; let mut task_notify_rx = task_notify_rx;
1049
1050 loop {
1051 tokio::select! {
1052 task_id = async {
1055 if let Some(ref mut rx) = task_notify_rx {
1056 rx.recv().await
1057 } else {
1058 futures::future::pending().await
1060 }
1061 } => {
1062 if let Some(task_id) = task_id {
1063 tracing::info!("Received task notification via CloudEvent: {}", task_id);
1064 if let Err(e) = poll_pending_tasks(
1066 client, server, token, worker_id, processing, auto_approve, bus,
1067 ).await {
1068 tracing::warn!("Task notification poll failed: {}", e);
1069 }
1070 }
1071 }
1072 chunk = stream.next() => {
1073 match chunk {
1074 Some(Ok(chunk)) => {
1075 buffer.push_str(&String::from_utf8_lossy(&chunk));
1076
1077 while let Some(pos) = buffer.find("\n\n") {
1079 let event_str = buffer[..pos].to_string();
1080 buffer = buffer[pos + 2..].to_string();
1081
1082 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
1083 let data = data_line.trim_start_matches("data:").trim();
1084 if data == "[DONE]" || data.is_empty() {
1085 continue;
1086 }
1087
1088 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
1089 spawn_task_handler(
1090 &task, client, server, token, worker_id,
1091 processing, auto_approve, bus,
1092 ).await;
1093 }
1094 }
1095 }
1096 }
1097 Some(Err(e)) => {
1098 return Err(e.into());
1099 }
1100 None => {
1101 return Ok(());
1103 }
1104 }
1105 }
1106 _ = poll_interval.tick() => {
1107 if let Err(e) = poll_pending_tasks(
1109 client, server, token, worker_id, processing, auto_approve, bus,
1110 ).await {
1111 tracing::warn!("Periodic task poll failed: {}", e);
1112 }
1113 }
1114 }
1115 }
1116}
1117
1118async fn spawn_task_handler(
1119 task: &serde_json::Value,
1120 client: &Client,
1121 server: &str,
1122 token: &Option<String>,
1123 worker_id: &str,
1124 processing: &Arc<Mutex<HashSet<String>>>,
1125 auto_approve: &AutoApprove,
1126 bus: &Arc<AgentBus>,
1127) {
1128 if let Some(id) = task
1129 .get("task")
1130 .and_then(|t| t["id"].as_str())
1131 .or_else(|| task["id"].as_str())
1132 {
1133 let mut proc = processing.lock().await;
1134 if !proc.contains(id) {
1135 proc.insert(id.to_string());
1136 drop(proc);
1137
1138 let task_id = id.to_string();
1139 let task = task.clone();
1140 let client = client.clone();
1141 let server = server.to_string();
1142 let token = token.clone();
1143 let worker_id = worker_id.to_string();
1144 let auto_approve = *auto_approve;
1145 let processing_clone = processing.clone();
1146 let bus = bus.clone();
1147
1148 tokio::spawn(async move {
1149 if let Err(e) = handle_task(
1150 &client,
1151 &server,
1152 &token,
1153 &worker_id,
1154 &task,
1155 auto_approve,
1156 &bus,
1157 )
1158 .await
1159 {
1160 tracing::error!("Task {} failed: {}", task_id, e);
1161 }
1162 processing_clone.lock().await.remove(&task_id);
1163 });
1164 }
1165 }
1166}
1167
1168async fn poll_pending_tasks(
1169 client: &Client,
1170 server: &str,
1171 token: &Option<String>,
1172 worker_id: &str,
1173 processing: &Arc<Mutex<HashSet<String>>>,
1174 auto_approve: &AutoApprove,
1175 bus: &Arc<AgentBus>,
1176) -> Result<()> {
1177 let mut req = client.get(format!("{}/v1/agent/tasks?status=pending", server));
1178 if let Some(t) = token {
1179 req = req.bearer_auth(t);
1180 }
1181
1182 let res = req.send().await?;
1183 if !res.status().is_success() {
1184 return Ok(());
1185 }
1186
1187 let data: serde_json::Value = res.json().await?;
1188 let tasks = if let Some(arr) = data.as_array() {
1189 arr.clone()
1190 } else {
1191 data["tasks"].as_array().cloned().unwrap_or_default()
1192 };
1193
1194 if !tasks.is_empty() {
1195 tracing::debug!("Poll found {} pending task(s)", tasks.len());
1196 }
1197
1198 for task in &tasks {
1199 spawn_task_handler(
1200 task,
1201 client,
1202 server,
1203 token,
1204 worker_id,
1205 processing,
1206 auto_approve,
1207 bus,
1208 )
1209 .await;
1210 }
1211
1212 Ok(())
1213}
1214
1215async fn handle_task(
1216 client: &Client,
1217 server: &str,
1218 token: &Option<String>,
1219 worker_id: &str,
1220 task: &serde_json::Value,
1221 auto_approve: AutoApprove,
1222 bus: &Arc<AgentBus>,
1223) -> Result<()> {
1224 let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
1225 let title = task_str(task, "title").unwrap_or("Untitled");
1226
1227 tracing::info!("Handling task: {} ({})", title, task_id);
1228
1229 let mut req = client
1231 .post(format!("{}/v1/worker/tasks/claim", server))
1232 .header("X-Worker-ID", worker_id);
1233 if let Some(t) = token {
1234 req = req.bearer_auth(t);
1235 }
1236
1237 let res = req
1238 .json(&serde_json::json!({ "task_id": task_id }))
1239 .send()
1240 .await?;
1241
1242 if !res.status().is_success() {
1243 let status = res.status();
1244 let text = res.text().await?;
1245 if status == reqwest::StatusCode::CONFLICT {
1246 tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1247 } else {
1248 tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1249 }
1250 return Ok(());
1251 }
1252
1253 tracing::info!("Claimed task: {}", task_id);
1254
1255 let metadata = task_metadata(task);
1256 let resume_session_id = metadata
1257 .get("resume_session_id")
1258 .and_then(|v| v.as_str())
1259 .map(|s| s.trim().to_string())
1260 .filter(|s| !s.is_empty());
1261 let complexity_hint = metadata_str(&metadata, &["complexity"]);
1262 let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1263 .map(|s| s.to_ascii_lowercase())
1264 .or_else(|| {
1265 complexity_hint.as_ref().map(|complexity| {
1266 match complexity.to_ascii_lowercase().as_str() {
1267 "quick" => "fast".to_string(),
1268 "deep" => "heavy".to_string(),
1269 _ => "balanced".to_string(),
1270 }
1271 })
1272 });
1273 let worker_personality = metadata_str(
1274 &metadata,
1275 &["worker_personality", "personality", "agent_personality"],
1276 );
1277 let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1278 let raw_model = task_str(task, "model_ref")
1279 .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1280 .or_else(|| task_str(task, "model"))
1281 .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1282 let selected_model = raw_model.map(model_ref_to_provider_model);
1283
1284 let mut session = if let Some(ref sid) = resume_session_id {
1286 match Session::load(sid).await {
1287 Ok(existing) => {
1288 tracing::info!("Resuming session {} for task {}", sid, task_id);
1289 existing
1290 }
1291 Err(e) => {
1292 tracing::warn!(
1293 "Could not load session {} for task {} ({}), starting a new session",
1294 sid,
1295 task_id,
1296 e
1297 );
1298 Session::new().await?
1299 }
1300 }
1301 } else {
1302 Session::new().await?
1303 };
1304
1305 let raw_agent = task_str(task, "agent_type")
1306 .or_else(|| task_str(task, "agent"))
1307 .unwrap_or("build");
1308 let agent_type = if is_swarm_agent(raw_agent) {
1311 raw_agent
1312 } else {
1313 match raw_agent {
1314 "build" | "plan" => raw_agent,
1315 other => {
1316 tracing::info!(
1317 "Agent \"{}\" is not a primary agent, falling back to \"build\"",
1318 other
1319 );
1320 "build"
1321 }
1322 }
1323 };
1324 session.agent = agent_type.to_string();
1325
1326 if let Some(model) = selected_model.clone() {
1327 session.metadata.model = Some(model);
1328 }
1329
1330 let prompt = task_str(task, "prompt")
1331 .or_else(|| task_str(task, "description"))
1332 .unwrap_or(title);
1333
1334 tracing::info!("Executing prompt: {}", prompt);
1335
1336 let stream_client = client.clone();
1338 let stream_server = server.to_string();
1339 let stream_token = token.clone();
1340 let stream_worker_id = worker_id.to_string();
1341 let stream_task_id = task_id.to_string();
1342 let stream_bus = Arc::clone(bus);
1343
1344 let output_callback: Arc<dyn Fn(String) + Send + Sync + 'static> =
1345 Arc::new(move |output: String| {
1346 let c = stream_client.clone();
1347 let s = stream_server.clone();
1348 let t = stream_token.clone();
1349 let w = stream_worker_id.clone();
1350 let tid = stream_task_id.clone();
1351
1352 let bus_handle = stream_bus.handle("task-output");
1354 bus_handle.send(
1355 format!("task.{}", tid),
1356 crate::bus::BusMessage::TaskUpdate {
1357 task_id: tid.clone(),
1358 state: crate::a2a::types::TaskState::Working,
1359 message: Some(output.clone()),
1360 },
1361 );
1362
1363 tokio::spawn(async move {
1364 let mut req = c
1365 .post(format!("{}/v1/agent/tasks/{}/output", s, tid))
1366 .header("X-Worker-ID", &w);
1367 if let Some(tok) = &t {
1368 req = req.bearer_auth(tok);
1369 }
1370 let _ = req
1371 .json(&serde_json::json!({
1372 "worker_id": w,
1373 "output": output,
1374 }))
1375 .send()
1376 .await;
1377 });
1378 });
1379
1380 let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1382 match execute_swarm_with_policy(
1383 &mut session,
1384 prompt,
1385 model_tier.as_deref(),
1386 selected_model,
1387 &metadata,
1388 complexity_hint.as_deref(),
1389 worker_personality.as_deref(),
1390 target_agent_name.as_deref(),
1391 Some(bus),
1392 Some(Arc::clone(&output_callback)),
1393 )
1394 .await
1395 {
1396 Ok((session_result, true)) => {
1397 tracing::info!("Swarm task completed successfully: {}", task_id);
1398 (
1399 "completed",
1400 Some(session_result.text),
1401 None,
1402 Some(session_result.session_id),
1403 )
1404 }
1405 Ok((session_result, false)) => {
1406 tracing::warn!("Swarm task completed with failures: {}", task_id);
1407 (
1408 "failed",
1409 Some(session_result.text),
1410 Some("Swarm execution completed with failures".to_string()),
1411 Some(session_result.session_id),
1412 )
1413 }
1414 Err(e) => {
1415 tracing::error!("Swarm task failed: {} - {}", task_id, e);
1416 ("failed", None, Some(format!("Error: {}", e)), None)
1417 }
1418 }
1419 } else {
1420 match execute_session_with_policy(
1421 &mut session,
1422 prompt,
1423 auto_approve,
1424 model_tier.as_deref(),
1425 Some(Arc::clone(&output_callback)),
1426 )
1427 .await
1428 {
1429 Ok(session_result) => {
1430 tracing::info!("Task completed successfully: {}", task_id);
1431 (
1432 "completed",
1433 Some(session_result.text),
1434 None,
1435 Some(session_result.session_id),
1436 )
1437 }
1438 Err(e) => {
1439 tracing::error!("Task failed: {} - {}", task_id, e);
1440 ("failed", None, Some(format!("Error: {}", e)), None)
1441 }
1442 }
1443 };
1444
1445 let mut req = client
1447 .post(format!("{}/v1/worker/tasks/release", server))
1448 .header("X-Worker-ID", worker_id);
1449 if let Some(t) = token {
1450 req = req.bearer_auth(t);
1451 }
1452
1453 req.json(&serde_json::json!({
1454 "task_id": task_id,
1455 "status": status,
1456 "result": result,
1457 "error": error,
1458 "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1459 }))
1460 .send()
1461 .await?;
1462
1463 tracing::info!("Task released: {} with status: {}", task_id, status);
1464 Ok(())
1465}
1466
1467async fn execute_swarm_with_policy(
1468 session: &mut Session,
1469 prompt: &str,
1470 model_tier: Option<&str>,
1471 explicit_model: Option<String>,
1472 metadata: &serde_json::Map<String, serde_json::Value>,
1473 complexity_hint: Option<&str>,
1474 worker_personality: Option<&str>,
1475 target_agent_name: Option<&str>,
1476 bus: Option<&Arc<AgentBus>>,
1477 output_callback: Option<Arc<dyn Fn(String) + Send + Sync + 'static>>,
1478) -> Result<(crate::session::SessionResult, bool)> {
1479 use crate::provider::{ContentPart, Message, Role};
1480
1481 session.add_message(Message {
1482 role: Role::User,
1483 content: vec![ContentPart::Text {
1484 text: prompt.to_string(),
1485 }],
1486 });
1487
1488 if session.title.is_none() {
1489 session.generate_title().await?;
1490 }
1491
1492 let strategy = parse_swarm_strategy(metadata);
1493 let max_subagents = metadata_usize(
1494 metadata,
1495 &["swarm_max_subagents", "max_subagents", "subagents"],
1496 )
1497 .unwrap_or(10)
1498 .clamp(1, 100);
1499 let max_steps_per_subagent = metadata_usize(
1500 metadata,
1501 &[
1502 "swarm_max_steps_per_subagent",
1503 "max_steps_per_subagent",
1504 "max_steps",
1505 ],
1506 )
1507 .unwrap_or(50)
1508 .clamp(1, 200);
1509 let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1510 .unwrap_or(600)
1511 .clamp(30, 3600);
1512 let parallel_enabled =
1513 metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1514
1515 let model = resolve_swarm_model(explicit_model, model_tier).await;
1516 if let Some(ref selected_model) = model {
1517 session.metadata.model = Some(selected_model.clone());
1518 }
1519
1520 if let Some(ref cb) = output_callback {
1521 cb(format!(
1522 "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1523 complexity_hint.unwrap_or("standard"),
1524 model_tier.unwrap_or("balanced"),
1525 worker_personality.unwrap_or("auto"),
1526 target_agent_name.unwrap_or("auto")
1527 ));
1528 cb(format!(
1529 "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1530 strategy,
1531 max_subagents,
1532 max_steps_per_subagent,
1533 timeout_secs,
1534 model_tier.unwrap_or("balanced")
1535 ));
1536 }
1537
1538 let swarm_config = SwarmConfig {
1539 max_subagents,
1540 max_steps_per_subagent,
1541 subagent_timeout_secs: timeout_secs,
1542 parallel_enabled,
1543 model,
1544 working_dir: session
1545 .metadata
1546 .directory
1547 .as_ref()
1548 .map(|p| p.to_string_lossy().to_string()),
1549 ..Default::default()
1550 };
1551
1552 let swarm_result = if output_callback.is_some() {
1553 let (event_tx, mut event_rx) = mpsc::channel(256);
1554 let mut executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1555 if let Some(bus) = bus {
1556 executor = executor.with_bus(Arc::clone(bus));
1557 }
1558 let prompt_owned = prompt.to_string();
1559 let mut exec_handle =
1560 tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1561
1562 let mut final_result: Option<crate::swarm::SwarmResult> = None;
1563
1564 while final_result.is_none() {
1565 tokio::select! {
1566 maybe_event = event_rx.recv() => {
1567 if let Some(event) = maybe_event {
1568 if let Some(ref cb) = output_callback {
1569 if let Some(line) = format_swarm_event_for_output(&event) {
1570 cb(line);
1571 }
1572 }
1573 }
1574 }
1575 join_result = &mut exec_handle => {
1576 let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1577 final_result = Some(joined?);
1578 }
1579 }
1580 }
1581
1582 while let Ok(event) = event_rx.try_recv() {
1583 if let Some(ref cb) = output_callback {
1584 if let Some(line) = format_swarm_event_for_output(&event) {
1585 cb(line);
1586 }
1587 }
1588 }
1589
1590 final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1591 } else {
1592 let mut executor = SwarmExecutor::new(swarm_config);
1593 if let Some(bus) = bus {
1594 executor = executor.with_bus(Arc::clone(bus));
1595 }
1596 executor.execute(prompt, strategy).await?
1597 };
1598
1599 let final_text = if swarm_result.result.trim().is_empty() {
1600 if swarm_result.success {
1601 "Swarm completed without textual output.".to_string()
1602 } else {
1603 "Swarm finished with failures and no textual output.".to_string()
1604 }
1605 } else {
1606 swarm_result.result.clone()
1607 };
1608
1609 session.add_message(Message {
1610 role: Role::Assistant,
1611 content: vec![ContentPart::Text {
1612 text: final_text.clone(),
1613 }],
1614 });
1615 session.save().await?;
1616
1617 Ok((
1618 crate::session::SessionResult {
1619 text: final_text,
1620 session_id: session.id.clone(),
1621 },
1622 swarm_result.success,
1623 ))
1624}
1625
1626async fn execute_session_with_policy(
1629 session: &mut Session,
1630 prompt: &str,
1631 auto_approve: AutoApprove,
1632 model_tier: Option<&str>,
1633 output_callback: Option<Arc<dyn Fn(String) + Send + Sync + 'static>>,
1634) -> Result<crate::session::SessionResult> {
1635 use crate::provider::{
1636 CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1637 };
1638 use std::sync::Arc;
1639
1640 let registry = ProviderRegistry::from_vault().await?;
1642 let providers = registry.list();
1643 tracing::info!("Available providers: {:?}", providers);
1644
1645 if providers.is_empty() {
1646 anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1647 }
1648
1649 let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1651 let (prov, model) = parse_model_string(model_str);
1652 let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
1653 if prov.is_some() {
1654 (prov.map(|s| s.to_string()), model.to_string())
1655 } else if providers.contains(&model) {
1656 (Some(model.to_string()), String::new())
1657 } else {
1658 (None, model.to_string())
1659 }
1660 } else {
1661 (None, String::new())
1662 };
1663
1664 let provider_slice = providers.as_slice();
1665 let provider_requested_but_unavailable = provider_name
1666 .as_deref()
1667 .map(|p| !providers.contains(&p))
1668 .unwrap_or(false);
1669
1670 let selected_provider = provider_name
1672 .as_deref()
1673 .filter(|p| providers.contains(p))
1674 .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1675
1676 let provider = registry
1677 .get(selected_provider)
1678 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1679
1680 session.add_message(Message {
1682 role: Role::User,
1683 content: vec![ContentPart::Text {
1684 text: prompt.to_string(),
1685 }],
1686 });
1687
1688 if session.title.is_none() {
1690 session.generate_title().await?;
1691 }
1692
1693 let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1696 model_id
1697 } else {
1698 default_model_for_provider(selected_provider, model_tier)
1699 };
1700
1701 let tool_registry = create_filtered_registry(
1703 Arc::clone(&provider),
1704 model.clone(),
1705 auto_approve,
1706 output_callback.clone(),
1707 );
1708 let tool_definitions = tool_registry.definitions();
1709
1710 let temperature = if prefers_temperature_one(&model) {
1711 Some(1.0)
1712 } else {
1713 Some(0.7)
1714 };
1715
1716 tracing::info!(
1717 "Using model: {} via provider: {} (tier: {:?})",
1718 model,
1719 selected_provider,
1720 model_tier
1721 );
1722 tracing::info!(
1723 "Available tools: {} (auto_approve: {:?})",
1724 tool_definitions.len(),
1725 auto_approve
1726 );
1727
1728 let cwd = std::env::var("PWD")
1730 .map(std::path::PathBuf::from)
1731 .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1732 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1733
1734 let mut final_output = String::new();
1735 let max_steps = 50;
1736
1737 for step in 1..=max_steps {
1738 tracing::info!(step = step, "Agent step starting");
1739
1740 let mut messages = vec![Message {
1742 role: Role::System,
1743 content: vec![ContentPart::Text {
1744 text: system_prompt.clone(),
1745 }],
1746 }];
1747 messages.extend(session.messages.clone());
1748
1749 let request = CompletionRequest {
1750 messages,
1751 tools: tool_definitions.clone(),
1752 model: model.clone(),
1753 temperature,
1754 top_p: None,
1755 max_tokens: Some(8192),
1756 stop: Vec::new(),
1757 };
1758
1759 let response = provider.complete(request).await?;
1760
1761 crate::telemetry::TOKEN_USAGE.record_model_usage(
1762 &model,
1763 response.usage.prompt_tokens as u64,
1764 response.usage.completion_tokens as u64,
1765 );
1766
1767 let tool_calls: Vec<(String, String, serde_json::Value)> = response
1769 .message
1770 .content
1771 .iter()
1772 .filter_map(|part| {
1773 if let ContentPart::ToolCall {
1774 id,
1775 name,
1776 arguments,
1777 ..
1778 } = part
1779 {
1780 let args: serde_json::Value =
1781 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1782 Some((id.clone(), name.clone(), args))
1783 } else {
1784 None
1785 }
1786 })
1787 .collect();
1788
1789 for part in &response.message.content {
1791 if let ContentPart::Text { text } = part {
1792 if !text.is_empty() {
1793 final_output.push_str(text);
1794 final_output.push('\n');
1795 if let Some(ref cb) = output_callback {
1796 cb(text.clone());
1797 }
1798 }
1799 }
1800 }
1801
1802 if tool_calls.is_empty() {
1804 session.add_message(response.message.clone());
1805 break;
1806 }
1807
1808 session.add_message(response.message.clone());
1809
1810 tracing::info!(
1811 step = step,
1812 num_tools = tool_calls.len(),
1813 "Executing tool calls"
1814 );
1815
1816 for (tool_id, tool_name, tool_input) in tool_calls {
1818 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1819
1820 if let Some(ref cb) = output_callback {
1822 cb(format!("[tool:start:{}]", tool_name));
1823 }
1824
1825 if !is_tool_allowed(&tool_name, auto_approve) {
1827 let msg = format!(
1828 "Tool '{}' requires approval but auto-approve policy is {:?}",
1829 tool_name, auto_approve
1830 );
1831 tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1832 session.add_message(Message {
1833 role: Role::Tool,
1834 content: vec![ContentPart::ToolResult {
1835 tool_call_id: tool_id,
1836 content: msg,
1837 }],
1838 });
1839 continue;
1840 }
1841
1842 let content = if let Some(tool) = tool_registry.get(&tool_name) {
1843 let exec_result: Result<crate::tool::ToolResult> =
1844 tool.execute(tool_input.clone()).await;
1845 match exec_result {
1846 Ok(result) => {
1847 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1848 if let Some(ref cb) = output_callback {
1849 let status = if result.success { "ok" } else { "err" };
1850 cb(format!(
1851 "[tool:{}:{}] {}",
1852 tool_name,
1853 status,
1854 &result.output[..result.output.len().min(500)]
1855 ));
1856 }
1857 result.output
1858 }
1859 Err(e) => {
1860 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1861 if let Some(ref cb) = output_callback {
1862 cb(format!("[tool:{}:err] {}", tool_name, e));
1863 }
1864 format!("Error: {}", e)
1865 }
1866 }
1867 } else {
1868 tracing::warn!(tool = %tool_name, "Tool not found");
1869 format!("Error: Unknown tool '{}'", tool_name)
1870 };
1871
1872 session.add_message(Message {
1873 role: Role::Tool,
1874 content: vec![ContentPart::ToolResult {
1875 tool_call_id: tool_id,
1876 content,
1877 }],
1878 });
1879 }
1880 }
1881
1882 session.save().await?;
1883
1884 Ok(crate::session::SessionResult {
1885 text: final_output.trim().to_string(),
1886 session_id: session.id.clone(),
1887 })
1888}
1889
1890fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1892 match auto_approve {
1893 AutoApprove::All => true,
1894 AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1895 }
1896}
1897
1898fn is_safe_tool(tool_name: &str) -> bool {
1900 let safe_tools = [
1901 "read",
1902 "list",
1903 "glob",
1904 "grep",
1905 "codesearch",
1906 "lsp",
1907 "webfetch",
1908 "websearch",
1909 "todo_read",
1910 "skill",
1911 ];
1912 safe_tools.contains(&tool_name)
1913}
1914
1915fn create_filtered_registry(
1917 provider: Arc<dyn crate::provider::Provider>,
1918 model: String,
1919 auto_approve: AutoApprove,
1920 completion_callback: Option<Arc<dyn Fn(String) + Send + Sync + 'static>>,
1921) -> crate::tool::ToolRegistry {
1922 use crate::tool::*;
1923
1924 let mut registry = ToolRegistry::new();
1925
1926 registry.register(Arc::new(file::ReadTool::new()));
1928 registry.register(Arc::new(file::ListTool::new()));
1929 registry.register(Arc::new(file::GlobTool::new()));
1930 registry.register(Arc::new(search::GrepTool::new()));
1931 registry.register(Arc::new(lsp::LspTool::new()));
1932 registry.register(Arc::new(webfetch::WebFetchTool::new()));
1933 registry.register(Arc::new(websearch::WebSearchTool::new()));
1934 registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1935 registry.register(Arc::new(todo::TodoReadTool::new()));
1936 registry.register(Arc::new(skill::SkillTool::new()));
1937
1938 if matches!(auto_approve, AutoApprove::All) {
1940 registry.register(Arc::new(file::WriteTool::new()));
1941 registry.register(Arc::new(advanced_edit::AdvancedEditTool::new()));
1942 registry.register(Arc::new(bash::BashTool::new()));
1943 registry.register(Arc::new(multiedit::MultiEditTool::new()));
1944 registry.register(Arc::new(patch::ApplyPatchTool::new()));
1945 registry.register(Arc::new(todo::TodoWriteTool::new()));
1946 registry.register(Arc::new(task::TaskTool::new()));
1947 registry.register(Arc::new(plan::PlanEnterTool::new()));
1948 registry.register(Arc::new(plan::PlanExitTool::new()));
1949 registry.register(Arc::new(rlm::RlmTool::new(
1950 Arc::clone(&provider),
1951 model.clone(),
1952 )));
1953 registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1954 registry.register(Arc::new(prd::PrdTool::new()));
1955 if let Some(cb) = completion_callback {
1958 registry.register(Arc::new(go::GoTool::with_callback(cb)));
1959 } else {
1960 registry.register(Arc::new(go::GoTool::new()));
1961 }
1962 registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1963 registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1964 registry.register(Arc::new(undo::UndoTool));
1965 registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1966 }
1967
1968 registry.register(Arc::new(invalid::InvalidTool::new()));
1969
1970 registry
1971}
1972
1973pub fn start_heartbeat(
1976 client: Client,
1977 server: String,
1978 token: Option<String>,
1979 heartbeat_state: HeartbeatState,
1980 processing: Arc<Mutex<HashSet<String>>>,
1981 cognition_config: CognitionHeartbeatConfig,
1982) -> JoinHandle<()> {
1983 tokio::spawn(async move {
1984 let mut consecutive_failures = 0u32;
1985 const MAX_FAILURES: u32 = 3;
1986 const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1987 const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1988 let mut cognition_payload_disabled_until: Option<Instant> = None;
1989
1990 let mut interval =
1991 tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1992 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1993
1994 loop {
1995 interval.tick().await;
1996
1997 let active_count = processing.lock().await.len();
1999 heartbeat_state.set_task_count(active_count).await;
2000
2001 let status = if active_count > 0 {
2003 WorkerStatus::Processing
2004 } else {
2005 WorkerStatus::Idle
2006 };
2007 heartbeat_state.set_status(status).await;
2008
2009 let url = format!(
2011 "{}/v1/agent/workers/{}/heartbeat",
2012 server, heartbeat_state.worker_id
2013 );
2014 let mut req = client.post(&url);
2015
2016 if let Some(ref t) = token {
2017 req = req.bearer_auth(t);
2018 }
2019
2020 let status_str = heartbeat_state.status.lock().await.as_str().to_string();
2021 let base_payload = serde_json::json!({
2022 "worker_id": &heartbeat_state.worker_id,
2023 "agent_name": &heartbeat_state.agent_name,
2024 "status": status_str,
2025 "active_task_count": active_count,
2026 });
2027 let mut payload = base_payload.clone();
2028 let mut included_cognition_payload = false;
2029 let cognition_payload_allowed = cognition_payload_disabled_until
2030 .map(|until| Instant::now() >= until)
2031 .unwrap_or(true);
2032
2033 if cognition_config.enabled
2034 && cognition_payload_allowed
2035 && let Some(cognition_payload) =
2036 fetch_cognition_heartbeat_payload(&client, &cognition_config).await
2037 && let Some(obj) = payload.as_object_mut()
2038 {
2039 obj.insert("cognition".to_string(), cognition_payload);
2040 included_cognition_payload = true;
2041 }
2042
2043 match req.json(&payload).send().await {
2044 Ok(res) => {
2045 if res.status().is_success() {
2046 consecutive_failures = 0;
2047 tracing::debug!(
2048 worker_id = %heartbeat_state.worker_id,
2049 status = status_str,
2050 active_tasks = active_count,
2051 "Heartbeat sent successfully"
2052 );
2053 } else if included_cognition_payload && res.status().is_client_error() {
2054 tracing::warn!(
2055 worker_id = %heartbeat_state.worker_id,
2056 status = %res.status(),
2057 "Heartbeat cognition payload rejected, retrying without cognition payload"
2058 );
2059
2060 let mut retry_req = client.post(&url);
2061 if let Some(ref t) = token {
2062 retry_req = retry_req.bearer_auth(t);
2063 }
2064
2065 match retry_req.json(&base_payload).send().await {
2066 Ok(retry_res) if retry_res.status().is_success() => {
2067 cognition_payload_disabled_until = Some(
2068 Instant::now()
2069 + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
2070 );
2071 consecutive_failures = 0;
2072 tracing::warn!(
2073 worker_id = %heartbeat_state.worker_id,
2074 retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
2075 "Paused cognition heartbeat payload after schema rejection"
2076 );
2077 }
2078 Ok(retry_res) => {
2079 consecutive_failures += 1;
2080 tracing::warn!(
2081 worker_id = %heartbeat_state.worker_id,
2082 status = %retry_res.status(),
2083 failures = consecutive_failures,
2084 "Heartbeat failed even after retry without cognition payload"
2085 );
2086 }
2087 Err(e) => {
2088 consecutive_failures += 1;
2089 tracing::warn!(
2090 worker_id = %heartbeat_state.worker_id,
2091 error = %e,
2092 failures = consecutive_failures,
2093 "Heartbeat retry without cognition payload failed"
2094 );
2095 }
2096 }
2097 } else {
2098 consecutive_failures += 1;
2099 tracing::warn!(
2100 worker_id = %heartbeat_state.worker_id,
2101 status = %res.status(),
2102 failures = consecutive_failures,
2103 "Heartbeat failed"
2104 );
2105 }
2106 }
2107 Err(e) => {
2108 consecutive_failures += 1;
2109 tracing::warn!(
2110 worker_id = %heartbeat_state.worker_id,
2111 error = %e,
2112 failures = consecutive_failures,
2113 "Heartbeat request failed"
2114 );
2115 }
2116 }
2117
2118 if consecutive_failures >= MAX_FAILURES {
2120 tracing::error!(
2121 worker_id = %heartbeat_state.worker_id,
2122 failures = consecutive_failures,
2123 "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
2124 MAX_FAILURES
2125 );
2126 consecutive_failures = 0;
2128 }
2129 }
2130 })
2131}
2132
2133async fn fetch_cognition_heartbeat_payload(
2134 client: &Client,
2135 config: &CognitionHeartbeatConfig,
2136) -> Option<serde_json::Value> {
2137 let status_url = format!("{}/v1/cognition/status", config.source_base_url);
2138 let status_res = tokio::time::timeout(
2139 Duration::from_millis(config.request_timeout_ms),
2140 client.get(status_url).send(),
2141 )
2142 .await
2143 .ok()?
2144 .ok()?;
2145
2146 if !status_res.status().is_success() {
2147 return None;
2148 }
2149
2150 let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
2151 let mut payload = serde_json::json!({
2152 "running": status.running,
2153 "last_tick_at": status.last_tick_at,
2154 "active_persona_count": status.active_persona_count,
2155 "events_buffered": status.events_buffered,
2156 "snapshots_buffered": status.snapshots_buffered,
2157 "loop_interval_ms": status.loop_interval_ms,
2158 });
2159
2160 if config.include_thought_summary {
2161 let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
2162 let snapshot_res = tokio::time::timeout(
2163 Duration::from_millis(config.request_timeout_ms),
2164 client.get(snapshot_url).send(),
2165 )
2166 .await
2167 .ok()
2168 .and_then(Result::ok);
2169
2170 if let Some(snapshot_res) = snapshot_res
2171 && snapshot_res.status().is_success()
2172 && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
2173 && let Some(obj) = payload.as_object_mut()
2174 {
2175 obj.insert(
2176 "latest_snapshot_at".to_string(),
2177 serde_json::Value::String(snapshot.generated_at),
2178 );
2179 obj.insert(
2180 "latest_thought".to_string(),
2181 serde_json::Value::String(trim_for_heartbeat(
2182 &snapshot.summary,
2183 config.summary_max_chars,
2184 )),
2185 );
2186 if let Some(model) = snapshot
2187 .metadata
2188 .get("model")
2189 .and_then(serde_json::Value::as_str)
2190 {
2191 obj.insert(
2192 "latest_thought_model".to_string(),
2193 serde_json::Value::String(model.to_string()),
2194 );
2195 }
2196 if let Some(source) = snapshot
2197 .metadata
2198 .get("source")
2199 .and_then(serde_json::Value::as_str)
2200 {
2201 obj.insert(
2202 "latest_thought_source".to_string(),
2203 serde_json::Value::String(source.to_string()),
2204 );
2205 }
2206 }
2207 }
2208
2209 Some(payload)
2210}
2211
2212fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
2213 if input.chars().count() <= max_chars {
2214 return input.trim().to_string();
2215 }
2216
2217 let mut trimmed = String::with_capacity(max_chars + 3);
2218 for ch in input.chars().take(max_chars) {
2219 trimmed.push(ch);
2220 }
2221 trimmed.push_str("...");
2222 trimmed.trim().to_string()
2223}
2224
2225fn env_bool(name: &str, default: bool) -> bool {
2226 std::env::var(name)
2227 .ok()
2228 .and_then(|v| match v.to_ascii_lowercase().as_str() {
2229 "1" | "true" | "yes" | "on" => Some(true),
2230 "0" | "false" | "no" | "off" => Some(false),
2231 _ => None,
2232 })
2233 .unwrap_or(default)
2234}
2235
2236fn env_usize(name: &str, default: usize) -> usize {
2237 std::env::var(name)
2238 .ok()
2239 .and_then(|v| v.parse::<usize>().ok())
2240 .unwrap_or(default)
2241}
2242
2243fn env_u64(name: &str, default: u64) -> u64 {
2244 std::env::var(name)
2245 .ok()
2246 .and_then(|v| v.parse::<u64>().ok())
2247 .unwrap_or(default)
2248}
2249
2250fn start_workspace_sync(
2254 client: Client,
2255 server: String,
2256 token: Option<String>,
2257 shared_codebases: Arc<Mutex<Vec<String>>>,
2258) -> JoinHandle<()> {
2259 tokio::spawn(async move {
2260 const POLL_INTERVAL_SECS: u64 = 60;
2261 let mut interval = tokio::time::interval(Duration::from_secs(POLL_INTERVAL_SECS));
2262 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
2263 interval.tick().await; loop {
2266 interval.tick().await;
2267 if let Err(e) =
2268 sync_workspaces_from_server(&client, &server, &token, &shared_codebases).await
2269 {
2270 tracing::warn!("Workspace sync failed: {}", e);
2271 }
2272 }
2273 })
2274}
2275
2276async fn sync_workspaces_from_server(
2280 client: &Client,
2281 server: &str,
2282 token: &Option<String>,
2283 shared_codebases: &Arc<Mutex<Vec<String>>>,
2284) -> Result<()> {
2285 let mut req = client.get(format!("{}/v1/agent/workspaces", server));
2286 if let Some(t) = token {
2287 req = req.bearer_auth(t);
2288 }
2289
2290 let res = req.send().await?;
2291 if !res.status().is_success() {
2292 tracing::debug!(
2293 status = %res.status(),
2294 "Workspace sync: server returned non-success, skipping"
2295 );
2296 return Ok(());
2297 }
2298
2299 let data: serde_json::Value = res.json().await?;
2300
2301 let entries = data["workspaces"]
2303 .as_array()
2304 .or_else(|| data["codebases"].as_array())
2305 .cloned()
2306 .unwrap_or_default();
2307
2308 let mut new_paths: Vec<String> = Vec::new();
2309 {
2310 let current = shared_codebases.lock().await;
2311 for entry in &entries {
2312 let path = match entry["path"].as_str().filter(|p| !p.is_empty()) {
2313 Some(p) => p,
2314 None => continue,
2315 };
2316 if std::path::Path::new(path).exists() && !current.iter().any(|c| c.as_str() == path) {
2319 new_paths.push(path.to_string());
2320 }
2321 }
2322 }
2323
2324 if !new_paths.is_empty() {
2325 let mut current = shared_codebases.lock().await;
2326 for path in &new_paths {
2327 tracing::info!(
2328 path = %path,
2329 "Workspace sync: auto-discovered local path, adding to codebases"
2330 );
2331 current.push(path.clone());
2332 }
2333 tracing::info!(
2334 added = new_paths.len(),
2335 total = current.len(),
2336 "Workspace sync complete -- new paths take effect on next reconnect"
2337 );
2338 } else {
2339 tracing::debug!("Workspace sync: no new local paths found");
2340 }
2341
2342 Ok(())
2343}