use agentic_config::types::OrchestratorConfig;
use anyhow::Context;
use opencode_rs::Client;
use opencode_rs::server::ManagedServer;
use opencode_rs::server::ServerOptions;
use opencode_rs::types::message::Message;
use opencode_rs::types::message::Part;
use opencode_rs::types::provider::ProviderListResponse;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::time::Duration;
use tokio::sync::Mutex as AsyncMutex;
use tokio::sync::RwLock;
use crate::error::OrchestratorError;
use crate::version;
pub const OPENCODE_ORCHESTRATOR_MANAGED_ENV: &str = "OPENCODE_ORCHESTRATOR_MANAGED";
pub const ORCHESTRATOR_MANAGED_GUARD_MESSAGE: &str = "ENV VAR OPENCODE_ORCHESTRATOR_MANAGED is set to 1. This most commonly happens when you're \
in a nested orchestration session. Consult a human for assistance or try to accomplish your \
task without the orchestration tools.";
pub fn managed_guard_enabled() -> bool {
match std::env::var(OPENCODE_ORCHESTRATOR_MANAGED_ENV) {
Ok(v) => v != "0" && !v.trim().is_empty(),
Err(_) => false,
}
}
pub async fn init_with_retry<T, F, Fut>(mut f: F) -> anyhow::Result<T>
where
F: FnMut(usize) -> Fut,
Fut: std::future::Future<Output = anyhow::Result<T>>,
{
let mut last_err: Option<anyhow::Error> = None;
for attempt in 1..=2 {
tracing::info!(attempt, "orchestrator server lazy init attempt");
match f(attempt).await {
Ok(v) => {
if attempt > 1 {
tracing::info!(
attempt,
"orchestrator server lazy init succeeded after retry"
);
}
return Ok(v);
}
Err(e) => {
tracing::warn!(attempt, error = %e, "orchestrator server lazy init failed");
last_err = Some(e);
}
}
}
tracing::error!("orchestrator server lazy init exhausted retries");
match last_err {
Some(e) => Err(e),
None => anyhow::bail!("init_with_retry: unexpected empty error state"),
}
}
pub type ModelKey = (String, String);
#[derive(Debug, Clone, PartialEq, Eq)]
enum ServerEntryState {
Healthy,
NeedsRecovery { reason: String },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecoveryMode {
Managed,
External,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommandPolicyDecision {
Allowed,
DeniedByAllowlist,
DeniedByDenylist,
}
impl CommandPolicyDecision {
#[must_use]
pub fn is_allowed(self) -> bool {
matches!(self, Self::Allowed)
}
}
impl RecoveryMode {
fn as_str(self) -> &'static str {
match self {
Self::Managed => "managed",
Self::External => "external",
}
}
}
enum HandleState {
Empty,
Ready {
snapshot: Arc<OrchestratorServer>,
mode: RecoveryMode,
},
Stale {
snapshot: Arc<OrchestratorServer>,
mode: RecoveryMode,
reason: String,
},
Failed {
mode: RecoveryMode,
base_url: Option<String>,
error: String,
},
}
const TOOL_ENTRY_HEALTH_PROBE_TIMEOUT: Duration = Duration::from_secs(5);
pub struct OrchestratorServerHandle {
state: AsyncMutex<HandleState>,
}
impl Default for OrchestratorServerHandle {
fn default() -> Self {
Self::new()
}
}
impl OrchestratorServerHandle {
#[must_use]
pub fn new() -> Self {
Self {
state: AsyncMutex::new(HandleState::Empty),
}
}
pub async fn acquire(&self) -> anyhow::Result<Arc<OrchestratorServer>> {
self.get_or_recover_with(OrchestratorServer::start_lazy)
.await
}
async fn get_or_recover_with<F, Fut>(
&self,
mut start: F,
) -> anyhow::Result<Arc<OrchestratorServer>>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = anyhow::Result<OrchestratorServer>>,
{
loop {
let ready_snapshot = {
let mut state = self.state.lock().await;
match &mut *state {
HandleState::Empty => {
tracing::info!(
"orchestrator server missing cached snapshot; starting embedded server"
);
match start().await {
Ok(server) => {
let rebuilt_mode = if server.is_managed() {
RecoveryMode::Managed
} else {
RecoveryMode::External
};
let rebuilt = Arc::new(server);
trace_state_transition(
"Empty",
"Ready",
"initialization",
rebuilt_mode,
Some(rebuilt.base_url()),
);
*state = HandleState::Ready {
snapshot: Arc::clone(&rebuilt),
mode: rebuilt_mode,
};
return Ok(rebuilt);
}
Err(error) => {
let reason = error.to_string();
trace_state_transition(
"Empty",
"Failed",
&reason,
RecoveryMode::Managed,
None,
);
*state = HandleState::Failed {
mode: RecoveryMode::Managed,
base_url: None,
error: reason,
};
return Err(error);
}
}
}
HandleState::Ready { snapshot, mode } => Some((Arc::clone(snapshot), *mode)),
HandleState::Stale {
snapshot,
mode,
reason,
} => match mode {
RecoveryMode::Managed => {
let stale_reason = reason.clone();
match start().await {
Ok(server) => {
let rebuilt_mode = if server.is_managed() {
RecoveryMode::Managed
} else {
RecoveryMode::External
};
let rebuilt = Arc::new(server);
trace_state_transition(
"Stale",
"Ready",
&stale_reason,
rebuilt_mode,
Some(rebuilt.base_url()),
);
*state = HandleState::Ready {
snapshot: Arc::clone(&rebuilt),
mode: rebuilt_mode,
};
return Ok(rebuilt);
}
Err(error) => {
let failure = error.to_string();
trace_state_transition(
"Stale",
"Failed",
&failure,
*mode,
Some(snapshot.base_url()),
);
*state = HandleState::Failed {
mode: *mode,
base_url: Some(snapshot.base_url().to_string()),
error: failure,
};
return Err(error);
}
}
}
RecoveryMode::External => {
let base_url = snapshot.base_url().to_string();
let stale_reason = reason.clone();
trace_state_transition(
"Stale",
"Failed",
&stale_reason,
*mode,
Some(&base_url),
);
*state = HandleState::Failed {
mode: *mode,
base_url: Some(base_url.clone()),
error: stale_reason.clone(),
};
return Err(external_unavailable(Some(base_url), stale_reason));
}
},
HandleState::Failed {
mode,
base_url,
error,
} => match mode {
RecoveryMode::Managed => match start().await {
Ok(server) => {
let rebuilt_mode = if server.is_managed() {
RecoveryMode::Managed
} else {
RecoveryMode::External
};
let rebuilt = Arc::new(server);
trace_state_transition(
"Failed",
"Ready",
error,
rebuilt_mode,
Some(rebuilt.base_url()),
);
*state = HandleState::Ready {
snapshot: Arc::clone(&rebuilt),
mode: rebuilt_mode,
};
return Ok(rebuilt);
}
Err(start_error) => {
let failure = start_error.to_string();
error.clone_from(&failure);
return Err(start_error);
}
},
RecoveryMode::External => {
return Err(external_unavailable(base_url.clone(), error.clone()));
}
},
}
};
let Some((snapshot, mode)) = ready_snapshot else {
continue;
};
let validation = snapshot.validate_for_tool_entry().await?;
let mut state = self.state.lock().await;
let HandleState::Ready {
snapshot: current,
mode: current_mode,
} = &*state
else {
continue;
};
if !Arc::ptr_eq(current, &snapshot) || *current_mode != mode {
continue;
}
match validation {
ServerEntryState::Healthy => return Ok(snapshot),
ServerEntryState::NeedsRecovery { reason } => {
trace_cache_invalidated(&reason, mode, Some(snapshot.base_url()));
match mode {
RecoveryMode::Managed => {
tracing::warn!(reason = %reason, "cached orchestrator server failed liveness check; rebuilding");
trace_state_transition(
"Ready",
"Stale",
&reason,
mode,
Some(snapshot.base_url()),
);
*state = HandleState::Stale {
snapshot: Arc::clone(&snapshot),
mode,
reason: reason.clone(),
};
match start().await {
Ok(server) => {
let rebuilt_mode = if server.is_managed() {
RecoveryMode::Managed
} else {
RecoveryMode::External
};
let rebuilt = Arc::new(server);
trace_state_transition(
"Stale",
"Ready",
&reason,
rebuilt_mode,
Some(rebuilt.base_url()),
);
*state = HandleState::Ready {
snapshot: Arc::clone(&rebuilt),
mode: rebuilt_mode,
};
return Ok(rebuilt);
}
Err(error) => {
let failure = error.to_string();
trace_state_transition(
"Stale",
"Failed",
&failure,
mode,
Some(snapshot.base_url()),
);
*state = HandleState::Failed {
mode,
base_url: Some(snapshot.base_url().to_string()),
error: failure,
};
return Err(error);
}
}
}
RecoveryMode::External => {
let base_url = snapshot.base_url().to_string();
trace_state_transition(
"Ready",
"Failed",
&reason,
mode,
Some(&base_url),
);
*state = HandleState::Failed {
mode,
base_url: Some(base_url.clone()),
error: reason.clone(),
};
return Err(external_unavailable(Some(base_url), reason));
}
}
}
}
}
}
#[cfg(any(test, feature = "test-support"))]
#[must_use]
pub fn from_server_unshared(server: OrchestratorServer) -> Self {
let mode = if server.is_managed() {
RecoveryMode::Managed
} else {
RecoveryMode::External
};
Self {
state: AsyncMutex::new(HandleState::Ready {
snapshot: Arc::new(server),
mode,
}),
}
}
#[cfg(any(test, feature = "test-support"))]
pub async fn acquire_or_recover_with<F, Fut>(
&self,
start: F,
) -> anyhow::Result<Arc<OrchestratorServer>>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = anyhow::Result<OrchestratorServer>>,
{
self.get_or_recover_with(start).await
}
}
fn trace_cache_invalidated(reason: &str, mode: RecoveryMode, base_url: Option<&str>) {
if let Some(base_url) = base_url {
tracing::info!(
event = "cache_invalidated",
reason = %reason,
mode = mode.as_str(),
base_url = %base_url,
);
} else {
tracing::info!(
event = "cache_invalidated",
reason = %reason,
mode = mode.as_str(),
);
}
}
fn trace_state_transition(
from: &'static str,
to: &'static str,
reason: &str,
mode: RecoveryMode,
base_url: Option<&str>,
) {
if let Some(base_url) = base_url {
tracing::info!(
event = "state_transition",
from,
to,
reason = %reason,
mode = mode.as_str(),
base_url = %base_url,
);
} else {
tracing::info!(
event = "state_transition",
from,
to,
reason = %reason,
mode = mode.as_str(),
);
}
}
fn external_unavailable(base_url: Option<String>, reason: String) -> anyhow::Error {
OrchestratorError::ExternalServerUnavailable {
base_url: base_url.unwrap_or_else(|| "<unknown>".to_string()),
reason,
}
.into()
}
pub struct OrchestratorServer {
managed_server: StdMutex<Option<ManagedServer>>,
client: Client,
model_context_limits: HashMap<ModelKey, u64>,
base_url: String,
config: OrchestratorConfig,
spawned_sessions: Arc<RwLock<HashSet<String>>>,
}
impl OrchestratorServer {
pub fn command_policy_decision(&self, command: &str) -> CommandPolicyDecision {
let deny_matches = self
.config
.commands
.deny
.iter()
.map(String::as_str)
.map(str::trim)
.filter(|entry| !entry.is_empty())
.any(|entry| entry == command);
if deny_matches {
return CommandPolicyDecision::DeniedByDenylist;
}
let mut allow_entries = self
.config
.commands
.allow
.iter()
.map(String::as_str)
.map(str::trim)
.filter(|entry| !entry.is_empty())
.peekable();
if allow_entries.peek().is_some() && !allow_entries.any(|entry| entry == command) {
return CommandPolicyDecision::DeniedByAllowlist;
}
CommandPolicyDecision::Allowed
}
pub fn is_command_allowed(&self, command: &str) -> bool {
self.command_policy_decision(command).is_allowed()
}
#[allow(clippy::allow_attributes, dead_code)]
pub async fn start() -> anyhow::Result<Arc<Self>> {
Ok(Arc::new(Self::start_impl().await?))
}
pub async fn start_lazy() -> anyhow::Result<Self> {
Self::start_lazy_with_config(None).await
}
pub async fn start_lazy_with_config(config_json: Option<String>) -> anyhow::Result<Self> {
if managed_guard_enabled() {
anyhow::bail!(ORCHESTRATOR_MANAGED_GUARD_MESSAGE);
}
init_with_retry(|_attempt| {
let cfg = config_json.clone();
async move { Self::start_impl_with_config(cfg).await }
})
.await
}
async fn start_impl() -> anyhow::Result<Self> {
let cwd = std::env::current_dir().context("Failed to resolve current directory")?;
let config = match agentic_config::loader::load_merged(&cwd) {
Ok(loaded) => {
for w in &loaded.warnings {
tracing::warn!("{w}");
}
loaded.config.orchestrator
}
Err(e) => {
tracing::warn!("Failed to load config, using defaults: {e}");
OrchestratorConfig::default()
}
};
let launcher_config = version::resolve_launcher_config(&cwd)
.context("Failed to resolve OpenCode launcher configuration")?;
tracing::info!(
binary = %launcher_config.binary,
launcher_args = ?launcher_config.launcher_args,
expected_version = %version::PINNED_OPENCODE_VERSION,
"starting embedded opencode serve (pinned stable)"
);
let opts = ServerOptions::default()
.binary(&launcher_config.binary)
.launcher_args(launcher_config.launcher_args)
.directory(cwd.clone());
let managed = ManagedServer::start(opts)
.await
.context("Failed to start embedded `opencode serve`")?;
let base_url = managed.url().to_string().trim_end_matches('/').to_string();
let client = Client::builder()
.base_url(&base_url)
.directory(cwd.to_string_lossy().to_string())
.build()
.context("Failed to build opencode-rs HTTP client")?;
let health = client
.misc()
.health()
.await
.context("Failed to fetch /global/health for version validation")?;
version::validate_exact_version(health.version.as_deref()).with_context(|| {
format!(
"Embedded OpenCode server did not match pinned stable v{} (binary={})",
version::PINNED_OPENCODE_VERSION,
launcher_config.binary
)
})?;
let model_context_limits = Self::load_model_limits(&client).await.unwrap_or_else(|e| {
tracing::warn!("Failed to load model limits: {}", e);
HashMap::new()
});
tracing::info!("Loaded {} model context limits", model_context_limits.len());
Ok(Self {
managed_server: StdMutex::new(Some(managed)),
client,
model_context_limits,
base_url,
config,
spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
})
}
async fn start_impl_with_config(config_json: Option<String>) -> anyhow::Result<Self> {
let cwd = std::env::current_dir().context("Failed to resolve current directory")?;
let config = match agentic_config::loader::load_merged(&cwd) {
Ok(loaded) => {
for w in &loaded.warnings {
tracing::warn!("{w}");
}
loaded.config.orchestrator
}
Err(e) => {
tracing::warn!("Failed to load config, using defaults: {e}");
OrchestratorConfig::default()
}
};
let launcher_config = version::resolve_launcher_config(&cwd)
.context("Failed to resolve OpenCode launcher configuration")?;
tracing::info!(
binary = %launcher_config.binary,
launcher_args = ?launcher_config.launcher_args,
expected_version = %version::PINNED_OPENCODE_VERSION,
config_injected = config_json.is_some(),
"starting embedded opencode serve (pinned stable)"
);
let mut opts = ServerOptions::default()
.binary(&launcher_config.binary)
.launcher_args(launcher_config.launcher_args)
.directory(cwd.clone());
if let Some(cfg) = config_json {
opts = opts.config_json(cfg);
}
let managed = ManagedServer::start(opts)
.await
.context("Failed to start embedded `opencode serve`")?;
let base_url = managed.url().to_string().trim_end_matches('/').to_string();
let client = Client::builder()
.base_url(&base_url)
.directory(cwd.to_string_lossy().to_string())
.build()
.context("Failed to build opencode-rs HTTP client")?;
let health = client
.misc()
.health()
.await
.context("Failed to fetch /global/health for version validation")?;
version::validate_exact_version(health.version.as_deref()).with_context(|| {
format!(
"Embedded OpenCode server did not match pinned stable v{} (binary={})",
version::PINNED_OPENCODE_VERSION,
launcher_config.binary
)
})?;
let model_context_limits = Self::load_model_limits(&client).await.unwrap_or_else(|e| {
tracing::warn!("Failed to load model limits: {}", e);
HashMap::new()
});
tracing::info!("Loaded {} model context limits", model_context_limits.len());
Ok(Self {
managed_server: StdMutex::new(Some(managed)),
client,
model_context_limits,
base_url,
config,
spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
})
}
pub fn client(&self) -> &Client {
&self.client
}
#[allow(clippy::allow_attributes, dead_code)]
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn context_limit(&self, provider_id: &str, model_id: &str) -> Option<u64> {
self.model_context_limits
.get(&(provider_id.to_string(), model_id.to_string()))
.copied()
}
pub fn session_deadline(&self) -> Duration {
Duration::from_secs(self.config.session_deadline_secs)
}
pub fn inactivity_timeout(&self) -> Duration {
Duration::from_secs(self.config.inactivity_timeout_secs)
}
pub fn compaction_threshold(&self) -> f64 {
self.config.compaction_threshold
}
pub fn spawned_sessions(&self) -> &Arc<RwLock<HashSet<String>>> {
&self.spawned_sessions
}
fn managed_server_lock(&self) -> std::sync::MutexGuard<'_, Option<ManagedServer>> {
self.managed_server
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn is_managed(&self) -> bool {
self.managed_server_lock().is_some()
}
async fn validate_for_tool_entry(&self) -> anyhow::Result<ServerEntryState> {
self.validate_for_tool_entry_with_timeout(TOOL_ENTRY_HEALTH_PROBE_TIMEOUT)
.await
}
async fn validate_for_tool_entry_with_timeout(
&self,
health_probe_timeout: Duration,
) -> anyhow::Result<ServerEntryState> {
if self.is_managed() {
let is_running = {
let mut managed = self.managed_server_lock();
managed
.as_mut()
.is_some_and(opencode_rs::server::ManagedServer::is_running)
};
if !is_running {
return Ok(ServerEntryState::NeedsRecovery {
reason: "managed child is no longer running".to_string(),
});
}
}
match tokio::time::timeout(health_probe_timeout, self.client.misc().health()).await {
Ok(Ok(health)) if health.healthy => Ok(ServerEntryState::Healthy),
Ok(Ok(_health)) => Ok(ServerEntryState::NeedsRecovery {
reason: "/global/health reported unhealthy".to_string(),
}),
Ok(Err(error)) => Ok(ServerEntryState::NeedsRecovery {
reason: format!("/global/health probe failed: {error}"),
}),
Err(_elapsed) => Ok(ServerEntryState::NeedsRecovery {
reason: format!("/global/health probe timed out after {health_probe_timeout:?}"),
}),
}
}
async fn load_model_limits(client: &Client) -> anyhow::Result<HashMap<ModelKey, u64>> {
let resp: ProviderListResponse = client.providers().list().await?;
let mut limits = HashMap::new();
for provider in resp.all {
for (model_id, model) in provider.models {
if let Some(limit) = model.limit.as_ref().and_then(|l| l.context) {
limits.insert((provider.id.clone(), model_id), limit);
}
}
}
Ok(limits)
}
pub fn extract_assistant_text(messages: &[Message]) -> Option<String> {
let assistant_msg = messages.iter().rev().find(|m| m.info.role == "assistant")?;
let text: String = assistant_msg
.parts
.iter()
.filter_map(|p| {
if let Part::Text { text, .. } = p {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("");
if text.trim().is_empty() {
None
} else {
Some(text)
}
}
}
#[cfg(any(test, feature = "test-support"))]
#[allow(dead_code, clippy::allow_attributes)]
impl OrchestratorServer {
pub fn from_client(
client: Client,
base_url: impl Into<String>,
mode: RecoveryMode,
) -> Arc<Self> {
Arc::new(Self::from_client_unshared(client, base_url, mode))
}
pub fn from_client_with_config(
client: Client,
base_url: impl Into<String>,
mode: RecoveryMode,
config: OrchestratorConfig,
) -> Arc<Self> {
Arc::new(Self::from_client_unshared_with_config(
client, base_url, mode, config,
))
}
pub fn from_client_unshared(
client: Client,
base_url: impl Into<String>,
mode: RecoveryMode,
) -> Self {
Self::from_client_unshared_with_config(
client,
base_url,
mode,
OrchestratorConfig::default(),
)
}
pub fn from_client_unshared_with_config(
client: Client,
base_url: impl Into<String>,
_mode: RecoveryMode,
config: OrchestratorConfig,
) -> Self {
Self {
managed_server: StdMutex::new(None),
client,
model_context_limits: HashMap::new(),
base_url: base_url.into().trim_end_matches('/').to_string(),
config,
spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
}
}
pub fn from_managed_for_testing(
managed: ManagedServer,
client: Client,
base_url: impl Into<String>,
) -> Self {
Self::from_managed_for_testing_with_config(
managed,
client,
base_url,
OrchestratorConfig::default(),
)
}
pub fn from_managed_for_testing_with_config(
managed: ManagedServer,
client: Client,
base_url: impl Into<String>,
config: OrchestratorConfig,
) -> Self {
Self {
managed_server: StdMutex::new(Some(managed)),
client,
model_context_limits: HashMap::new(),
base_url: base_url.into().trim_end_matches('/').to_string(),
config,
spawned_sessions: Arc::new(RwLock::new(HashSet::new())),
}
}
pub async fn stop_managed_for_testing(&self) -> anyhow::Result<()> {
let managed = {
let mut guard = self.managed_server_lock();
guard.take()
};
match managed {
Some(managed) => managed.stop().await.map_err(Into::into),
None => anyhow::bail!("no managed server is attached to this snapshot"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use agentic_config::types::OrchestratorCommandsConfig;
use serial_test::serial;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::process::Command;
use tokio::sync::Notify;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
struct ManagedEnvGuard {
previous: Option<std::ffi::OsString>,
}
impl ManagedEnvGuard {
fn new() -> Self {
Self {
previous: std::env::var_os(OPENCODE_ORCHESTRATOR_MANAGED_ENV),
}
}
}
impl Drop for ManagedEnvGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => unsafe {
std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, value);
},
None => unsafe {
std::env::remove_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV);
},
}
}
}
async fn health_mock_server() -> MockServer {
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/global/health"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"healthy": true,
"version": version::PINNED_OPENCODE_VERSION,
})))
.mount(&mock)
.await;
mock
}
fn test_client(base_url: &str) -> Client {
opencode_rs::ClientBuilder::new()
.base_url(base_url)
.timeout_secs(5)
.build()
.unwrap()
}
fn external_server(base_url: &str) -> OrchestratorServer {
OrchestratorServer::from_client_unshared(
test_client(base_url),
base_url,
RecoveryMode::External,
)
}
fn external_server_with_config(
base_url: &str,
config: OrchestratorConfig,
) -> OrchestratorServer {
OrchestratorServer::from_client_unshared_with_config(
test_client(base_url),
base_url,
RecoveryMode::External,
config,
)
}
async fn exited_child() -> tokio::process::Child {
let mut child = Command::new("sh").arg("-c").arg("exit 0").spawn().unwrap();
let _status = child.wait().await.unwrap();
child
}
async fn managed_server_with_exited_child(base_url: &str) -> OrchestratorServer {
let managed = ManagedServer::from_child_for_testing(exited_child().await, base_url, 9);
OrchestratorServer::from_managed_for_testing(managed, test_client(base_url), base_url)
}
#[test]
fn command_policy_allows_all_when_allowlist_is_empty() {
let server = external_server_with_config(
"http://127.0.0.1:9",
OrchestratorConfig {
commands: OrchestratorCommandsConfig {
allow: vec![],
deny: vec!["blocked".into()],
},
..OrchestratorConfig::default()
},
);
assert_eq!(
server.command_policy_decision("plan"),
CommandPolicyDecision::Allowed
);
assert_eq!(
server.command_policy_decision("blocked"),
CommandPolicyDecision::DeniedByDenylist
);
}
#[test]
fn command_policy_trims_entries_and_deny_wins() {
let server = external_server_with_config(
"http://127.0.0.1:9",
OrchestratorConfig {
commands: OrchestratorCommandsConfig {
allow: vec![" plan ".into()],
deny: vec!["plan".into()],
},
..OrchestratorConfig::default()
},
);
assert_eq!(
server.command_policy_decision("plan"),
CommandPolicyDecision::DeniedByDenylist
);
}
#[test]
fn command_policy_matching_is_case_sensitive() {
let server = external_server_with_config(
"http://127.0.0.1:9",
OrchestratorConfig {
commands: OrchestratorCommandsConfig {
allow: vec!["Plan".into()],
deny: vec!["blocked".into()],
},
..OrchestratorConfig::default()
},
);
assert_eq!(
server.command_policy_decision("Plan"),
CommandPolicyDecision::Allowed
);
assert_eq!(
server.command_policy_decision("plan"),
CommandPolicyDecision::DeniedByAllowlist
);
assert!(server.is_command_allowed("Plan"));
assert!(!server.is_command_allowed("plan"));
}
struct BlockingHealthServer {
base_url: String,
started_requests: Arc<AtomicUsize>,
started_notify: Arc<Notify>,
released: Arc<AtomicBool>,
release_notify: Arc<Notify>,
task: tokio::task::JoinHandle<()>,
}
impl BlockingHealthServer {
async fn start(expected_requests: usize) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let started_requests = Arc::new(AtomicUsize::new(0));
let started_notify = Arc::new(Notify::new());
let released = Arc::new(AtomicBool::new(false));
let release_notify = Arc::new(Notify::new());
let body = format!(
r#"{{"healthy":true,"version":"{}"}}"#,
version::PINNED_OPENCODE_VERSION
);
let response = Arc::new(format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
body.len(),
body
));
let task = tokio::spawn({
let started_requests = Arc::clone(&started_requests);
let started_notify = Arc::clone(&started_notify);
let released = Arc::clone(&released);
let release_notify = Arc::clone(&release_notify);
let response = Arc::clone(&response);
async move {
let mut connections = Vec::with_capacity(expected_requests);
for _ in 0..expected_requests {
let (mut stream, _addr) = listener.accept().await.unwrap();
let started_requests = Arc::clone(&started_requests);
let started_notify = Arc::clone(&started_notify);
let released = Arc::clone(&released);
let release_notify = Arc::clone(&release_notify);
let response = Arc::clone(&response);
connections.push(tokio::spawn(async move {
let mut request = [0_u8; 1024];
let _read = stream.read(&mut request).await.unwrap();
started_requests.fetch_add(1, Ordering::SeqCst);
started_notify.notify_waiters();
loop {
let notified = release_notify.notified();
if released.load(Ordering::SeqCst) {
break;
}
notified.await;
}
stream.write_all(response.as_bytes()).await.unwrap();
stream.shutdown().await.unwrap();
}));
}
for connection in connections {
connection.await.unwrap();
}
}
});
Self {
base_url: format!("http://{addr}"),
started_requests,
started_notify,
released,
release_notify,
task,
}
}
async fn wait_for_requests(&self, expected_requests: usize) {
tokio::time::timeout(Duration::from_secs(1), async {
while self.started_requests.load(Ordering::SeqCst) < expected_requests {
self.started_notify.notified().await;
}
})
.await
.unwrap();
}
fn release(&self) {
self.released.store(true, Ordering::SeqCst);
self.release_notify.notify_waiters();
}
}
impl Drop for BlockingHealthServer {
fn drop(&mut self) {
self.release();
self.task.abort();
}
}
#[tokio::test]
async fn init_with_retry_succeeds_on_first_attempt() {
let attempts = AtomicUsize::new(0);
let result: u32 = init_with_retry(|_| {
let n = attempts.fetch_add(1, Ordering::SeqCst);
async move {
assert_eq!(n, 0, "should only be called once on success");
Ok(42)
}
})
.await
.unwrap();
assert_eq!(result, 42);
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn init_with_retry_retries_once_and_succeeds() {
let attempts = AtomicUsize::new(0);
let result: u32 = init_with_retry(|_| {
let n = attempts.fetch_add(1, Ordering::SeqCst);
async move {
if n == 0 {
anyhow::bail!("fail first");
}
Ok(42)
}
})
.await
.unwrap();
assert_eq!(result, 42);
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn init_with_retry_fails_after_two_attempts() {
let attempts = AtomicUsize::new(0);
let err = init_with_retry::<(), _, _>(|_| {
attempts.fetch_add(1, Ordering::SeqCst);
async { anyhow::bail!("always fail") }
})
.await
.unwrap_err();
assert!(err.to_string().contains("always fail"));
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn handle_serializes_initialization_and_reuses_snapshot() {
let mock = health_mock_server().await;
let base_url = mock.uri();
let handle = Arc::new(OrchestratorServerHandle::new());
let starts = Arc::new(AtomicUsize::new(0));
let first = {
let handle = Arc::clone(&handle);
let starts = Arc::clone(&starts);
let base_url = base_url.clone();
tokio::spawn(async move {
handle
.get_or_recover_with(|| {
let starts = Arc::clone(&starts);
let base_url = base_url.clone();
async move {
starts.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
Ok(external_server(&base_url))
}
})
.await
})
};
let second = {
let handle = Arc::clone(&handle);
let starts = Arc::clone(&starts);
let base_url = base_url.clone();
tokio::spawn(async move {
handle
.get_or_recover_with(|| {
let starts = Arc::clone(&starts);
let base_url = base_url.clone();
async move {
starts.fetch_add(1, Ordering::SeqCst);
Ok(external_server(&base_url))
}
})
.await
})
};
let first = first.await.unwrap().unwrap();
let second = second.await.unwrap().unwrap();
assert_eq!(starts.load(Ordering::SeqCst), 1);
assert!(Arc::ptr_eq(&first, &second));
}
#[tokio::test]
async fn validate_for_tool_entry_uses_health_for_external_server() {
let mock = health_mock_server().await;
let server = external_server(&mock.uri());
let state = server.validate_for_tool_entry().await.unwrap();
assert_eq!(state, ServerEntryState::Healthy);
let requests = mock.received_requests().await.unwrap();
assert!(
requests
.iter()
.any(|request| request.url.path() == "/global/health"),
"expected /global/health request"
);
}
#[tokio::test]
async fn validate_for_tool_entry_times_out_health_probe() {
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/global/health"))
.respond_with(
ResponseTemplate::new(200)
.set_delay(Duration::from_secs(30))
.set_body_json(serde_json::json!({
"healthy": true,
"version": version::PINNED_OPENCODE_VERSION,
})),
)
.mount(&mock)
.await;
let server = external_server(&mock.uri());
let state = server
.validate_for_tool_entry_with_timeout(Duration::from_millis(25))
.await
.unwrap();
assert_eq!(
state,
ServerEntryState::NeedsRecovery {
reason: "/global/health probe timed out after 25ms".to_string(),
}
);
}
#[tokio::test]
async fn validate_for_tool_entry_short_circuits_dead_managed_server() {
let server = managed_server_with_exited_child("http://127.0.0.1:9").await;
let state = server.validate_for_tool_entry().await.unwrap();
assert_eq!(
state,
ServerEntryState::NeedsRecovery {
reason: "managed child is no longer running".to_string(),
}
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn handle_allows_concurrent_healthy_acquires_without_serializing_validation() {
let health = BlockingHealthServer::start(3).await;
let handle = Arc::new(OrchestratorServerHandle::from_server_unshared(
external_server(&health.base_url),
));
let started_at = Instant::now();
let tasks = (0..3)
.map(|_| {
let handle = Arc::clone(&handle);
tokio::spawn(async move { handle.acquire().await })
})
.collect::<Vec<_>>();
health.wait_for_requests(3).await;
tokio::time::sleep(Duration::from_millis(75)).await;
health.release();
let mut snapshots = Vec::with_capacity(tasks.len());
for task in tasks {
snapshots.push(task.await.unwrap().unwrap());
}
assert!(
started_at.elapsed() < Duration::from_millis(250),
"healthy acquires should overlap rather than serialize"
);
assert!(Arc::ptr_eq(&snapshots[0], &snapshots[1]));
assert!(Arc::ptr_eq(&snapshots[1], &snapshots[2]));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn handle_single_flights_concurrent_stale_acquires() {
let stale = Arc::new(managed_server_with_exited_child("http://127.0.0.1:9").await);
let handle = Arc::new(OrchestratorServerHandle {
state: AsyncMutex::new(HandleState::Ready {
snapshot: Arc::clone(&stale),
mode: RecoveryMode::Managed,
}),
});
let mock = health_mock_server().await;
let base_url = mock.uri();
let starts = Arc::new(AtomicUsize::new(0));
let tasks = (0..3)
.map(|_| {
let handle = Arc::clone(&handle);
let starts = Arc::clone(&starts);
let base_url = base_url.clone();
tokio::spawn(async move {
handle
.get_or_recover_with(|| {
let starts = Arc::clone(&starts);
let base_url = base_url.clone();
async move {
starts.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
Ok(external_server(&base_url))
}
})
.await
})
})
.collect::<Vec<_>>();
let mut snapshots = Vec::with_capacity(tasks.len());
for task in tasks {
snapshots.push(task.await.unwrap().unwrap());
}
assert_eq!(starts.load(Ordering::SeqCst), 1);
assert!(!Arc::ptr_eq(&stale, &snapshots[0]));
assert!(Arc::ptr_eq(&snapshots[0], &snapshots[1]));
assert!(Arc::ptr_eq(&snapshots[1], &snapshots[2]));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn handle_retries_if_cache_changes_while_validating() {
let old_health = BlockingHealthServer::start(1).await;
let original = Arc::new(external_server(&old_health.base_url));
let handle = Arc::new(OrchestratorServerHandle {
state: AsyncMutex::new(HandleState::Ready {
snapshot: Arc::clone(&original),
mode: RecoveryMode::External,
}),
});
let replacement_mock = health_mock_server().await;
let replacement = Arc::new(external_server(&replacement_mock.uri()));
let acquire = {
let handle = Arc::clone(&handle);
tokio::spawn(async move {
handle
.acquire_or_recover_with(|| async { anyhow::bail!("should not rebuild") })
.await
})
};
old_health.wait_for_requests(1).await;
{
let mut state = tokio::time::timeout(Duration::from_millis(100), handle.state.lock())
.await
.expect("validation should not hold the handle mutex");
*state = HandleState::Ready {
snapshot: Arc::clone(&replacement),
mode: RecoveryMode::External,
};
}
old_health.release();
let snapshot = acquire.await.unwrap().unwrap();
assert!(!Arc::ptr_eq(&snapshot, &original));
assert!(Arc::ptr_eq(&snapshot, &replacement));
}
#[tokio::test]
async fn handle_rebuilds_without_invalidating_held_snapshot() {
let stale = Arc::new(managed_server_with_exited_child("http://127.0.0.1:9").await);
let handle = OrchestratorServerHandle {
state: AsyncMutex::new(HandleState::Ready {
snapshot: Arc::clone(&stale),
mode: RecoveryMode::Managed,
}),
};
let mock = health_mock_server().await;
let base_url = mock.uri();
let starts = Arc::new(AtomicUsize::new(0));
let rebuilt = handle
.get_or_recover_with(|| {
let starts = Arc::clone(&starts);
let base_url = base_url.clone();
async move {
starts.fetch_add(1, Ordering::SeqCst);
Ok(external_server(&base_url))
}
})
.await
.unwrap();
assert_eq!(starts.load(Ordering::SeqCst), 1);
assert!(!Arc::ptr_eq(&stale, &rebuilt));
assert_eq!(stale.base_url(), "http://127.0.0.1:9");
assert_eq!(rebuilt.base_url(), base_url.trim_end_matches('/'));
}
#[test]
#[serial(env)]
fn managed_guard_disabled_when_env_not_set() {
let _env = ManagedEnvGuard::new();
unsafe { std::env::remove_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV) };
assert!(!managed_guard_enabled());
}
#[test]
#[serial(env)]
fn managed_guard_enabled_when_env_is_1() {
let _env = ManagedEnvGuard::new();
unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "1") };
assert!(managed_guard_enabled());
}
#[test]
#[serial(env)]
fn managed_guard_disabled_when_env_is_0() {
let _env = ManagedEnvGuard::new();
unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "0") };
assert!(!managed_guard_enabled());
}
#[test]
#[serial(env)]
fn managed_guard_disabled_when_env_is_empty() {
let _env = ManagedEnvGuard::new();
unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "") };
assert!(!managed_guard_enabled());
}
#[test]
#[serial(env)]
fn managed_guard_disabled_when_env_is_whitespace() {
let _env = ManagedEnvGuard::new();
unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, " ") };
assert!(!managed_guard_enabled());
}
#[test]
#[serial(env)]
fn managed_guard_enabled_when_env_is_truthy() {
let _env = ManagedEnvGuard::new();
unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "true") };
assert!(managed_guard_enabled());
}
#[tokio::test]
#[serial(env)]
async fn recursion_guard_only_blocks_real_startup_paths() {
let _env = ManagedEnvGuard::new();
unsafe { std::env::set_var(OPENCODE_ORCHESTRATOR_MANAGED_ENV, "1") };
let mock = health_mock_server().await;
let handle = OrchestratorServerHandle::from_server_unshared(external_server(&mock.uri()));
let reused = handle
.get_or_recover_with(|| async { anyhow::bail!("should not start") })
.await
.unwrap();
assert_eq!(reused.base_url(), mock.uri().trim_end_matches('/'));
let fresh_handle = OrchestratorServerHandle::new();
let err = match fresh_handle.acquire().await {
Ok(_server) => panic!("expected recursion guard to block fresh startup"),
Err(error) => error,
};
assert!(err.to_string().contains(ORCHESTRATOR_MANAGED_GUARD_MESSAGE));
}
#[tokio::test]
async fn external_failure_becomes_sticky_and_typed() {
let handle = OrchestratorServerHandle::from_server_unshared(
OrchestratorServer::from_client_unshared(
test_client("http://127.0.0.1:9"),
"http://127.0.0.1:9",
RecoveryMode::External,
),
);
let starts = AtomicUsize::new(0);
let first = handle
.acquire_or_recover_with(|| {
starts.fetch_add(1, Ordering::SeqCst);
async { anyhow::bail!("should not rebuild external servers") }
})
.await;
let second = handle
.acquire_or_recover_with(|| {
starts.fetch_add(1, Ordering::SeqCst);
async { anyhow::bail!("should not rebuild external servers") }
})
.await;
let first = match first {
Ok(_snapshot) => panic!("expected typed external failure on first acquire"),
Err(error) => error,
};
let second = match second {
Ok(_snapshot) => panic!("expected sticky external failure on second acquire"),
Err(error) => error,
};
assert_eq!(starts.load(Ordering::SeqCst), 0);
assert!(
first
.to_string()
.contains("External OpenCode server unavailable"),
"expected typed external failure, got: {first}"
);
assert_eq!(first.to_string(), second.to_string());
}
}