1use std::collections::{BTreeMap, HashMap};
19use std::sync::{Arc, Mutex};
20use std::time::{Duration, Instant};
21
22use serde::{Deserialize, Serialize};
23use serde_json::Value as JsonValue;
24use sha2::{Digest, Sha256};
25
26use crate::mcp::{call_mcp_tool_with_hint, VmMcpClientHandle};
27use crate::mcp_protocol::McpCacheHint;
28use crate::mcp_registry::{self, RegisteredMcpServer};
29use crate::value::VmError;
30
31pub const DEFAULT_MAX_RESTARTS: u32 = 5;
36pub const DEFAULT_RESTART_WINDOW: Duration = Duration::from_mins(5);
37
38pub const DEFAULT_CIRCUIT_THRESHOLD: u32 = 5;
44pub const DEFAULT_CIRCUIT_RESET: Duration = Duration::from_secs(30);
45
46pub const INITIAL_RESTART_BACKOFF: Duration = Duration::from_millis(100);
49pub const MAX_RESTART_BACKOFF: Duration = Duration::from_secs(5);
50
51pub const RESPONSE_CACHE_MAX_ENTRIES_PER_TOOL: usize = 64;
56
57#[derive(Debug)]
62struct SupervisionState {
63 restart_attempts: Vec<Instant>,
66 consecutive_failures: u32,
69 breaker_opens_until: Option<Instant>,
74 ejected: bool,
77 circuit_threshold: u32,
79 circuit_reset: Duration,
81 max_restarts: u32,
83 restart_window: Duration,
85}
86
87impl SupervisionState {
88 fn new(policy: SupervisionPolicy) -> Self {
89 Self {
90 restart_attempts: Vec::new(),
91 consecutive_failures: 0,
92 breaker_opens_until: None,
93 ejected: false,
94 circuit_threshold: policy.circuit_threshold,
95 circuit_reset: policy.circuit_reset,
96 max_restarts: policy.max_restarts,
97 restart_window: policy.restart_window,
98 }
99 }
100
101 fn breaker_state(&mut self, now: Instant) -> BreakerState {
105 match self.breaker_opens_until {
106 Some(deadline) if now < deadline => BreakerState::Open,
107 Some(_) => BreakerState::HalfOpen,
108 None => BreakerState::Closed,
109 }
110 }
111
112 fn record_success(&mut self) {
115 self.consecutive_failures = 0;
116 self.breaker_opens_until = None;
117 }
118
119 fn record_failure(&mut self, now: Instant) {
123 self.consecutive_failures = self.consecutive_failures.saturating_add(1);
124 if self.consecutive_failures >= self.circuit_threshold {
125 self.breaker_opens_until = Some(now + self.circuit_reset);
126 }
127 }
128
129 fn record_restart(&mut self, now: Instant) -> bool {
134 self.prune_restart_window(now);
135 self.restart_attempts.push(now);
136 if self.restart_attempts.len() as u32 > self.max_restarts {
137 self.ejected = true;
138 return false;
139 }
140 true
141 }
142
143 fn backoff_delay(&self) -> Duration {
146 let attempt = self.restart_attempts.len() as u32;
147 let exp = attempt.saturating_sub(1).min(6);
148 let mul = 1u64 << exp;
149 let nanos = INITIAL_RESTART_BACKOFF.as_nanos() as u64 * mul;
150 Duration::from_nanos(nanos).min(MAX_RESTART_BACKOFF)
151 }
152
153 fn prune_restart_window(&mut self, now: Instant) {
154 let window = self.restart_window;
155 self.restart_attempts
156 .retain(|t| now.duration_since(*t) <= window);
157 }
158
159 fn clear(&mut self) {
160 self.restart_attempts.clear();
161 self.consecutive_failures = 0;
162 self.breaker_opens_until = None;
163 self.ejected = false;
164 }
165}
166
167#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
169#[serde(rename_all = "snake_case")]
170pub enum BreakerState {
171 Closed,
172 Open,
173 HalfOpen,
174}
175
176impl BreakerState {
177 pub fn as_str(self) -> &'static str {
178 match self {
179 BreakerState::Closed => "closed",
180 BreakerState::Open => "open",
181 BreakerState::HalfOpen => "half_open",
182 }
183 }
184}
185
186#[derive(Clone, Copy, Debug)]
190pub struct SupervisionPolicy {
191 pub circuit_threshold: u32,
192 pub circuit_reset: Duration,
193 pub max_restarts: u32,
194 pub restart_window: Duration,
195}
196
197impl Default for SupervisionPolicy {
198 fn default() -> Self {
199 Self {
200 circuit_threshold: DEFAULT_CIRCUIT_THRESHOLD,
201 circuit_reset: DEFAULT_CIRCUIT_RESET,
202 max_restarts: DEFAULT_MAX_RESTARTS,
203 restart_window: DEFAULT_RESTART_WINDOW,
204 }
205 }
206}
207
208#[derive(Clone, Debug)]
212struct CachedResponse {
213 payload: JsonValue,
214 inserted_at: Instant,
215 expires_at: Instant,
216 #[allow(dead_code)]
221 scope: Option<&'static str>,
222}
223
224#[derive(Clone, Debug, PartialEq, Eq)]
228pub enum AllowlistDecision {
229 Allow,
230 Deny { reason: String },
231}
232
233pub type AllowlistGuard = Arc<dyn Fn(&str, Option<&str>) -> AllowlistDecision + Send + Sync>;
238
239#[derive(Clone, Debug, Serialize)]
242pub struct McpHostStatus {
243 pub name: String,
244 pub active: bool,
245 pub lazy: bool,
246 pub ref_count: usize,
247 pub restart_count: u32,
248 pub consecutive_failures: u32,
249 pub circuit: BreakerState,
250 pub ejected: bool,
251 pub cache_entries: usize,
254}
255
256#[derive(Clone, Debug, Default, Deserialize)]
259pub struct SpawnOptions {
260 #[serde(default)]
264 pub lazy: bool,
265 #[serde(default)]
268 pub keep_alive_ms: Option<u64>,
269 #[serde(default)]
272 pub card: Option<String>,
273 #[serde(default)]
276 pub circuit_threshold: Option<u32>,
277 #[serde(default)]
278 pub circuit_reset_ms: Option<u64>,
279 #[serde(default)]
280 pub max_restarts: Option<u32>,
281 #[serde(default)]
282 pub restart_window_ms: Option<u64>,
283}
284
285impl SpawnOptions {
286 fn into_policy(self) -> (SupervisionPolicy, RegisteredMcpServerMeta) {
287 let default = SupervisionPolicy::default();
288 let policy = SupervisionPolicy {
289 circuit_threshold: self.circuit_threshold.unwrap_or(default.circuit_threshold),
290 circuit_reset: self
291 .circuit_reset_ms
292 .map(Duration::from_millis)
293 .unwrap_or(default.circuit_reset),
294 max_restarts: self.max_restarts.unwrap_or(default.max_restarts),
295 restart_window: self
296 .restart_window_ms
297 .map(Duration::from_millis)
298 .unwrap_or(default.restart_window),
299 };
300 let meta = RegisteredMcpServerMeta {
301 lazy: self.lazy,
302 keep_alive: self.keep_alive_ms.map(Duration::from_millis),
303 card: self.card,
304 };
305 (policy, meta)
306 }
307}
308
309struct RegisteredMcpServerMeta {
310 lazy: bool,
311 keep_alive: Option<Duration>,
312 card: Option<String>,
313}
314
315struct HostInner {
318 supervision: HashMap<String, SupervisionState>,
322 response_cache: HashMap<(String, String), HashMap<String, CachedResponse>>,
326 allowlist: Option<AllowlistGuard>,
328 cache_hits: u64,
331 cache_misses: u64,
332}
333
334impl HostInner {
335 fn new() -> Self {
336 Self {
337 supervision: HashMap::new(),
338 response_cache: HashMap::new(),
339 allowlist: None,
340 cache_hits: 0,
341 cache_misses: 0,
342 }
343 }
344}
345
346static HOST: Mutex<Option<HostInner>> = Mutex::new(None);
347
348fn with_inner<F, R>(f: F) -> R
349where
350 F: FnOnce(&mut HostInner) -> R,
351{
352 let mut guard = HOST.lock().expect("mcp host mutex poisoned");
353 if guard.is_none() {
354 *guard = Some(HostInner::new());
355 }
356 f(guard.as_mut().expect("host inner just initialized"))
357}
358
359pub fn set_allowlist(guard: Option<AllowlistGuard>) {
361 with_inner(|inner| inner.allowlist = guard);
362}
363
364pub fn reset_for_tests() {
367 with_inner(|inner| {
368 inner.supervision.clear();
369 inner.response_cache.clear();
370 inner.allowlist = None;
371 inner.cache_hits = 0;
372 inner.cache_misses = 0;
373 });
374 mcp_registry::reset();
375}
376
377#[derive(Clone, Copy, Debug)]
381pub struct CacheStats {
382 pub hits: u64,
383 pub misses: u64,
384}
385
386pub fn cache_stats() -> CacheStats {
387 with_inner(|inner| CacheStats {
388 hits: inner.cache_hits,
389 misses: inner.cache_misses,
390 })
391}
392
393pub async fn spawn(spec: JsonValue, options: SpawnOptions) -> Result<String, VmError> {
397 let name = spec
398 .get("name")
399 .and_then(|v| v.as_str())
400 .ok_or_else(|| VmError::Runtime("mcp.spawn: spec must include a `name` field".into()))?
401 .to_string();
402 if name.is_empty() {
403 return Err(VmError::Runtime(
404 "mcp.spawn: spec.name must be a non-empty string".into(),
405 ));
406 }
407
408 if let Some(guard) = current_allowlist() {
409 if let AllowlistDecision::Deny { reason } = guard(&name, None) {
410 return Err(VmError::Runtime(format!(
411 "mcp.spawn({name}): denied by allowlist: {reason}"
412 )));
413 }
414 }
415
416 let (policy, meta) = options.into_policy();
417 mcp_registry::register_servers(vec![RegisteredMcpServer {
418 name: name.clone(),
419 spec: spec.clone(),
420 lazy: meta.lazy,
421 card: meta.card,
422 keep_alive: meta.keep_alive,
423 }]);
424
425 with_inner(|inner| {
426 inner
427 .supervision
428 .insert(name.clone(), SupervisionState::new(policy));
429 });
430
431 if !meta.lazy {
432 let _ = mcp_registry::ensure_active(&name).await.inspect_err(|_| {
435 with_inner(|inner| {
436 inner.supervision.remove(&name);
437 });
438 })?;
439 }
440
441 Ok(name)
442}
443
444pub fn stop(name: &str) -> Result<(), VmError> {
450 if !mcp_registry::is_registered(name) {
451 return Err(VmError::Runtime(format!(
452 "mcp.stop: no server named '{name}' is hosted"
453 )));
454 }
455 mcp_registry::release(name);
456 with_inner(|inner| {
457 inner.supervision.remove(name);
458 inner.response_cache.retain(|(s, _), _| s != name);
459 });
460 Ok(())
461}
462
463pub fn reload(name: &str) -> Result<(), VmError> {
468 if !mcp_registry::is_registered(name) {
469 return Err(VmError::Runtime(format!(
470 "mcp.reload: no server named '{name}' is hosted"
471 )));
472 }
473 mcp_registry::release(name);
474 with_inner(|inner| {
475 if let Some(state) = inner.supervision.get_mut(name) {
476 state.clear();
477 }
478 inner.response_cache.retain(|(s, _), _| s != name);
479 });
480 Ok(())
481}
482
483pub async fn tools(name: &str) -> Result<Vec<JsonValue>, VmError> {
488 let handle = ensure_or_restart(name).await?;
489 let result = supervised_call(name, || async {
490 handle.call("tools/list", serde_json::json!({})).await
491 })
492 .await?;
493
494 let mut tools = result
495 .get("tools")
496 .and_then(|t| t.as_array())
497 .cloned()
498 .unwrap_or_default();
499 for tool in tools.iter_mut() {
500 if let Some(obj) = tool.as_object_mut() {
501 obj.entry("_mcp_server")
502 .or_insert_with(|| JsonValue::String(name.to_string()));
503 }
504 }
505 Ok(tools)
506}
507
508pub async fn call(name: &str, tool: &str, args: JsonValue) -> Result<JsonValue, VmError> {
515 if let Some(guard) = current_allowlist() {
516 if let AllowlistDecision::Deny { reason } = guard(name, Some(tool)) {
517 return Err(VmError::Runtime(format!(
518 "mcp.call({name}/{tool}): denied by allowlist: {reason}"
519 )));
520 }
521 }
522
523 crate::call_budget::charge_mcp_call()?;
528
529 let now = Instant::now();
530 let args_hash = hash_args(&args);
531 if let Some(payload) = take_cache_hit(name, tool, &args_hash, now) {
532 return Ok(payload);
533 }
534 with_inner(|inner| inner.cache_misses = inner.cache_misses.saturating_add(1));
535
536 breaker_gate(name, now)?;
537
538 let handle = ensure_or_restart(name).await?;
539 let envelope_hint: Arc<Mutex<Option<McpCacheHint>>> = Arc::new(Mutex::new(None));
544 let hint_slot = Arc::clone(&envelope_hint);
545 let result = supervised_call(name, move || {
546 let handle = handle.clone();
547 let tool = tool.to_string();
548 let args = args.clone();
549 let hint_slot = Arc::clone(&hint_slot);
550 async move {
551 let (content, hint) = call_mcp_tool_with_hint(&handle, &tool, args).await?;
552 if let Ok(mut slot) = hint_slot.lock() {
553 *slot = hint;
554 }
555 Ok(content)
556 }
557 })
558 .await?;
559
560 let hint = envelope_hint.lock().ok().and_then(|slot| *slot);
561 if let Some(hint) = hint {
562 insert_cache(name, tool, &args_hash, &result, hint, now);
563 }
564
565 Ok(result)
566}
567
568pub async fn discover() -> Result<Vec<JsonValue>, VmError> {
572 let names: Vec<String> = mcp_registry::snapshot_status()
573 .into_iter()
574 .map(|s| s.name)
575 .collect();
576 let mut out: Vec<JsonValue> = Vec::new();
577 for name in names {
578 if let Some(guard) = current_allowlist() {
581 if matches!(guard(&name, None), AllowlistDecision::Deny { .. }) {
582 continue;
583 }
584 }
585 match tools(&name).await {
589 Ok(tools) => {
590 for tool in tools {
591 let tool_name = tool
592 .get("name")
593 .and_then(|v| v.as_str())
594 .unwrap_or("")
595 .to_string();
596 out.push(serde_json::json!({
597 "server": name,
598 "tool": tool_name,
599 "schema": tool,
600 }));
601 }
602 }
603 Err(err) => {
604 out.push(serde_json::json!({
605 "server": name,
606 "error": err.to_string(),
607 }));
608 }
609 }
610 }
611 Ok(out)
612}
613
614pub fn status() -> Vec<McpHostStatus> {
616 let registry: BTreeMap<String, mcp_registry::RegistryStatus> = mcp_registry::snapshot_status()
617 .into_iter()
618 .map(|s| (s.name.clone(), s))
619 .collect();
620 with_inner(|inner| {
621 let mut out = Vec::new();
622 let now = Instant::now();
623 for (name, reg) in ®istry {
624 let (restart_count, consecutive_failures, circuit, ejected) =
625 if let Some(state) = inner.supervision.get_mut(name) {
626 let st = state.breaker_state(now);
627 (
628 state.restart_attempts.len() as u32,
629 state.consecutive_failures,
630 st,
631 state.ejected,
632 )
633 } else {
634 (0, 0, BreakerState::Closed, false)
635 };
636 let cache_entries = inner
637 .response_cache
638 .iter()
639 .filter(|((s, _), _)| s == name)
640 .map(|(_, v)| v.len())
641 .sum();
642 out.push(McpHostStatus {
643 name: name.clone(),
644 active: reg.active,
645 lazy: reg.lazy,
646 ref_count: reg.ref_count,
647 restart_count,
648 consecutive_failures,
649 circuit,
650 ejected,
651 cache_entries,
652 });
653 }
654 out
655 })
656}
657
658fn current_allowlist() -> Option<AllowlistGuard> {
659 with_inner(|inner| inner.allowlist.clone())
660}
661
662fn breaker_gate(name: &str, now: Instant) -> Result<(), VmError> {
663 with_inner(|inner| {
664 let Some(state) = inner.supervision.get_mut(name) else {
665 return Ok(());
666 };
667 if state.ejected {
668 return Err(VmError::Runtime(format!(
669 "mcp.call({name}): server is ejected after exhausting its restart budget; call `harn.mcp.reload({name:?})` to clear"
670 )));
671 }
672 match state.breaker_state(now) {
673 BreakerState::Open => Err(VmError::Runtime(format!(
674 "mcp.call({name}): circuit breaker is open (last {n} consecutive failures); retry after the breaker resets",
675 n = state.consecutive_failures
676 ))),
677 BreakerState::Closed | BreakerState::HalfOpen => Ok(()),
680 }
681 })
682}
683
684async fn ensure_or_restart(name: &str) -> Result<VmMcpClientHandle, VmError> {
685 if let Some(handle) = mcp_registry::active_handle(name) {
687 return Ok(handle);
688 }
689
690 mcp_registry::ensure_active(name).await
696}
697
698async fn supervised_call<F, Fut>(name: &str, op: F) -> Result<JsonValue, VmError>
702where
703 F: Fn() -> Fut,
704 Fut: std::future::Future<Output = Result<JsonValue, VmError>>,
705{
706 let span = tracing::info_span!(
707 "harn.mcp.call",
708 otel.name = "harn.mcp.call",
709 harn.mcp.server = name,
710 );
711 let _enter = span.enter();
712
713 let first = op().await;
714 match first {
715 Ok(v) => {
716 with_inner(|inner| {
717 if let Some(state) = inner.supervision.get_mut(name) {
718 state.record_success();
719 }
720 });
721 Ok(v)
722 }
723 Err(err) => {
724 let now = Instant::now();
725 let (should_retry, backoff) = with_inner(|inner| {
726 let Some(state) = inner.supervision.get_mut(name) else {
727 return (false, Duration::ZERO);
728 };
729 state.record_failure(now);
730 if !looks_like_transport_failure(&err) {
735 return (false, Duration::ZERO);
736 }
737 let ok = state.record_restart(now);
738 if !ok {
739 return (false, Duration::ZERO);
740 }
741 (true, state.backoff_delay())
742 });
743 if !should_retry {
744 tracing::warn!(
745 server = name,
746 error = %err,
747 "harn.mcp.call: failure (no retry)"
748 );
749 return Err(err);
750 }
751
752 tracing::info!(
753 server = name,
754 error = %err,
755 backoff_ms = backoff.as_millis() as u64,
756 "harn.mcp.call: retrying after transport failure"
757 );
758
759 mcp_registry::release(name);
762 tokio::time::sleep(backoff).await;
763 let _handle = ensure_or_restart(name).await?;
764 let second = op().await;
765 match &second {
766 Ok(_) => with_inner(|inner| {
767 if let Some(state) = inner.supervision.get_mut(name) {
768 state.record_success();
769 }
770 }),
771 Err(err) => with_inner(|inner| {
772 if let Some(state) = inner.supervision.get_mut(name) {
773 state.record_failure(Instant::now());
774 }
775 tracing::warn!(
776 server = name,
777 error = %err,
778 "harn.mcp.call: second attempt failed"
779 );
780 }),
781 }
782 second
783 }
784 }
785}
786
787fn looks_like_transport_failure(err: &VmError) -> bool {
788 let text = err.to_string();
789 let needles = [
790 "server closed connection",
791 "disconnected",
792 "MCP read error",
793 "MCP write error",
794 "did not respond to",
795 "MCP flush error",
796 "connect",
797 ];
798 needles.iter().any(|n| text.contains(n))
799}
800
801fn hash_args(args: &JsonValue) -> String {
802 let mut hasher = Sha256::new();
803 let canonical = canonicalize_json(args);
804 hasher.update(canonical.as_bytes());
805 let digest = hasher.finalize();
806 let mut hex = String::with_capacity(digest.len() * 2);
807 for byte in digest {
808 use std::fmt::Write;
809 let _ = write!(&mut hex, "{byte:02x}");
810 }
811 hex
812}
813
814fn canonicalize_json(value: &JsonValue) -> String {
817 match value {
818 JsonValue::Object(map) => {
819 let mut sorted: Vec<(&String, &JsonValue)> = map.iter().collect();
820 sorted.sort_by(|a, b| a.0.cmp(b.0));
821 let body: Vec<String> = sorted
822 .into_iter()
823 .map(|(k, v)| {
824 format!(
825 "{}:{}",
826 serde_json::to_string(k).unwrap_or_default(),
827 canonicalize_json(v)
828 )
829 })
830 .collect();
831 format!("{{{}}}", body.join(","))
832 }
833 JsonValue::Array(items) => {
834 let body: Vec<String> = items.iter().map(canonicalize_json).collect();
835 format!("[{}]", body.join(","))
836 }
837 other => serde_json::to_string(other).unwrap_or_default(),
838 }
839}
840
841fn take_cache_hit(server: &str, tool: &str, args_hash: &str, now: Instant) -> Option<JsonValue> {
842 with_inner(|inner| {
843 let key = (server.to_string(), tool.to_string());
844 let entry = inner.response_cache.get_mut(&key)?;
845 let cached = entry.get(args_hash)?;
846 if now >= cached.expires_at {
847 entry.remove(args_hash);
848 return None;
849 }
850 let payload = cached.payload.clone();
851 inner.cache_hits = inner.cache_hits.saturating_add(1);
852 Some(payload)
853 })
854}
855
856fn insert_cache(
861 server: &str,
862 tool: &str,
863 args_hash: &str,
864 payload: &JsonValue,
865 hint: McpCacheHint,
866 now: Instant,
867) {
868 let Some(ttl_ms) = hint.ttl_ms else {
869 return;
870 };
871 if ttl_ms == 0 {
872 return;
873 }
874 let expires_at = now + Duration::from_millis(ttl_ms);
875 let cached = CachedResponse {
876 payload: payload.clone(),
877 inserted_at: now,
878 expires_at,
879 scope: hint.scope,
880 };
881 with_inner(|inner| {
882 let key = (server.to_string(), tool.to_string());
883 let bucket = inner.response_cache.entry(key).or_default();
884 if bucket.len() >= RESPONSE_CACHE_MAX_ENTRIES_PER_TOOL {
885 if let Some(oldest_key) = bucket
887 .iter()
888 .min_by_key(|(_, v)| v.inserted_at)
889 .map(|(k, _)| k.clone())
890 {
891 bucket.remove(&oldest_key);
892 }
893 }
894 bucket.insert(args_hash.to_string(), cached);
895 });
896}
897
898#[cfg(test)]
902fn insert_cache_if_hinted(
903 server: &str,
904 tool: &str,
905 args_hash: &str,
906 payload: &JsonValue,
907 now: Instant,
908) {
909 if let Some(hint) = McpCacheHint::from_result(payload) {
910 insert_cache(server, tool, args_hash, payload, hint, now);
911 }
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917
918 static TEST_LOCK: Mutex<()> = Mutex::new(());
919
920 fn lock() -> std::sync::MutexGuard<'static, ()> {
921 TEST_LOCK.lock().unwrap_or_else(|p| p.into_inner())
922 }
923
924 #[test]
925 fn supervision_breaker_opens_after_threshold() {
926 let _g = lock();
927 let mut state = SupervisionState::new(SupervisionPolicy {
928 circuit_threshold: 3,
929 circuit_reset: Duration::from_millis(100),
930 ..SupervisionPolicy::default()
931 });
932 let t0 = Instant::now();
933 assert_eq!(state.breaker_state(t0), BreakerState::Closed);
934 state.record_failure(t0);
935 state.record_failure(t0);
936 assert_eq!(state.breaker_state(t0), BreakerState::Closed);
937 state.record_failure(t0);
938 assert_eq!(state.breaker_state(t0), BreakerState::Open);
939 assert_eq!(
941 state.breaker_state(t0 + Duration::from_millis(200)),
942 BreakerState::HalfOpen
943 );
944 }
945
946 #[test]
947 fn supervision_restart_budget_ejects_after_n_attempts() {
948 let _g = lock();
949 let mut state = SupervisionState::new(SupervisionPolicy {
950 max_restarts: 2,
951 restart_window: Duration::from_mins(1),
952 ..SupervisionPolicy::default()
953 });
954 let t = Instant::now();
955 assert!(state.record_restart(t));
956 assert!(state.record_restart(t));
957 assert!(!state.record_restart(t));
958 assert!(state.ejected);
959 }
960
961 #[test]
962 fn supervision_backoff_grows_exponentially_then_caps() {
963 let _g = lock();
964 let mut state = SupervisionState::new(SupervisionPolicy::default());
965 let t = Instant::now();
966 state.record_restart(t);
967 let d1 = state.backoff_delay();
968 state.record_restart(t);
969 let d2 = state.backoff_delay();
970 state.record_restart(t);
971 let d3 = state.backoff_delay();
972 assert!(
973 d2 > d1,
974 "second backoff ({d2:?}) should exceed first ({d1:?})"
975 );
976 assert!(d3 > d2);
977 for _ in 0..16 {
978 state.record_restart(t);
979 }
980 assert!(state.backoff_delay() <= MAX_RESTART_BACKOFF);
981 }
982
983 #[test]
984 fn canonical_json_sorts_object_keys() {
985 let a = canonicalize_json(&serde_json::json!({"b": 1, "a": 2}));
986 let b = canonicalize_json(&serde_json::json!({"a": 2, "b": 1}));
987 assert_eq!(a, b);
988 }
989
990 #[test]
991 fn hash_args_is_stable_across_key_order() {
992 let h1 = hash_args(&serde_json::json!({"x": 1, "y": [1, 2]}));
993 let h2 = hash_args(&serde_json::json!({"y": [1, 2], "x": 1}));
994 assert_eq!(h1, h2);
995 }
996
997 #[test]
998 fn cache_insert_and_take_respects_ttl() {
999 let _g = lock();
1000 reset_for_tests();
1001 let payload = serde_json::json!({
1002 "ttlMs": 100,
1003 "cacheScope": "private",
1004 "value": 1
1005 });
1006 let now = Instant::now();
1007 insert_cache_if_hinted("srv", "ping", "deadbeef", &payload, now);
1008 let hit = take_cache_hit("srv", "ping", "deadbeef", now);
1009 assert!(hit.is_some(), "fresh entry should hit");
1010 let stale = take_cache_hit("srv", "ping", "deadbeef", now + Duration::from_millis(200));
1011 assert!(stale.is_none(), "expired entry should miss");
1012 }
1013
1014 #[test]
1015 fn cache_skips_payload_without_hint() {
1016 let _g = lock();
1017 reset_for_tests();
1018 insert_cache_if_hinted(
1019 "srv",
1020 "ping",
1021 "h",
1022 &serde_json::json!({"value": 1}),
1023 Instant::now(),
1024 );
1025 assert!(take_cache_hit("srv", "ping", "h", Instant::now()).is_none());
1026 }
1027
1028 #[test]
1029 fn allowlist_denies_disallowed_tool() {
1030 let _g = lock();
1031 reset_for_tests();
1032 set_allowlist(Some(Arc::new(|server, tool| {
1033 if server == "github" && tool == Some("delete_repo") {
1034 AllowlistDecision::Deny {
1035 reason: "destructive tool blocked".into(),
1036 }
1037 } else {
1038 AllowlistDecision::Allow
1039 }
1040 })));
1041 let runtime = tokio::runtime::Builder::new_current_thread()
1042 .enable_all()
1043 .build()
1044 .unwrap();
1045 let err = runtime
1046 .block_on(call("github", "delete_repo", serde_json::json!({})))
1047 .unwrap_err();
1048 assert!(err.to_string().contains("denied by allowlist"));
1049 set_allowlist(None);
1050 }
1051
1052 #[test]
1053 fn stop_unregistered_server_errors() {
1054 let _g = lock();
1055 reset_for_tests();
1056 let err = stop("nope").unwrap_err();
1057 assert!(err.to_string().contains("no server named 'nope'"));
1058 }
1059
1060 #[test]
1061 fn supervision_record_success_resets_counters() {
1062 let _g = lock();
1063 let mut state = SupervisionState::new(SupervisionPolicy::default());
1064 let t = Instant::now();
1065 state.record_failure(t);
1066 state.record_failure(t);
1067 state.record_success();
1068 assert_eq!(state.consecutive_failures, 0);
1069 assert!(state.breaker_opens_until.is_none());
1070 }
1071
1072 #[test]
1073 fn looks_like_transport_failure_matches_common_errors() {
1074 let cases = [
1075 "MCP: server closed connection",
1076 "MCP: server did not respond to 'tools/call' within 60s",
1077 "MCP write error: broken pipe",
1078 "MCP client is disconnected",
1079 ];
1080 for msg in cases {
1081 assert!(
1082 looks_like_transport_failure(&VmError::Runtime(msg.into())),
1083 "expected {msg:?} to be classified as transport failure"
1084 );
1085 }
1086 assert!(
1087 !looks_like_transport_failure(&VmError::Runtime(
1088 "tool 'foo' rejected arguments".into()
1089 )),
1090 "tool-level errors must not trigger an auto-restart"
1091 );
1092 }
1093}