1use agentic_config::types::OrchestratorConfig;
6use anyhow::Context;
7use opencode_rs::Client;
8use opencode_rs::server::ManagedServer;
9use opencode_rs::server::ServerOptions;
10use opencode_rs::types::message::Message;
11use opencode_rs::types::message::Part;
12use opencode_rs::types::provider::ProviderListResponse;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::sync::Mutex as StdMutex;
17use std::time::Duration;
18use tokio::sync::Mutex as AsyncMutex;
19use tokio::sync::RwLock;
20
21use crate::error::OrchestratorError;
22use crate::version;
23
24pub const OPENCODE_ORCHESTRATOR_MANAGED_ENV: &str = "OPENCODE_ORCHESTRATOR_MANAGED";
26
27pub const ORCHESTRATOR_MANAGED_GUARD_MESSAGE: &str = "ENV VAR OPENCODE_ORCHESTRATOR_MANAGED is set to 1. This most commonly happens when you're \
29 in a nested orchestration session. Consult a human for assistance or try to accomplish your \
30 task without the orchestration tools.";
31
32pub fn managed_guard_enabled() -> bool {
34 match std::env::var(OPENCODE_ORCHESTRATOR_MANAGED_ENV) {
35 Ok(v) => v != "0" && !v.trim().is_empty(),
36 Err(_) => false,
37 }
38}
39
40pub async fn init_with_retry<T, F, Fut>(mut f: F) -> anyhow::Result<T>
42where
43 F: FnMut(usize) -> Fut,
44 Fut: std::future::Future<Output = anyhow::Result<T>>,
45{
46 let mut last_err: Option<anyhow::Error> = None;
47
48 for attempt in 1..=2 {
49 tracing::info!(attempt, "orchestrator server lazy init attempt");
50 match f(attempt).await {
51 Ok(v) => {
52 if attempt > 1 {
53 tracing::info!(
54 attempt,
55 "orchestrator server lazy init succeeded after retry"
56 );
57 }
58 return Ok(v);
59 }
60 Err(e) => {
61 tracing::warn!(attempt, error = %e, "orchestrator server lazy init failed");
62 last_err = Some(e);
63 }
64 }
65 }
66
67 tracing::error!("orchestrator server lazy init exhausted retries");
68 match last_err {
70 Some(e) => Err(e),
71 None => anyhow::bail!("init_with_retry: unexpected empty error state"),
72 }
73}
74
75pub type ModelKey = (String, String);
77
78#[derive(Debug, Clone, PartialEq, Eq)]
79enum ServerEntryState {
80 Healthy,
81 NeedsRecovery { reason: String },
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum RecoveryMode {
86 Managed,
87 External,
88}
89
90impl RecoveryMode {
91 fn as_str(self) -> &'static str {
92 match self {
93 Self::Managed => "managed",
94 Self::External => "external",
95 }
96 }
97}
98
99enum HandleState {
100 Empty,
101 Ready {
102 snapshot: Arc<OrchestratorServer>,
103 mode: RecoveryMode,
104 },
105 Stale {
106 snapshot: Arc<OrchestratorServer>,
107 mode: RecoveryMode,
108 reason: String,
109 },
110 Failed {
111 mode: RecoveryMode,
112 base_url: Option<String>,
113 error: String,
114 },
115}
116
117const TOOL_ENTRY_HEALTH_PROBE_TIMEOUT: Duration = Duration::from_secs(5);
118
119pub struct OrchestratorServerHandle {
121 state: AsyncMutex<HandleState>,
122}
123
124impl Default for OrchestratorServerHandle {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl OrchestratorServerHandle {
131 #[must_use]
132 pub fn new() -> Self {
133 Self {
134 state: AsyncMutex::new(HandleState::Empty),
135 }
136 }
137
138 pub async fn acquire(&self) -> anyhow::Result<Arc<OrchestratorServer>> {
143 self.get_or_recover_with(OrchestratorServer::start_lazy)
144 .await
145 }
146
147 async fn get_or_recover_with<F, Fut>(
148 &self,
149 mut start: F,
150 ) -> anyhow::Result<Arc<OrchestratorServer>>
151 where
152 F: FnMut() -> Fut,
153 Fut: std::future::Future<Output = anyhow::Result<OrchestratorServer>>,
154 {
155 loop {
156 let ready_snapshot = {
157 let mut state = self.state.lock().await;
158
159 match &mut *state {
160 HandleState::Empty => {
161 tracing::info!(
162 "orchestrator server missing cached snapshot; starting embedded server"
163 );
164
165 match start().await {
166 Ok(server) => {
167 let rebuilt_mode = if server.is_managed() {
168 RecoveryMode::Managed
169 } else {
170 RecoveryMode::External
171 };
172 let rebuilt = Arc::new(server);
173 trace_state_transition(
174 "Empty",
175 "Ready",
176 "initialization",
177 rebuilt_mode,
178 Some(rebuilt.base_url()),
179 );
180 *state = HandleState::Ready {
181 snapshot: Arc::clone(&rebuilt),
182 mode: rebuilt_mode,
183 };
184 return Ok(rebuilt);
185 }
186 Err(error) => {
187 let reason = error.to_string();
188 trace_state_transition(
189 "Empty",
190 "Failed",
191 &reason,
192 RecoveryMode::Managed,
193 None,
194 );
195 *state = HandleState::Failed {
196 mode: RecoveryMode::Managed,
197 base_url: None,
198 error: reason,
199 };
200 return Err(error);
201 }
202 }
203 }
204 HandleState::Ready { snapshot, mode } => Some((Arc::clone(snapshot), *mode)),
205 HandleState::Stale {
206 snapshot,
207 mode,
208 reason,
209 } => match mode {
210 RecoveryMode::Managed => {
211 let stale_reason = reason.clone();
212 match start().await {
213 Ok(server) => {
214 let rebuilt_mode = if server.is_managed() {
215 RecoveryMode::Managed
216 } else {
217 RecoveryMode::External
218 };
219 let rebuilt = Arc::new(server);
220 trace_state_transition(
221 "Stale",
222 "Ready",
223 &stale_reason,
224 rebuilt_mode,
225 Some(rebuilt.base_url()),
226 );
227 *state = HandleState::Ready {
228 snapshot: Arc::clone(&rebuilt),
229 mode: rebuilt_mode,
230 };
231 return Ok(rebuilt);
232 }
233 Err(error) => {
234 let failure = error.to_string();
235 trace_state_transition(
236 "Stale",
237 "Failed",
238 &failure,
239 *mode,
240 Some(snapshot.base_url()),
241 );
242 *state = HandleState::Failed {
243 mode: *mode,
244 base_url: Some(snapshot.base_url().to_string()),
245 error: failure,
246 };
247 return Err(error);
248 }
249 }
250 }
251 RecoveryMode::External => {
252 let base_url = snapshot.base_url().to_string();
253 let stale_reason = reason.clone();
254 trace_state_transition(
255 "Stale",
256 "Failed",
257 &stale_reason,
258 *mode,
259 Some(&base_url),
260 );
261 *state = HandleState::Failed {
262 mode: *mode,
263 base_url: Some(base_url.clone()),
264 error: stale_reason.clone(),
265 };
266 return Err(external_unavailable(Some(base_url), stale_reason));
267 }
268 },
269 HandleState::Failed {
270 mode,
271 base_url,
272 error,
273 } => match mode {
274 RecoveryMode::Managed => match start().await {
275 Ok(server) => {
276 let rebuilt_mode = if server.is_managed() {
277 RecoveryMode::Managed
278 } else {
279 RecoveryMode::External
280 };
281 let rebuilt = Arc::new(server);
282 trace_state_transition(
283 "Failed",
284 "Ready",
285 error,
286 rebuilt_mode,
287 Some(rebuilt.base_url()),
288 );
289 *state = HandleState::Ready {
290 snapshot: Arc::clone(&rebuilt),
291 mode: rebuilt_mode,
292 };
293 return Ok(rebuilt);
294 }
295 Err(start_error) => {
296 let failure = start_error.to_string();
297 error.clone_from(&failure);
298 return Err(start_error);
299 }
300 },
301 RecoveryMode::External => {
302 return Err(external_unavailable(base_url.clone(), error.clone()));
303 }
304 },
305 }
306 };
307
308 let Some((snapshot, mode)) = ready_snapshot else {
309 continue;
310 };
311
312 let validation = snapshot.validate_for_tool_entry().await?;
313
314 let mut state = self.state.lock().await;
315 let HandleState::Ready {
316 snapshot: current,
317 mode: current_mode,
318 } = &*state
319 else {
320 continue;
321 };
322
323 if !Arc::ptr_eq(current, &snapshot) || *current_mode != mode {
324 continue;
325 }
326
327 match validation {
328 ServerEntryState::Healthy => return Ok(snapshot),
329 ServerEntryState::NeedsRecovery { reason } => {
330 trace_cache_invalidated(&reason, mode, Some(snapshot.base_url()));
331
332 match mode {
333 RecoveryMode::Managed => {
334 tracing::warn!(reason = %reason, "cached orchestrator server failed liveness check; rebuilding");
335 trace_state_transition(
336 "Ready",
337 "Stale",
338 &reason,
339 mode,
340 Some(snapshot.base_url()),
341 );
342 *state = HandleState::Stale {
343 snapshot: Arc::clone(&snapshot),
344 mode,
345 reason: reason.clone(),
346 };
347
348 match start().await {
349 Ok(server) => {
350 let rebuilt_mode = if server.is_managed() {
351 RecoveryMode::Managed
352 } else {
353 RecoveryMode::External
354 };
355 let rebuilt = Arc::new(server);
356 trace_state_transition(
357 "Stale",
358 "Ready",
359 &reason,
360 rebuilt_mode,
361 Some(rebuilt.base_url()),
362 );
363 *state = HandleState::Ready {
364 snapshot: Arc::clone(&rebuilt),
365 mode: rebuilt_mode,
366 };
367 return Ok(rebuilt);
368 }
369 Err(error) => {
370 let failure = error.to_string();
371 trace_state_transition(
372 "Stale",
373 "Failed",
374 &failure,
375 mode,
376 Some(snapshot.base_url()),
377 );
378 *state = HandleState::Failed {
379 mode,
380 base_url: Some(snapshot.base_url().to_string()),
381 error: failure,
382 };
383 return Err(error);
384 }
385 }
386 }
387 RecoveryMode::External => {
388 let base_url = snapshot.base_url().to_string();
389 trace_state_transition(
390 "Ready",
391 "Failed",
392 &reason,
393 mode,
394 Some(&base_url),
395 );
396 *state = HandleState::Failed {
397 mode,
398 base_url: Some(base_url.clone()),
399 error: reason.clone(),
400 };
401 return Err(external_unavailable(Some(base_url), reason));
402 }
403 }
404 }
405 }
406 }
407 }
408
409 #[cfg(any(test, feature = "test-support"))]
410 #[must_use]
411 pub fn from_server_unshared(server: OrchestratorServer) -> Self {
412 let mode = if server.is_managed() {
413 RecoveryMode::Managed
414 } else {
415 RecoveryMode::External
416 };
417
418 Self {
419 state: AsyncMutex::new(HandleState::Ready {
420 snapshot: Arc::new(server),
421 mode,
422 }),
423 }
424 }
425
426 #[cfg(any(test, feature = "test-support"))]
427 pub async fn acquire_or_recover_with<F, Fut>(
428 &self,
429 start: F,
430 ) -> anyhow::Result<Arc<OrchestratorServer>>
431 where
432 F: FnMut() -> Fut,
433 Fut: std::future::Future<Output = anyhow::Result<OrchestratorServer>>,
434 {
435 self.get_or_recover_with(start).await
436 }
437}
438
439fn trace_cache_invalidated(reason: &str, mode: RecoveryMode, base_url: Option<&str>) {
440 if let Some(base_url) = base_url {
441 tracing::info!(
442 event = "cache_invalidated",
443 reason = %reason,
444 mode = mode.as_str(),
445 base_url = %base_url,
446 );
447 } else {
448 tracing::info!(
449 event = "cache_invalidated",
450 reason = %reason,
451 mode = mode.as_str(),
452 );
453 }
454}
455
456fn trace_state_transition(
457 from: &'static str,
458 to: &'static str,
459 reason: &str,
460 mode: RecoveryMode,
461 base_url: Option<&str>,
462) {
463 if let Some(base_url) = base_url {
464 tracing::info!(
465 event = "state_transition",
466 from,
467 to,
468 reason = %reason,
469 mode = mode.as_str(),
470 base_url = %base_url,
471 );
472 } else {
473 tracing::info!(
474 event = "state_transition",
475 from,
476 to,
477 reason = %reason,
478 mode = mode.as_str(),
479 );
480 }
481}
482
483fn external_unavailable(base_url: Option<String>, reason: String) -> anyhow::Error {
484 OrchestratorError::ExternalServerUnavailable {
485 base_url: base_url.unwrap_or_else(|| "<unknown>".to_string()),
486 reason,
487 }
488 .into()
489}
490
491pub struct OrchestratorServer {
493 managed_server: StdMutex<Option<ManagedServer>>,
496 client: Client,
498 model_context_limits: HashMap<ModelKey, u64>,
500 base_url: String,
502 config: OrchestratorConfig,
504 spawned_sessions: Arc<RwLock<HashSet<String>>>,
506}
507
508impl OrchestratorServer {
509 #[allow(clippy::allow_attributes, dead_code)]
518 pub async fn start() -> anyhow::Result<Arc<Self>> {
519 Ok(Arc::new(Self::start_impl().await?))
520 }
521
522 pub async fn start_lazy() -> anyhow::Result<Self> {
532 Self::start_lazy_with_config(None).await
533 }
534
535 pub async fn start_lazy_with_config(config_json: Option<String>) -> anyhow::Result<Self> {
546 if managed_guard_enabled() {
547 anyhow::bail!(ORCHESTRATOR_MANAGED_GUARD_MESSAGE);
548 }
549
550 init_with_retry(|_attempt| {
551 let cfg = config_json.clone();
552 async move { Self::start_impl_with_config(cfg).await }
553 })
554 .await
555 }
556
557 async fn start_impl() -> anyhow::Result<Self> {
559 let cwd = std::env::current_dir().context("Failed to resolve current directory")?;
560
561 let config = match agentic_config::loader::load_merged(&cwd) {
563 Ok(loaded) => {
564 for w in &loaded.warnings {
565 tracing::warn!("{w}");
566 }
567 loaded.config.orchestrator
568 }
569 Err(e) => {
570 tracing::warn!("Failed to load config, using defaults: {e}");
571 OrchestratorConfig::default()
572 }
573 };
574
575 let launcher_config = version::resolve_launcher_config(&cwd)
576 .context("Failed to resolve OpenCode launcher configuration")?;
577
578 tracing::info!(
579 binary = %launcher_config.binary,
580 launcher_args = ?launcher_config.launcher_args,
581 expected_version = %version::PINNED_OPENCODE_VERSION,
582 "starting embedded opencode serve (pinned stable)"
583 );
584
585 let opts = ServerOptions::default()
586 .binary(&launcher_config.binary)
587 .launcher_args(launcher_config.launcher_args)
588 .directory(cwd.clone());
589
590 let managed = ManagedServer::start(opts)
591 .await
592 .context("Failed to start embedded `opencode serve`")?;
593
594 let base_url = managed.url().to_string().trim_end_matches('/').to_string();
596
597 let client = Client::builder()
598 .base_url(&base_url)
599 .directory(cwd.to_string_lossy().to_string())
600 .build()
601 .context("Failed to build opencode-rs HTTP client")?;
602
603 let health = client
604 .misc()
605 .health()
606 .await
607 .context("Failed to fetch /global/health for version validation")?;
608
609 version::validate_exact_version(health.version.as_deref()).with_context(|| {
610 format!(
611 "Embedded OpenCode server did not match pinned stable v{} (binary={})",
612 version::PINNED_OPENCODE_VERSION,
613 launcher_config.binary
614 )
615 })?;
616
617 let model_context_limits = Self::load_model_limits(&client).await.unwrap_or_else(|e| {
619 tracing::warn!("Failed to load model limits: {}", e);
620 HashMap::new()
621 });
622
623 tracing::info!("Loaded {} model context limits", model_context_limits.len());
624
625 Ok(Self {
626 managed_server: StdMutex::new(Some(managed)),
627 client,
628 model_context_limits,
629 base_url,
630 config,
631 spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
632 })
633 }
634
635 async fn start_impl_with_config(config_json: Option<String>) -> anyhow::Result<Self> {
637 let cwd = std::env::current_dir().context("Failed to resolve current directory")?;
638
639 let config = match agentic_config::loader::load_merged(&cwd) {
641 Ok(loaded) => {
642 for w in &loaded.warnings {
643 tracing::warn!("{w}");
644 }
645 loaded.config.orchestrator
646 }
647 Err(e) => {
648 tracing::warn!("Failed to load config, using defaults: {e}");
649 OrchestratorConfig::default()
650 }
651 };
652
653 let launcher_config = version::resolve_launcher_config(&cwd)
654 .context("Failed to resolve OpenCode launcher configuration")?;
655
656 tracing::info!(
657 binary = %launcher_config.binary,
658 launcher_args = ?launcher_config.launcher_args,
659 expected_version = %version::PINNED_OPENCODE_VERSION,
660 config_injected = config_json.is_some(),
661 "starting embedded opencode serve (pinned stable)"
662 );
663
664 let mut opts = ServerOptions::default()
665 .binary(&launcher_config.binary)
666 .launcher_args(launcher_config.launcher_args)
667 .directory(cwd.clone());
668
669 if let Some(cfg) = config_json {
671 opts = opts.config_json(cfg);
672 }
673
674 let managed = ManagedServer::start(opts)
675 .await
676 .context("Failed to start embedded `opencode serve`")?;
677
678 let base_url = managed.url().to_string().trim_end_matches('/').to_string();
680
681 let client = Client::builder()
682 .base_url(&base_url)
683 .directory(cwd.to_string_lossy().to_string())
684 .build()
685 .context("Failed to build opencode-rs HTTP client")?;
686
687 let health = client
688 .misc()
689 .health()
690 .await
691 .context("Failed to fetch /global/health for version validation")?;
692
693 version::validate_exact_version(health.version.as_deref()).with_context(|| {
694 format!(
695 "Embedded OpenCode server did not match pinned stable v{} (binary={})",
696 version::PINNED_OPENCODE_VERSION,
697 launcher_config.binary
698 )
699 })?;
700
701 let model_context_limits = Self::load_model_limits(&client).await.unwrap_or_else(|e| {
703 tracing::warn!("Failed to load model limits: {}", e);
704 HashMap::new()
705 });
706
707 tracing::info!("Loaded {} model context limits", model_context_limits.len());
708
709 Ok(Self {
710 managed_server: StdMutex::new(Some(managed)),
711 client,
712 model_context_limits,
713 base_url,
714 config,
715 spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
716 })
717 }
718
719 pub fn client(&self) -> &Client {
721 &self.client
722 }
723
724 #[allow(clippy::allow_attributes, dead_code)]
726 pub fn base_url(&self) -> &str {
727 &self.base_url
728 }
729
730 pub fn context_limit(&self, provider_id: &str, model_id: &str) -> Option<u64> {
732 self.model_context_limits
733 .get(&(provider_id.to_string(), model_id.to_string()))
734 .copied()
735 }
736
737 pub fn session_deadline(&self) -> Duration {
739 Duration::from_secs(self.config.session_deadline_secs)
740 }
741
742 pub fn inactivity_timeout(&self) -> Duration {
744 Duration::from_secs(self.config.inactivity_timeout_secs)
745 }
746
747 pub fn compaction_threshold(&self) -> f64 {
749 self.config.compaction_threshold
750 }
751
752 pub fn spawned_sessions(&self) -> &Arc<RwLock<HashSet<String>>> {
754 &self.spawned_sessions
755 }
756
757 fn managed_server_lock(&self) -> std::sync::MutexGuard<'_, Option<ManagedServer>> {
758 self.managed_server
759 .lock()
760 .unwrap_or_else(std::sync::PoisonError::into_inner)
761 }
762
763 fn is_managed(&self) -> bool {
764 self.managed_server_lock().is_some()
765 }
766
767 async fn validate_for_tool_entry(&self) -> anyhow::Result<ServerEntryState> {
768 self.validate_for_tool_entry_with_timeout(TOOL_ENTRY_HEALTH_PROBE_TIMEOUT)
769 .await
770 }
771
772 async fn validate_for_tool_entry_with_timeout(
773 &self,
774 health_probe_timeout: Duration,
775 ) -> anyhow::Result<ServerEntryState> {
776 if self.is_managed() {
777 let is_running = {
778 let mut managed = self.managed_server_lock();
779 managed
780 .as_mut()
781 .is_some_and(opencode_rs::server::ManagedServer::is_running)
782 };
783
784 if !is_running {
785 return Ok(ServerEntryState::NeedsRecovery {
786 reason: "managed child is no longer running".to_string(),
787 });
788 }
789 }
790
791 match tokio::time::timeout(health_probe_timeout, self.client.misc().health()).await {
792 Ok(Ok(health)) if health.healthy => Ok(ServerEntryState::Healthy),
793 Ok(Ok(_health)) => Ok(ServerEntryState::NeedsRecovery {
794 reason: "/global/health reported unhealthy".to_string(),
795 }),
796 Ok(Err(error)) => Ok(ServerEntryState::NeedsRecovery {
797 reason: format!("/global/health probe failed: {error}"),
798 }),
799 Err(_elapsed) => Ok(ServerEntryState::NeedsRecovery {
800 reason: format!("/global/health probe timed out after {health_probe_timeout:?}"),
801 }),
802 }
803 }
804
805 async fn load_model_limits(client: &Client) -> anyhow::Result<HashMap<ModelKey, u64>> {
807 let resp: ProviderListResponse = client.providers().list().await?;
808 let mut limits = HashMap::new();
809
810 for provider in resp.all {
811 for (model_id, model) in provider.models {
812 if let Some(limit) = model.limit.as_ref().and_then(|l| l.context) {
813 limits.insert((provider.id.clone(), model_id), limit);
814 }
815 }
816 }
817
818 Ok(limits)
819 }
820
821 pub fn extract_assistant_text(messages: &[Message]) -> Option<String> {
823 let assistant_msg = messages.iter().rev().find(|m| m.info.role == "assistant")?;
825
826 let text: String = assistant_msg
828 .parts
829 .iter()
830 .filter_map(|p| {
831 if let Part::Text { text, .. } = p {
832 Some(text.as_str())
833 } else {
834 None
835 }
836 })
837 .collect::<Vec<_>>()
838 .join("");
839
840 if text.trim().is_empty() {
841 None
842 } else {
843 Some(text)
844 }
845 }
846}
847
848#[cfg(any(test, feature = "test-support"))]
854#[allow(dead_code, clippy::allow_attributes)]
855impl OrchestratorServer {
856 pub fn from_client(
861 client: Client,
862 base_url: impl Into<String>,
863 mode: RecoveryMode,
864 ) -> Arc<Self> {
865 Arc::new(Self::from_client_unshared(client, base_url, mode))
866 }
867
868 pub fn from_client_unshared(
872 client: Client,
873 base_url: impl Into<String>,
874 _mode: RecoveryMode,
875 ) -> Self {
876 Self {
877 managed_server: StdMutex::new(None),
878 client,
879 model_context_limits: HashMap::new(),
880 base_url: base_url.into().trim_end_matches('/').to_string(),
881 config: OrchestratorConfig::default(),
882 spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
883 }
884 }
885
886 pub fn from_managed_for_testing(
887 managed: ManagedServer,
888 client: Client,
889 base_url: impl Into<String>,
890 ) -> Self {
891 Self {
892 managed_server: StdMutex::new(Some(managed)),
893 client,
894 model_context_limits: HashMap::new(),
895 base_url: base_url.into().trim_end_matches('/').to_string(),
896 config: OrchestratorConfig::default(),
897 spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
898 }
899 }
900
901 pub async fn stop_managed_for_testing(&self) -> anyhow::Result<()> {
902 let managed = {
903 let mut guard = self.managed_server_lock();
904 guard.take()
905 };
906
907 match managed {
908 Some(managed) => managed.stop().await.map_err(Into::into),
909 None => anyhow::bail!("no managed server is attached to this snapshot"),
910 }
911 }
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917 use serial_test::serial;
918 use std::sync::Arc;
919 use std::sync::atomic::AtomicBool;
920 use std::sync::atomic::AtomicUsize;
921 use std::sync::atomic::Ordering;
922 use std::time::Duration;
923 use std::time::Instant;
924 use tokio::io::AsyncReadExt;
925 use tokio::io::AsyncWriteExt;
926 use tokio::net::TcpListener;
927 use tokio::process::Command;
928 use tokio::sync::Notify;
929 use wiremock::Mock;
930 use wiremock::MockServer;
931 use wiremock::ResponseTemplate;
932 use wiremock::matchers::method;
933 use wiremock::matchers::path;
934
935 struct ManagedEnvGuard {
936 previous: Option<std::ffi::OsString>,
937 }
938
939 impl ManagedEnvGuard {
940 fn new() -> Self {
941 Self {
942 previous: std::env::var_os(OPENCODE_ORCHESTRATOR_MANAGED_ENV),
943 }
944 }
945 }
946
947 impl Drop for ManagedEnvGuard {
948 fn drop(&mut self) {
949 match &self.previous {
950 Some(value) => unsafe {
952 std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, value);
953 },
954 None => unsafe {
956 std::env::remove_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV);
957 },
958 }
959 }
960 }
961
962 async fn health_mock_server() -> MockServer {
963 let mock = MockServer::start().await;
964 Mock::given(method("GET"))
965 .and(path("/global/health"))
966 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
967 "healthy": true,
968 "version": version::PINNED_OPENCODE_VERSION,
969 })))
970 .mount(&mock)
971 .await;
972 mock
973 }
974
975 fn test_client(base_url: &str) -> Client {
976 opencode_rs::ClientBuilder::new()
977 .base_url(base_url)
978 .timeout_secs(5)
979 .build()
980 .unwrap()
981 }
982
983 fn external_server(base_url: &str) -> OrchestratorServer {
984 OrchestratorServer::from_client_unshared(
985 test_client(base_url),
986 base_url,
987 RecoveryMode::External,
988 )
989 }
990
991 async fn exited_child() -> tokio::process::Child {
992 let mut child = Command::new("sh").arg("-c").arg("exit 0").spawn().unwrap();
993 let _status = child.wait().await.unwrap();
994 child
995 }
996
997 async fn managed_server_with_exited_child(base_url: &str) -> OrchestratorServer {
998 let managed = ManagedServer::from_child_for_testing(exited_child().await, base_url, 9);
999 OrchestratorServer::from_managed_for_testing(managed, test_client(base_url), base_url)
1000 }
1001
1002 struct BlockingHealthServer {
1003 base_url: String,
1004 started_requests: Arc<AtomicUsize>,
1005 started_notify: Arc<Notify>,
1006 released: Arc<AtomicBool>,
1007 release_notify: Arc<Notify>,
1008 task: tokio::task::JoinHandle<()>,
1009 }
1010
1011 impl BlockingHealthServer {
1012 async fn start(expected_requests: usize) -> Self {
1013 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1014 let addr = listener.local_addr().unwrap();
1015 let started_requests = Arc::new(AtomicUsize::new(0));
1016 let started_notify = Arc::new(Notify::new());
1017 let released = Arc::new(AtomicBool::new(false));
1018 let release_notify = Arc::new(Notify::new());
1019 let body = format!(
1020 r#"{{"healthy":true,"version":"{}"}}"#,
1021 version::PINNED_OPENCODE_VERSION
1022 );
1023 let response = Arc::new(format!(
1024 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
1025 body.len(),
1026 body
1027 ));
1028
1029 let task = tokio::spawn({
1030 let started_requests = Arc::clone(&started_requests);
1031 let started_notify = Arc::clone(&started_notify);
1032 let released = Arc::clone(&released);
1033 let release_notify = Arc::clone(&release_notify);
1034 let response = Arc::clone(&response);
1035
1036 async move {
1037 let mut connections = Vec::with_capacity(expected_requests);
1038
1039 for _ in 0..expected_requests {
1040 let (mut stream, _addr) = listener.accept().await.unwrap();
1041 let started_requests = Arc::clone(&started_requests);
1042 let started_notify = Arc::clone(&started_notify);
1043 let released = Arc::clone(&released);
1044 let release_notify = Arc::clone(&release_notify);
1045 let response = Arc::clone(&response);
1046
1047 connections.push(tokio::spawn(async move {
1048 let mut request = [0_u8; 1024];
1049 let _read = stream.read(&mut request).await.unwrap();
1050 started_requests.fetch_add(1, Ordering::SeqCst);
1051 started_notify.notify_waiters();
1052
1053 loop {
1054 let notified = release_notify.notified();
1055 if released.load(Ordering::SeqCst) {
1056 break;
1057 }
1058 notified.await;
1059 }
1060
1061 stream.write_all(response.as_bytes()).await.unwrap();
1062 stream.shutdown().await.unwrap();
1063 }));
1064 }
1065
1066 for connection in connections {
1067 connection.await.unwrap();
1068 }
1069 }
1070 });
1071
1072 Self {
1073 base_url: format!("http://{addr}"),
1074 started_requests,
1075 started_notify,
1076 released,
1077 release_notify,
1078 task,
1079 }
1080 }
1081
1082 async fn wait_for_requests(&self, expected_requests: usize) {
1083 tokio::time::timeout(Duration::from_secs(1), async {
1084 while self.started_requests.load(Ordering::SeqCst) < expected_requests {
1085 self.started_notify.notified().await;
1086 }
1087 })
1088 .await
1089 .unwrap();
1090 }
1091
1092 fn release(&self) {
1093 self.released.store(true, Ordering::SeqCst);
1094 self.release_notify.notify_waiters();
1095 }
1096 }
1097
1098 impl Drop for BlockingHealthServer {
1099 fn drop(&mut self) {
1100 self.release();
1101 self.task.abort();
1102 }
1103 }
1104
1105 #[tokio::test]
1106 async fn init_with_retry_succeeds_on_first_attempt() {
1107 let attempts = AtomicUsize::new(0);
1108
1109 let result: u32 = init_with_retry(|_| {
1110 let n = attempts.fetch_add(1, Ordering::SeqCst);
1111 async move {
1112 assert_eq!(n, 0, "should only be called once on success");
1114 Ok(42)
1115 }
1116 })
1117 .await
1118 .unwrap();
1119
1120 assert_eq!(result, 42);
1121 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1122 }
1123
1124 #[tokio::test]
1125 async fn init_with_retry_retries_once_and_succeeds() {
1126 let attempts = AtomicUsize::new(0);
1127
1128 let result: u32 = init_with_retry(|_| {
1129 let n = attempts.fetch_add(1, Ordering::SeqCst);
1130 async move {
1131 if n == 0 {
1132 anyhow::bail!("fail first");
1133 }
1134 Ok(42)
1135 }
1136 })
1137 .await
1138 .unwrap();
1139
1140 assert_eq!(result, 42);
1141 assert_eq!(attempts.load(Ordering::SeqCst), 2);
1142 }
1143
1144 #[tokio::test]
1145 async fn init_with_retry_fails_after_two_attempts() {
1146 let attempts = AtomicUsize::new(0);
1147
1148 let err = init_with_retry::<(), _, _>(|_| {
1149 attempts.fetch_add(1, Ordering::SeqCst);
1150 async { anyhow::bail!("always fail") }
1151 })
1152 .await
1153 .unwrap_err();
1154
1155 assert!(err.to_string().contains("always fail"));
1156 assert_eq!(attempts.load(Ordering::SeqCst), 2);
1157 }
1158
1159 #[tokio::test]
1160 async fn handle_serializes_initialization_and_reuses_snapshot() {
1161 let mock = health_mock_server().await;
1162 let base_url = mock.uri();
1163 let handle = Arc::new(OrchestratorServerHandle::new());
1164 let starts = Arc::new(AtomicUsize::new(0));
1165
1166 let first = {
1167 let handle = Arc::clone(&handle);
1168 let starts = Arc::clone(&starts);
1169 let base_url = base_url.clone();
1170 tokio::spawn(async move {
1171 handle
1172 .get_or_recover_with(|| {
1173 let starts = Arc::clone(&starts);
1174 let base_url = base_url.clone();
1175 async move {
1176 starts.fetch_add(1, Ordering::SeqCst);
1177 tokio::time::sleep(Duration::from_millis(50)).await;
1178 Ok(external_server(&base_url))
1179 }
1180 })
1181 .await
1182 })
1183 };
1184
1185 let second = {
1186 let handle = Arc::clone(&handle);
1187 let starts = Arc::clone(&starts);
1188 let base_url = base_url.clone();
1189 tokio::spawn(async move {
1190 handle
1191 .get_or_recover_with(|| {
1192 let starts = Arc::clone(&starts);
1193 let base_url = base_url.clone();
1194 async move {
1195 starts.fetch_add(1, Ordering::SeqCst);
1196 Ok(external_server(&base_url))
1197 }
1198 })
1199 .await
1200 })
1201 };
1202
1203 let first = first.await.unwrap().unwrap();
1204 let second = second.await.unwrap().unwrap();
1205
1206 assert_eq!(starts.load(Ordering::SeqCst), 1);
1207 assert!(Arc::ptr_eq(&first, &second));
1208 }
1209
1210 #[tokio::test]
1211 async fn validate_for_tool_entry_uses_health_for_external_server() {
1212 let mock = health_mock_server().await;
1213 let server = external_server(&mock.uri());
1214
1215 let state = server.validate_for_tool_entry().await.unwrap();
1216
1217 assert_eq!(state, ServerEntryState::Healthy);
1218 let requests = mock.received_requests().await.unwrap();
1219 assert!(
1220 requests
1221 .iter()
1222 .any(|request| request.url.path() == "/global/health"),
1223 "expected /global/health request"
1224 );
1225 }
1226
1227 #[tokio::test]
1228 async fn validate_for_tool_entry_times_out_health_probe() {
1229 let mock = MockServer::start().await;
1230 Mock::given(method("GET"))
1231 .and(path("/global/health"))
1232 .respond_with(
1233 ResponseTemplate::new(200)
1234 .set_delay(Duration::from_secs(30))
1235 .set_body_json(serde_json::json!({
1236 "healthy": true,
1237 "version": version::PINNED_OPENCODE_VERSION,
1238 })),
1239 )
1240 .mount(&mock)
1241 .await;
1242 let server = external_server(&mock.uri());
1243
1244 let state = server
1245 .validate_for_tool_entry_with_timeout(Duration::from_millis(25))
1246 .await
1247 .unwrap();
1248
1249 assert_eq!(
1250 state,
1251 ServerEntryState::NeedsRecovery {
1252 reason: "/global/health probe timed out after 25ms".to_string(),
1253 }
1254 );
1255 }
1256
1257 #[tokio::test]
1258 async fn validate_for_tool_entry_short_circuits_dead_managed_server() {
1259 let server = managed_server_with_exited_child("http://127.0.0.1:9").await;
1260
1261 let state = server.validate_for_tool_entry().await.unwrap();
1262
1263 assert_eq!(
1264 state,
1265 ServerEntryState::NeedsRecovery {
1266 reason: "managed child is no longer running".to_string(),
1267 }
1268 );
1269 }
1270
1271 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1272 async fn handle_allows_concurrent_healthy_acquires_without_serializing_validation() {
1273 let health = BlockingHealthServer::start(3).await;
1274 let handle = Arc::new(OrchestratorServerHandle::from_server_unshared(
1275 external_server(&health.base_url),
1276 ));
1277
1278 let started_at = Instant::now();
1279 let tasks = (0..3)
1280 .map(|_| {
1281 let handle = Arc::clone(&handle);
1282 tokio::spawn(async move { handle.acquire().await })
1283 })
1284 .collect::<Vec<_>>();
1285
1286 health.wait_for_requests(3).await;
1287 tokio::time::sleep(Duration::from_millis(75)).await;
1288 health.release();
1289
1290 let mut snapshots = Vec::with_capacity(tasks.len());
1291 for task in tasks {
1292 snapshots.push(task.await.unwrap().unwrap());
1293 }
1294
1295 assert!(
1296 started_at.elapsed() < Duration::from_millis(250),
1297 "healthy acquires should overlap rather than serialize"
1298 );
1299 assert!(Arc::ptr_eq(&snapshots[0], &snapshots[1]));
1300 assert!(Arc::ptr_eq(&snapshots[1], &snapshots[2]));
1301 }
1302
1303 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1304 async fn handle_single_flights_concurrent_stale_acquires() {
1305 let stale = Arc::new(managed_server_with_exited_child("http://127.0.0.1:9").await);
1306 let handle = Arc::new(OrchestratorServerHandle {
1307 state: AsyncMutex::new(HandleState::Ready {
1308 snapshot: Arc::clone(&stale),
1309 mode: RecoveryMode::Managed,
1310 }),
1311 });
1312 let mock = health_mock_server().await;
1313 let base_url = mock.uri();
1314 let starts = Arc::new(AtomicUsize::new(0));
1315
1316 let tasks = (0..3)
1317 .map(|_| {
1318 let handle = Arc::clone(&handle);
1319 let starts = Arc::clone(&starts);
1320 let base_url = base_url.clone();
1321 tokio::spawn(async move {
1322 handle
1323 .get_or_recover_with(|| {
1324 let starts = Arc::clone(&starts);
1325 let base_url = base_url.clone();
1326 async move {
1327 starts.fetch_add(1, Ordering::SeqCst);
1328 tokio::time::sleep(Duration::from_millis(50)).await;
1329 Ok(external_server(&base_url))
1330 }
1331 })
1332 .await
1333 })
1334 })
1335 .collect::<Vec<_>>();
1336
1337 let mut snapshots = Vec::with_capacity(tasks.len());
1338 for task in tasks {
1339 snapshots.push(task.await.unwrap().unwrap());
1340 }
1341
1342 assert_eq!(starts.load(Ordering::SeqCst), 1);
1343 assert!(!Arc::ptr_eq(&stale, &snapshots[0]));
1344 assert!(Arc::ptr_eq(&snapshots[0], &snapshots[1]));
1345 assert!(Arc::ptr_eq(&snapshots[1], &snapshots[2]));
1346 }
1347
1348 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1349 async fn handle_retries_if_cache_changes_while_validating() {
1350 let old_health = BlockingHealthServer::start(1).await;
1351 let original = Arc::new(external_server(&old_health.base_url));
1352 let handle = Arc::new(OrchestratorServerHandle {
1353 state: AsyncMutex::new(HandleState::Ready {
1354 snapshot: Arc::clone(&original),
1355 mode: RecoveryMode::External,
1356 }),
1357 });
1358 let replacement_mock = health_mock_server().await;
1359 let replacement = Arc::new(external_server(&replacement_mock.uri()));
1360
1361 let acquire = {
1362 let handle = Arc::clone(&handle);
1363 tokio::spawn(async move {
1364 handle
1365 .acquire_or_recover_with(|| async { anyhow::bail!("should not rebuild") })
1366 .await
1367 })
1368 };
1369
1370 old_health.wait_for_requests(1).await;
1371
1372 {
1373 let mut state = tokio::time::timeout(Duration::from_millis(100), handle.state.lock())
1374 .await
1375 .expect("validation should not hold the handle mutex");
1376 *state = HandleState::Ready {
1377 snapshot: Arc::clone(&replacement),
1378 mode: RecoveryMode::External,
1379 };
1380 }
1381
1382 old_health.release();
1383
1384 let snapshot = acquire.await.unwrap().unwrap();
1385
1386 assert!(!Arc::ptr_eq(&snapshot, &original));
1387 assert!(Arc::ptr_eq(&snapshot, &replacement));
1388 }
1389
1390 #[tokio::test]
1391 async fn handle_rebuilds_without_invalidating_held_snapshot() {
1392 let stale = Arc::new(managed_server_with_exited_child("http://127.0.0.1:9").await);
1393 let handle = OrchestratorServerHandle {
1394 state: AsyncMutex::new(HandleState::Ready {
1395 snapshot: Arc::clone(&stale),
1396 mode: RecoveryMode::Managed,
1397 }),
1398 };
1399 let mock = health_mock_server().await;
1400 let base_url = mock.uri();
1401 let starts = Arc::new(AtomicUsize::new(0));
1402
1403 let rebuilt = handle
1404 .get_or_recover_with(|| {
1405 let starts = Arc::clone(&starts);
1406 let base_url = base_url.clone();
1407 async move {
1408 starts.fetch_add(1, Ordering::SeqCst);
1409 Ok(external_server(&base_url))
1410 }
1411 })
1412 .await
1413 .unwrap();
1414
1415 assert_eq!(starts.load(Ordering::SeqCst), 1);
1416 assert!(!Arc::ptr_eq(&stale, &rebuilt));
1417 assert_eq!(stale.base_url(), "http://127.0.0.1:9");
1418 assert_eq!(rebuilt.base_url(), base_url.trim_end_matches('/'));
1419 }
1420
1421 #[test]
1422 #[serial(env)]
1423 fn managed_guard_disabled_when_env_not_set() {
1424 let _env = ManagedEnvGuard::new();
1425 unsafe { std::env::remove_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV) };
1428 assert!(!managed_guard_enabled());
1429 }
1430
1431 #[test]
1432 #[serial(env)]
1433 fn managed_guard_enabled_when_env_is_1() {
1434 let _env = ManagedEnvGuard::new();
1435 unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "1") };
1437 assert!(managed_guard_enabled());
1438 }
1439
1440 #[test]
1441 #[serial(env)]
1442 fn managed_guard_disabled_when_env_is_0() {
1443 let _env = ManagedEnvGuard::new();
1444 unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "0") };
1446 assert!(!managed_guard_enabled());
1447 }
1448
1449 #[test]
1450 #[serial(env)]
1451 fn managed_guard_disabled_when_env_is_empty() {
1452 let _env = ManagedEnvGuard::new();
1453 unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "") };
1455 assert!(!managed_guard_enabled());
1456 }
1457
1458 #[test]
1459 #[serial(env)]
1460 fn managed_guard_disabled_when_env_is_whitespace() {
1461 let _env = ManagedEnvGuard::new();
1462 unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, " ") };
1464 assert!(!managed_guard_enabled());
1465 }
1466
1467 #[test]
1468 #[serial(env)]
1469 fn managed_guard_enabled_when_env_is_truthy() {
1470 let _env = ManagedEnvGuard::new();
1471 unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "true") };
1473 assert!(managed_guard_enabled());
1474 }
1475
1476 #[tokio::test]
1477 #[serial(env)]
1478 async fn recursion_guard_only_blocks_real_startup_paths() {
1479 let _env = ManagedEnvGuard::new();
1480 unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "1") };
1482
1483 let mock = health_mock_server().await;
1484 let handle = OrchestratorServerHandle::from_server_unshared(external_server(&mock.uri()));
1485 let reused = handle
1486 .get_or_recover_with(|| async { anyhow::bail!("should not start") })
1487 .await
1488 .unwrap();
1489 assert_eq!(reused.base_url(), mock.uri().trim_end_matches('/'));
1490
1491 let fresh_handle = OrchestratorServerHandle::new();
1492 let err = match fresh_handle.acquire().await {
1493 Ok(_server) => panic!("expected recursion guard to block fresh startup"),
1494 Err(error) => error,
1495 };
1496 assert!(err.to_string().contains(ORCHESTRATOR_MANAGED_GUARD_MESSAGE));
1497 }
1498
1499 #[tokio::test]
1500 async fn external_failure_becomes_sticky_and_typed() {
1501 let handle = OrchestratorServerHandle::from_server_unshared(
1502 OrchestratorServer::from_client_unshared(
1503 test_client("http://127.0.0.1:9"),
1504 "http://127.0.0.1:9",
1505 RecoveryMode::External,
1506 ),
1507 );
1508 let starts = AtomicUsize::new(0);
1509
1510 let first = handle
1511 .acquire_or_recover_with(|| {
1512 starts.fetch_add(1, Ordering::SeqCst);
1513 async { anyhow::bail!("should not rebuild external servers") }
1514 })
1515 .await;
1516 let second = handle
1517 .acquire_or_recover_with(|| {
1518 starts.fetch_add(1, Ordering::SeqCst);
1519 async { anyhow::bail!("should not rebuild external servers") }
1520 })
1521 .await;
1522
1523 let first = match first {
1524 Ok(_snapshot) => panic!("expected typed external failure on first acquire"),
1525 Err(error) => error,
1526 };
1527 let second = match second {
1528 Ok(_snapshot) => panic!("expected sticky external failure on second acquire"),
1529 Err(error) => error,
1530 };
1531
1532 assert_eq!(starts.load(Ordering::SeqCst), 0);
1533 assert!(
1534 first
1535 .to_string()
1536 .contains("External OpenCode server unavailable"),
1537 "expected typed external failure, got: {first}"
1538 );
1539 assert_eq!(first.to_string(), second.to_string());
1540 }
1541}