1use std::collections::{BTreeMap, HashMap};
19use std::sync::{Arc, Mutex};
20use std::time::{Duration, Instant};
21
22use serde::{Deserialize, Serialize};
23use serde_json::Value as JsonValue;
24use sha2::{Digest, Sha256};
25
26use crate::mcp::{call_mcp_tool_with_hint, VmMcpClientHandle};
27use crate::mcp_protocol::McpCacheHint;
28use crate::mcp_registry::{self, RegisteredMcpServer};
29use crate::value::VmError;
30
31pub const DEFAULT_MAX_RESTARTS: u32 = 5;
36pub const DEFAULT_RESTART_WINDOW: Duration = Duration::from_mins(5);
37
38pub const DEFAULT_CIRCUIT_THRESHOLD: u32 = 5;
44pub const DEFAULT_CIRCUIT_RESET: Duration = Duration::from_secs(30);
45
46pub const INITIAL_RESTART_BACKOFF: Duration = Duration::from_millis(100);
49pub const MAX_RESTART_BACKOFF: Duration = Duration::from_secs(5);
50
51pub const RESPONSE_CACHE_MAX_ENTRIES_PER_TOOL: usize = 64;
56
57#[derive(Debug)]
62struct SupervisionState {
63 restart_attempts: Vec<Instant>,
66 consecutive_failures: u32,
69 breaker_opens_until: Option<Instant>,
74 ejected: bool,
77 circuit_threshold: u32,
79 circuit_reset: Duration,
81 max_restarts: u32,
83 restart_window: Duration,
85}
86
87impl SupervisionState {
88 fn new(policy: SupervisionPolicy) -> Self {
89 Self {
90 restart_attempts: Vec::new(),
91 consecutive_failures: 0,
92 breaker_opens_until: None,
93 ejected: false,
94 circuit_threshold: policy.circuit_threshold,
95 circuit_reset: policy.circuit_reset,
96 max_restarts: policy.max_restarts,
97 restart_window: policy.restart_window,
98 }
99 }
100
101 fn breaker_state(&mut self, now: Instant) -> BreakerState {
105 match self.breaker_opens_until {
106 Some(deadline) if now < deadline => BreakerState::Open,
107 Some(_) => BreakerState::HalfOpen,
108 None => BreakerState::Closed,
109 }
110 }
111
112 fn record_success(&mut self) {
115 self.consecutive_failures = 0;
116 self.breaker_opens_until = None;
117 }
118
119 fn record_failure(&mut self, now: Instant) {
123 self.consecutive_failures = self.consecutive_failures.saturating_add(1);
124 if self.consecutive_failures >= self.circuit_threshold {
125 self.breaker_opens_until = Some(now + self.circuit_reset);
126 }
127 }
128
129 fn record_restart(&mut self, now: Instant) -> bool {
134 self.prune_restart_window(now);
135 self.restart_attempts.push(now);
136 if self.restart_attempts.len() as u32 > self.max_restarts {
137 self.ejected = true;
138 return false;
139 }
140 true
141 }
142
143 fn backoff_delay(&self) -> Duration {
146 let attempt = self.restart_attempts.len() as u32;
147 let exp = attempt.saturating_sub(1).min(6);
148 let mul = 1u64 << exp;
149 let nanos = INITIAL_RESTART_BACKOFF.as_nanos() as u64 * mul;
150 Duration::from_nanos(nanos).min(MAX_RESTART_BACKOFF)
151 }
152
153 fn prune_restart_window(&mut self, now: Instant) {
154 let window = self.restart_window;
155 self.restart_attempts
156 .retain(|t| now.duration_since(*t) <= window);
157 }
158
159 fn clear(&mut self) {
160 self.restart_attempts.clear();
161 self.consecutive_failures = 0;
162 self.breaker_opens_until = None;
163 self.ejected = false;
164 }
165}
166
167#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
169#[serde(rename_all = "snake_case")]
170pub enum BreakerState {
171 Closed,
172 Open,
173 HalfOpen,
174}
175
176impl BreakerState {
177 pub fn as_str(self) -> &'static str {
178 match self {
179 BreakerState::Closed => "closed",
180 BreakerState::Open => "open",
181 BreakerState::HalfOpen => "half_open",
182 }
183 }
184}
185
186#[derive(Clone, Copy, Debug)]
190pub struct SupervisionPolicy {
191 pub circuit_threshold: u32,
192 pub circuit_reset: Duration,
193 pub max_restarts: u32,
194 pub restart_window: Duration,
195}
196
197impl Default for SupervisionPolicy {
198 fn default() -> Self {
199 Self {
200 circuit_threshold: DEFAULT_CIRCUIT_THRESHOLD,
201 circuit_reset: DEFAULT_CIRCUIT_RESET,
202 max_restarts: DEFAULT_MAX_RESTARTS,
203 restart_window: DEFAULT_RESTART_WINDOW,
204 }
205 }
206}
207
208#[derive(Clone, Debug)]
212struct CachedResponse {
213 payload: JsonValue,
214 inserted_at: Instant,
215 expires_at: Instant,
216 #[allow(dead_code)]
221 scope: Option<&'static str>,
222}
223
224#[derive(Clone, Debug, PartialEq, Eq)]
228pub enum AllowlistDecision {
229 Allow,
230 Deny { reason: String },
231}
232
233pub type AllowlistGuard = Arc<dyn Fn(&str, Option<&str>) -> AllowlistDecision + Send + Sync>;
238
239#[derive(Clone, Debug, Serialize)]
242pub struct McpHostStatus {
243 pub name: String,
244 pub transport: String,
245 pub url: Option<String>,
246 pub active: bool,
247 pub lazy: bool,
248 pub ref_count: usize,
249 pub restart_count: u32,
250 pub consecutive_failures: u32,
251 pub circuit: BreakerState,
252 pub ejected: bool,
253 pub cache_entries: usize,
256 pub display_identity: Option<String>,
260}
261
262#[derive(Clone, Debug, Default, Deserialize)]
265pub struct SpawnOptions {
266 #[serde(default)]
270 pub lazy: bool,
271 #[serde(default)]
274 pub keep_alive_ms: Option<u64>,
275 #[serde(default)]
278 pub card: Option<String>,
279 #[serde(default)]
282 pub circuit_threshold: Option<u32>,
283 #[serde(default)]
284 pub circuit_reset_ms: Option<u64>,
285 #[serde(default)]
286 pub max_restarts: Option<u32>,
287 #[serde(default)]
288 pub restart_window_ms: Option<u64>,
289}
290
291impl SpawnOptions {
292 fn into_policy(self) -> (SupervisionPolicy, RegisteredMcpServerMeta) {
293 let default = SupervisionPolicy::default();
294 let policy = SupervisionPolicy {
295 circuit_threshold: self.circuit_threshold.unwrap_or(default.circuit_threshold),
296 circuit_reset: self
297 .circuit_reset_ms
298 .map(Duration::from_millis)
299 .unwrap_or(default.circuit_reset),
300 max_restarts: self.max_restarts.unwrap_or(default.max_restarts),
301 restart_window: self
302 .restart_window_ms
303 .map(Duration::from_millis)
304 .unwrap_or(default.restart_window),
305 };
306 let meta = RegisteredMcpServerMeta {
307 lazy: self.lazy,
308 keep_alive: self.keep_alive_ms.map(Duration::from_millis),
309 card: self.card,
310 };
311 (policy, meta)
312 }
313}
314
315struct RegisteredMcpServerMeta {
316 lazy: bool,
317 keep_alive: Option<Duration>,
318 card: Option<String>,
319}
320
321struct HostInner {
324 supervision: HashMap<String, SupervisionState>,
328 response_cache: HashMap<(String, String), HashMap<String, CachedResponse>>,
332 allowlist: Option<AllowlistGuard>,
334 cache_hits: u64,
337 cache_misses: u64,
338}
339
340impl HostInner {
341 fn new() -> Self {
342 Self {
343 supervision: HashMap::new(),
344 response_cache: HashMap::new(),
345 allowlist: None,
346 cache_hits: 0,
347 cache_misses: 0,
348 }
349 }
350}
351
352static HOST: Mutex<Option<HostInner>> = Mutex::new(None);
353
354fn with_inner<F, R>(f: F) -> R
355where
356 F: FnOnce(&mut HostInner) -> R,
357{
358 let mut guard = HOST.lock().expect("mcp host mutex poisoned");
359 if guard.is_none() {
360 *guard = Some(HostInner::new());
361 }
362 f(guard.as_mut().expect("host inner just initialized"))
363}
364
365pub fn set_allowlist(guard: Option<AllowlistGuard>) {
367 with_inner(|inner| inner.allowlist = guard);
368}
369
370pub fn reset_for_tests() {
373 with_inner(|inner| {
374 inner.supervision.clear();
375 inner.response_cache.clear();
376 inner.allowlist = None;
377 inner.cache_hits = 0;
378 inner.cache_misses = 0;
379 });
380 mcp_registry::reset();
381}
382
383#[derive(Clone, Copy, Debug)]
387pub struct CacheStats {
388 pub hits: u64,
389 pub misses: u64,
390}
391
392pub fn cache_stats() -> CacheStats {
393 with_inner(|inner| CacheStats {
394 hits: inner.cache_hits,
395 misses: inner.cache_misses,
396 })
397}
398
399pub async fn spawn(spec: JsonValue, options: SpawnOptions) -> Result<String, VmError> {
403 let name = spec
404 .get("name")
405 .and_then(|v| v.as_str())
406 .ok_or_else(|| VmError::Runtime("mcp.spawn: spec must include a `name` field".into()))?
407 .to_string();
408 if name.is_empty() {
409 return Err(VmError::Runtime(
410 "mcp.spawn: spec.name must be a non-empty string".into(),
411 ));
412 }
413
414 if let Some(guard) = current_allowlist() {
415 if let AllowlistDecision::Deny { reason } = guard(&name, None) {
416 return Err(VmError::Runtime(format!(
417 "mcp.spawn({name}): denied by allowlist: {reason}"
418 )));
419 }
420 }
421
422 let (policy, meta) = options.into_policy();
423 mcp_registry::register_servers(vec![RegisteredMcpServer {
424 name: name.clone(),
425 spec: spec.clone(),
426 lazy: meta.lazy,
427 card: meta.card,
428 keep_alive: meta.keep_alive,
429 }]);
430
431 with_inner(|inner| {
432 inner
433 .supervision
434 .insert(name.clone(), SupervisionState::new(policy));
435 });
436
437 if !meta.lazy {
438 let _ = mcp_registry::ensure_active(&name).await.inspect_err(|_| {
441 with_inner(|inner| {
442 inner.supervision.remove(&name);
443 });
444 })?;
445 }
446
447 Ok(name)
448}
449
450pub fn stop(name: &str) -> Result<(), VmError> {
456 if !mcp_registry::is_registered(name) {
457 return Err(VmError::Runtime(format!(
458 "mcp.stop: no server named '{name}' is hosted"
459 )));
460 }
461 mcp_registry::release(name);
462 with_inner(|inner| {
463 inner.supervision.remove(name);
464 inner.response_cache.retain(|(s, _), _| s != name);
465 });
466 Ok(())
467}
468
469pub fn reload(name: &str) -> Result<(), VmError> {
474 if !mcp_registry::is_registered(name) {
475 return Err(VmError::Runtime(format!(
476 "mcp.reload: no server named '{name}' is hosted"
477 )));
478 }
479 mcp_registry::release(name);
480 with_inner(|inner| {
481 if let Some(state) = inner.supervision.get_mut(name) {
482 state.clear();
483 }
484 inner.response_cache.retain(|(s, _), _| s != name);
485 });
486 Ok(())
487}
488
489pub async fn tools(name: &str) -> Result<Vec<JsonValue>, VmError> {
494 let handle = ensure_or_restart(name).await?;
495 let result = supervised_call(name, || async {
496 handle.call("tools/list", serde_json::json!({})).await
497 })
498 .await?;
499
500 let mut tools = result
501 .get("tools")
502 .and_then(|t| t.as_array())
503 .cloned()
504 .unwrap_or_default();
505 for tool in tools.iter_mut() {
506 if let Some(obj) = tool.as_object_mut() {
507 obj.entry("_mcp_server")
508 .or_insert_with(|| JsonValue::String(name.to_string()));
509 }
510 }
511 let security_policy = crate::security::current_policy();
516 if security_policy.pin_mcp_schemas && !security_policy.server_is_trusted(name) {
517 for tool in tools.iter_mut() {
518 let hash = crate::security::tool_schema_hash(tool);
519 let tool_name = tool
520 .get("name")
521 .and_then(|v| v.as_str())
522 .unwrap_or_default()
523 .to_string();
524 if tool_name.is_empty() {
525 continue;
526 }
527 if crate::security::pin_and_detect_change(name, &tool_name, &hash) {
528 if let Some(obj) = tool.as_object_mut() {
529 obj.insert("_schema_changed".to_string(), JsonValue::Bool(true));
530 }
531 }
532 }
533 }
534 Ok(tools)
535}
536
537pub async fn call(name: &str, tool: &str, args: JsonValue) -> Result<JsonValue, VmError> {
544 if let Some(guard) = current_allowlist() {
545 if let AllowlistDecision::Deny { reason } = guard(name, Some(tool)) {
546 return Err(VmError::Runtime(format!(
547 "mcp.call({name}/{tool}): denied by allowlist: {reason}"
548 )));
549 }
550 }
551
552 crate::call_budget::charge_mcp_call()?;
557
558 let now = Instant::now();
559 let args_hash = hash_args(&args);
560 if let Some(payload) = take_cache_hit(name, tool, &args_hash, now) {
561 return Ok(payload);
562 }
563 with_inner(|inner| inner.cache_misses = inner.cache_misses.saturating_add(1));
564
565 breaker_gate(name, now)?;
566
567 let handle = ensure_or_restart(name).await?;
568 let envelope_hint: Arc<Mutex<Option<McpCacheHint>>> = Arc::new(Mutex::new(None));
573 let hint_slot = Arc::clone(&envelope_hint);
574 let result = supervised_call(name, move || {
575 let handle = handle.clone();
576 let tool = tool.to_string();
577 let args = args.clone();
578 let hint_slot = Arc::clone(&hint_slot);
579 async move {
580 let (content, hint) = call_mcp_tool_with_hint(&handle, &tool, args).await?;
581 if let Ok(mut slot) = hint_slot.lock() {
582 *slot = hint;
583 }
584 Ok(content)
585 }
586 })
587 .await?;
588
589 let hint = envelope_hint.lock().ok().and_then(|slot| *slot);
590 if let Some(hint) = hint {
591 insert_cache(name, tool, &args_hash, &result, hint, now);
592 }
593
594 Ok(result)
595}
596
597pub async fn discover() -> Result<Vec<JsonValue>, VmError> {
601 let names: Vec<String> = mcp_registry::snapshot_status()
602 .into_iter()
603 .map(|s| s.name)
604 .collect();
605 let mut out: Vec<JsonValue> = Vec::new();
606 for name in names {
607 if let Some(guard) = current_allowlist() {
610 if matches!(guard(&name, None), AllowlistDecision::Deny { .. }) {
611 continue;
612 }
613 }
614 match tools(&name).await {
618 Ok(tools) => {
619 for tool in tools {
620 let tool_name = tool
621 .get("name")
622 .and_then(|v| v.as_str())
623 .unwrap_or("")
624 .to_string();
625 out.push(serde_json::json!({
626 "server": name,
627 "tool": tool_name,
628 "schema": tool,
629 }));
630 }
631 }
632 Err(err) => {
633 out.push(serde_json::json!({
634 "server": name,
635 "error": err.to_string(),
636 }));
637 }
638 }
639 }
640 Ok(out)
641}
642
643pub async fn status() -> Vec<McpHostStatus> {
645 let registry: BTreeMap<String, mcp_registry::RegistryStatus> = mcp_registry::snapshot_status()
646 .into_iter()
647 .map(|s| (s.name.clone(), s))
648 .collect();
649 let mut statuses = with_inner(|inner| {
650 let mut out = Vec::new();
651 let now = Instant::now();
652 for (name, reg) in ®istry {
653 let (restart_count, consecutive_failures, circuit, ejected) =
654 if let Some(state) = inner.supervision.get_mut(name) {
655 let st = state.breaker_state(now);
656 (
657 state.restart_attempts.len() as u32,
658 state.consecutive_failures,
659 st,
660 state.ejected,
661 )
662 } else {
663 (0, 0, BreakerState::Closed, false)
664 };
665 let cache_entries = inner
666 .response_cache
667 .iter()
668 .filter(|((s, _), _)| s == name)
669 .map(|(_, v)| v.len())
670 .sum();
671 out.push(McpHostStatus {
672 name: name.clone(),
673 transport: reg.transport.clone(),
674 url: reg.url.clone(),
675 active: reg.active,
676 lazy: reg.lazy,
677 ref_count: reg.ref_count,
678 restart_count,
679 consecutive_failures,
680 circuit,
681 ejected,
682 cache_entries,
683 display_identity: None,
684 });
685 }
686 out
687 });
688 for status in &mut statuses {
689 if !status.active || status.transport != "http" {
690 continue;
691 }
692 let Some(url) = status.url.as_deref() else {
693 continue;
694 };
695 status.display_identity = crate::mcp_identity::display_identity_from_store(url, None).await;
696 }
697 statuses
698}
699
700fn current_allowlist() -> Option<AllowlistGuard> {
701 with_inner(|inner| inner.allowlist.clone())
702}
703
704fn breaker_gate(name: &str, now: Instant) -> Result<(), VmError> {
705 with_inner(|inner| {
706 let Some(state) = inner.supervision.get_mut(name) else {
707 return Ok(());
708 };
709 if state.ejected {
710 return Err(VmError::Runtime(format!(
711 "mcp.call({name}): server is ejected after exhausting its restart budget; call `harn.mcp.reload({name:?})` to clear"
712 )));
713 }
714 match state.breaker_state(now) {
715 BreakerState::Open => Err(VmError::Runtime(format!(
716 "mcp.call({name}): circuit breaker is open (last {n} consecutive failures); retry after the breaker resets",
717 n = state.consecutive_failures
718 ))),
719 BreakerState::Closed | BreakerState::HalfOpen => Ok(()),
722 }
723 })
724}
725
726async fn ensure_or_restart(name: &str) -> Result<VmMcpClientHandle, VmError> {
727 if let Some(handle) = mcp_registry::active_handle(name) {
729 return Ok(handle);
730 }
731
732 mcp_registry::ensure_active(name).await
738}
739
740async fn supervised_call<F, Fut>(name: &str, op: F) -> Result<JsonValue, VmError>
744where
745 F: Fn() -> Fut,
746 Fut: std::future::Future<Output = Result<JsonValue, VmError>>,
747{
748 let span = tracing::info_span!(
749 "harn.mcp.call",
750 otel.name = "harn.mcp.call",
751 harn.mcp.server = name,
752 );
753 let _enter = span.enter();
754
755 let first = op().await;
756 match first {
757 Ok(v) => {
758 with_inner(|inner| {
759 if let Some(state) = inner.supervision.get_mut(name) {
760 state.record_success();
761 }
762 });
763 Ok(v)
764 }
765 Err(err) => {
766 let now = Instant::now();
767 let (should_retry, backoff) = with_inner(|inner| {
768 let Some(state) = inner.supervision.get_mut(name) else {
769 return (false, Duration::ZERO);
770 };
771 state.record_failure(now);
772 if !looks_like_transport_failure(&err) {
777 return (false, Duration::ZERO);
778 }
779 let ok = state.record_restart(now);
780 if !ok {
781 return (false, Duration::ZERO);
782 }
783 (true, state.backoff_delay())
784 });
785 if !should_retry {
786 tracing::warn!(
787 server = name,
788 error = %err,
789 "harn.mcp.call: failure (no retry)"
790 );
791 return Err(err);
792 }
793
794 tracing::info!(
795 server = name,
796 error = %err,
797 backoff_ms = backoff.as_millis() as u64,
798 "harn.mcp.call: retrying after transport failure"
799 );
800
801 mcp_registry::release(name);
804 tokio::time::sleep(backoff).await;
805 let _handle = ensure_or_restart(name).await?;
806 let second = op().await;
807 match &second {
808 Ok(_) => with_inner(|inner| {
809 if let Some(state) = inner.supervision.get_mut(name) {
810 state.record_success();
811 }
812 }),
813 Err(err) => with_inner(|inner| {
814 if let Some(state) = inner.supervision.get_mut(name) {
815 state.record_failure(Instant::now());
816 }
817 tracing::warn!(
818 server = name,
819 error = %err,
820 "harn.mcp.call: second attempt failed"
821 );
822 }),
823 }
824 second
825 }
826 }
827}
828
829fn looks_like_transport_failure(err: &VmError) -> bool {
830 let text = err.to_string();
831 let needles = [
832 "server closed connection",
833 "disconnected",
834 "MCP read error",
835 "MCP write error",
836 "did not respond to",
837 "MCP flush error",
838 "connect",
839 ];
840 needles.iter().any(|n| text.contains(n))
841}
842
843fn hash_args(args: &JsonValue) -> String {
844 let mut hasher = Sha256::new();
845 let canonical = canonicalize_json(args);
846 hasher.update(canonical.as_bytes());
847 let digest = hasher.finalize();
848 let mut hex = String::with_capacity(digest.len() * 2);
849 for byte in digest {
850 use std::fmt::Write;
851 let _ = write!(&mut hex, "{byte:02x}");
852 }
853 hex
854}
855
856fn canonicalize_json(value: &JsonValue) -> String {
859 match value {
860 JsonValue::Object(map) => {
861 let mut sorted: Vec<(&String, &JsonValue)> = map.iter().collect();
862 sorted.sort_by(|a, b| a.0.cmp(b.0));
863 let body: Vec<String> = sorted
864 .into_iter()
865 .map(|(k, v)| {
866 format!(
867 "{}:{}",
868 serde_json::to_string(k).unwrap_or_default(),
869 canonicalize_json(v)
870 )
871 })
872 .collect();
873 format!("{{{}}}", body.join(","))
874 }
875 JsonValue::Array(items) => {
876 let body: Vec<String> = items.iter().map(canonicalize_json).collect();
877 format!("[{}]", body.join(","))
878 }
879 other => serde_json::to_string(other).unwrap_or_default(),
880 }
881}
882
883fn take_cache_hit(server: &str, tool: &str, args_hash: &str, now: Instant) -> Option<JsonValue> {
884 with_inner(|inner| {
885 let key = (server.to_string(), tool.to_string());
886 let entry = inner.response_cache.get_mut(&key)?;
887 let cached = entry.get(args_hash)?;
888 if now >= cached.expires_at {
889 entry.remove(args_hash);
890 return None;
891 }
892 let payload = cached.payload.clone();
893 inner.cache_hits = inner.cache_hits.saturating_add(1);
894 Some(payload)
895 })
896}
897
898fn insert_cache(
903 server: &str,
904 tool: &str,
905 args_hash: &str,
906 payload: &JsonValue,
907 hint: McpCacheHint,
908 now: Instant,
909) {
910 let Some(ttl_ms) = hint.ttl_ms else {
911 return;
912 };
913 if ttl_ms == 0 {
914 return;
915 }
916 let expires_at = now + Duration::from_millis(ttl_ms);
917 let cached = CachedResponse {
918 payload: payload.clone(),
919 inserted_at: now,
920 expires_at,
921 scope: hint.scope,
922 };
923 with_inner(|inner| {
924 let key = (server.to_string(), tool.to_string());
925 let bucket = inner.response_cache.entry(key).or_default();
926 if bucket.len() >= RESPONSE_CACHE_MAX_ENTRIES_PER_TOOL {
927 if let Some(oldest_key) = bucket
929 .iter()
930 .min_by_key(|(_, v)| v.inserted_at)
931 .map(|(k, _)| k.clone())
932 {
933 bucket.remove(&oldest_key);
934 }
935 }
936 bucket.insert(args_hash.to_string(), cached);
937 });
938}
939
940#[cfg(test)]
944fn insert_cache_if_hinted(
945 server: &str,
946 tool: &str,
947 args_hash: &str,
948 payload: &JsonValue,
949 now: Instant,
950) {
951 if let Some(hint) = McpCacheHint::from_result(payload) {
952 insert_cache(server, tool, args_hash, payload, hint, now);
953 }
954}
955
956#[cfg(test)]
957mod tests {
958 use super::*;
959
960 static TEST_LOCK: Mutex<()> = Mutex::new(());
961
962 fn lock() -> std::sync::MutexGuard<'static, ()> {
963 TEST_LOCK.lock().unwrap_or_else(|p| p.into_inner())
964 }
965
966 #[test]
967 fn supervision_breaker_opens_after_threshold() {
968 let _g = lock();
969 let mut state = SupervisionState::new(SupervisionPolicy {
970 circuit_threshold: 3,
971 circuit_reset: Duration::from_millis(100),
972 ..SupervisionPolicy::default()
973 });
974 let t0 = Instant::now();
975 assert_eq!(state.breaker_state(t0), BreakerState::Closed);
976 state.record_failure(t0);
977 state.record_failure(t0);
978 assert_eq!(state.breaker_state(t0), BreakerState::Closed);
979 state.record_failure(t0);
980 assert_eq!(state.breaker_state(t0), BreakerState::Open);
981 assert_eq!(
983 state.breaker_state(t0 + Duration::from_millis(200)),
984 BreakerState::HalfOpen
985 );
986 }
987
988 #[test]
989 fn supervision_restart_budget_ejects_after_n_attempts() {
990 let _g = lock();
991 let mut state = SupervisionState::new(SupervisionPolicy {
992 max_restarts: 2,
993 restart_window: Duration::from_mins(1),
994 ..SupervisionPolicy::default()
995 });
996 let t = Instant::now();
997 assert!(state.record_restart(t));
998 assert!(state.record_restart(t));
999 assert!(!state.record_restart(t));
1000 assert!(state.ejected);
1001 }
1002
1003 #[test]
1004 fn supervision_backoff_grows_exponentially_then_caps() {
1005 let _g = lock();
1006 let mut state = SupervisionState::new(SupervisionPolicy::default());
1007 let t = Instant::now();
1008 state.record_restart(t);
1009 let d1 = state.backoff_delay();
1010 state.record_restart(t);
1011 let d2 = state.backoff_delay();
1012 state.record_restart(t);
1013 let d3 = state.backoff_delay();
1014 assert!(
1015 d2 > d1,
1016 "second backoff ({d2:?}) should exceed first ({d1:?})"
1017 );
1018 assert!(d3 > d2);
1019 for _ in 0..16 {
1020 state.record_restart(t);
1021 }
1022 assert!(state.backoff_delay() <= MAX_RESTART_BACKOFF);
1023 }
1024
1025 #[test]
1026 fn canonical_json_sorts_object_keys() {
1027 let a = canonicalize_json(&serde_json::json!({"b": 1, "a": 2}));
1028 let b = canonicalize_json(&serde_json::json!({"a": 2, "b": 1}));
1029 assert_eq!(a, b);
1030 }
1031
1032 #[test]
1033 fn hash_args_is_stable_across_key_order() {
1034 let h1 = hash_args(&serde_json::json!({"x": 1, "y": [1, 2]}));
1035 let h2 = hash_args(&serde_json::json!({"y": [1, 2], "x": 1}));
1036 assert_eq!(h1, h2);
1037 }
1038
1039 #[test]
1040 fn cache_insert_and_take_respects_ttl() {
1041 let _g = lock();
1042 reset_for_tests();
1043 let payload = serde_json::json!({
1044 "ttlMs": 100,
1045 "cacheScope": "private",
1046 "value": 1
1047 });
1048 let now = Instant::now();
1049 insert_cache_if_hinted("srv", "ping", "deadbeef", &payload, now);
1050 let hit = take_cache_hit("srv", "ping", "deadbeef", now);
1051 assert!(hit.is_some(), "fresh entry should hit");
1052 let stale = take_cache_hit("srv", "ping", "deadbeef", now + Duration::from_millis(200));
1053 assert!(stale.is_none(), "expired entry should miss");
1054 }
1055
1056 #[test]
1057 fn cache_skips_payload_without_hint() {
1058 let _g = lock();
1059 reset_for_tests();
1060 insert_cache_if_hinted(
1061 "srv",
1062 "ping",
1063 "h",
1064 &serde_json::json!({"value": 1}),
1065 Instant::now(),
1066 );
1067 assert!(take_cache_hit("srv", "ping", "h", Instant::now()).is_none());
1068 }
1069
1070 #[test]
1071 fn allowlist_denies_disallowed_tool() {
1072 let _g = lock();
1073 reset_for_tests();
1074 set_allowlist(Some(Arc::new(|server, tool| {
1075 if server == "github" && tool == Some("delete_repo") {
1076 AllowlistDecision::Deny {
1077 reason: "destructive tool blocked".into(),
1078 }
1079 } else {
1080 AllowlistDecision::Allow
1081 }
1082 })));
1083 let runtime = tokio::runtime::Builder::new_current_thread()
1084 .enable_all()
1085 .build()
1086 .unwrap();
1087 let err = runtime
1088 .block_on(call("github", "delete_repo", serde_json::json!({})))
1089 .unwrap_err();
1090 assert!(err.to_string().contains("denied by allowlist"));
1091 set_allowlist(None);
1092 }
1093
1094 #[test]
1095 fn stop_unregistered_server_errors() {
1096 let _g = lock();
1097 reset_for_tests();
1098 let err = stop("nope").unwrap_err();
1099 assert!(err.to_string().contains("no server named 'nope'"));
1100 }
1101
1102 #[test]
1103 fn supervision_record_success_resets_counters() {
1104 let _g = lock();
1105 let mut state = SupervisionState::new(SupervisionPolicy::default());
1106 let t = Instant::now();
1107 state.record_failure(t);
1108 state.record_failure(t);
1109 state.record_success();
1110 assert_eq!(state.consecutive_failures, 0);
1111 assert!(state.breaker_opens_until.is_none());
1112 }
1113
1114 #[test]
1115 fn looks_like_transport_failure_matches_common_errors() {
1116 let cases = [
1117 "MCP: server closed connection",
1118 "MCP: server did not respond to 'tools/call' within 60s",
1119 "MCP write error: broken pipe",
1120 "MCP client is disconnected",
1121 ];
1122 for msg in cases {
1123 assert!(
1124 looks_like_transport_failure(&VmError::Runtime(msg.into())),
1125 "expected {msg:?} to be classified as transport failure"
1126 );
1127 }
1128 assert!(
1129 !looks_like_transport_failure(&VmError::Runtime(
1130 "tool 'foo' rejected arguments".into()
1131 )),
1132 "tool-level errors must not trigger an auto-restart"
1133 );
1134 }
1135}