1use crate::bus::AgentBus;
4use crate::cli::A2aArgs;
5use crate::provider::ProviderRegistry;
6use crate::session::Session;
7use crate::swarm::{DecompositionStrategy, SwarmConfig, SwarmExecutor};
8use crate::tui::swarm_view::SwarmEvent;
9use anyhow::Result;
10use futures::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{Mutex, mpsc};
18use tokio::task::JoinHandle;
19use tokio::time::Instant;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum WorkerStatus {
24 Idle,
25 Processing,
26}
27
28impl WorkerStatus {
29 pub fn as_str(&self) -> &'static str {
30 match self {
31 WorkerStatus::Idle => "idle",
32 WorkerStatus::Processing => "processing",
33 }
34 }
35}
36
37#[derive(Clone)]
39pub struct HeartbeatState {
40 worker_id: String,
41 pub agent_name: String,
42 pub status: Arc<Mutex<WorkerStatus>>,
43 pub active_task_count: Arc<Mutex<usize>>,
44}
45
46impl HeartbeatState {
47 pub fn new(worker_id: String, agent_name: String) -> Self {
48 Self {
49 worker_id,
50 agent_name,
51 status: Arc::new(Mutex::new(WorkerStatus::Idle)),
52 active_task_count: Arc::new(Mutex::new(0)),
53 }
54 }
55
56 pub async fn set_status(&self, status: WorkerStatus) {
57 *self.status.lock().await = status;
58 }
59
60 pub async fn set_task_count(&self, count: usize) {
61 *self.active_task_count.lock().await = count;
62 }
63}
64
65#[derive(Clone, Debug)]
66struct CognitionHeartbeatConfig {
67 enabled: bool,
68 source_base_url: String,
69 include_thought_summary: bool,
70 summary_max_chars: usize,
71 request_timeout_ms: u64,
72}
73
74impl CognitionHeartbeatConfig {
75 fn from_env() -> Self {
76 let source_base_url = std::env::var("CODETETHER_WORKER_COGNITION_SOURCE_URL")
77 .unwrap_or_else(|_| "http://127.0.0.1:4096".to_string())
78 .trim_end_matches('/')
79 .to_string();
80
81 Self {
82 enabled: env_bool("CODETETHER_WORKER_COGNITION_SHARE_ENABLED", true),
83 source_base_url,
84 include_thought_summary: env_bool("CODETETHER_WORKER_COGNITION_INCLUDE_THOUGHTS", true),
85 summary_max_chars: env_usize("CODETETHER_WORKER_COGNITION_THOUGHT_MAX_CHARS", 480)
86 .max(120),
87 request_timeout_ms: env_u64("CODETETHER_WORKER_COGNITION_TIMEOUT_MS", 2_500).max(250),
88 }
89 }
90}
91
92#[derive(Debug, Deserialize)]
93struct CognitionStatusSnapshot {
94 running: bool,
95 #[serde(default)]
96 last_tick_at: Option<String>,
97 #[serde(default)]
98 active_persona_count: usize,
99 #[serde(default)]
100 events_buffered: usize,
101 #[serde(default)]
102 snapshots_buffered: usize,
103 #[serde(default)]
104 loop_interval_ms: u64,
105}
106
107#[derive(Debug, Deserialize)]
108struct CognitionLatestSnapshot {
109 generated_at: String,
110 summary: String,
111 #[serde(default)]
112 metadata: HashMap<String, serde_json::Value>,
113}
114
115pub async fn run(args: A2aArgs) -> Result<()> {
117 let server = args.server.trim_end_matches('/');
118 let name = args
119 .name
120 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
121 let worker_id = generate_worker_id();
122
123 let codebases: Vec<String> = args
124 .codebases
125 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
126 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
127
128 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
129 tracing::info!("Server: {}", server);
130 tracing::info!("Codebases: {:?}", codebases);
131
132 let client = Client::new();
133 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
134 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
135 if cognition_heartbeat.enabled {
136 tracing::info!(
137 source = %cognition_heartbeat.source_base_url,
138 include_thoughts = cognition_heartbeat.include_thought_summary,
139 max_chars = cognition_heartbeat.summary_max_chars,
140 timeout_ms = cognition_heartbeat.request_timeout_ms,
141 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
142 );
143 } else {
144 tracing::warn!(
145 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
146 );
147 }
148
149 let auto_approve = match args.auto_approve.as_str() {
150 "all" => AutoApprove::All,
151 "safe" => AutoApprove::Safe,
152 _ => AutoApprove::None,
153 };
154
155 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
157
158 let bus = AgentBus::new().into_arc();
160 {
161 let handle = bus.handle(&worker_id);
162 handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
163 }
164
165 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
167
168 fetch_pending_tasks(
170 &client,
171 server,
172 &args.token,
173 &worker_id,
174 &processing,
175 &auto_approve,
176 &bus,
177 )
178 .await?;
179
180 loop {
182 if let Err(e) =
184 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
185 {
186 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
187 }
188
189 let heartbeat_handle = start_heartbeat(
191 client.clone(),
192 server.to_string(),
193 args.token.clone(),
194 heartbeat_state.clone(),
195 processing.clone(),
196 cognition_heartbeat.clone(),
197 );
198
199 match connect_stream(
200 &client,
201 server,
202 &args.token,
203 &worker_id,
204 &name,
205 &codebases,
206 &processing,
207 &auto_approve,
208 &bus,
209 None, )
211 .await
212 {
213 Ok(()) => {
214 tracing::warn!("Stream ended, reconnecting...");
215 }
216 Err(e) => {
217 tracing::error!("Stream error: {}, reconnecting...", e);
218 }
219 }
220
221 heartbeat_handle.abort();
223 tracing::debug!("Heartbeat cancelled for reconnection");
224
225 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
226 }
227}
228
229pub async fn run_with_state(
232 args: A2aArgs,
233 server_state: crate::worker_server::WorkerServerState,
234) -> Result<()> {
235 let server = args.server.trim_end_matches('/');
236 let name = args
237 .name
238 .unwrap_or_else(|| format!("codetether-{}", std::process::id()));
239 let worker_id = generate_worker_id();
240
241 server_state.set_worker_id(worker_id.clone()).await;
243
244 let codebases: Vec<String> = args
245 .codebases
246 .map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
247 .unwrap_or_else(|| vec![std::env::current_dir().unwrap().display().to_string()]);
248
249 tracing::info!("Starting A2A worker: {} ({})", name, worker_id);
250 tracing::info!("Server: {}", server);
251 tracing::info!("Codebases: {:?}", codebases);
252
253 let client = Client::new();
254 let processing = Arc::new(Mutex::new(HashSet::<String>::new()));
255 let cognition_heartbeat = CognitionHeartbeatConfig::from_env();
256 if cognition_heartbeat.enabled {
257 tracing::info!(
258 source = %cognition_heartbeat.source_base_url,
259 include_thoughts = cognition_heartbeat.include_thought_summary,
260 max_chars = cognition_heartbeat.summary_max_chars,
261 timeout_ms = cognition_heartbeat.request_timeout_ms,
262 "Cognition heartbeat sharing enabled (set CODETETHER_WORKER_COGNITION_SHARE_ENABLED=false to disable)"
263 );
264 } else {
265 tracing::warn!(
266 "Cognition heartbeat sharing disabled; worker thought state will not be shared upstream"
267 );
268 }
269
270 let auto_approve = match args.auto_approve.as_str() {
271 "all" => AutoApprove::All,
272 "safe" => AutoApprove::Safe,
273 _ => AutoApprove::None,
274 };
275
276 let heartbeat_state = HeartbeatState::new(worker_id.clone(), name.clone());
278
279 server_state
281 .set_heartbeat_state(Arc::new(heartbeat_state.clone()))
282 .await;
283
284 let bus = AgentBus::new().into_arc();
286 {
287 let handle = bus.handle(&worker_id);
288 handle.announce_ready(WORKER_CAPABILITIES.iter().map(|s| s.to_string()).collect());
289 }
290
291 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await?;
293
294 server_state.set_connected(true).await;
296
297 fetch_pending_tasks(
299 &client,
300 server,
301 &args.token,
302 &worker_id,
303 &processing,
304 &auto_approve,
305 &bus,
306 )
307 .await?;
308
309 loop {
311 let (task_notify_tx, task_notify_rx) = mpsc::channel::<String>(32);
314 server_state
315 .set_task_notification_channel(task_notify_tx)
316 .await;
317
318 server_state.set_connected(true).await;
320
321 if let Err(e) =
323 register_worker(&client, server, &args.token, &worker_id, &name, &codebases).await
324 {
325 tracing::warn!("Failed to re-register worker on reconnection: {}", e);
326 }
327
328 let heartbeat_handle = start_heartbeat(
330 client.clone(),
331 server.to_string(),
332 args.token.clone(),
333 heartbeat_state.clone(),
334 processing.clone(),
335 cognition_heartbeat.clone(),
336 );
337
338 match connect_stream(
339 &client,
340 server,
341 &args.token,
342 &worker_id,
343 &name,
344 &codebases,
345 &processing,
346 &auto_approve,
347 &bus,
348 Some(task_notify_rx),
349 )
350 .await
351 {
352 Ok(()) => {
353 tracing::warn!("Stream ended, reconnecting...");
354 }
355 Err(e) => {
356 tracing::error!("Stream error: {}, reconnecting...", e);
357 }
358 }
359
360 server_state.set_connected(false).await;
362
363 heartbeat_handle.abort();
365 tracing::debug!("Heartbeat cancelled for reconnection");
366
367 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
368 }
369}
370
371fn generate_worker_id() -> String {
372 format!(
373 "wrk_{}_{:x}",
374 chrono::Utc::now().timestamp(),
375 rand::random::<u64>()
376 )
377}
378
379#[derive(Debug, Clone, Copy)]
380enum AutoApprove {
381 All,
382 Safe,
383 None,
384}
385
386pub const DEFAULT_A2A_SERVER_URL: &str = "https://api.codetether.run";
388
389const WORKER_CAPABILITIES: &[&str] = &["ralph", "swarm", "rlm", "a2a", "mcp"];
391
392fn task_value<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a serde_json::Value> {
393 task.get("task")
394 .and_then(|t| t.get(key))
395 .or_else(|| task.get(key))
396}
397
398fn task_str<'a>(task: &'a serde_json::Value, key: &str) -> Option<&'a str> {
399 task_value(task, key).and_then(|v| v.as_str())
400}
401
402fn task_metadata(task: &serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
403 task_value(task, "metadata")
404 .and_then(|m| m.as_object())
405 .cloned()
406 .unwrap_or_default()
407}
408
409fn model_ref_to_provider_model(model: &str) -> String {
410 if !model.contains('/') && model.contains(':') {
414 model.replacen(':', "/", 1)
415 } else {
416 model.to_string()
417 }
418}
419
420fn provider_preferences_for_tier(model_tier: Option<&str>) -> &'static [&'static str] {
421 match model_tier.unwrap_or("balanced") {
422 "fast" | "quick" => &[
423 "zai",
424 "openai",
425 "github-copilot",
426 "moonshotai",
427 "openrouter",
428 "novita",
429 "google",
430 "anthropic",
431 ],
432 "heavy" | "deep" => &[
433 "zai",
434 "anthropic",
435 "openai",
436 "github-copilot",
437 "moonshotai",
438 "openrouter",
439 "novita",
440 "google",
441 ],
442 _ => &[
443 "zai",
444 "openai",
445 "github-copilot",
446 "anthropic",
447 "moonshotai",
448 "openrouter",
449 "novita",
450 "google",
451 ],
452 }
453}
454
455fn choose_provider_for_tier<'a>(providers: &'a [&'a str], model_tier: Option<&str>) -> &'a str {
456 for preferred in provider_preferences_for_tier(model_tier) {
457 if let Some(found) = providers.iter().copied().find(|p| *p == *preferred) {
458 return found;
459 }
460 }
461 if let Some(found) = providers.iter().copied().find(|p| *p == "zai") {
462 return found;
463 }
464 providers[0]
465}
466
467fn default_model_for_provider(provider: &str, model_tier: Option<&str>) -> String {
468 match model_tier.unwrap_or("balanced") {
469 "fast" | "quick" => match provider {
470 "moonshotai" => "kimi-k2.5".to_string(),
471 "anthropic" => "claude-haiku-4-5".to_string(),
472 "openai" => "gpt-4o-mini".to_string(),
473 "google" => "gemini-2.5-flash".to_string(),
474 "zhipuai" | "zai" => "glm-5".to_string(),
475 "openrouter" => "z-ai/glm-5".to_string(),
476 "novita" => "qwen/qwen3-coder-next".to_string(),
477 "bedrock" => "amazon.nova-lite-v1:0".to_string(),
478 _ => "glm-5".to_string(),
479 },
480 "heavy" | "deep" => match provider {
481 "moonshotai" => "kimi-k2.5".to_string(),
482 "anthropic" => "claude-sonnet-4-20250514".to_string(),
483 "openai" => "o3".to_string(),
484 "google" => "gemini-2.5-pro".to_string(),
485 "zhipuai" | "zai" => "glm-5".to_string(),
486 "openrouter" => "z-ai/glm-5".to_string(),
487 "novita" => "qwen/qwen3-coder-next".to_string(),
488 "bedrock" => "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
489 _ => "glm-5".to_string(),
490 },
491 _ => match provider {
492 "moonshotai" => "kimi-k2.5".to_string(),
493 "anthropic" => "claude-sonnet-4-20250514".to_string(),
494 "openai" => "gpt-4o".to_string(),
495 "google" => "gemini-2.5-pro".to_string(),
496 "zhipuai" | "zai" => "glm-5".to_string(),
497 "openrouter" => "z-ai/glm-5".to_string(),
498 "novita" => "qwen/qwen3-coder-next".to_string(),
499 "bedrock" => "amazon.nova-lite-v1:0".to_string(),
500 _ => "glm-5".to_string(),
501 },
502 }
503}
504
505fn prefers_temperature_one(model: &str) -> bool {
506 let normalized = model.to_ascii_lowercase();
507 normalized.contains("kimi-k2") || normalized.contains("glm-") || normalized.contains("minimax")
508}
509
510fn is_swarm_agent(agent_type: &str) -> bool {
511 matches!(
512 agent_type.trim().to_ascii_lowercase().as_str(),
513 "swarm" | "parallel" | "multi-agent"
514 )
515}
516
517fn metadata_lookup<'a>(
518 metadata: &'a serde_json::Map<String, serde_json::Value>,
519 key: &str,
520) -> Option<&'a serde_json::Value> {
521 metadata
522 .get(key)
523 .or_else(|| {
524 metadata
525 .get("routing")
526 .and_then(|v| v.as_object())
527 .and_then(|obj| obj.get(key))
528 })
529 .or_else(|| {
530 metadata
531 .get("swarm")
532 .and_then(|v| v.as_object())
533 .and_then(|obj| obj.get(key))
534 })
535}
536
537fn metadata_str(
538 metadata: &serde_json::Map<String, serde_json::Value>,
539 keys: &[&str],
540) -> Option<String> {
541 for key in keys {
542 if let Some(value) = metadata_lookup(metadata, key).and_then(|v| v.as_str()) {
543 let trimmed = value.trim();
544 if !trimmed.is_empty() {
545 return Some(trimmed.to_string());
546 }
547 }
548 }
549 None
550}
551
552fn metadata_usize(
553 metadata: &serde_json::Map<String, serde_json::Value>,
554 keys: &[&str],
555) -> Option<usize> {
556 for key in keys {
557 if let Some(value) = metadata_lookup(metadata, key) {
558 if let Some(v) = value.as_u64() {
559 return usize::try_from(v).ok();
560 }
561 if let Some(v) = value.as_i64() {
562 if v >= 0 {
563 return usize::try_from(v as u64).ok();
564 }
565 }
566 if let Some(v) = value.as_str() {
567 if let Ok(parsed) = v.trim().parse::<usize>() {
568 return Some(parsed);
569 }
570 }
571 }
572 }
573 None
574}
575
576fn metadata_u64(
577 metadata: &serde_json::Map<String, serde_json::Value>,
578 keys: &[&str],
579) -> Option<u64> {
580 for key in keys {
581 if let Some(value) = metadata_lookup(metadata, key) {
582 if let Some(v) = value.as_u64() {
583 return Some(v);
584 }
585 if let Some(v) = value.as_i64() {
586 if v >= 0 {
587 return Some(v as u64);
588 }
589 }
590 if let Some(v) = value.as_str() {
591 if let Ok(parsed) = v.trim().parse::<u64>() {
592 return Some(parsed);
593 }
594 }
595 }
596 }
597 None
598}
599
600fn metadata_bool(
601 metadata: &serde_json::Map<String, serde_json::Value>,
602 keys: &[&str],
603) -> Option<bool> {
604 for key in keys {
605 if let Some(value) = metadata_lookup(metadata, key) {
606 if let Some(v) = value.as_bool() {
607 return Some(v);
608 }
609 if let Some(v) = value.as_str() {
610 match v.trim().to_ascii_lowercase().as_str() {
611 "1" | "true" | "yes" | "on" => return Some(true),
612 "0" | "false" | "no" | "off" => return Some(false),
613 _ => {}
614 }
615 }
616 }
617 }
618 None
619}
620
621fn parse_swarm_strategy(
622 metadata: &serde_json::Map<String, serde_json::Value>,
623) -> DecompositionStrategy {
624 match metadata_str(
625 metadata,
626 &[
627 "decomposition_strategy",
628 "swarm_strategy",
629 "strategy",
630 "swarm_decomposition",
631 ],
632 )
633 .as_deref()
634 .map(|s| s.to_ascii_lowercase())
635 .as_deref()
636 {
637 Some("none") | Some("single") => DecompositionStrategy::None,
638 Some("domain") | Some("by_domain") => DecompositionStrategy::ByDomain,
639 Some("data") | Some("by_data") => DecompositionStrategy::ByData,
640 Some("stage") | Some("by_stage") => DecompositionStrategy::ByStage,
641 _ => DecompositionStrategy::Automatic,
642 }
643}
644
645async fn resolve_swarm_model(
646 explicit_model: Option<String>,
647 model_tier: Option<&str>,
648) -> Option<String> {
649 if let Some(model) = explicit_model {
650 if !model.trim().is_empty() {
651 return Some(model);
652 }
653 }
654
655 let registry = ProviderRegistry::from_vault().await.ok()?;
656 let providers = registry.list();
657 if providers.is_empty() {
658 return None;
659 }
660 let provider = choose_provider_for_tier(providers.as_slice(), model_tier);
661 let model = default_model_for_provider(provider, model_tier);
662 Some(format!("{}/{}", provider, model))
663}
664
665fn format_swarm_event_for_output(event: &SwarmEvent) -> Option<String> {
666 match event {
667 SwarmEvent::Started {
668 task,
669 total_subtasks,
670 } => Some(format!(
671 "[swarm] started task={} planned_subtasks={}",
672 task, total_subtasks
673 )),
674 SwarmEvent::StageComplete {
675 stage,
676 completed,
677 failed,
678 } => Some(format!(
679 "[swarm] stage={} completed={} failed={}",
680 stage, completed, failed
681 )),
682 SwarmEvent::SubTaskUpdate { id, status, .. } => Some(format!(
683 "[swarm] subtask id={} status={}",
684 &id.chars().take(8).collect::<String>(),
685 format!("{status:?}").to_ascii_lowercase()
686 )),
687 SwarmEvent::AgentToolCall {
688 subtask_id,
689 tool_name,
690 } => Some(format!(
691 "[swarm] subtask id={} tool={}",
692 &subtask_id.chars().take(8).collect::<String>(),
693 tool_name
694 )),
695 SwarmEvent::AgentError { subtask_id, error } => Some(format!(
696 "[swarm] subtask id={} error={}",
697 &subtask_id.chars().take(8).collect::<String>(),
698 error
699 )),
700 SwarmEvent::Complete { success, stats } => Some(format!(
701 "[swarm] complete success={} subtasks={} speedup={:.2}",
702 success,
703 stats.subagents_completed + stats.subagents_failed,
704 stats.speedup_factor
705 )),
706 SwarmEvent::Error(err) => Some(format!("[swarm] error message={}", err)),
707 _ => None,
708 }
709}
710
711async fn register_worker(
712 client: &Client,
713 server: &str,
714 token: &Option<String>,
715 worker_id: &str,
716 name: &str,
717 codebases: &[String],
718) -> Result<()> {
719 let models = match load_provider_models().await {
721 Ok(m) => m,
722 Err(e) => {
723 tracing::warn!(
724 "Failed to load provider models: {}, proceeding without model info",
725 e
726 );
727 HashMap::new()
728 }
729 };
730
731 let mut req = client.post(format!("{}/v1/opencode/workers/register", server));
733
734 if let Some(t) = token {
735 req = req.bearer_auth(t);
736 }
737
738 let models_array: Vec<serde_json::Value> = models
741 .iter()
742 .flat_map(|(provider, model_infos)| {
743 model_infos.iter().map(move |m| {
744 let mut obj = serde_json::json!({
745 "id": format!("{}/{}", provider, m.id),
746 "name": &m.id,
747 "provider": provider,
748 "provider_id": provider,
749 });
750 if let Some(input_cost) = m.input_cost_per_million {
751 obj["input_cost_per_million"] = serde_json::json!(input_cost);
752 }
753 if let Some(output_cost) = m.output_cost_per_million {
754 obj["output_cost_per_million"] = serde_json::json!(output_cost);
755 }
756 obj
757 })
758 })
759 .collect();
760
761 tracing::info!(
762 "Registering worker with {} models from {} providers",
763 models_array.len(),
764 models.len()
765 );
766
767 let hostname = std::env::var("HOSTNAME")
768 .or_else(|_| std::env::var("COMPUTERNAME"))
769 .unwrap_or_else(|_| "unknown".to_string());
770
771 let res = req
772 .json(&serde_json::json!({
773 "worker_id": worker_id,
774 "name": name,
775 "capabilities": WORKER_CAPABILITIES,
776 "hostname": hostname,
777 "models": models_array,
778 "codebases": codebases,
779 }))
780 .send()
781 .await?;
782
783 if res.status().is_success() {
784 tracing::info!("Worker registered successfully");
785 } else {
786 tracing::warn!("Failed to register worker: {}", res.status());
787 }
788
789 Ok(())
790}
791
792async fn load_provider_models() -> Result<HashMap<String, Vec<crate::provider::ModelInfo>>> {
796 let registry = match ProviderRegistry::from_vault().await {
798 Ok(r) if !r.list().is_empty() => {
799 tracing::info!("Loaded {} providers from Vault", r.list().len());
800 r
801 }
802 Ok(_) => {
803 tracing::warn!("Vault returned 0 providers, falling back to config/env vars");
804 fallback_registry().await?
805 }
806 Err(e) => {
807 tracing::warn!("Vault unreachable ({}), falling back to config/env vars", e);
808 fallback_registry().await?
809 }
810 };
811
812 let catalog = crate::provider::models::ModelCatalog::fetch().await.ok();
814
815 let catalog_alias = |pid: &str| -> String {
817 match pid {
818 "bedrock" => "amazon-bedrock".to_string(),
819 "novita" => "novita-ai".to_string(),
820 _ => pid.to_string(),
821 }
822 };
823
824 let mut models_by_provider: HashMap<String, Vec<crate::provider::ModelInfo>> = HashMap::new();
825
826 for provider_name in registry.list() {
827 if let Some(provider) = registry.get(provider_name) {
828 match provider.list_models().await {
829 Ok(models) => {
830 let enriched: Vec<crate::provider::ModelInfo> = models
831 .into_iter()
832 .map(|mut m| {
833 if m.input_cost_per_million.is_none()
835 || m.output_cost_per_million.is_none()
836 {
837 if let Some(ref cat) = catalog {
838 let cat_pid = catalog_alias(provider_name);
839 if let Some(prov_info) = cat.get_provider(&cat_pid) {
840 let model_info =
842 prov_info.models.get(&m.id).or_else(|| {
843 m.id.strip_prefix("us.").and_then(|stripped| {
844 prov_info.models.get(stripped)
845 })
846 });
847 if let Some(model_info) = model_info {
848 if let Some(ref cost) = model_info.cost {
849 if m.input_cost_per_million.is_none() {
850 m.input_cost_per_million = Some(cost.input);
851 }
852 if m.output_cost_per_million.is_none() {
853 m.output_cost_per_million = Some(cost.output);
854 }
855 }
856 }
857 }
858 }
859 }
860 m
861 })
862 .collect();
863 if !enriched.is_empty() {
864 tracing::debug!("Provider {}: {} models", provider_name, enriched.len());
865 models_by_provider.insert(provider_name.to_string(), enriched);
866 }
867 }
868 Err(e) => {
869 tracing::debug!("Failed to list models for {}: {}", provider_name, e);
870 }
871 }
872 }
873 }
874
875 if models_by_provider.is_empty() {
879 tracing::info!(
880 "No authenticated providers found, fetching models.dev catalog (all providers)"
881 );
882 if let Ok(cat) = crate::provider::models::ModelCatalog::fetch().await {
883 for (provider_id, provider_info) in cat.all_providers() {
886 let model_infos: Vec<crate::provider::ModelInfo> = provider_info
887 .models
888 .values()
889 .map(|m| cat.to_model_info(m, provider_id))
890 .collect();
891 if !model_infos.is_empty() {
892 tracing::debug!(
893 "Catalog provider {}: {} models",
894 provider_id,
895 model_infos.len()
896 );
897 models_by_provider.insert(provider_id.clone(), model_infos);
898 }
899 }
900 tracing::info!(
901 "Loaded {} providers with {} total models from catalog",
902 models_by_provider.len(),
903 models_by_provider.values().map(|v| v.len()).sum::<usize>()
904 );
905 }
906 }
907
908 Ok(models_by_provider)
909}
910
911async fn fallback_registry() -> Result<ProviderRegistry> {
913 let config = crate::config::Config::load().await.unwrap_or_default();
914 ProviderRegistry::from_config(&config).await
915}
916
917async fn fetch_pending_tasks(
918 client: &Client,
919 server: &str,
920 token: &Option<String>,
921 worker_id: &str,
922 processing: &Arc<Mutex<HashSet<String>>>,
923 auto_approve: &AutoApprove,
924 bus: &Arc<AgentBus>,
925) -> Result<()> {
926 tracing::info!("Checking for pending tasks...");
927
928 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
929 if let Some(t) = token {
930 req = req.bearer_auth(t);
931 }
932
933 let res = req.send().await?;
934 if !res.status().is_success() {
935 return Ok(());
936 }
937
938 let data: serde_json::Value = res.json().await?;
939 let tasks = if let Some(arr) = data.as_array() {
941 arr.clone()
942 } else {
943 data["tasks"].as_array().cloned().unwrap_or_default()
944 };
945
946 tracing::info!("Found {} pending task(s)", tasks.len());
947
948 for task in tasks {
949 if let Some(id) = task["id"].as_str() {
950 let mut proc = processing.lock().await;
951 if !proc.contains(id) {
952 proc.insert(id.to_string());
953 drop(proc);
954
955 let task_id = id.to_string();
956 let client = client.clone();
957 let server = server.to_string();
958 let token = token.clone();
959 let worker_id = worker_id.to_string();
960 let auto_approve = *auto_approve;
961 let processing = processing.clone();
962 let bus = bus.clone();
963
964 tokio::spawn(async move {
965 if let Err(e) = handle_task(
966 &client,
967 &server,
968 &token,
969 &worker_id,
970 &task,
971 auto_approve,
972 &bus,
973 )
974 .await
975 {
976 tracing::error!("Task {} failed: {}", task_id, e);
977 }
978 processing.lock().await.remove(&task_id);
979 });
980 }
981 }
982 }
983
984 Ok(())
985}
986
987#[allow(clippy::too_many_arguments)]
988async fn connect_stream(
989 client: &Client,
990 server: &str,
991 token: &Option<String>,
992 worker_id: &str,
993 name: &str,
994 codebases: &[String],
995 processing: &Arc<Mutex<HashSet<String>>>,
996 auto_approve: &AutoApprove,
997 bus: &Arc<AgentBus>,
998 task_notify_rx: Option<mpsc::Receiver<String>>,
999) -> Result<()> {
1000 let url = format!(
1001 "{}/v1/worker/tasks/stream?agent_name={}&worker_id={}",
1002 server,
1003 urlencoding::encode(name),
1004 urlencoding::encode(worker_id)
1005 );
1006
1007 let mut req = client
1008 .get(&url)
1009 .header("Accept", "text/event-stream")
1010 .header("X-Worker-ID", worker_id)
1011 .header("X-Agent-Name", name)
1012 .header("X-Codebases", codebases.join(","));
1013
1014 if let Some(t) = token {
1015 req = req.bearer_auth(t);
1016 }
1017
1018 let res = req.send().await?;
1019 if !res.status().is_success() {
1020 anyhow::bail!("Failed to connect: {}", res.status());
1021 }
1022
1023 tracing::info!("Connected to A2A server");
1024
1025 let mut stream = res.bytes_stream();
1026 let mut buffer = String::new();
1027 let mut poll_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
1028 poll_interval.tick().await; let mut task_notify_rx = task_notify_rx;
1032
1033 loop {
1034 tokio::select! {
1035 task_id = async {
1038 if let Some(ref mut rx) = task_notify_rx {
1039 rx.recv().await
1040 } else {
1041 futures::future::pending().await
1043 }
1044 } => {
1045 if let Some(task_id) = task_id {
1046 tracing::info!("Received task notification via CloudEvent: {}", task_id);
1047 if let Err(e) = poll_pending_tasks(
1049 client, server, token, worker_id, processing, auto_approve, bus,
1050 ).await {
1051 tracing::warn!("Task notification poll failed: {}", e);
1052 }
1053 }
1054 }
1055 chunk = stream.next() => {
1056 match chunk {
1057 Some(Ok(chunk)) => {
1058 buffer.push_str(&String::from_utf8_lossy(&chunk));
1059
1060 while let Some(pos) = buffer.find("\n\n") {
1062 let event_str = buffer[..pos].to_string();
1063 buffer = buffer[pos + 2..].to_string();
1064
1065 if let Some(data_line) = event_str.lines().find(|l| l.starts_with("data:")) {
1066 let data = data_line.trim_start_matches("data:").trim();
1067 if data == "[DONE]" || data.is_empty() {
1068 continue;
1069 }
1070
1071 if let Ok(task) = serde_json::from_str::<serde_json::Value>(data) {
1072 spawn_task_handler(
1073 &task, client, server, token, worker_id,
1074 processing, auto_approve, bus,
1075 ).await;
1076 }
1077 }
1078 }
1079 }
1080 Some(Err(e)) => {
1081 return Err(e.into());
1082 }
1083 None => {
1084 return Ok(());
1086 }
1087 }
1088 }
1089 _ = poll_interval.tick() => {
1090 if let Err(e) = poll_pending_tasks(
1092 client, server, token, worker_id, processing, auto_approve, bus,
1093 ).await {
1094 tracing::warn!("Periodic task poll failed: {}", e);
1095 }
1096 }
1097 }
1098 }
1099}
1100
1101async fn spawn_task_handler(
1102 task: &serde_json::Value,
1103 client: &Client,
1104 server: &str,
1105 token: &Option<String>,
1106 worker_id: &str,
1107 processing: &Arc<Mutex<HashSet<String>>>,
1108 auto_approve: &AutoApprove,
1109 bus: &Arc<AgentBus>,
1110) {
1111 if let Some(id) = task
1112 .get("task")
1113 .and_then(|t| t["id"].as_str())
1114 .or_else(|| task["id"].as_str())
1115 {
1116 let mut proc = processing.lock().await;
1117 if !proc.contains(id) {
1118 proc.insert(id.to_string());
1119 drop(proc);
1120
1121 let task_id = id.to_string();
1122 let task = task.clone();
1123 let client = client.clone();
1124 let server = server.to_string();
1125 let token = token.clone();
1126 let worker_id = worker_id.to_string();
1127 let auto_approve = *auto_approve;
1128 let processing_clone = processing.clone();
1129 let bus = bus.clone();
1130
1131 tokio::spawn(async move {
1132 if let Err(e) = handle_task(
1133 &client,
1134 &server,
1135 &token,
1136 &worker_id,
1137 &task,
1138 auto_approve,
1139 &bus,
1140 )
1141 .await
1142 {
1143 tracing::error!("Task {} failed: {}", task_id, e);
1144 }
1145 processing_clone.lock().await.remove(&task_id);
1146 });
1147 }
1148 }
1149}
1150
1151async fn poll_pending_tasks(
1152 client: &Client,
1153 server: &str,
1154 token: &Option<String>,
1155 worker_id: &str,
1156 processing: &Arc<Mutex<HashSet<String>>>,
1157 auto_approve: &AutoApprove,
1158 bus: &Arc<AgentBus>,
1159) -> Result<()> {
1160 let mut req = client.get(format!("{}/v1/opencode/tasks?status=pending", server));
1161 if let Some(t) = token {
1162 req = req.bearer_auth(t);
1163 }
1164
1165 let res = req.send().await?;
1166 if !res.status().is_success() {
1167 return Ok(());
1168 }
1169
1170 let data: serde_json::Value = res.json().await?;
1171 let tasks = if let Some(arr) = data.as_array() {
1172 arr.clone()
1173 } else {
1174 data["tasks"].as_array().cloned().unwrap_or_default()
1175 };
1176
1177 if !tasks.is_empty() {
1178 tracing::debug!("Poll found {} pending task(s)", tasks.len());
1179 }
1180
1181 for task in &tasks {
1182 spawn_task_handler(
1183 task,
1184 client,
1185 server,
1186 token,
1187 worker_id,
1188 processing,
1189 auto_approve,
1190 bus,
1191 )
1192 .await;
1193 }
1194
1195 Ok(())
1196}
1197
1198async fn handle_task(
1199 client: &Client,
1200 server: &str,
1201 token: &Option<String>,
1202 worker_id: &str,
1203 task: &serde_json::Value,
1204 auto_approve: AutoApprove,
1205 bus: &Arc<AgentBus>,
1206) -> Result<()> {
1207 let task_id = task_str(task, "id").ok_or_else(|| anyhow::anyhow!("No task ID"))?;
1208 let title = task_str(task, "title").unwrap_or("Untitled");
1209
1210 tracing::info!("Handling task: {} ({})", title, task_id);
1211
1212 let mut req = client
1214 .post(format!("{}/v1/worker/tasks/claim", server))
1215 .header("X-Worker-ID", worker_id);
1216 if let Some(t) = token {
1217 req = req.bearer_auth(t);
1218 }
1219
1220 let res = req
1221 .json(&serde_json::json!({ "task_id": task_id }))
1222 .send()
1223 .await?;
1224
1225 if !res.status().is_success() {
1226 let status = res.status();
1227 let text = res.text().await?;
1228 if status == reqwest::StatusCode::CONFLICT {
1229 tracing::debug!(task_id, "Task already claimed by another worker, skipping");
1230 } else {
1231 tracing::warn!(task_id, %status, "Failed to claim task: {}", text);
1232 }
1233 return Ok(());
1234 }
1235
1236 tracing::info!("Claimed task: {}", task_id);
1237
1238 let metadata = task_metadata(task);
1239 let resume_session_id = metadata
1240 .get("resume_session_id")
1241 .and_then(|v| v.as_str())
1242 .map(|s| s.trim().to_string())
1243 .filter(|s| !s.is_empty());
1244 let complexity_hint = metadata_str(&metadata, &["complexity"]);
1245 let model_tier = metadata_str(&metadata, &["model_tier", "tier"])
1246 .map(|s| s.to_ascii_lowercase())
1247 .or_else(|| {
1248 complexity_hint.as_ref().map(|complexity| {
1249 match complexity.to_ascii_lowercase().as_str() {
1250 "quick" => "fast".to_string(),
1251 "deep" => "heavy".to_string(),
1252 _ => "balanced".to_string(),
1253 }
1254 })
1255 });
1256 let worker_personality = metadata_str(
1257 &metadata,
1258 &["worker_personality", "personality", "agent_personality"],
1259 );
1260 let target_agent_name = metadata_str(&metadata, &["target_agent_name", "agent_name"]);
1261 let raw_model = task_str(task, "model_ref")
1262 .or_else(|| metadata_lookup(&metadata, "model_ref").and_then(|v| v.as_str()))
1263 .or_else(|| task_str(task, "model"))
1264 .or_else(|| metadata_lookup(&metadata, "model").and_then(|v| v.as_str()));
1265 let selected_model = raw_model.map(model_ref_to_provider_model);
1266
1267 let mut session = if let Some(ref sid) = resume_session_id {
1269 match Session::load(sid).await {
1270 Ok(existing) => {
1271 tracing::info!("Resuming session {} for task {}", sid, task_id);
1272 existing
1273 }
1274 Err(e) => {
1275 tracing::warn!(
1276 "Could not load session {} for task {} ({}), starting a new session",
1277 sid,
1278 task_id,
1279 e
1280 );
1281 Session::new().await?
1282 }
1283 }
1284 } else {
1285 Session::new().await?
1286 };
1287
1288 let agent_type = task_str(task, "agent_type")
1289 .or_else(|| task_str(task, "agent"))
1290 .unwrap_or("build");
1291 session.agent = agent_type.to_string();
1292
1293 if let Some(model) = selected_model.clone() {
1294 session.metadata.model = Some(model);
1295 }
1296
1297 let prompt = task_str(task, "prompt")
1298 .or_else(|| task_str(task, "description"))
1299 .unwrap_or(title);
1300
1301 tracing::info!("Executing prompt: {}", prompt);
1302
1303 let stream_client = client.clone();
1305 let stream_server = server.to_string();
1306 let stream_token = token.clone();
1307 let stream_worker_id = worker_id.to_string();
1308 let stream_task_id = task_id.to_string();
1309
1310 let output_callback = move |output: String| {
1311 let c = stream_client.clone();
1312 let s = stream_server.clone();
1313 let t = stream_token.clone();
1314 let w = stream_worker_id.clone();
1315 let tid = stream_task_id.clone();
1316 tokio::spawn(async move {
1317 let mut req = c
1318 .post(format!("{}/v1/opencode/tasks/{}/output", s, tid))
1319 .header("X-Worker-ID", &w);
1320 if let Some(tok) = &t {
1321 req = req.bearer_auth(tok);
1322 }
1323 let _ = req
1324 .json(&serde_json::json!({
1325 "worker_id": w,
1326 "output": output,
1327 }))
1328 .send()
1329 .await;
1330 });
1331 };
1332
1333 let (status, result, error, session_id) = if is_swarm_agent(agent_type) {
1335 match execute_swarm_with_policy(
1336 &mut session,
1337 prompt,
1338 model_tier.as_deref(),
1339 selected_model,
1340 &metadata,
1341 complexity_hint.as_deref(),
1342 worker_personality.as_deref(),
1343 target_agent_name.as_deref(),
1344 Some(bus),
1345 Some(&output_callback),
1346 )
1347 .await
1348 {
1349 Ok((session_result, true)) => {
1350 tracing::info!("Swarm task completed successfully: {}", task_id);
1351 (
1352 "completed",
1353 Some(session_result.text),
1354 None,
1355 Some(session_result.session_id),
1356 )
1357 }
1358 Ok((session_result, false)) => {
1359 tracing::warn!("Swarm task completed with failures: {}", task_id);
1360 (
1361 "failed",
1362 Some(session_result.text),
1363 Some("Swarm execution completed with failures".to_string()),
1364 Some(session_result.session_id),
1365 )
1366 }
1367 Err(e) => {
1368 tracing::error!("Swarm task failed: {} - {}", task_id, e);
1369 ("failed", None, Some(format!("Error: {}", e)), None)
1370 }
1371 }
1372 } else {
1373 match execute_session_with_policy(
1374 &mut session,
1375 prompt,
1376 auto_approve,
1377 model_tier.as_deref(),
1378 Some(&output_callback),
1379 )
1380 .await
1381 {
1382 Ok(session_result) => {
1383 tracing::info!("Task completed successfully: {}", task_id);
1384 (
1385 "completed",
1386 Some(session_result.text),
1387 None,
1388 Some(session_result.session_id),
1389 )
1390 }
1391 Err(e) => {
1392 tracing::error!("Task failed: {} - {}", task_id, e);
1393 ("failed", None, Some(format!("Error: {}", e)), None)
1394 }
1395 }
1396 };
1397
1398 let mut req = client
1400 .post(format!("{}/v1/worker/tasks/release", server))
1401 .header("X-Worker-ID", worker_id);
1402 if let Some(t) = token {
1403 req = req.bearer_auth(t);
1404 }
1405
1406 req.json(&serde_json::json!({
1407 "task_id": task_id,
1408 "status": status,
1409 "result": result,
1410 "error": error,
1411 "session_id": session_id.unwrap_or_else(|| session.id.clone()),
1412 }))
1413 .send()
1414 .await?;
1415
1416 tracing::info!("Task released: {} with status: {}", task_id, status);
1417 Ok(())
1418}
1419
1420async fn execute_swarm_with_policy<F>(
1421 session: &mut Session,
1422 prompt: &str,
1423 model_tier: Option<&str>,
1424 explicit_model: Option<String>,
1425 metadata: &serde_json::Map<String, serde_json::Value>,
1426 complexity_hint: Option<&str>,
1427 worker_personality: Option<&str>,
1428 target_agent_name: Option<&str>,
1429 bus: Option<&Arc<AgentBus>>,
1430 output_callback: Option<&F>,
1431) -> Result<(crate::session::SessionResult, bool)>
1432where
1433 F: Fn(String),
1434{
1435 use crate::provider::{ContentPart, Message, Role};
1436
1437 session.add_message(Message {
1438 role: Role::User,
1439 content: vec![ContentPart::Text {
1440 text: prompt.to_string(),
1441 }],
1442 });
1443
1444 if session.title.is_none() {
1445 session.generate_title().await?;
1446 }
1447
1448 let strategy = parse_swarm_strategy(metadata);
1449 let max_subagents = metadata_usize(
1450 metadata,
1451 &["swarm_max_subagents", "max_subagents", "subagents"],
1452 )
1453 .unwrap_or(10)
1454 .clamp(1, 100);
1455 let max_steps_per_subagent = metadata_usize(
1456 metadata,
1457 &[
1458 "swarm_max_steps_per_subagent",
1459 "max_steps_per_subagent",
1460 "max_steps",
1461 ],
1462 )
1463 .unwrap_or(50)
1464 .clamp(1, 200);
1465 let timeout_secs = metadata_u64(metadata, &["swarm_timeout_secs", "timeout_secs", "timeout"])
1466 .unwrap_or(600)
1467 .clamp(30, 3600);
1468 let parallel_enabled =
1469 metadata_bool(metadata, &["swarm_parallel_enabled", "parallel_enabled"]).unwrap_or(true);
1470
1471 let model = resolve_swarm_model(explicit_model, model_tier).await;
1472 if let Some(ref selected_model) = model {
1473 session.metadata.model = Some(selected_model.clone());
1474 }
1475
1476 if let Some(cb) = output_callback {
1477 cb(format!(
1478 "[swarm] routing complexity={} tier={} personality={} target_agent={}",
1479 complexity_hint.unwrap_or("standard"),
1480 model_tier.unwrap_or("balanced"),
1481 worker_personality.unwrap_or("auto"),
1482 target_agent_name.unwrap_or("auto")
1483 ));
1484 cb(format!(
1485 "[swarm] config strategy={:?} max_subagents={} max_steps={} timeout={}s tier={}",
1486 strategy,
1487 max_subagents,
1488 max_steps_per_subagent,
1489 timeout_secs,
1490 model_tier.unwrap_or("balanced")
1491 ));
1492 }
1493
1494 let swarm_config = SwarmConfig {
1495 max_subagents,
1496 max_steps_per_subagent,
1497 subagent_timeout_secs: timeout_secs,
1498 parallel_enabled,
1499 model,
1500 working_dir: session
1501 .metadata
1502 .directory
1503 .as_ref()
1504 .map(|p| p.to_string_lossy().to_string()),
1505 ..Default::default()
1506 };
1507
1508 let swarm_result = if output_callback.is_some() {
1509 let (event_tx, mut event_rx) = mpsc::channel(256);
1510 let mut executor = SwarmExecutor::new(swarm_config).with_event_tx(event_tx);
1511 if let Some(bus) = bus {
1512 executor = executor.with_bus(Arc::clone(bus));
1513 }
1514 let prompt_owned = prompt.to_string();
1515 let mut exec_handle =
1516 tokio::spawn(async move { executor.execute(&prompt_owned, strategy).await });
1517
1518 let mut final_result: Option<crate::swarm::SwarmResult> = None;
1519
1520 while final_result.is_none() {
1521 tokio::select! {
1522 maybe_event = event_rx.recv() => {
1523 if let Some(event) = maybe_event {
1524 if let Some(cb) = output_callback {
1525 if let Some(line) = format_swarm_event_for_output(&event) {
1526 cb(line);
1527 }
1528 }
1529 }
1530 }
1531 join_result = &mut exec_handle => {
1532 let joined = join_result.map_err(|e| anyhow::anyhow!("Swarm join failure: {}", e))?;
1533 final_result = Some(joined?);
1534 }
1535 }
1536 }
1537
1538 while let Ok(event) = event_rx.try_recv() {
1539 if let Some(cb) = output_callback {
1540 if let Some(line) = format_swarm_event_for_output(&event) {
1541 cb(line);
1542 }
1543 }
1544 }
1545
1546 final_result.ok_or_else(|| anyhow::anyhow!("Swarm execution returned no result"))?
1547 } else {
1548 let mut executor = SwarmExecutor::new(swarm_config);
1549 if let Some(bus) = bus {
1550 executor = executor.with_bus(Arc::clone(bus));
1551 }
1552 executor.execute(prompt, strategy).await?
1553 };
1554
1555 let final_text = if swarm_result.result.trim().is_empty() {
1556 if swarm_result.success {
1557 "Swarm completed without textual output.".to_string()
1558 } else {
1559 "Swarm finished with failures and no textual output.".to_string()
1560 }
1561 } else {
1562 swarm_result.result.clone()
1563 };
1564
1565 session.add_message(Message {
1566 role: Role::Assistant,
1567 content: vec![ContentPart::Text {
1568 text: final_text.clone(),
1569 }],
1570 });
1571 session.save().await?;
1572
1573 Ok((
1574 crate::session::SessionResult {
1575 text: final_text,
1576 session_id: session.id.clone(),
1577 },
1578 swarm_result.success,
1579 ))
1580}
1581
1582async fn execute_session_with_policy<F>(
1585 session: &mut Session,
1586 prompt: &str,
1587 auto_approve: AutoApprove,
1588 model_tier: Option<&str>,
1589 output_callback: Option<&F>,
1590) -> Result<crate::session::SessionResult>
1591where
1592 F: Fn(String),
1593{
1594 use crate::provider::{
1595 CompletionRequest, ContentPart, Message, ProviderRegistry, Role, parse_model_string,
1596 };
1597 use std::sync::Arc;
1598
1599 let registry = ProviderRegistry::from_vault().await?;
1601 let providers = registry.list();
1602 tracing::info!("Available providers: {:?}", providers);
1603
1604 if providers.is_empty() {
1605 anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
1606 }
1607
1608 let (provider_name, model_id) = if let Some(ref model_str) = session.metadata.model {
1610 let (prov, model) = parse_model_string(model_str);
1611 let prov = prov.map(|p| if p == "zhipuai" { "zai" } else { p });
1612 if prov.is_some() {
1613 (prov.map(|s| s.to_string()), model.to_string())
1614 } else if providers.contains(&model) {
1615 (Some(model.to_string()), String::new())
1616 } else {
1617 (None, model.to_string())
1618 }
1619 } else {
1620 (None, String::new())
1621 };
1622
1623 let provider_slice = providers.as_slice();
1624 let provider_requested_but_unavailable = provider_name
1625 .as_deref()
1626 .map(|p| !providers.contains(&p))
1627 .unwrap_or(false);
1628
1629 let selected_provider = provider_name
1631 .as_deref()
1632 .filter(|p| providers.contains(p))
1633 .unwrap_or_else(|| choose_provider_for_tier(provider_slice, model_tier));
1634
1635 let provider = registry
1636 .get(selected_provider)
1637 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
1638
1639 session.add_message(Message {
1641 role: Role::User,
1642 content: vec![ContentPart::Text {
1643 text: prompt.to_string(),
1644 }],
1645 });
1646
1647 if session.title.is_none() {
1649 session.generate_title().await?;
1650 }
1651
1652 let model = if !model_id.is_empty() && !provider_requested_but_unavailable {
1655 model_id
1656 } else {
1657 default_model_for_provider(selected_provider, model_tier)
1658 };
1659
1660 let tool_registry =
1662 create_filtered_registry(Arc::clone(&provider), model.clone(), auto_approve);
1663 let tool_definitions = tool_registry.definitions();
1664
1665 let temperature = if prefers_temperature_one(&model) {
1666 Some(1.0)
1667 } else {
1668 Some(0.7)
1669 };
1670
1671 tracing::info!(
1672 "Using model: {} via provider: {} (tier: {:?})",
1673 model,
1674 selected_provider,
1675 model_tier
1676 );
1677 tracing::info!(
1678 "Available tools: {} (auto_approve: {:?})",
1679 tool_definitions.len(),
1680 auto_approve
1681 );
1682
1683 let cwd = std::env::var("PWD")
1685 .map(std::path::PathBuf::from)
1686 .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
1687 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
1688
1689 let mut final_output = String::new();
1690 let max_steps = 50;
1691
1692 for step in 1..=max_steps {
1693 tracing::info!(step = step, "Agent step starting");
1694
1695 let mut messages = vec![Message {
1697 role: Role::System,
1698 content: vec![ContentPart::Text {
1699 text: system_prompt.clone(),
1700 }],
1701 }];
1702 messages.extend(session.messages.clone());
1703
1704 let request = CompletionRequest {
1705 messages,
1706 tools: tool_definitions.clone(),
1707 model: model.clone(),
1708 temperature,
1709 top_p: None,
1710 max_tokens: Some(8192),
1711 stop: Vec::new(),
1712 };
1713
1714 let response = provider.complete(request).await?;
1715
1716 crate::telemetry::TOKEN_USAGE.record_model_usage(
1717 &model,
1718 response.usage.prompt_tokens as u64,
1719 response.usage.completion_tokens as u64,
1720 );
1721
1722 let tool_calls: Vec<(String, String, serde_json::Value)> = response
1724 .message
1725 .content
1726 .iter()
1727 .filter_map(|part| {
1728 if let ContentPart::ToolCall {
1729 id,
1730 name,
1731 arguments,
1732 } = part
1733 {
1734 let args: serde_json::Value =
1735 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
1736 Some((id.clone(), name.clone(), args))
1737 } else {
1738 None
1739 }
1740 })
1741 .collect();
1742
1743 for part in &response.message.content {
1745 if let ContentPart::Text { text } = part {
1746 if !text.is_empty() {
1747 final_output.push_str(text);
1748 final_output.push('\n');
1749 if let Some(cb) = output_callback {
1750 cb(text.clone());
1751 }
1752 }
1753 }
1754 }
1755
1756 if tool_calls.is_empty() {
1758 session.add_message(response.message.clone());
1759 break;
1760 }
1761
1762 session.add_message(response.message.clone());
1763
1764 tracing::info!(
1765 step = step,
1766 num_tools = tool_calls.len(),
1767 "Executing tool calls"
1768 );
1769
1770 for (tool_id, tool_name, tool_input) in tool_calls {
1772 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
1773
1774 if let Some(cb) = output_callback {
1776 cb(format!("[tool:start:{}]", tool_name));
1777 }
1778
1779 if !is_tool_allowed(&tool_name, auto_approve) {
1781 let msg = format!(
1782 "Tool '{}' requires approval but auto-approve policy is {:?}",
1783 tool_name, auto_approve
1784 );
1785 tracing::warn!(tool = %tool_name, "Tool blocked by auto-approve policy");
1786 session.add_message(Message {
1787 role: Role::Tool,
1788 content: vec![ContentPart::ToolResult {
1789 tool_call_id: tool_id,
1790 content: msg,
1791 }],
1792 });
1793 continue;
1794 }
1795
1796 let content = if let Some(tool) = tool_registry.get(&tool_name) {
1797 let exec_result: Result<crate::tool::ToolResult> =
1798 tool.execute(tool_input.clone()).await;
1799 match exec_result {
1800 Ok(result) => {
1801 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
1802 if let Some(cb) = output_callback {
1803 let status = if result.success { "ok" } else { "err" };
1804 cb(format!(
1805 "[tool:{}:{}] {}",
1806 tool_name,
1807 status,
1808 &result.output[..result.output.len().min(500)]
1809 ));
1810 }
1811 result.output
1812 }
1813 Err(e) => {
1814 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
1815 if let Some(cb) = output_callback {
1816 cb(format!("[tool:{}:err] {}", tool_name, e));
1817 }
1818 format!("Error: {}", e)
1819 }
1820 }
1821 } else {
1822 tracing::warn!(tool = %tool_name, "Tool not found");
1823 format!("Error: Unknown tool '{}'", tool_name)
1824 };
1825
1826 session.add_message(Message {
1827 role: Role::Tool,
1828 content: vec![ContentPart::ToolResult {
1829 tool_call_id: tool_id,
1830 content,
1831 }],
1832 });
1833 }
1834 }
1835
1836 session.save().await?;
1837
1838 Ok(crate::session::SessionResult {
1839 text: final_output.trim().to_string(),
1840 session_id: session.id.clone(),
1841 })
1842}
1843
1844fn is_tool_allowed(tool_name: &str, auto_approve: AutoApprove) -> bool {
1846 match auto_approve {
1847 AutoApprove::All => true,
1848 AutoApprove::Safe | AutoApprove::None => is_safe_tool(tool_name),
1849 }
1850}
1851
1852fn is_safe_tool(tool_name: &str) -> bool {
1854 let safe_tools = [
1855 "read",
1856 "list",
1857 "glob",
1858 "grep",
1859 "codesearch",
1860 "lsp",
1861 "webfetch",
1862 "websearch",
1863 "todo_read",
1864 "skill",
1865 ];
1866 safe_tools.contains(&tool_name)
1867}
1868
1869fn create_filtered_registry(
1871 provider: Arc<dyn crate::provider::Provider>,
1872 model: String,
1873 auto_approve: AutoApprove,
1874) -> crate::tool::ToolRegistry {
1875 use crate::tool::*;
1876
1877 let mut registry = ToolRegistry::new();
1878
1879 registry.register(Arc::new(file::ReadTool::new()));
1881 registry.register(Arc::new(file::ListTool::new()));
1882 registry.register(Arc::new(file::GlobTool::new()));
1883 registry.register(Arc::new(search::GrepTool::new()));
1884 registry.register(Arc::new(lsp::LspTool::new()));
1885 registry.register(Arc::new(webfetch::WebFetchTool::new()));
1886 registry.register(Arc::new(websearch::WebSearchTool::new()));
1887 registry.register(Arc::new(codesearch::CodeSearchTool::new()));
1888 registry.register(Arc::new(todo::TodoReadTool::new()));
1889 registry.register(Arc::new(skill::SkillTool::new()));
1890
1891 if matches!(auto_approve, AutoApprove::All) {
1893 registry.register(Arc::new(file::WriteTool::new()));
1894 registry.register(Arc::new(edit::EditTool::new()));
1895 registry.register(Arc::new(bash::BashTool::new()));
1896 registry.register(Arc::new(multiedit::MultiEditTool::new()));
1897 registry.register(Arc::new(patch::ApplyPatchTool::new()));
1898 registry.register(Arc::new(todo::TodoWriteTool::new()));
1899 registry.register(Arc::new(task::TaskTool::new()));
1900 registry.register(Arc::new(plan::PlanEnterTool::new()));
1901 registry.register(Arc::new(plan::PlanExitTool::new()));
1902 registry.register(Arc::new(rlm::RlmTool::new(Arc::clone(&provider), model.clone())));
1903 registry.register(Arc::new(ralph::RalphTool::with_provider(provider, model)));
1904 registry.register(Arc::new(prd::PrdTool::new()));
1905 registry.register(Arc::new(confirm_edit::ConfirmEditTool::new()));
1906 registry.register(Arc::new(confirm_multiedit::ConfirmMultiEditTool::new()));
1907 registry.register(Arc::new(undo::UndoTool));
1908 registry.register(Arc::new(mcp_bridge::McpBridgeTool::new()));
1909 }
1910
1911 registry.register(Arc::new(invalid::InvalidTool::new()));
1912
1913 registry
1914}
1915
1916fn start_heartbeat(
1919 client: Client,
1920 server: String,
1921 token: Option<String>,
1922 heartbeat_state: HeartbeatState,
1923 processing: Arc<Mutex<HashSet<String>>>,
1924 cognition_config: CognitionHeartbeatConfig,
1925) -> JoinHandle<()> {
1926 tokio::spawn(async move {
1927 let mut consecutive_failures = 0u32;
1928 const MAX_FAILURES: u32 = 3;
1929 const HEARTBEAT_INTERVAL_SECS: u64 = 30;
1930 const COGNITION_RETRY_COOLDOWN_SECS: u64 = 300;
1931 let mut cognition_payload_disabled_until: Option<Instant> = None;
1932
1933 let mut interval =
1934 tokio::time::interval(tokio::time::Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
1935 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1936
1937 loop {
1938 interval.tick().await;
1939
1940 let active_count = processing.lock().await.len();
1942 heartbeat_state.set_task_count(active_count).await;
1943
1944 let status = if active_count > 0 {
1946 WorkerStatus::Processing
1947 } else {
1948 WorkerStatus::Idle
1949 };
1950 heartbeat_state.set_status(status).await;
1951
1952 let url = format!(
1954 "{}/v1/opencode/workers/{}/heartbeat",
1955 server, heartbeat_state.worker_id
1956 );
1957 let mut req = client.post(&url);
1958
1959 if let Some(ref t) = token {
1960 req = req.bearer_auth(t);
1961 }
1962
1963 let status_str = heartbeat_state.status.lock().await.as_str().to_string();
1964 let base_payload = serde_json::json!({
1965 "worker_id": &heartbeat_state.worker_id,
1966 "agent_name": &heartbeat_state.agent_name,
1967 "status": status_str,
1968 "active_task_count": active_count,
1969 });
1970 let mut payload = base_payload.clone();
1971 let mut included_cognition_payload = false;
1972 let cognition_payload_allowed = cognition_payload_disabled_until
1973 .map(|until| Instant::now() >= until)
1974 .unwrap_or(true);
1975
1976 if cognition_config.enabled
1977 && cognition_payload_allowed
1978 && let Some(cognition_payload) =
1979 fetch_cognition_heartbeat_payload(&client, &cognition_config).await
1980 && let Some(obj) = payload.as_object_mut()
1981 {
1982 obj.insert("cognition".to_string(), cognition_payload);
1983 included_cognition_payload = true;
1984 }
1985
1986 match req.json(&payload).send().await {
1987 Ok(res) => {
1988 if res.status().is_success() {
1989 consecutive_failures = 0;
1990 tracing::debug!(
1991 worker_id = %heartbeat_state.worker_id,
1992 status = status_str,
1993 active_tasks = active_count,
1994 "Heartbeat sent successfully"
1995 );
1996 } else if included_cognition_payload && res.status().is_client_error() {
1997 tracing::warn!(
1998 worker_id = %heartbeat_state.worker_id,
1999 status = %res.status(),
2000 "Heartbeat cognition payload rejected, retrying without cognition payload"
2001 );
2002
2003 let mut retry_req = client.post(&url);
2004 if let Some(ref t) = token {
2005 retry_req = retry_req.bearer_auth(t);
2006 }
2007
2008 match retry_req.json(&base_payload).send().await {
2009 Ok(retry_res) if retry_res.status().is_success() => {
2010 cognition_payload_disabled_until = Some(
2011 Instant::now()
2012 + Duration::from_secs(COGNITION_RETRY_COOLDOWN_SECS),
2013 );
2014 consecutive_failures = 0;
2015 tracing::warn!(
2016 worker_id = %heartbeat_state.worker_id,
2017 retry_after_secs = COGNITION_RETRY_COOLDOWN_SECS,
2018 "Paused cognition heartbeat payload after schema rejection"
2019 );
2020 }
2021 Ok(retry_res) => {
2022 consecutive_failures += 1;
2023 tracing::warn!(
2024 worker_id = %heartbeat_state.worker_id,
2025 status = %retry_res.status(),
2026 failures = consecutive_failures,
2027 "Heartbeat failed even after retry without cognition payload"
2028 );
2029 }
2030 Err(e) => {
2031 consecutive_failures += 1;
2032 tracing::warn!(
2033 worker_id = %heartbeat_state.worker_id,
2034 error = %e,
2035 failures = consecutive_failures,
2036 "Heartbeat retry without cognition payload failed"
2037 );
2038 }
2039 }
2040 } else {
2041 consecutive_failures += 1;
2042 tracing::warn!(
2043 worker_id = %heartbeat_state.worker_id,
2044 status = %res.status(),
2045 failures = consecutive_failures,
2046 "Heartbeat failed"
2047 );
2048 }
2049 }
2050 Err(e) => {
2051 consecutive_failures += 1;
2052 tracing::warn!(
2053 worker_id = %heartbeat_state.worker_id,
2054 error = %e,
2055 failures = consecutive_failures,
2056 "Heartbeat request failed"
2057 );
2058 }
2059 }
2060
2061 if consecutive_failures >= MAX_FAILURES {
2063 tracing::error!(
2064 worker_id = %heartbeat_state.worker_id,
2065 failures = consecutive_failures,
2066 "Heartbeat failed {} consecutive times - worker will continue running and attempt reconnection via SSE loop",
2067 MAX_FAILURES
2068 );
2069 consecutive_failures = 0;
2071 }
2072 }
2073 })
2074}
2075
2076async fn fetch_cognition_heartbeat_payload(
2077 client: &Client,
2078 config: &CognitionHeartbeatConfig,
2079) -> Option<serde_json::Value> {
2080 let status_url = format!("{}/v1/cognition/status", config.source_base_url);
2081 let status_res = tokio::time::timeout(
2082 Duration::from_millis(config.request_timeout_ms),
2083 client.get(status_url).send(),
2084 )
2085 .await
2086 .ok()?
2087 .ok()?;
2088
2089 if !status_res.status().is_success() {
2090 return None;
2091 }
2092
2093 let status: CognitionStatusSnapshot = status_res.json().await.ok()?;
2094 let mut payload = serde_json::json!({
2095 "running": status.running,
2096 "last_tick_at": status.last_tick_at,
2097 "active_persona_count": status.active_persona_count,
2098 "events_buffered": status.events_buffered,
2099 "snapshots_buffered": status.snapshots_buffered,
2100 "loop_interval_ms": status.loop_interval_ms,
2101 });
2102
2103 if config.include_thought_summary {
2104 let snapshot_url = format!("{}/v1/cognition/snapshots/latest", config.source_base_url);
2105 let snapshot_res = tokio::time::timeout(
2106 Duration::from_millis(config.request_timeout_ms),
2107 client.get(snapshot_url).send(),
2108 )
2109 .await
2110 .ok()
2111 .and_then(Result::ok);
2112
2113 if let Some(snapshot_res) = snapshot_res
2114 && snapshot_res.status().is_success()
2115 && let Ok(snapshot) = snapshot_res.json::<CognitionLatestSnapshot>().await
2116 && let Some(obj) = payload.as_object_mut()
2117 {
2118 obj.insert(
2119 "latest_snapshot_at".to_string(),
2120 serde_json::Value::String(snapshot.generated_at),
2121 );
2122 obj.insert(
2123 "latest_thought".to_string(),
2124 serde_json::Value::String(trim_for_heartbeat(
2125 &snapshot.summary,
2126 config.summary_max_chars,
2127 )),
2128 );
2129 if let Some(model) = snapshot
2130 .metadata
2131 .get("model")
2132 .and_then(serde_json::Value::as_str)
2133 {
2134 obj.insert(
2135 "latest_thought_model".to_string(),
2136 serde_json::Value::String(model.to_string()),
2137 );
2138 }
2139 if let Some(source) = snapshot
2140 .metadata
2141 .get("source")
2142 .and_then(serde_json::Value::as_str)
2143 {
2144 obj.insert(
2145 "latest_thought_source".to_string(),
2146 serde_json::Value::String(source.to_string()),
2147 );
2148 }
2149 }
2150 }
2151
2152 Some(payload)
2153}
2154
2155fn trim_for_heartbeat(input: &str, max_chars: usize) -> String {
2156 if input.chars().count() <= max_chars {
2157 return input.trim().to_string();
2158 }
2159
2160 let mut trimmed = String::with_capacity(max_chars + 3);
2161 for ch in input.chars().take(max_chars) {
2162 trimmed.push(ch);
2163 }
2164 trimmed.push_str("...");
2165 trimmed.trim().to_string()
2166}
2167
2168fn env_bool(name: &str, default: bool) -> bool {
2169 std::env::var(name)
2170 .ok()
2171 .and_then(|v| match v.to_ascii_lowercase().as_str() {
2172 "1" | "true" | "yes" | "on" => Some(true),
2173 "0" | "false" | "no" | "off" => Some(false),
2174 _ => None,
2175 })
2176 .unwrap_or(default)
2177}
2178
2179fn env_usize(name: &str, default: usize) -> usize {
2180 std::env::var(name)
2181 .ok()
2182 .and_then(|v| v.parse::<usize>().ok())
2183 .unwrap_or(default)
2184}
2185
2186fn env_u64(name: &str, default: u64) -> u64 {
2187 std::env::var(name)
2188 .ok()
2189 .and_then(|v| v.parse::<u64>().ok())
2190 .unwrap_or(default)
2191}