1use std::collections::{BTreeMap, HashMap};
19use std::sync::{Arc, Mutex};
20use std::time::{Duration, Instant};
21
22use serde::{Deserialize, Serialize};
23use serde_json::Value as JsonValue;
24use sha2::{Digest, Sha256};
25
26use crate::mcp::{call_mcp_tool_with_hint, VmMcpClientHandle};
27use crate::mcp_protocol::McpCacheHint;
28use crate::mcp_registry::{self, RegisteredMcpServer};
29use crate::value::VmError;
30
31pub const DEFAULT_MAX_RESTARTS: u32 = 5;
36pub const DEFAULT_RESTART_WINDOW: Duration = Duration::from_mins(5);
37
38pub const DEFAULT_CIRCUIT_THRESHOLD: u32 = 5;
44pub const DEFAULT_CIRCUIT_RESET: Duration = Duration::from_secs(30);
45
46pub const INITIAL_RESTART_BACKOFF: Duration = Duration::from_millis(100);
49pub const MAX_RESTART_BACKOFF: Duration = Duration::from_secs(5);
50
51pub const RESPONSE_CACHE_MAX_ENTRIES_PER_TOOL: usize = 64;
56
57#[derive(Debug)]
62struct SupervisionState {
63 restart_attempts: Vec<Instant>,
66 consecutive_failures: u32,
69 breaker_opens_until: Option<Instant>,
74 ejected: bool,
77 circuit_threshold: u32,
79 circuit_reset: Duration,
81 max_restarts: u32,
83 restart_window: Duration,
85}
86
87impl SupervisionState {
88 fn new(policy: SupervisionPolicy) -> Self {
89 Self {
90 restart_attempts: Vec::new(),
91 consecutive_failures: 0,
92 breaker_opens_until: None,
93 ejected: false,
94 circuit_threshold: policy.circuit_threshold,
95 circuit_reset: policy.circuit_reset,
96 max_restarts: policy.max_restarts,
97 restart_window: policy.restart_window,
98 }
99 }
100
101 fn breaker_state(&mut self, now: Instant) -> BreakerState {
105 match self.breaker_opens_until {
106 Some(deadline) if now < deadline => BreakerState::Open,
107 Some(_) => BreakerState::HalfOpen,
108 None => BreakerState::Closed,
109 }
110 }
111
112 fn record_success(&mut self) {
115 self.consecutive_failures = 0;
116 self.breaker_opens_until = None;
117 }
118
119 fn record_failure(&mut self, now: Instant) {
123 self.consecutive_failures = self.consecutive_failures.saturating_add(1);
124 if self.consecutive_failures >= self.circuit_threshold {
125 self.breaker_opens_until = Some(now + self.circuit_reset);
126 }
127 }
128
129 fn record_restart(&mut self, now: Instant) -> bool {
134 self.prune_restart_window(now);
135 self.restart_attempts.push(now);
136 if self.restart_attempts.len() as u32 > self.max_restarts {
137 self.ejected = true;
138 return false;
139 }
140 true
141 }
142
143 fn backoff_delay(&self) -> Duration {
146 let attempt = self.restart_attempts.len() as u32;
147 let exp = attempt.saturating_sub(1).min(6);
148 let mul = 1u64 << exp;
149 let nanos = INITIAL_RESTART_BACKOFF.as_nanos() as u64 * mul;
150 Duration::from_nanos(nanos).min(MAX_RESTART_BACKOFF)
151 }
152
153 fn prune_restart_window(&mut self, now: Instant) {
154 let window = self.restart_window;
155 self.restart_attempts
156 .retain(|t| now.duration_since(*t) <= window);
157 }
158
159 fn clear(&mut self) {
160 self.restart_attempts.clear();
161 self.consecutive_failures = 0;
162 self.breaker_opens_until = None;
163 self.ejected = false;
164 }
165}
166
167#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
169#[serde(rename_all = "snake_case")]
170pub enum BreakerState {
171 Closed,
172 Open,
173 HalfOpen,
174}
175
176impl BreakerState {
177 pub fn as_str(self) -> &'static str {
178 match self {
179 BreakerState::Closed => "closed",
180 BreakerState::Open => "open",
181 BreakerState::HalfOpen => "half_open",
182 }
183 }
184}
185
186#[derive(Clone, Copy, Debug)]
190pub struct SupervisionPolicy {
191 pub circuit_threshold: u32,
192 pub circuit_reset: Duration,
193 pub max_restarts: u32,
194 pub restart_window: Duration,
195}
196
197impl Default for SupervisionPolicy {
198 fn default() -> Self {
199 Self {
200 circuit_threshold: DEFAULT_CIRCUIT_THRESHOLD,
201 circuit_reset: DEFAULT_CIRCUIT_RESET,
202 max_restarts: DEFAULT_MAX_RESTARTS,
203 restart_window: DEFAULT_RESTART_WINDOW,
204 }
205 }
206}
207
208#[derive(Clone, Debug)]
212struct CachedResponse {
213 payload: JsonValue,
214 inserted_at: Instant,
215 expires_at: Instant,
216 #[allow(dead_code)]
221 scope: Option<&'static str>,
222}
223
224#[derive(Clone, Debug, PartialEq, Eq)]
228pub enum AllowlistDecision {
229 Allow,
230 Deny { reason: String },
231}
232
233pub type AllowlistGuard = Arc<dyn Fn(&str, Option<&str>) -> AllowlistDecision + Send + Sync>;
238
239#[derive(Clone, Debug, Serialize)]
242pub struct McpHostStatus {
243 pub name: String,
244 pub active: bool,
245 pub lazy: bool,
246 pub ref_count: usize,
247 pub restart_count: u32,
248 pub consecutive_failures: u32,
249 pub circuit: BreakerState,
250 pub ejected: bool,
251 pub cache_entries: usize,
254}
255
256#[derive(Clone, Debug, Default, Deserialize)]
259pub struct SpawnOptions {
260 #[serde(default)]
264 pub lazy: bool,
265 #[serde(default)]
268 pub keep_alive_ms: Option<u64>,
269 #[serde(default)]
272 pub card: Option<String>,
273 #[serde(default)]
276 pub circuit_threshold: Option<u32>,
277 #[serde(default)]
278 pub circuit_reset_ms: Option<u64>,
279 #[serde(default)]
280 pub max_restarts: Option<u32>,
281 #[serde(default)]
282 pub restart_window_ms: Option<u64>,
283}
284
285impl SpawnOptions {
286 fn into_policy(self) -> (SupervisionPolicy, RegisteredMcpServerMeta) {
287 let default = SupervisionPolicy::default();
288 let policy = SupervisionPolicy {
289 circuit_threshold: self.circuit_threshold.unwrap_or(default.circuit_threshold),
290 circuit_reset: self
291 .circuit_reset_ms
292 .map(Duration::from_millis)
293 .unwrap_or(default.circuit_reset),
294 max_restarts: self.max_restarts.unwrap_or(default.max_restarts),
295 restart_window: self
296 .restart_window_ms
297 .map(Duration::from_millis)
298 .unwrap_or(default.restart_window),
299 };
300 let meta = RegisteredMcpServerMeta {
301 lazy: self.lazy,
302 keep_alive: self.keep_alive_ms.map(Duration::from_millis),
303 card: self.card,
304 };
305 (policy, meta)
306 }
307}
308
309struct RegisteredMcpServerMeta {
310 lazy: bool,
311 keep_alive: Option<Duration>,
312 card: Option<String>,
313}
314
315struct HostInner {
318 supervision: HashMap<String, SupervisionState>,
322 response_cache: HashMap<(String, String), HashMap<String, CachedResponse>>,
326 allowlist: Option<AllowlistGuard>,
328 cache_hits: u64,
331 cache_misses: u64,
332}
333
334impl HostInner {
335 fn new() -> Self {
336 Self {
337 supervision: HashMap::new(),
338 response_cache: HashMap::new(),
339 allowlist: None,
340 cache_hits: 0,
341 cache_misses: 0,
342 }
343 }
344}
345
346static HOST: Mutex<Option<HostInner>> = Mutex::new(None);
347
348fn with_inner<F, R>(f: F) -> R
349where
350 F: FnOnce(&mut HostInner) -> R,
351{
352 let mut guard = HOST.lock().expect("mcp host mutex poisoned");
353 if guard.is_none() {
354 *guard = Some(HostInner::new());
355 }
356 f(guard.as_mut().expect("host inner just initialized"))
357}
358
359pub fn set_allowlist(guard: Option<AllowlistGuard>) {
361 with_inner(|inner| inner.allowlist = guard);
362}
363
364pub fn reset_for_tests() {
367 with_inner(|inner| {
368 inner.supervision.clear();
369 inner.response_cache.clear();
370 inner.allowlist = None;
371 inner.cache_hits = 0;
372 inner.cache_misses = 0;
373 });
374 mcp_registry::reset();
375}
376
377#[derive(Clone, Copy, Debug)]
381pub struct CacheStats {
382 pub hits: u64,
383 pub misses: u64,
384}
385
386pub fn cache_stats() -> CacheStats {
387 with_inner(|inner| CacheStats {
388 hits: inner.cache_hits,
389 misses: inner.cache_misses,
390 })
391}
392
393pub async fn spawn(spec: JsonValue, options: SpawnOptions) -> Result<String, VmError> {
397 let name = spec
398 .get("name")
399 .and_then(|v| v.as_str())
400 .ok_or_else(|| VmError::Runtime("mcp.spawn: spec must include a `name` field".into()))?
401 .to_string();
402 if name.is_empty() {
403 return Err(VmError::Runtime(
404 "mcp.spawn: spec.name must be a non-empty string".into(),
405 ));
406 }
407
408 if let Some(guard) = current_allowlist() {
409 if let AllowlistDecision::Deny { reason } = guard(&name, None) {
410 return Err(VmError::Runtime(format!(
411 "mcp.spawn({name}): denied by allowlist: {reason}"
412 )));
413 }
414 }
415
416 let (policy, meta) = options.into_policy();
417 mcp_registry::register_servers(vec![RegisteredMcpServer {
418 name: name.clone(),
419 spec: spec.clone(),
420 lazy: meta.lazy,
421 card: meta.card,
422 keep_alive: meta.keep_alive,
423 }]);
424
425 with_inner(|inner| {
426 inner
427 .supervision
428 .insert(name.clone(), SupervisionState::new(policy));
429 });
430
431 if !meta.lazy {
432 let _ = mcp_registry::ensure_active(&name).await.inspect_err(|_| {
435 with_inner(|inner| {
436 inner.supervision.remove(&name);
437 });
438 })?;
439 }
440
441 Ok(name)
442}
443
444pub fn stop(name: &str) -> Result<(), VmError> {
450 if !mcp_registry::is_registered(name) {
451 return Err(VmError::Runtime(format!(
452 "mcp.stop: no server named '{name}' is hosted"
453 )));
454 }
455 mcp_registry::release(name);
456 with_inner(|inner| {
457 inner.supervision.remove(name);
458 inner.response_cache.retain(|(s, _), _| s != name);
459 });
460 Ok(())
461}
462
463pub fn reload(name: &str) -> Result<(), VmError> {
468 if !mcp_registry::is_registered(name) {
469 return Err(VmError::Runtime(format!(
470 "mcp.reload: no server named '{name}' is hosted"
471 )));
472 }
473 mcp_registry::release(name);
474 with_inner(|inner| {
475 if let Some(state) = inner.supervision.get_mut(name) {
476 state.clear();
477 }
478 inner.response_cache.retain(|(s, _), _| s != name);
479 });
480 Ok(())
481}
482
483pub async fn tools(name: &str) -> Result<Vec<JsonValue>, VmError> {
488 let handle = ensure_or_restart(name).await?;
489 let result = supervised_call(name, || async {
490 handle.call("tools/list", serde_json::json!({})).await
491 })
492 .await?;
493
494 let mut tools = result
495 .get("tools")
496 .and_then(|t| t.as_array())
497 .cloned()
498 .unwrap_or_default();
499 for tool in tools.iter_mut() {
500 if let Some(obj) = tool.as_object_mut() {
501 obj.entry("_mcp_server")
502 .or_insert_with(|| JsonValue::String(name.to_string()));
503 }
504 }
505 let security_policy = crate::security::current_policy();
510 if security_policy.pin_mcp_schemas && !security_policy.server_is_trusted(name) {
511 for tool in tools.iter_mut() {
512 let hash = crate::security::tool_schema_hash(tool);
513 let tool_name = tool
514 .get("name")
515 .and_then(|v| v.as_str())
516 .unwrap_or_default()
517 .to_string();
518 if tool_name.is_empty() {
519 continue;
520 }
521 if crate::security::pin_and_detect_change(name, &tool_name, &hash) {
522 if let Some(obj) = tool.as_object_mut() {
523 obj.insert("_schema_changed".to_string(), JsonValue::Bool(true));
524 }
525 }
526 }
527 }
528 Ok(tools)
529}
530
531pub async fn call(name: &str, tool: &str, args: JsonValue) -> Result<JsonValue, VmError> {
538 if let Some(guard) = current_allowlist() {
539 if let AllowlistDecision::Deny { reason } = guard(name, Some(tool)) {
540 return Err(VmError::Runtime(format!(
541 "mcp.call({name}/{tool}): denied by allowlist: {reason}"
542 )));
543 }
544 }
545
546 crate::call_budget::charge_mcp_call()?;
551
552 let now = Instant::now();
553 let args_hash = hash_args(&args);
554 if let Some(payload) = take_cache_hit(name, tool, &args_hash, now) {
555 return Ok(payload);
556 }
557 with_inner(|inner| inner.cache_misses = inner.cache_misses.saturating_add(1));
558
559 breaker_gate(name, now)?;
560
561 let handle = ensure_or_restart(name).await?;
562 let envelope_hint: Arc<Mutex<Option<McpCacheHint>>> = Arc::new(Mutex::new(None));
567 let hint_slot = Arc::clone(&envelope_hint);
568 let result = supervised_call(name, move || {
569 let handle = handle.clone();
570 let tool = tool.to_string();
571 let args = args.clone();
572 let hint_slot = Arc::clone(&hint_slot);
573 async move {
574 let (content, hint) = call_mcp_tool_with_hint(&handle, &tool, args).await?;
575 if let Ok(mut slot) = hint_slot.lock() {
576 *slot = hint;
577 }
578 Ok(content)
579 }
580 })
581 .await?;
582
583 let hint = envelope_hint.lock().ok().and_then(|slot| *slot);
584 if let Some(hint) = hint {
585 insert_cache(name, tool, &args_hash, &result, hint, now);
586 }
587
588 Ok(result)
589}
590
591pub async fn discover() -> Result<Vec<JsonValue>, VmError> {
595 let names: Vec<String> = mcp_registry::snapshot_status()
596 .into_iter()
597 .map(|s| s.name)
598 .collect();
599 let mut out: Vec<JsonValue> = Vec::new();
600 for name in names {
601 if let Some(guard) = current_allowlist() {
604 if matches!(guard(&name, None), AllowlistDecision::Deny { .. }) {
605 continue;
606 }
607 }
608 match tools(&name).await {
612 Ok(tools) => {
613 for tool in tools {
614 let tool_name = tool
615 .get("name")
616 .and_then(|v| v.as_str())
617 .unwrap_or("")
618 .to_string();
619 out.push(serde_json::json!({
620 "server": name,
621 "tool": tool_name,
622 "schema": tool,
623 }));
624 }
625 }
626 Err(err) => {
627 out.push(serde_json::json!({
628 "server": name,
629 "error": err.to_string(),
630 }));
631 }
632 }
633 }
634 Ok(out)
635}
636
637pub fn status() -> Vec<McpHostStatus> {
639 let registry: BTreeMap<String, mcp_registry::RegistryStatus> = mcp_registry::snapshot_status()
640 .into_iter()
641 .map(|s| (s.name.clone(), s))
642 .collect();
643 with_inner(|inner| {
644 let mut out = Vec::new();
645 let now = Instant::now();
646 for (name, reg) in ®istry {
647 let (restart_count, consecutive_failures, circuit, ejected) =
648 if let Some(state) = inner.supervision.get_mut(name) {
649 let st = state.breaker_state(now);
650 (
651 state.restart_attempts.len() as u32,
652 state.consecutive_failures,
653 st,
654 state.ejected,
655 )
656 } else {
657 (0, 0, BreakerState::Closed, false)
658 };
659 let cache_entries = inner
660 .response_cache
661 .iter()
662 .filter(|((s, _), _)| s == name)
663 .map(|(_, v)| v.len())
664 .sum();
665 out.push(McpHostStatus {
666 name: name.clone(),
667 active: reg.active,
668 lazy: reg.lazy,
669 ref_count: reg.ref_count,
670 restart_count,
671 consecutive_failures,
672 circuit,
673 ejected,
674 cache_entries,
675 });
676 }
677 out
678 })
679}
680
681fn current_allowlist() -> Option<AllowlistGuard> {
682 with_inner(|inner| inner.allowlist.clone())
683}
684
685fn breaker_gate(name: &str, now: Instant) -> Result<(), VmError> {
686 with_inner(|inner| {
687 let Some(state) = inner.supervision.get_mut(name) else {
688 return Ok(());
689 };
690 if state.ejected {
691 return Err(VmError::Runtime(format!(
692 "mcp.call({name}): server is ejected after exhausting its restart budget; call `harn.mcp.reload({name:?})` to clear"
693 )));
694 }
695 match state.breaker_state(now) {
696 BreakerState::Open => Err(VmError::Runtime(format!(
697 "mcp.call({name}): circuit breaker is open (last {n} consecutive failures); retry after the breaker resets",
698 n = state.consecutive_failures
699 ))),
700 BreakerState::Closed | BreakerState::HalfOpen => Ok(()),
703 }
704 })
705}
706
707async fn ensure_or_restart(name: &str) -> Result<VmMcpClientHandle, VmError> {
708 if let Some(handle) = mcp_registry::active_handle(name) {
710 return Ok(handle);
711 }
712
713 mcp_registry::ensure_active(name).await
719}
720
721async fn supervised_call<F, Fut>(name: &str, op: F) -> Result<JsonValue, VmError>
725where
726 F: Fn() -> Fut,
727 Fut: std::future::Future<Output = Result<JsonValue, VmError>>,
728{
729 let span = tracing::info_span!(
730 "harn.mcp.call",
731 otel.name = "harn.mcp.call",
732 harn.mcp.server = name,
733 );
734 let _enter = span.enter();
735
736 let first = op().await;
737 match first {
738 Ok(v) => {
739 with_inner(|inner| {
740 if let Some(state) = inner.supervision.get_mut(name) {
741 state.record_success();
742 }
743 });
744 Ok(v)
745 }
746 Err(err) => {
747 let now = Instant::now();
748 let (should_retry, backoff) = with_inner(|inner| {
749 let Some(state) = inner.supervision.get_mut(name) else {
750 return (false, Duration::ZERO);
751 };
752 state.record_failure(now);
753 if !looks_like_transport_failure(&err) {
758 return (false, Duration::ZERO);
759 }
760 let ok = state.record_restart(now);
761 if !ok {
762 return (false, Duration::ZERO);
763 }
764 (true, state.backoff_delay())
765 });
766 if !should_retry {
767 tracing::warn!(
768 server = name,
769 error = %err,
770 "harn.mcp.call: failure (no retry)"
771 );
772 return Err(err);
773 }
774
775 tracing::info!(
776 server = name,
777 error = %err,
778 backoff_ms = backoff.as_millis() as u64,
779 "harn.mcp.call: retrying after transport failure"
780 );
781
782 mcp_registry::release(name);
785 tokio::time::sleep(backoff).await;
786 let _handle = ensure_or_restart(name).await?;
787 let second = op().await;
788 match &second {
789 Ok(_) => with_inner(|inner| {
790 if let Some(state) = inner.supervision.get_mut(name) {
791 state.record_success();
792 }
793 }),
794 Err(err) => with_inner(|inner| {
795 if let Some(state) = inner.supervision.get_mut(name) {
796 state.record_failure(Instant::now());
797 }
798 tracing::warn!(
799 server = name,
800 error = %err,
801 "harn.mcp.call: second attempt failed"
802 );
803 }),
804 }
805 second
806 }
807 }
808}
809
810fn looks_like_transport_failure(err: &VmError) -> bool {
811 let text = err.to_string();
812 let needles = [
813 "server closed connection",
814 "disconnected",
815 "MCP read error",
816 "MCP write error",
817 "did not respond to",
818 "MCP flush error",
819 "connect",
820 ];
821 needles.iter().any(|n| text.contains(n))
822}
823
824fn hash_args(args: &JsonValue) -> String {
825 let mut hasher = Sha256::new();
826 let canonical = canonicalize_json(args);
827 hasher.update(canonical.as_bytes());
828 let digest = hasher.finalize();
829 let mut hex = String::with_capacity(digest.len() * 2);
830 for byte in digest {
831 use std::fmt::Write;
832 let _ = write!(&mut hex, "{byte:02x}");
833 }
834 hex
835}
836
837fn canonicalize_json(value: &JsonValue) -> String {
840 match value {
841 JsonValue::Object(map) => {
842 let mut sorted: Vec<(&String, &JsonValue)> = map.iter().collect();
843 sorted.sort_by(|a, b| a.0.cmp(b.0));
844 let body: Vec<String> = sorted
845 .into_iter()
846 .map(|(k, v)| {
847 format!(
848 "{}:{}",
849 serde_json::to_string(k).unwrap_or_default(),
850 canonicalize_json(v)
851 )
852 })
853 .collect();
854 format!("{{{}}}", body.join(","))
855 }
856 JsonValue::Array(items) => {
857 let body: Vec<String> = items.iter().map(canonicalize_json).collect();
858 format!("[{}]", body.join(","))
859 }
860 other => serde_json::to_string(other).unwrap_or_default(),
861 }
862}
863
864fn take_cache_hit(server: &str, tool: &str, args_hash: &str, now: Instant) -> Option<JsonValue> {
865 with_inner(|inner| {
866 let key = (server.to_string(), tool.to_string());
867 let entry = inner.response_cache.get_mut(&key)?;
868 let cached = entry.get(args_hash)?;
869 if now >= cached.expires_at {
870 entry.remove(args_hash);
871 return None;
872 }
873 let payload = cached.payload.clone();
874 inner.cache_hits = inner.cache_hits.saturating_add(1);
875 Some(payload)
876 })
877}
878
879fn insert_cache(
884 server: &str,
885 tool: &str,
886 args_hash: &str,
887 payload: &JsonValue,
888 hint: McpCacheHint,
889 now: Instant,
890) {
891 let Some(ttl_ms) = hint.ttl_ms else {
892 return;
893 };
894 if ttl_ms == 0 {
895 return;
896 }
897 let expires_at = now + Duration::from_millis(ttl_ms);
898 let cached = CachedResponse {
899 payload: payload.clone(),
900 inserted_at: now,
901 expires_at,
902 scope: hint.scope,
903 };
904 with_inner(|inner| {
905 let key = (server.to_string(), tool.to_string());
906 let bucket = inner.response_cache.entry(key).or_default();
907 if bucket.len() >= RESPONSE_CACHE_MAX_ENTRIES_PER_TOOL {
908 if let Some(oldest_key) = bucket
910 .iter()
911 .min_by_key(|(_, v)| v.inserted_at)
912 .map(|(k, _)| k.clone())
913 {
914 bucket.remove(&oldest_key);
915 }
916 }
917 bucket.insert(args_hash.to_string(), cached);
918 });
919}
920
921#[cfg(test)]
925fn insert_cache_if_hinted(
926 server: &str,
927 tool: &str,
928 args_hash: &str,
929 payload: &JsonValue,
930 now: Instant,
931) {
932 if let Some(hint) = McpCacheHint::from_result(payload) {
933 insert_cache(server, tool, args_hash, payload, hint, now);
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use super::*;
940
941 static TEST_LOCK: Mutex<()> = Mutex::new(());
942
943 fn lock() -> std::sync::MutexGuard<'static, ()> {
944 TEST_LOCK.lock().unwrap_or_else(|p| p.into_inner())
945 }
946
947 #[test]
948 fn supervision_breaker_opens_after_threshold() {
949 let _g = lock();
950 let mut state = SupervisionState::new(SupervisionPolicy {
951 circuit_threshold: 3,
952 circuit_reset: Duration::from_millis(100),
953 ..SupervisionPolicy::default()
954 });
955 let t0 = Instant::now();
956 assert_eq!(state.breaker_state(t0), BreakerState::Closed);
957 state.record_failure(t0);
958 state.record_failure(t0);
959 assert_eq!(state.breaker_state(t0), BreakerState::Closed);
960 state.record_failure(t0);
961 assert_eq!(state.breaker_state(t0), BreakerState::Open);
962 assert_eq!(
964 state.breaker_state(t0 + Duration::from_millis(200)),
965 BreakerState::HalfOpen
966 );
967 }
968
969 #[test]
970 fn supervision_restart_budget_ejects_after_n_attempts() {
971 let _g = lock();
972 let mut state = SupervisionState::new(SupervisionPolicy {
973 max_restarts: 2,
974 restart_window: Duration::from_mins(1),
975 ..SupervisionPolicy::default()
976 });
977 let t = Instant::now();
978 assert!(state.record_restart(t));
979 assert!(state.record_restart(t));
980 assert!(!state.record_restart(t));
981 assert!(state.ejected);
982 }
983
984 #[test]
985 fn supervision_backoff_grows_exponentially_then_caps() {
986 let _g = lock();
987 let mut state = SupervisionState::new(SupervisionPolicy::default());
988 let t = Instant::now();
989 state.record_restart(t);
990 let d1 = state.backoff_delay();
991 state.record_restart(t);
992 let d2 = state.backoff_delay();
993 state.record_restart(t);
994 let d3 = state.backoff_delay();
995 assert!(
996 d2 > d1,
997 "second backoff ({d2:?}) should exceed first ({d1:?})"
998 );
999 assert!(d3 > d2);
1000 for _ in 0..16 {
1001 state.record_restart(t);
1002 }
1003 assert!(state.backoff_delay() <= MAX_RESTART_BACKOFF);
1004 }
1005
1006 #[test]
1007 fn canonical_json_sorts_object_keys() {
1008 let a = canonicalize_json(&serde_json::json!({"b": 1, "a": 2}));
1009 let b = canonicalize_json(&serde_json::json!({"a": 2, "b": 1}));
1010 assert_eq!(a, b);
1011 }
1012
1013 #[test]
1014 fn hash_args_is_stable_across_key_order() {
1015 let h1 = hash_args(&serde_json::json!({"x": 1, "y": [1, 2]}));
1016 let h2 = hash_args(&serde_json::json!({"y": [1, 2], "x": 1}));
1017 assert_eq!(h1, h2);
1018 }
1019
1020 #[test]
1021 fn cache_insert_and_take_respects_ttl() {
1022 let _g = lock();
1023 reset_for_tests();
1024 let payload = serde_json::json!({
1025 "ttlMs": 100,
1026 "cacheScope": "private",
1027 "value": 1
1028 });
1029 let now = Instant::now();
1030 insert_cache_if_hinted("srv", "ping", "deadbeef", &payload, now);
1031 let hit = take_cache_hit("srv", "ping", "deadbeef", now);
1032 assert!(hit.is_some(), "fresh entry should hit");
1033 let stale = take_cache_hit("srv", "ping", "deadbeef", now + Duration::from_millis(200));
1034 assert!(stale.is_none(), "expired entry should miss");
1035 }
1036
1037 #[test]
1038 fn cache_skips_payload_without_hint() {
1039 let _g = lock();
1040 reset_for_tests();
1041 insert_cache_if_hinted(
1042 "srv",
1043 "ping",
1044 "h",
1045 &serde_json::json!({"value": 1}),
1046 Instant::now(),
1047 );
1048 assert!(take_cache_hit("srv", "ping", "h", Instant::now()).is_none());
1049 }
1050
1051 #[test]
1052 fn allowlist_denies_disallowed_tool() {
1053 let _g = lock();
1054 reset_for_tests();
1055 set_allowlist(Some(Arc::new(|server, tool| {
1056 if server == "github" && tool == Some("delete_repo") {
1057 AllowlistDecision::Deny {
1058 reason: "destructive tool blocked".into(),
1059 }
1060 } else {
1061 AllowlistDecision::Allow
1062 }
1063 })));
1064 let runtime = tokio::runtime::Builder::new_current_thread()
1065 .enable_all()
1066 .build()
1067 .unwrap();
1068 let err = runtime
1069 .block_on(call("github", "delete_repo", serde_json::json!({})))
1070 .unwrap_err();
1071 assert!(err.to_string().contains("denied by allowlist"));
1072 set_allowlist(None);
1073 }
1074
1075 #[test]
1076 fn stop_unregistered_server_errors() {
1077 let _g = lock();
1078 reset_for_tests();
1079 let err = stop("nope").unwrap_err();
1080 assert!(err.to_string().contains("no server named 'nope'"));
1081 }
1082
1083 #[test]
1084 fn supervision_record_success_resets_counters() {
1085 let _g = lock();
1086 let mut state = SupervisionState::new(SupervisionPolicy::default());
1087 let t = Instant::now();
1088 state.record_failure(t);
1089 state.record_failure(t);
1090 state.record_success();
1091 assert_eq!(state.consecutive_failures, 0);
1092 assert!(state.breaker_opens_until.is_none());
1093 }
1094
1095 #[test]
1096 fn looks_like_transport_failure_matches_common_errors() {
1097 let cases = [
1098 "MCP: server closed connection",
1099 "MCP: server did not respond to 'tools/call' within 60s",
1100 "MCP write error: broken pipe",
1101 "MCP client is disconnected",
1102 ];
1103 for msg in cases {
1104 assert!(
1105 looks_like_transport_failure(&VmError::Runtime(msg.into())),
1106 "expected {msg:?} to be classified as transport failure"
1107 );
1108 }
1109 assert!(
1110 !looks_like_transport_failure(&VmError::Runtime(
1111 "tool 'foo' rejected arguments".into()
1112 )),
1113 "tool-level errors must not trigger an auto-restart"
1114 );
1115 }
1116}