#![deny(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]
#[cfg(feature = "http")]
pub(crate) mod http;
pub(crate) mod outbound;
#[cfg(feature = "http")]
pub(crate) mod outbound_ring;
pub(crate) mod session;
pub(crate) mod workflow;
#[cfg(feature = "governor")]
pub(crate) mod governor;
#[cfg_attr(not(feature = "http"), allow(dead_code))]
pub(crate) mod resume_ticket;
pub mod outbound_sink;
pub use outbound_sink::{OutboundFrameSink, OutboundSinkError};
pub mod sampling;
pub use sampling::{
ModelHint, ModelPreferences, SamplingContent, SamplingMessage, SamplingRequest,
SamplingResponse,
};
pub mod roots;
pub use roots::Root;
pub mod outbound_ext;
pub use outbound_ext::McpOutboundExt;
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub use http::{handle_dead_leader_orphan_mcp, OrphanOutcome};
#[cfg(all(feature = "http", feature = "bench"))]
pub use http::encode_sse_frame;
#[cfg(feature = "bench")]
pub use outbound_sink::bench_stdio_sink;
#[cfg(all(feature = "http", feature = "bench"))]
pub fn bench_filter_replay(
entries: &[(u64, std::sync::Arc<serde_json::Value>)],
since_id: u64,
) -> Vec<(u64, std::sync::Arc<serde_json::Value>)> {
entries
.iter()
.filter(|(id, _)| *id > since_id)
.cloned()
.collect()
}
use async_trait::async_trait;
use klieo_core::agent::Agent;
use klieo_core::error::ToolError;
use klieo_core::llm::ToolDef;
use klieo_core::tool::{ToolCtx, ToolInvoker};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::warn;
pub(crate) const JSONRPC_PARSE_ERROR: i64 = -32700;
const JSONRPC_METHOD_NOT_FOUND: i64 = -32601;
#[cfg(feature = "http")]
pub(crate) const JSONRPC_INVALID_PARAMS: i64 = -32602;
pub(crate) const JSONRPC_SERVER_ERROR: i64 = -32000;
#[cfg(feature = "http")]
pub(crate) const JSONRPC_RESUME_BUFFER_EXPIRED: i64 = -32011;
#[cfg(feature = "http")]
pub(crate) const JSONRPC_RESUME_BUFFER_NOT_FOUND: i64 = -32012;
#[cfg(feature = "http")]
pub(crate) const JSONRPC_LEADER_DIED: i64 = -32099;
#[cfg(feature = "http")]
pub(crate) const JSONRPC_SESSION_CONFLICT: i64 = -32002;
#[cfg(feature = "http")]
pub(crate) const JSONRPC_UNAUTHENTICATED: i64 = -32001;
#[cfg(feature = "http")]
const _: () = {
let codes: [i64; 9] = [
JSONRPC_PARSE_ERROR,
JSONRPC_METHOD_NOT_FOUND,
JSONRPC_INVALID_PARAMS,
JSONRPC_SERVER_ERROR,
JSONRPC_UNAUTHENTICATED,
JSONRPC_RESUME_BUFFER_EXPIRED,
JSONRPC_RESUME_BUFFER_NOT_FOUND,
JSONRPC_LEADER_DIED,
JSONRPC_SESSION_CONFLICT,
];
let mut i = 0;
while i < codes.len() {
let mut j = i + 1;
while j < codes.len() {
assert!(codes[i] != codes[j], "JSONRPC_* code collision");
j += 1;
}
i += 1;
}
};
pub const LEADER_TTL: std::time::Duration = std::time::Duration::from_secs(5);
pub const MCP_LEADER_KEY_PREFIX: &str = "mcp.";
pub const MCP_PROTOCOL_VERSION: &str = "2025-03-26";
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum McpServerError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("json decode error: {0}")]
Json(#[from] serde_json::Error),
#[error("resume window expired (since_id={since_id})")]
ResumeBufferExpired {
since_id: u64,
},
#[error("no buffered stream for progressToken: {0}")]
ResumeBufferNotFound(String),
#[error("tool error: {0}")]
Tool(#[from] ToolError),
#[error("invalid subject token: {0}")]
InvalidSubject(#[source] klieo_core::BusError),
#[error("bus error: {0}")]
Bus(#[source] klieo_core::BusError),
#[error("outbound request timed out")]
OutboundTimeout,
#[error("client returned error: code={code} message={message}")]
ClientReturnedError {
code: i64,
message: String,
},
#[error("transport closed")]
TransportClosed,
#[error("outbound channel unsupported on this transport")]
OutboundUnsupported,
#[error("outbound serialisation failed: {0}")]
OutboundSerialisation(#[source] serde_json::Error),
#[error("failed to serialise sampling request: {0}")]
SamplingSerialise(serde_json::Error),
#[error("failed to deserialise sampling response: {0}")]
SamplingDeserialise(serde_json::Error),
}
impl From<klieo_core::ServerOutboundError> for McpServerError {
fn from(e: klieo_core::ServerOutboundError) -> Self {
use klieo_core::ServerOutboundError as E;
match e {
E::Timeout => McpServerError::OutboundTimeout,
E::PeerError { code, message } => McpServerError::ClientReturnedError { code, message },
E::TransportClosed => McpServerError::TransportClosed,
E::Unsupported => McpServerError::OutboundUnsupported,
E::Serialisation(err) => McpServerError::OutboundSerialisation(err),
_ => McpServerError::OutboundUnsupported,
}
}
}
impl From<klieo_core::BusError> for McpServerError {
fn from(e: klieo_core::BusError) -> Self {
match e {
klieo_core::BusError::Invalid(_) => Self::InvalidSubject(e),
other => Self::Bus(other),
}
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum McpBuildError {
#[error("with_cancel_subscription requires build_arc()")]
CancelRequiresArc,
#[error("at least one invoker required; call add_tools or add_agent before build")]
NoInvokers,
#[error("duplicate tool name {0:?} across registered invokers")]
DuplicateTool(String),
#[error(transparent)]
RegulatedProfile(#[from] klieo_core::ProfileViolation),
#[error("workflow registered without with_hitl(..); call with_hitl before build")]
WorkflowWithoutHitl,
#[error("workflow registered without with_governor(..); call with_governor before build")]
WorkflowWithoutGovernor,
}
pub struct McpServer {
pub(crate) invoker: Arc<dyn ToolInvoker>,
tool_ctx_factory: ToolCtxFactory,
pub(crate) parent_cancel: CancellationToken,
pub(crate) resume_buffer: std::sync::Arc<dyn klieo_core::resume::ResumeBuffer>,
pub(crate) pubsub: std::sync::Arc<dyn klieo_core::Pubsub>,
pub(crate) cancel_registry: klieo_core::CancelRegistry<String>,
#[cfg(feature = "http")]
pub(crate) publish_permits: std::sync::Arc<tokio::sync::Semaphore>,
pub(crate) leader_registry: Option<klieo_core::LeaderRegistry>,
pub(crate) ownership_registry: Option<klieo_core::OwnershipRegistry>,
#[cfg_attr(not(feature = "http"), allow(dead_code))]
pub(crate) resume_ticket_store: Option<Arc<crate::resume_ticket::ResumeTicketStore>>,
#[cfg_attr(not(feature = "http"), allow(dead_code))]
pub(crate) workflow_resume_handles:
std::collections::HashMap<String, Arc<dyn crate::workflow::WorkflowResumeHandle>>,
pub(crate) authenticator: Option<Arc<dyn klieo_auth_common::Authenticator>>,
pub(crate) leader_ttl: std::time::Duration,
pub(crate) leader_heartbeat_interval: std::time::Duration,
pub(crate) max_failover_attempts: u32,
pub(crate) kv_reaper_interval: Option<std::time::Duration>,
_kv_reaper: Option<klieo_core::KvReaperHandle>,
pub(crate) stdio_session: tokio::sync::OnceCell<std::sync::Arc<crate::session::Session>>,
#[cfg(feature = "http")]
pub(crate) sessions: std::sync::Arc<
tokio::sync::RwLock<
std::collections::HashMap<uuid::Uuid, std::sync::Arc<crate::session::Session>>,
>,
>,
#[cfg(feature = "http")]
pub(crate) max_sessions: usize,
#[cfg(feature = "http")]
pub(crate) max_sessions_per_principal: usize,
#[cfg(feature = "http")]
pub(crate) sse_replay_capacity: usize,
#[cfg(feature = "http")]
pub(crate) principal_counts:
std::sync::Arc<tokio::sync::RwLock<std::collections::HashMap<String, usize>>>,
#[cfg(feature = "http")]
pub(crate) idle_reaper_started: tokio::sync::OnceCell<()>,
pub(crate) declare_sampling: bool,
pub(crate) stdout_writer: tokio::sync::OnceCell<crate::outbound::SharedWriter>,
pub(crate) client_caps: tokio::sync::Mutex<ClientCaps>,
#[cfg(feature = "http")]
pub(crate) session_idle_timeout: std::time::Duration,
#[cfg(feature = "http")]
pub(crate) idle_reaper_tick: std::time::Duration,
#[cfg(feature = "http")]
pub(crate) server_start: std::time::Instant,
}
#[derive(Default, Debug)]
pub(crate) struct ClientCaps {
pub roots_supported: bool,
}
pub type AgentContextFactory =
Arc<dyn Fn() -> klieo_core::agent::AgentContext + Send + Sync + 'static>;
pub type ToolCtxFactory = Arc<dyn Fn() -> klieo_core::tool::ToolCtx + Send + Sync + 'static>;
fn default_tool_ctx_factory() -> ToolCtxFactory {
Arc::new(noop_ctx)
}
#[cfg(feature = "http")]
const DEFAULT_PUBLISH_PERMITS: usize = 64;
#[cfg(feature = "http")]
pub(crate) const DEFAULT_SESSION_IDLE_TIMEOUT: std::time::Duration =
std::time::Duration::from_secs(300);
#[cfg(feature = "http")]
pub(crate) const DEFAULT_MAX_SESSIONS: usize = 1024;
#[cfg(feature = "http")]
pub(crate) const DEFAULT_MAX_SESSIONS_PER_PRINCIPAL_DIVISOR: usize = 16;
#[cfg(feature = "http")]
pub(crate) const DEFAULT_SSE_REPLAY_CAPACITY: usize = 256;
#[cfg(feature = "http")]
pub(crate) fn default_max_sessions_per_principal(max_sessions: usize) -> usize {
(max_sessions / DEFAULT_MAX_SESSIONS_PER_PRINCIPAL_DIVISOR).max(1)
}
#[cfg(feature = "http")]
pub(crate) const DEFAULT_IDLE_REAPER_TICK: std::time::Duration = std::time::Duration::from_secs(10);
pub struct McpServerBuilder {
invokers: Vec<Arc<dyn ToolInvoker>>,
parent_cancel: CancellationToken,
tool_ctx_factory: ToolCtxFactory,
resume_buffer: Option<std::sync::Arc<dyn klieo_core::resume::ResumeBuffer>>,
pubsub: Option<std::sync::Arc<dyn klieo_core::Pubsub>>,
subscribe_cancels: bool,
#[cfg(feature = "http")]
publish_permits: Option<usize>,
leader_kv: Option<Arc<dyn klieo_core::KvStore>>,
tenant_kv: Option<Arc<dyn klieo_core::KvStore>>,
checkpoint_kv: Option<Arc<dyn klieo_core::KvStore>>,
tenant_strict: bool,
profile: klieo_core::DeploymentProfile,
authenticator: Option<Arc<dyn klieo_auth_common::Authenticator>>,
leader_ttl: Option<std::time::Duration>,
leader_heartbeat_interval: Option<std::time::Duration>,
max_failover_attempts: Option<u32>,
kv_reaper_interval: Option<std::time::Duration>,
declare_sampling: bool,
#[cfg(feature = "http")]
session_idle_timeout: Option<std::time::Duration>,
#[cfg(feature = "http")]
max_sessions: Option<usize>,
#[cfg(feature = "http")]
max_sessions_per_principal: Option<usize>,
#[cfg(feature = "http")]
sse_replay_capacity: Option<usize>,
#[cfg(all(feature = "http", any(test, feature = "test-fixtures")))]
idle_reaper_tick: Option<std::time::Duration>,
hitl: Option<crate::workflow::HitlBundle>,
pending_workflows: Vec<crate::workflow::WorkflowRegistration>,
governor_bundle: Option<GovernorBundleHolder>,
}
#[cfg(feature = "governor")]
pub(crate) type GovernorBundleHolder = crate::governor::GovernorBundle;
#[cfg(not(feature = "governor"))]
#[derive(Clone)]
pub(crate) struct GovernorBundleHolder;
impl Default for McpServerBuilder {
fn default() -> Self {
Self::new()
}
}
fn spawn_reaper_if_configured(
interval: Option<std::time::Duration>,
leader_registry: Option<&klieo_core::LeaderRegistry>,
ownership_registry: Option<&klieo_core::OwnershipRegistry>,
resume_buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
) -> Option<klieo_core::KvReaperHandle> {
let interval = interval?;
let mut buckets: Vec<String> = Vec::new();
let kv = if let Some(reg) = leader_registry {
buckets.push(reg.bucket().to_string());
reg.kv().clone()
} else if let Some(reg) = ownership_registry {
buckets.push(reg.bucket().to_string());
reg.kv().clone()
} else {
return None;
};
if let (Some(_), Some(ownership)) = (leader_registry, ownership_registry) {
buckets.push(ownership.bucket().to_string());
}
Some(klieo_core::spawn_kv_reaper(
kv,
resume_buffer.clone(),
buckets,
interval,
))
}
impl McpServerBuilder {
pub fn new() -> Self {
Self {
invokers: Vec::new(),
parent_cancel: CancellationToken::new(),
tool_ctx_factory: default_tool_ctx_factory(),
resume_buffer: None,
pubsub: None,
subscribe_cancels: false,
#[cfg(feature = "http")]
publish_permits: None,
leader_kv: None,
tenant_kv: None,
checkpoint_kv: None,
tenant_strict: false,
profile: klieo_core::DeploymentProfile::Unprofiled,
authenticator: None,
leader_ttl: None,
leader_heartbeat_interval: None,
max_failover_attempts: None,
kv_reaper_interval: None,
declare_sampling: false,
#[cfg(feature = "http")]
session_idle_timeout: None,
#[cfg(feature = "http")]
max_sessions: None,
#[cfg(feature = "http")]
max_sessions_per_principal: None,
#[cfg(feature = "http")]
sse_replay_capacity: None,
#[cfg(all(feature = "http", any(test, feature = "test-fixtures")))]
idle_reaper_tick: None,
hitl: None,
pending_workflows: Vec::new(),
governor_bundle: None,
}
}
pub fn with_parent_cancel(mut self, parent_cancel: CancellationToken) -> Self {
self.parent_cancel = parent_cancel;
self
}
pub fn with_tool_ctx_factory(mut self, factory: ToolCtxFactory) -> Self {
self.tool_ctx_factory = factory;
self
}
#[must_use]
pub fn with_resume_buffer(
mut self,
buffer: std::sync::Arc<dyn klieo_core::resume::ResumeBuffer>,
) -> Self {
self.resume_buffer = Some(buffer);
self
}
#[must_use]
pub fn with_pubsub(mut self, pubsub: std::sync::Arc<dyn klieo_core::Pubsub>) -> Self {
self.pubsub = Some(pubsub);
self
}
#[must_use]
pub fn with_cancel_subscription(mut self) -> Self {
self.subscribe_cancels = true;
self
}
#[cfg(feature = "http")]
#[must_use]
pub fn with_publish_concurrency(mut self, permits: usize) -> Self {
self.publish_permits = Some(permits);
self
}
#[must_use]
pub fn with_leader_election(mut self, kv: Arc<dyn klieo_core::KvStore>) -> Self {
self.leader_kv = Some(kv);
self
}
#[must_use]
pub fn with_tenant_binding(mut self, kv: Arc<dyn klieo_core::KvStore>) -> Self {
self.tenant_kv = Some(kv);
self.tenant_strict = false;
self
}
#[must_use]
pub fn with_tenant_binding_strict(mut self, kv: Arc<dyn klieo_core::KvStore>) -> Self {
self.tenant_kv = Some(kv);
self.tenant_strict = true;
self
}
#[must_use]
pub fn with_checkpoint_kv(mut self, kv: Arc<dyn klieo_core::KvStore>) -> Self {
self.checkpoint_kv = Some(kv);
self
}
#[must_use]
pub fn with_authenticator(
mut self,
authenticator: Arc<dyn klieo_auth_common::Authenticator>,
) -> Self {
self.authenticator = Some(authenticator);
self
}
pub fn profile(mut self, profile: klieo_core::DeploymentProfile) -> Self {
self.profile = profile;
self
}
#[must_use]
pub fn with_leader_ttl(mut self, ttl: std::time::Duration) -> Self {
self.leader_ttl = Some(ttl);
self
}
#[must_use]
pub fn with_leader_heartbeat_interval(mut self, interval: std::time::Duration) -> Self {
self.leader_heartbeat_interval = Some(interval);
self
}
#[must_use]
pub fn with_max_failover_attempts(mut self, cap: u32) -> Self {
self.max_failover_attempts = Some(cap);
self
}
#[must_use]
pub fn with_kv_reaper(mut self, interval: std::time::Duration) -> Self {
self.kv_reaper_interval = Some(interval);
self
}
#[must_use]
pub fn with_client_sampling(mut self) -> Self {
self.declare_sampling = true;
self
}
#[cfg(feature = "http")]
#[must_use]
pub fn with_session_idle_timeout(mut self, ttl: std::time::Duration) -> Self {
self.session_idle_timeout = Some(ttl);
self
}
#[cfg(feature = "http")]
#[must_use]
pub fn with_max_sessions(mut self, cap: usize) -> Self {
assert!(cap > 0, "max_sessions must be > 0");
self.max_sessions = Some(cap);
self
}
#[cfg(feature = "http")]
#[must_use]
pub fn with_max_sessions_per_principal(mut self, cap: usize) -> Self {
assert!(cap > 0, "max_sessions_per_principal must be > 0");
self.max_sessions_per_principal = Some(cap);
self
}
#[cfg(feature = "http")]
#[must_use]
pub fn with_sse_replay_capacity(mut self, capacity: usize) -> Self {
self.sse_replay_capacity = Some(capacity);
self
}
#[cfg(all(feature = "http", any(test, feature = "test-fixtures")))]
#[must_use]
pub fn with_idle_reaper_tick(mut self, tick: std::time::Duration) -> Self {
self.idle_reaper_tick = Some(tick);
self
}
pub fn add_tools(mut self, invoker: Arc<dyn ToolInvoker>) -> Self {
self.invokers.push(invoker);
self
}
pub fn add_agent_with_schema<A>(
mut self,
agent: A,
input_schema: serde_json::Value,
ctx_factory: AgentContextFactory,
) -> Self
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
A::Output: serde::Serialize + Send + 'static,
{
let name = agent.name().to_string();
let invoker: Arc<dyn ToolInvoker> = Arc::new(AgentAsToolInvoker {
agent: Arc::new(agent),
name,
input_schema,
ctx_factory,
#[cfg(feature = "governor")]
governor: self.governor_bundle.clone(),
});
self.invokers.push(invoker);
self
}
#[cfg(feature = "schemars")]
pub fn add_agent<A>(self, agent: A, ctx_factory: AgentContextFactory) -> Self
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
A::Output: serde::Serialize + Send + 'static,
{
let schema = serde_json::to_value(schemars::schema_for!(A::Input))
.expect("schemars::Schema serialises to JSON via #[derive(Serialize)]");
self.add_agent_with_schema(agent, schema, ctx_factory)
}
#[cfg(feature = "governor")]
#[must_use]
pub fn with_governor(
mut self,
governor: Arc<dyn klieo_ops::governor::Governor>,
provider: klieo_ops::ProviderId,
) -> Self {
self.governor_bundle = Some(crate::governor::GovernorBundle { governor, provider });
self
}
pub fn with_hitl(
mut self,
client: Arc<klieo_hitl_client::HitlClient>,
cfg: Arc<klieo_hitl::HitlConfig>,
) -> Self {
self.hitl = Some(crate::workflow::HitlBundle { client, cfg });
self
}
pub fn add_workflow_with_schema<A>(
mut self,
agent: A,
system_prompt: impl Into<String>,
input_schema: serde_json::Value,
run_options: klieo_core::runtime::RunOptions,
ctx_factory: AgentContextFactory,
) -> Self
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
{
let name = agent.name().to_string();
let prompt = system_prompt.into();
drop(agent);
let materialise: crate::workflow::WorkflowMaterialiser = Box::new(
move |bundle: crate::workflow::HitlBundle,
ticket_store: Option<Arc<crate::resume_ticket::ResumeTicketStore>>,
governor_bundle: Option<crate::GovernorBundleHolder>| {
#[cfg(not(feature = "governor"))]
let _ = governor_bundle;
let invoker = Arc::new(crate::workflow::WorkflowAsToolInvoker::<A>::new(
name.clone(),
prompt.clone(),
input_schema.clone(),
ctx_factory.clone(),
run_options.clone(),
bundle,
ticket_store,
#[cfg(feature = "governor")]
governor_bundle,
));
crate::workflow::WorkflowMaterialisation {
name: name.clone(),
resume_handle: invoker.clone()
as Arc<dyn crate::workflow::WorkflowResumeHandle>,
invoker: invoker as Arc<dyn ToolInvoker>,
}
},
);
self.pending_workflows
.push(crate::workflow::WorkflowRegistration { materialise });
self
}
#[cfg(feature = "schemars")]
pub fn add_workflow<A>(
self,
agent: A,
system_prompt: impl Into<String>,
run_options: klieo_core::runtime::RunOptions,
ctx_factory: AgentContextFactory,
) -> Self
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
{
let schema = serde_json::to_value(schemars::schema_for!(A::Input))
.expect("schemars::Schema serialises to JSON via #[derive(Serialize)]");
self.add_workflow_with_schema(agent, system_prompt, schema, run_options, ctx_factory)
}
pub fn build(self) -> Result<McpServer, McpBuildError> {
if self.subscribe_cancels {
return Err(McpBuildError::CancelRequiresArc);
}
self.build_inner()
}
fn build_inner(mut self) -> Result<McpServer, McpBuildError> {
let ticket_store = self
.checkpoint_kv
.clone()
.map(|kv| Arc::new(crate::resume_ticket::ResumeTicketStore::new(kv)));
let mut workflow_resume_handles: std::collections::HashMap<
String,
Arc<dyn crate::workflow::WorkflowResumeHandle>,
> = std::collections::HashMap::new();
if !self.pending_workflows.is_empty() {
let bundle = self
.hitl
.clone()
.ok_or(McpBuildError::WorkflowWithoutHitl)?;
#[cfg(feature = "governor")]
if self.governor_bundle.is_none() {
return Err(McpBuildError::WorkflowWithoutGovernor);
}
let governor_bundle = self.governor_bundle.clone();
let pending = std::mem::take(&mut self.pending_workflows);
for reg in pending {
let mat = (reg.materialise)(
bundle.clone(),
ticket_store.clone(),
governor_bundle.clone(),
);
if workflow_resume_handles
.insert(mat.name.clone(), mat.resume_handle)
.is_some()
{
return Err(McpBuildError::DuplicateTool(mat.name));
}
self.invokers.push(mat.invoker);
}
}
let invoker_count = self.invokers.len();
if invoker_count == 0 {
return Err(McpBuildError::NoInvokers);
}
let invoker: Arc<dyn ToolInvoker> = if invoker_count == 1 {
self.invokers.into_iter().next().unwrap() } else {
Arc::new(MergedInvoker::new(self.invokers)?)
};
#[cfg(feature = "http")]
let permits = self.publish_permits.unwrap_or(DEFAULT_PUBLISH_PERMITS);
#[cfg(feature = "http")]
let max_sessions = self.max_sessions.unwrap_or(DEFAULT_MAX_SESSIONS);
#[cfg(feature = "http")]
let max_sessions_per_principal = self
.max_sessions_per_principal
.unwrap_or_else(|| default_max_sessions_per_principal(max_sessions));
let leader_registry = self.leader_kv.map(|kv| {
klieo_core::LeaderRegistry::new(
kv,
"klieo-leaders".into(),
uuid::Uuid::new_v4().to_string(),
)
});
let profile = self.profile;
profile.validate(
self.tenant_kv.is_some(),
self.authenticator.as_ref().map(|a| a.allows_anonymous()),
)?;
let tenant_strict = self.tenant_strict || profile.requires_strict_binding();
let ownership_registry = self.tenant_kv.map(|kv| {
let bucket = "klieo-tenants".into();
if tenant_strict {
klieo_core::OwnershipRegistry::new_strict(kv, bucket)
} else {
klieo_core::OwnershipRegistry::new(kv, bucket)
}
});
if profile.requires_strict_binding() || profile.requires_named_principal() {
tracing::warn!(
target: "klieo.security",
cwe = 639,
"regulated multi-tenant profile active on this replica; \
cross-replica tenant isolation assumes ALL replicas run the \
same profile — a lenient peer reintroduces CWE-639. Fleet \
homogeneity is NOT verified by this replica."
);
}
let resume_buffer = self
.resume_buffer
.unwrap_or_else(|| std::sync::Arc::new(klieo_core::resume::NoopResumeBuffer));
let leader_ttl = self.leader_ttl.unwrap_or(LEADER_TTL);
let leader_heartbeat_interval = self.leader_heartbeat_interval.unwrap_or(leader_ttl / 2);
let max_failover_attempts = self
.max_failover_attempts
.unwrap_or(klieo_core::FAILOVER_ATTEMPT_CAP);
let kv_reaper = spawn_reaper_if_configured(
self.kv_reaper_interval,
leader_registry.as_ref(),
ownership_registry.as_ref(),
&resume_buffer,
);
Ok(McpServer {
invoker,
tool_ctx_factory: self.tool_ctx_factory,
parent_cancel: self.parent_cancel,
resume_buffer,
pubsub: self
.pubsub
.unwrap_or_else(|| klieo_bus_memory::MemoryBus::new().pubsub.clone()),
cancel_registry: klieo_core::CancelRegistry::new(),
#[cfg(feature = "http")]
publish_permits: std::sync::Arc::new(tokio::sync::Semaphore::new(permits)),
leader_registry,
ownership_registry,
resume_ticket_store: ticket_store,
workflow_resume_handles,
authenticator: self.authenticator,
leader_ttl,
leader_heartbeat_interval,
max_failover_attempts,
kv_reaper_interval: self.kv_reaper_interval,
_kv_reaper: kv_reaper,
stdio_session: tokio::sync::OnceCell::new(),
#[cfg(feature = "http")]
sessions: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "http")]
max_sessions,
#[cfg(feature = "http")]
max_sessions_per_principal,
#[cfg(feature = "http")]
sse_replay_capacity: self
.sse_replay_capacity
.unwrap_or(DEFAULT_SSE_REPLAY_CAPACITY),
#[cfg(feature = "http")]
principal_counts: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
#[cfg(feature = "http")]
idle_reaper_started: tokio::sync::OnceCell::new(),
declare_sampling: self.declare_sampling,
stdout_writer: tokio::sync::OnceCell::new(),
client_caps: tokio::sync::Mutex::new(ClientCaps::default()),
#[cfg(feature = "http")]
session_idle_timeout: self
.session_idle_timeout
.unwrap_or(DEFAULT_SESSION_IDLE_TIMEOUT),
#[cfg(all(feature = "http", any(test, feature = "test-fixtures")))]
idle_reaper_tick: self.idle_reaper_tick.unwrap_or(DEFAULT_IDLE_REAPER_TICK),
#[cfg(all(feature = "http", not(any(test, feature = "test-fixtures"))))]
idle_reaper_tick: DEFAULT_IDLE_REAPER_TICK,
#[cfg(feature = "http")]
server_start: std::time::Instant::now(),
})
}
pub fn build_arc(self) -> Result<std::sync::Arc<McpServer>, McpBuildError> {
let spawn_subscriber = self.subscribe_cancels;
let server = std::sync::Arc::new(self.build_inner()?);
if spawn_subscriber {
klieo_core::cancel::spawn_wildcard_cancel_subscriber(
server.pubsub.clone(),
"klieo.mcp.cancel.>".to_string(),
"klieo.mcp.cancel.".to_string(),
server.cancel_registry.clone(),
"mcp.cancel",
);
}
Ok(server)
}
}
impl McpServer {
pub fn builder() -> McpServerBuilder {
McpServerBuilder::new()
}
pub fn leader_registry(&self) -> Option<&klieo_core::LeaderRegistry> {
self.leader_registry.as_ref()
}
pub fn ownership_registry(&self) -> Option<&klieo_core::OwnershipRegistry> {
self.ownership_registry.as_ref()
}
pub fn authenticator(&self) -> Option<&Arc<dyn klieo_auth_common::Authenticator>> {
self.authenticator.as_ref()
}
pub fn leader_ttl(&self) -> std::time::Duration {
self.leader_ttl
}
pub fn leader_heartbeat_interval(&self) -> std::time::Duration {
self.leader_heartbeat_interval
}
pub fn max_failover_attempts(&self) -> u32 {
self.max_failover_attempts
}
pub fn kv_reaper_interval(&self) -> Option<std::time::Duration> {
self.kv_reaper_interval
}
#[cfg(feature = "http")]
pub async fn session_ids(&self) -> Vec<uuid::Uuid> {
self.sessions.read().await.keys().copied().collect()
}
#[cfg(feature = "http")]
pub async fn is_session_closed_by_id(&self, id: uuid::Uuid) -> Option<bool> {
self.sessions.read().await.get(&id).map(|s| s.is_closed())
}
#[cfg(feature = "http")]
pub(crate) fn sse_replay_enabled(&self) -> bool {
self.sse_replay_capacity > 0
}
#[cfg(feature = "http")]
pub(crate) async fn decrement_principal_count(&self, principal: Option<&str>) {
let Some(principal) = principal else { return };
let mut counts = self.principal_counts.write().await;
if let Some(entry) = counts.get_mut(principal) {
*entry = entry.saturating_sub(1);
if *entry == 0 {
counts.remove(principal);
}
}
}
pub fn outbound(&self) -> Option<Arc<dyn klieo_core::ServerOutbound>> {
self.stdio_session
.get()
.and_then(|session| session.outbound.get())
.map(|o| o.clone() as Arc<dyn klieo_core::ServerOutbound>)
}
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub async fn client_roots_for_session(&self, session_id: uuid::Uuid) -> Option<Vec<Root>> {
let sessions = self.sessions.read().await;
let session = sessions.get(&session_id)?;
let cache = session.roots_cache.get()?;
Some(cache.snapshot())
}
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub async fn outbound_for_session(
&self,
session_id: uuid::Uuid,
) -> Option<Arc<dyn klieo_core::ServerOutbound>> {
let sessions = self.sessions.read().await;
let session = sessions.get(&session_id)?;
let outbound = session.outbound.get()?.clone();
Some(outbound as Arc<dyn klieo_core::ServerOutbound>)
}
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub fn sse_replay_capacity(&self) -> usize {
self.sse_replay_capacity
}
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub async fn sse_replay_snapshot(
&self,
session_id: uuid::Uuid,
) -> Option<Vec<(u64, std::sync::Arc<serde_json::Value>)>> {
let sessions = self.sessions.read().await;
let session = sessions.get(&session_id)?;
let buffer = session.sse_replay_buffer.lock();
Some(buffer.iter().cloned().collect())
}
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub async fn emit_test_notification(
&self,
session_id: uuid::Uuid,
method: &str,
payload_bytes: usize,
) -> Result<(), ()> {
let session = {
let sessions = self.sessions.read().await;
sessions.get(&session_id).cloned()
};
let Some(outbound) = session.as_ref().and_then(|s| s.outbound.get()) else {
return Err(());
};
outbound
.send_notification_frame(method, payload_bytes)
.await
.map_err(|_| ())
}
pub fn client_roots(&self) -> Vec<Root> {
self.stdio_session
.get()
.and_then(|session| session.roots_cache.get())
.map(|c| c.snapshot())
.unwrap_or_default()
}
pub fn subscribe_root_changes(&self) -> Option<tokio::sync::watch::Receiver<Vec<Root>>> {
self.stdio_session
.get()
.and_then(|session| session.roots_cache.get())
.map(|c| c.subscribe())
}
pub fn expose_tools(invoker: Arc<dyn ToolInvoker>) -> Self {
Self::builder()
.add_tools(invoker)
.build()
.expect("single-invoker build cannot fail")
}
pub fn expose_agent_with_schema<A>(
agent: A,
input_schema: serde_json::Value,
ctx_factory: AgentContextFactory,
) -> Self
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
A::Output: serde::Serialize + Send + 'static,
{
Self::builder()
.add_agent_with_schema(agent, input_schema, ctx_factory)
.build()
.expect("single-invoker build cannot fail")
}
#[cfg(feature = "schemars")]
pub fn expose_agent<A>(agent: A, ctx_factory: AgentContextFactory) -> Self
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
A::Output: serde::Serialize + Send + 'static,
{
Self::builder()
.add_agent(agent, ctx_factory)
.build()
.expect("single-invoker build cannot fail")
}
#[cfg(not(feature = "governor"))]
pub fn expose_workflow_with_schema<A>(
agent: A,
system_prompt: impl Into<String>,
input_schema: serde_json::Value,
run_options: klieo_core::runtime::RunOptions,
hitl_client: Arc<klieo_hitl_client::HitlClient>,
hitl_cfg: Arc<klieo_hitl::HitlConfig>,
ctx_factory: AgentContextFactory,
) -> Result<Self, McpBuildError>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
{
Self::builder()
.with_hitl(hitl_client, hitl_cfg)
.add_workflow_with_schema(agent, system_prompt, input_schema, run_options, ctx_factory)
.build()
}
#[cfg(all(feature = "schemars", not(feature = "governor")))]
pub fn expose_workflow<A>(
agent: A,
system_prompt: impl Into<String>,
run_options: klieo_core::runtime::RunOptions,
hitl_client: Arc<klieo_hitl_client::HitlClient>,
hitl_cfg: Arc<klieo_hitl::HitlConfig>,
ctx_factory: AgentContextFactory,
) -> Result<Self, McpBuildError>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + schemars::JsonSchema + Send + 'static,
{
Self::builder()
.with_hitl(hitl_client, hitl_cfg)
.add_workflow(agent, system_prompt, run_options, ctx_factory)
.build()
}
pub fn invoker(&self) -> &std::sync::Arc<dyn ToolInvoker> {
&self.invoker
}
pub fn resume_buffer(&self) -> &std::sync::Arc<dyn klieo_core::resume::ResumeBuffer> {
&self.resume_buffer
}
pub fn pubsub(&self) -> &std::sync::Arc<dyn klieo_core::Pubsub> {
&self.pubsub
}
pub fn cancel_registry(&self) -> &klieo_core::CancelRegistry<String> {
&self.cancel_registry
}
pub async fn publish_cancel(&self, progress_token: &str) -> Result<(), McpServerError> {
klieo_core::cancel::publish_cancel_signal(
&self.pubsub,
"klieo.mcp.cancel.",
progress_token,
)
.await?;
Ok(())
}
#[cfg(feature = "http")]
pub(crate) fn tool_ctx_with_progress(
&self,
progress: tokio::sync::broadcast::Sender<klieo_core::AgentEvent>,
cancel: tokio_util::sync::CancellationToken,
caller_principal: Option<String>,
parent_anchor: Option<String>,
) -> klieo_core::tool::ToolCtx {
let mut ctx = (self.tool_ctx_factory)()
.with_progress(progress)
.with_cancel(cancel);
if let Some(principal) = caller_principal {
ctx = ctx.with_caller_principal(principal);
}
if let Some(anchor) = parent_anchor {
ctx = ctx.with_parent_anchor(anchor);
}
ctx
}
pub async fn serve_stdio(self: Arc<Self>) -> Result<(), McpServerError> {
let stdin = tokio::io::stdin();
let stdout: outbound::SharedWriter = Arc::new(Mutex::new(tokio::io::stdout()));
self.serve_with_streams(stdin, stdout).await
}
pub async fn serve_with_streams<R>(
self: Arc<Self>,
reader: R,
writer: outbound::SharedWriter,
) -> Result<(), McpServerError>
where
R: tokio::io::AsyncRead + Unpin,
{
let stdout = self.ensure_stdout_writer_with(writer).await;
self.ensure_outbound_and_roots().await;
let mut lines = BufReader::new(reader).lines();
while let Some(line) = lines.next_line().await? {
if line.trim().is_empty() {
continue;
}
self.dispatch_stdio_line(line, stdout.clone());
}
Ok(())
}
fn dispatch_stdio_line(self: &Arc<Self>, line: String, writer: outbound::SharedWriter) {
let server = self.clone();
tokio::spawn(async move {
if let Err(error) = server.process_stdio_line(&line, &writer).await {
warn!(error = ?error, "stdio dispatch task failed");
}
});
}
async fn ensure_stdout_writer_with(
&self,
writer: outbound::SharedWriter,
) -> outbound::SharedWriter {
let _ = self.stdout_writer.set(writer);
self.stdout_writer
.get()
.expect("stdout_writer populated above")
.clone()
}
async fn ensure_outbound_and_roots(&self) {
if !self.declare_sampling {
return;
}
let writer = self
.stdout_writer
.get()
.expect("serve_with_streams primes stdout_writer before ensure_outbound_and_roots")
.clone();
let session = self
.stdio_session
.get_or_init(|| async { std::sync::Arc::new(crate::session::Session::new_stdio()) })
.await
.clone();
let outbound = session
.outbound
.get_or_init(|| async {
let sink: Arc<dyn OutboundFrameSink> =
Arc::new(crate::outbound_sink::StdioFrameSink::new(writer.clone()));
Arc::new(crate::outbound::OutboundRequests::new(sink))
})
.await
.clone();
let _ = session
.roots_cache
.get_or_init(|| async {
let outbound: Arc<dyn klieo_core::ServerOutbound> = outbound.clone();
Arc::new(crate::roots::RootsCache::new(outbound))
})
.await;
}
async fn process_stdio_line(
&self,
line: &str,
writer: &outbound::SharedWriter,
) -> Result<(), McpServerError> {
let parsed: serde_json::Value = match serde_json::from_str(line) {
Ok(value) => value,
Err(error) => {
warn!(error = %error, "rejected malformed JSON-RPC frame");
let envelope = rpc_error(None, JSONRPC_PARSE_ERROR, "malformed JSON-RPC frame");
return write_frame(writer, &envelope).await;
}
};
let stdio_session = self.stdio_session.get();
match classify_inbound(&parsed) {
InboundKind::Request => {
let envelope = self.handle_jsonrpc(parsed, stdio_session).await;
write_frame(writer, &envelope).await
}
InboundKind::Notification => {
self.handle_jsonrpc(parsed, stdio_session).await;
Ok(())
}
InboundKind::OutboundResponse(id) => {
self.route_outbound_response(id, parsed).await;
Ok(())
}
InboundKind::Unparseable => {
warn!("rejected inbound frame: no method and no id");
Ok(())
}
}
}
async fn route_outbound_response(&self, id: i64, frame: serde_json::Value) {
if let Some(outbound) = self
.stdio_session
.get()
.and_then(|session| session.outbound.get())
{
outbound.complete_pending(id, frame).await;
} else {
warn!(
rpc_id = id,
"outbound response received but server has no outbound table wired"
);
}
}
#[cfg(test)]
async fn handle_line(&self, line: &str) -> serde_json::Value {
let req: serde_json::Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, "rejected malformed JSON-RPC frame");
return rpc_error(None, JSONRPC_PARSE_ERROR, "malformed JSON-RPC frame");
}
};
self.handle_jsonrpc(req, self.stdio_session.get()).await
}
pub(crate) async fn handle_jsonrpc(
&self,
req: serde_json::Value,
session: Option<&std::sync::Arc<crate::session::Session>>,
) -> serde_json::Value {
let id = req.get("id").cloned();
let method = req.get("method").and_then(|m| m.as_str()).unwrap_or("");
match method {
"initialize" => rpc_ok(id, self.handle_initialize(&req).await),
"notifications/initialized" => {
self.handle_initialized_notification(session).await;
serde_json::Value::Null
}
"notifications/roots/list_changed" => {
self.handle_roots_list_changed_notification(session);
serde_json::Value::Null
}
"shutdown" => rpc_ok(id, serde_json::Value::Null),
"tools/list" => rpc_ok(id, self.tools_list()),
"tools/call" => match self.tools_call(req.get("params")).await {
Ok(v) => rpc_ok(id, v),
Err(e) => tool_error_to_envelope(id, e),
},
other => {
warn!(rpc_id = ?id, method = other, "method not found");
rpc_error(
id,
JSONRPC_METHOD_NOT_FOUND,
&format!("method not found: {other}"),
)
}
}
}
async fn handle_initialize(&self, req: &serde_json::Value) -> serde_json::Value {
let roots_supported = req.pointer("/params/capabilities/roots").is_some();
{
let mut caps = self.client_caps.lock().await;
caps.roots_supported = roots_supported;
}
if self.declare_sampling {
initialize_result_with_sampling()
} else {
initialize_result_without_sampling()
}
}
async fn handle_initialized_notification(
&self,
session: Option<&std::sync::Arc<crate::session::Session>>,
) {
let roots_supported = self.client_caps.lock().await.roots_supported;
if !roots_supported {
return;
}
let Some(cache) = session
.and_then(|session| session.roots_cache.get())
.cloned()
else {
return;
};
tokio::spawn(async move {
if let Err(error) = cache.refresh().await {
warn!(error = ?error, "initial roots/list fetch failed");
}
});
}
fn handle_roots_list_changed_notification(
&self,
session: Option<&std::sync::Arc<crate::session::Session>>,
) {
let Some(cache) = session
.and_then(|session| session.roots_cache.get())
.cloned()
else {
return;
};
tokio::spawn(async move {
if let Err(error) = cache.refresh().await {
warn!(error = ?error, "roots list_changed re-fetch failed");
}
});
}
fn tools_list(&self) -> serde_json::Value {
let tools: Vec<serde_json::Value> = self
.invoker
.catalogue()
.iter()
.map(tool_def_to_mcp_descriptor)
.collect();
serde_json::json!({ "tools": tools })
}
async fn tools_call(
&self,
params: Option<&serde_json::Value>,
) -> Result<serde_json::Value, ToolError> {
let params = params.ok_or_else(|| ToolError::InvalidArgs("missing params".into()))?;
let name = params
.get("name")
.and_then(|n| n.as_str())
.ok_or_else(|| ToolError::InvalidArgs("missing tool name".into()))?;
let args = params
.get("arguments")
.cloned()
.unwrap_or(serde_json::Value::Null);
let ctx = (self.tool_ctx_factory)().with_cancel(self.parent_cancel.child_token());
let out = self.invoker.invoke(name, args, ctx).await?;
Ok(serde_json::json!({
"content": [
{ "type": "text", "text": out.to_string() }
]
}))
}
}
struct AgentAsToolInvoker<A>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
A::Output: serde::Serialize + Send + 'static,
{
agent: Arc<A>,
name: String,
input_schema: serde_json::Value,
ctx_factory: AgentContextFactory,
#[cfg(feature = "governor")]
governor: Option<crate::governor::GovernorBundle>,
}
#[async_trait]
impl<A> ToolInvoker for AgentAsToolInvoker<A>
where
A: Agent + 'static,
A::Input: serde::de::DeserializeOwned + Send + 'static,
A::Output: serde::Serialize + Send + 'static,
{
fn catalogue(&self) -> Vec<ToolDef> {
vec![ToolDef::new(
self.name.clone(),
format!("klieo agent: {}", self.name),
self.input_schema.clone(),
)]
}
async fn invoke(
&self,
name: &str,
args: serde_json::Value,
tool_ctx: ToolCtx,
) -> Result<serde_json::Value, ToolError> {
if name != self.name {
return Err(ToolError::UnknownTool(name.into()));
}
let input: A::Input = serde_json::from_value(args).map_err(|e| {
warn!(agent = %self.name, error = %e, "decode of MCP tools/call args failed");
ToolError::InvalidArgs("arguments do not match inputSchema".into())
})?;
let mut ctx = (self.ctx_factory)();
ctx.cancel = tool_ctx.cancel.child_token();
ctx.progress = tool_ctx.progress.clone();
if let Some(principal) = tool_ctx.caller_principal.as_ref() {
ctx = ctx.with_tenant_label(klieo_core::principal_hash(principal.as_str()));
}
if let Some(anchor) = tool_ctx.parent_anchor.as_ref() {
ctx = ctx.with_parent_anchor(anchor.as_str().to_string());
}
#[cfg(feature = "governor")]
if let Some(bundle) = self.governor.as_ref() {
ctx = crate::governor::wrap_ctx_with_governor(ctx, bundle);
}
let output = self.agent.run(ctx, input).await.map_err(|e| {
warn!(agent = %self.name, error = %e, "exposed agent execution failed");
ToolError::Permanent("agent execution failed".into())
})?;
serde_json::to_value(output).map_err(|e| {
warn!(agent = %self.name, error = %e, "encode of agent output failed");
ToolError::Permanent("agent output not serialisable".into())
})
}
}
struct MergedInvoker {
inner: Vec<Arc<dyn ToolInvoker>>,
routes: std::collections::HashMap<String, usize>,
merged_catalogue: Vec<ToolDef>,
}
impl MergedInvoker {
fn new(inner: Vec<Arc<dyn ToolInvoker>>) -> Result<Self, McpBuildError> {
let mut routes: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut merged_catalogue: Vec<ToolDef> = Vec::new();
for (index, invoker) in inner.iter().enumerate() {
for tool in invoker.catalogue() {
if routes.insert(tool.name.clone(), index).is_some() {
return Err(McpBuildError::DuplicateTool(tool.name));
}
merged_catalogue.push(tool);
}
}
Ok(Self {
inner,
routes,
merged_catalogue,
})
}
}
#[async_trait]
impl ToolInvoker for MergedInvoker {
fn catalogue(&self) -> Vec<ToolDef> {
self.merged_catalogue.clone()
}
async fn invoke(
&self,
name: &str,
args: serde_json::Value,
ctx: ToolCtx,
) -> Result<serde_json::Value, ToolError> {
match self.routes.get(name) {
Some(&index) => self.inner[index].invoke(name, args, ctx).await,
None => Err(ToolError::UnknownTool(name.into())),
}
}
fn is_tool_idempotent(&self, name: &str) -> bool {
match self.routes.get(name) {
Some(&index) => self.inner[index].is_tool_idempotent(name),
None => false,
}
}
fn tool_redacts_audit(&self, name: &str) -> bool {
match self.routes.get(name) {
Some(&index) => self.inner[index].tool_redacts_audit(name),
None => false,
}
}
}
pub(crate) fn tool_error_to_envelope(
id: Option<serde_json::Value>,
e: ToolError,
) -> serde_json::Value {
let stable_msg = match &e {
ToolError::UnknownTool(name) => {
warn!(rpc_id = ?id, tool = %name, "tools/call: unknown tool");
format!("unknown tool: {name}")
}
ToolError::InvalidArgs(reason) => {
warn!(rpc_id = ?id, reason = %reason, "tools/call: invalid args");
reason.clone()
}
_ => {
warn!(rpc_id = ?id, error = %e, "tools/call failed");
"tool invocation failed".into()
}
};
rpc_error(id, JSONRPC_SERVER_ERROR, &stable_msg)
}
fn tool_def_to_mcp_descriptor(def: &ToolDef) -> serde_json::Value {
serde_json::json!({
"name": def.name,
"description": def.description,
"inputSchema": def.json_schema,
})
}
fn initialize_result_with_sampling() -> serde_json::Value {
initialize_result_inner(true)
}
fn initialize_result_without_sampling() -> serde_json::Value {
initialize_result_inner(false)
}
fn initialize_result_inner(with_sampling: bool) -> serde_json::Value {
let mut capabilities = serde_json::json!({ "tools": {} });
if with_sampling {
capabilities["sampling"] = serde_json::json!({});
}
serde_json::json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": capabilities,
"serverInfo": { "name": "klieo-mcp-server", "version": env!("CARGO_PKG_VERSION") }
})
}
pub(crate) fn rpc_ok(id: Option<serde_json::Value>, result: serde_json::Value) -> serde_json::Value {
serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": result })
}
#[derive(Debug)]
enum InboundKind {
Request,
Notification,
OutboundResponse(i64),
Unparseable,
}
fn classify_inbound(value: &serde_json::Value) -> InboundKind {
let has_method = value.get("method").is_some();
let id = value.get("id");
if has_method {
return if id.is_some() {
InboundKind::Request
} else {
InboundKind::Notification
};
}
let has_payload = value.get("result").is_some() || value.get("error").is_some();
match (id.and_then(serde_json::Value::as_i64), has_payload) {
(Some(id), true) => InboundKind::OutboundResponse(id),
_ => InboundKind::Unparseable,
}
}
async fn write_frame(
writer: &outbound::SharedWriter,
envelope: &serde_json::Value,
) -> Result<(), McpServerError> {
let bytes = serde_json::to_vec(envelope)?;
let mut guard = writer.lock().await;
guard.write_all(&bytes).await?;
guard.write_all(b"\n").await?;
guard.flush().await?;
Ok(())
}
pub(crate) fn rpc_error(
id: Option<serde_json::Value>,
code: i64,
message: &str,
) -> serde_json::Value {
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": { "code": code, "message": message }
})
}
#[doc(hidden)]
pub fn __test_noop_ctx() -> klieo_core::tool::ToolCtx {
noop_ctx()
}
fn noop_ctx() -> ToolCtx {
let bus = klieo_bus_memory::MemoryBus::new();
ToolCtx::new(bus.pubsub, bus.kv, bus.jobs)
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use klieo_core::tool::Tool;
use std::sync::OnceLock;
struct EmptyInvoker;
#[async_trait]
impl ToolInvoker for EmptyInvoker {
fn catalogue(&self) -> Vec<ToolDef> {
Vec::new()
}
async fn invoke(
&self,
name: &str,
_args: serde_json::Value,
_ctx: ToolCtx,
) -> Result<serde_json::Value, ToolError> {
Err(ToolError::UnknownTool(name.into()))
}
}
struct Echo;
#[async_trait]
impl Tool for Echo {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"echoes back its args"
}
fn json_schema(&self) -> &serde_json::Value {
static S: OnceLock<serde_json::Value> = OnceLock::new();
S.get_or_init(|| serde_json::json!({"type": "object"}))
}
async fn invoke(
&self,
args: serde_json::Value,
_ctx: ToolCtx,
) -> Result<serde_json::Value, ToolError> {
Ok(args)
}
}
struct OneToolInvoker;
#[async_trait]
impl ToolInvoker for OneToolInvoker {
fn catalogue(&self) -> Vec<ToolDef> {
vec![ToolDef::new(
"echo",
"echoes back its args",
serde_json::json!({"type": "object"}),
)]
}
async fn invoke(
&self,
name: &str,
args: serde_json::Value,
ctx: ToolCtx,
) -> Result<serde_json::Value, ToolError> {
if name == "echo" {
Echo.invoke(args, ctx).await
} else {
Err(ToolError::UnknownTool(name.into()))
}
}
}
#[tokio::test]
async fn initialize_returns_server_info() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let resp = server
.handle_line(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#)
.await;
let info = resp["result"]["serverInfo"]["name"].as_str().unwrap();
assert_eq!(info, "klieo-mcp-server");
}
#[tokio::test]
async fn tools_list_surfaces_invoker_catalogue() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let resp = server
.handle_line(r#"{"jsonrpc":"2.0","id":2,"method":"tools/list"}"#)
.await;
let tools = resp["result"]["tools"].as_array().unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0]["name"], "echo");
}
#[tokio::test]
async fn tools_call_dispatches_to_invoker() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"echo","arguments":{"hello":"world"}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(text.contains("hello"));
assert!(text.contains("world"));
}
#[tokio::test]
async fn unknown_method_returns_method_not_found() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let resp = server
.handle_line(r#"{"jsonrpc":"2.0","id":4,"method":"nope"}"#)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_METHOD_NOT_FOUND);
}
#[tokio::test]
async fn tools_call_without_params_returns_server_error() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let resp = server
.handle_line(r#"{"jsonrpc":"2.0","id":5,"method":"tools/call"}"#)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_SERVER_ERROR);
assert!(resp["error"]["message"]
.as_str()
.unwrap()
.contains("missing params"));
}
#[tokio::test]
async fn tools_call_unknown_tool_surfaces_invoker_error() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":6,"method":"tools/call","params":{"name":"does-not-exist","arguments":{}}}"#,
)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_SERVER_ERROR);
assert!(resp["error"]["message"]
.as_str()
.unwrap()
.contains("does-not-exist"));
}
#[tokio::test]
async fn malformed_frame_returns_sanitised_parse_error() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let resp = server.handle_line("not json").await;
assert_eq!(resp["error"]["code"], JSONRPC_PARSE_ERROR);
let msg = resp["error"]["message"].as_str().unwrap();
assert_eq!(msg, "malformed JSON-RPC frame");
}
#[tokio::test]
async fn handle_jsonrpc_dispatches_initialize() {
let server = McpServer::builder()
.add_tools(Arc::new(EmptyInvoker))
.build()
.unwrap();
let req = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {}
});
let resp = server.handle_jsonrpc(req, None).await;
assert_eq!(resp["jsonrpc"], "2.0");
assert_eq!(resp["id"], 1);
assert!(resp["result"].is_object());
}
#[tokio::test]
async fn handle_jsonrpc_returns_method_not_found_for_unknown() {
let server = McpServer::builder()
.add_tools(Arc::new(EmptyInvoker))
.build()
.unwrap();
let req = serde_json::json!({
"jsonrpc": "2.0",
"id": 7,
"method": "no_such_method"
});
let resp = server.handle_jsonrpc(req, None).await;
assert_eq!(resp["error"]["code"], JSONRPC_METHOD_NOT_FOUND);
assert_eq!(resp["id"], 7);
}
#[tokio::test]
async fn classify_inbound_recognises_request_shape() {
let frame = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
});
assert!(matches!(classify_inbound(&frame), InboundKind::Request));
}
#[tokio::test]
async fn classify_inbound_recognises_notification_shape() {
let frame = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
});
assert!(matches!(
classify_inbound(&frame),
InboundKind::Notification
));
}
#[tokio::test]
async fn classify_inbound_recognises_outbound_result_shape() {
let frame = serde_json::json!({
"jsonrpc": "2.0",
"id": 42,
"result": {"role": "assistant"}
});
match classify_inbound(&frame) {
InboundKind::OutboundResponse(id) => assert_eq!(id, 42),
other => panic!("expected OutboundResponse(42), got {other:?}"),
}
}
#[tokio::test]
async fn classify_inbound_recognises_outbound_error_shape() {
let frame = serde_json::json!({
"jsonrpc": "2.0",
"id": 7,
"error": {"code": -32601, "message": "Method not found"}
});
match classify_inbound(&frame) {
InboundKind::OutboundResponse(id) => assert_eq!(id, 7),
other => panic!("expected OutboundResponse(7), got {other:?}"),
}
}
#[tokio::test]
async fn classify_inbound_rejects_no_method_no_id() {
let frame = serde_json::json!({"jsonrpc": "2.0"});
assert!(matches!(classify_inbound(&frame), InboundKind::Unparseable));
}
#[tokio::test]
async fn classify_inbound_rejects_bare_id_without_payload() {
let frame = serde_json::json!({"jsonrpc": "2.0", "id": 9});
assert!(matches!(classify_inbound(&frame), InboundKind::Unparseable));
}
type CapturedBytes = std::sync::Arc<std::sync::Mutex<Vec<u8>>>;
fn duplex_writer() -> (outbound::SharedWriter, CapturedBytes) {
let buffer: CapturedBytes = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let shared: outbound::SharedWriter = Arc::new(Mutex::new(BufferSink(buffer.clone())));
(shared, buffer)
}
struct BufferSink(CapturedBytes);
impl tokio::io::AsyncWrite for BufferSink {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
self.0
.lock()
.expect("BufferSink mutex poisoned in test")
.extend_from_slice(buf);
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}
fn captured_bytes(buffer: &CapturedBytes) -> Vec<u8> {
buffer
.lock()
.expect("captured-bytes mutex poisoned in test")
.clone()
}
#[tokio::test]
async fn process_stdio_line_writes_response_for_request() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let (writer, buffer) = duplex_writer();
let request = r#"{"jsonrpc":"2.0","id":11,"method":"tools/list"}"#;
server
.process_stdio_line(request, &writer)
.await
.expect("stdio dispatch must not fail");
let bytes = captured_bytes(&buffer);
assert!(bytes.ends_with(b"\n"), "frames are newline-delimited");
let envelope: serde_json::Value =
serde_json::from_slice(bytes.trim_ascii_end()).expect("written frame must be JSON");
assert_eq!(envelope["id"], 11);
assert!(envelope["result"]["tools"].is_array());
}
#[tokio::test]
async fn process_stdio_line_drops_outbound_response_when_table_absent() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
assert!(
server
.stdio_session
.get()
.and_then(|s| s.outbound.get())
.is_none(),
"default-built server must not wire an outbound table"
);
let (writer, buffer) = duplex_writer();
let stray = r#"{"jsonrpc":"2.0","id":99,"result":{"role":"assistant"}}"#;
server
.process_stdio_line(stray, &writer)
.await
.expect("stray response must not break the loop");
assert!(
captured_bytes(&buffer).is_empty(),
"stray outbound responses must never produce wire output"
);
}
#[tokio::test]
async fn process_stdio_line_drops_notification_without_writing() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let (writer, buffer) = duplex_writer();
let notification = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
server
.process_stdio_line(notification, &writer)
.await
.expect("notification dispatch must not fail");
assert!(
captured_bytes(&buffer).is_empty(),
"notifications must not produce wire output"
);
}
#[tokio::test]
async fn process_stdio_line_drops_unparseable_frame() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let (writer, buffer) = duplex_writer();
let unparseable = r#"{"jsonrpc":"2.0"}"#;
server
.process_stdio_line(unparseable, &writer)
.await
.expect("unparseable frame must not break the loop");
assert!(
captured_bytes(&buffer).is_empty(),
"unparseable frames must not produce wire output"
);
}
#[tokio::test]
async fn process_stdio_line_writes_parse_error_for_malformed_json() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let (writer, buffer) = duplex_writer();
server
.process_stdio_line("not json", &writer)
.await
.expect("parse-error path must not fail the loop");
let bytes = captured_bytes(&buffer);
let envelope: serde_json::Value =
serde_json::from_slice(bytes.trim_ascii_end()).expect("parse-error envelope is JSON");
assert_eq!(envelope["error"]["code"], JSONRPC_PARSE_ERROR);
assert_eq!(envelope["error"]["message"], "malformed JSON-RPC frame");
}
#[tokio::test]
async fn initialize_arm_records_roots_capability_when_advertised() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let req = serde_json::json!({
"jsonrpc": "2.0",
"id": 400,
"method": "initialize",
"params": { "capabilities": { "roots": {} } }
});
server.handle_jsonrpc(req, None).await;
assert!(
server.client_caps.lock().await.roots_supported,
"initialize must record advertised roots capability"
);
}
#[tokio::test]
async fn initialize_arm_defaults_roots_unsupported_when_absent() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let req = serde_json::json!({
"jsonrpc": "2.0",
"id": 401,
"method": "initialize",
"params": { "capabilities": {} }
});
server.handle_jsonrpc(req, None).await;
assert!(
!server.client_caps.lock().await.roots_supported,
"initialize must leave roots_supported=false when absent"
);
}
#[tokio::test]
async fn initialize_result_includes_sampling_when_flag_set() {
let payload = super::initialize_result_with_sampling();
assert!(
payload["capabilities"]["sampling"].is_object(),
"initialize_result_with_sampling must surface capabilities.sampling; got: {payload}"
);
}
#[tokio::test]
async fn initialize_result_omits_sampling_when_flag_unset() {
let payload = super::initialize_result_without_sampling();
assert!(
payload["capabilities"].get("sampling").is_none(),
"initialize_result_without_sampling must omit capabilities.sampling; got: {payload}"
);
}
#[tokio::test]
async fn initialized_notification_returns_null_value() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
});
let resp = server.handle_jsonrpc(req, None).await;
assert!(
resp.is_null(),
"notifications/initialized must yield a Null sentinel; got: {resp}"
);
}
#[tokio::test]
async fn list_changed_notification_returns_null_value() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/roots/list_changed"
});
let resp = server.handle_jsonrpc(req, None).await;
assert!(
resp.is_null(),
"notifications/roots/list_changed must yield a Null sentinel; got: {resp}"
);
}
#[tokio::test]
async fn list_changed_when_cache_absent_is_noop() {
let server = McpServer::expose_tools(Arc::new(OneToolInvoker));
assert!(
server
.stdio_session
.get()
.and_then(|s| s.roots_cache.get())
.is_none(),
"default-built server must not wire a roots cache"
);
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/roots/list_changed"
});
let resp = server.handle_jsonrpc(req, None).await;
assert!(
resp.is_null(),
"cache-absent list_changed must still yield Null; got: {resp}"
);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn tool_ctx_with_progress_threads_cancel() {
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
let server = Arc::new(
McpServer::builder()
.add_tools(Arc::new(EmptyInvoker))
.build()
.unwrap(),
);
let (tx, _rx) = tokio::sync::broadcast::channel::<klieo_core::AgentEvent>(8);
let token = CancellationToken::new();
let ctx = server.tool_ctx_with_progress(tx, token.clone(), None, None);
token.cancel();
assert!(ctx.cancel.is_cancelled());
}
mod expose_agent_tests {
use super::*;
use async_trait::async_trait;
use klieo_core::agent::{Agent, AgentContext};
use klieo_core::error::Error as KlieoError;
use klieo_core::llm::ToolDef;
use klieo_core::test_utils::fake_context;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
struct GreetIn {
who: String,
}
#[derive(Debug, Clone, Serialize)]
struct GreetOut {
greeting: String,
}
struct Greeter;
#[async_trait]
impl Agent for Greeter {
type Input = GreetIn;
type Output = GreetOut;
type Error = KlieoError;
fn name(&self) -> &str {
"greeter"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
_ctx: AgentContext,
input: GreetIn,
) -> Result<GreetOut, KlieoError> {
Ok(GreetOut {
greeting: format!("hello {}", input.who),
})
}
}
#[derive(Debug, Clone, Serialize)]
struct CancelObserveOut {
state: String,
}
struct CancelObserver;
#[async_trait]
impl Agent for CancelObserver {
type Input = serde_json::Value;
type Output = CancelObserveOut;
type Error = KlieoError;
fn name(&self) -> &str {
"cancel-observer"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
ctx: AgentContext,
_input: serde_json::Value,
) -> Result<CancelObserveOut, KlieoError> {
let state = if ctx.cancel.is_cancelled() {
"cancelled".into()
} else {
"ran".into()
};
Ok(CancelObserveOut { state })
}
}
fn fresh_ctx() -> AgentContext {
fake_context("greeter")
}
fn one_object_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {"who": {"type": "string"}},
"required": ["who"]
})
}
#[tokio::test]
async fn expose_agent_with_schema_lists_agent_as_single_tool() {
let server = McpServer::expose_agent_with_schema(
Greeter,
one_object_schema(),
Arc::new(fresh_ctx),
);
let resp = server
.handle_line(r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#)
.await;
let tools = resp["result"]["tools"].as_array().unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0]["name"], "greeter");
assert_eq!(tools[0]["inputSchema"]["type"], "object");
}
#[tokio::test]
async fn expose_agent_with_schema_dispatches_tools_call_through_agent() {
let server = McpServer::expose_agent_with_schema(
Greeter,
one_object_schema(),
Arc::new(fresh_ctx),
);
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"greeter","arguments":{"who":"world"}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(
text.contains(r#""greeting":"hello world""#),
"tools/call must return serialised agent output; got: {text}"
);
}
#[tokio::test]
async fn expose_agent_with_schema_rejects_unknown_tool_name() {
let server = McpServer::expose_agent_with_schema(
Greeter,
one_object_schema(),
Arc::new(fresh_ctx),
);
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"not-greeter","arguments":{}}}"#,
)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_SERVER_ERROR);
assert!(resp["error"]["message"]
.as_str()
.unwrap()
.contains("not-greeter"));
}
#[tokio::test]
async fn expose_agent_with_schema_rejects_malformed_args() {
let server = McpServer::expose_agent_with_schema(
Greeter,
one_object_schema(),
Arc::new(fresh_ctx),
);
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"greeter","arguments":{}}}"#,
)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_SERVER_ERROR);
let msg = resp["error"]["message"].as_str().unwrap();
assert!(
msg.contains("arguments do not match inputSchema"),
"wire message must be the sanitised string; got: {msg}"
);
assert!(
!msg.contains("GreetIn") && !msg.contains("missing field"),
"internal decode detail must not leak: {msg}"
);
}
#[tokio::test]
async fn expose_agent_sanitises_run_error_on_wire() {
struct Failing;
#[async_trait]
impl Agent for Failing {
type Input = serde_json::Value;
type Output = serde_json::Value;
type Error = KlieoError;
fn name(&self) -> &str {
"failing"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
_ctx: AgentContext,
_input: serde_json::Value,
) -> Result<serde_json::Value, KlieoError> {
Err(KlieoError::BadResponse(
"internal: token=secret-abc upstream=https://provider/url".into(),
))
}
}
let server = McpServer::expose_agent_with_schema(
Failing,
serde_json::json!({}),
Arc::new(fresh_ctx),
);
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":99,"method":"tools/call","params":{"name":"failing","arguments":{}}}"#,
)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_SERVER_ERROR);
let msg = resp["error"]["message"].as_str().unwrap();
assert!(
msg.contains("tool invocation failed"),
"wire message must contain the sanitised stable string; got: {msg}"
);
assert!(
!msg.contains("secret-abc") && !msg.contains("https://"),
"internal error detail must not leak: {msg}"
);
}
#[tokio::test]
async fn builder_propagates_parent_cancel_into_ctx() {
let parent = CancellationToken::new();
let server = McpServer::builder()
.with_parent_cancel(parent.clone())
.add_agent_with_schema(CancelObserver, serde_json::json!({}), Arc::new(fresh_ctx))
.build()
.unwrap();
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":200,"method":"tools/call","params":{"name":"cancel-observer","arguments":{}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(
text.contains(r#""state":"ran""#),
"live parent token must produce live ctx.cancel; got: {text}"
);
parent.cancel();
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":201,"method":"tools/call","params":{"name":"cancel-observer","arguments":{}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(
text.contains(r#""state":"cancelled""#),
"cancelled parent must propagate into ctx.cancel via child_token; got: {text}"
);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn tool_ctx_with_progress_cancel_cascades_into_agent_context() {
let request_cancel = CancellationToken::new();
let server = Arc::new(
McpServer::builder()
.add_agent_with_schema(
CancelObserver,
serde_json::json!({}),
Arc::new(fresh_ctx),
)
.build()
.unwrap(),
);
let (tx, _rx) = tokio::sync::broadcast::channel::<klieo_core::AgentEvent>(8);
request_cancel.cancel();
let tool_ctx = server.tool_ctx_with_progress(tx, request_cancel, None, None);
let result = server
.invoker
.invoke("cancel-observer", serde_json::json!({}), tool_ctx)
.await
.unwrap();
let text = result.to_string();
assert!(
text.contains(r#""state":"cancelled""#),
"cancelled request token must cascade into AgentContext.cancel; got: {text}"
);
}
#[tokio::test]
async fn shim_ctor_uses_default_uncancelled_parent_token() {
let server = McpServer::expose_agent_with_schema(
CancelObserver,
serde_json::json!({}),
Arc::new(fresh_ctx),
);
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":202,"method":"tools/call","params":{"name":"cancel-observer","arguments":{}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(
text.contains(r#""state":"ran""#),
"shim ctor must default to a never-cancelled parent token; got: {text}"
);
}
#[tokio::test]
async fn agent_as_tool_invoker_installs_tenant_label_from_caller_principal() {
use klieo_core::test_utils::{noop_bus, FakeLlmClient, FakeLlmStep};
const PRINCIPAL: &str = "alice@example.com";
struct EchoLoopAgent;
#[async_trait]
impl Agent for EchoLoopAgent {
type Input = serde_json::Value;
type Output = serde_json::Value;
type Error = KlieoError;
fn name(&self) -> &str {
"echo-loop"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
ctx: AgentContext,
_input: serde_json::Value,
) -> Result<serde_json::Value, KlieoError> {
let out = klieo_core::runtime::run_steps(
&ctx,
"",
klieo_core::ids::ThreadId::new("echo-loop-thread"),
klieo_core::runtime::RunOptions::default(),
)
.await?;
Ok(serde_json::Value::String(out))
}
}
let mut ctx_seed = fake_context("echo-loop");
ctx_seed.llm = Arc::new(
FakeLlmClient::new("fake").with_steps(vec![FakeLlmStep::Text("done".into())]),
);
let episodic_for_probe = ctx_seed.episodic.clone();
let short_term_for_probe = ctx_seed.short_term.clone();
let run_id_for_probe = ctx_seed.run_id;
let slot = Arc::new(std::sync::Mutex::new(Some(ctx_seed)));
let ctx_factory: AgentContextFactory = Arc::new(move || {
slot.lock()
.unwrap()
.take()
.expect("ctx_factory called more than once")
});
let server = McpServer::builder()
.add_agent_with_schema(
EchoLoopAgent,
serde_json::json!({"type": "object"}),
ctx_factory,
)
.build()
.unwrap();
let (pubsub, _, kv, jobs) = noop_bus();
let tool_ctx = klieo_core::tool::ToolCtx::new(pubsub, kv, jobs)
.with_caller_principal(PRINCIPAL.into());
let _ = server
.invoker
.invoke(
"echo-loop",
serde_json::json!({}),
tool_ctx,
)
.await
.unwrap();
let expected = klieo_core::principal_hash(PRINCIPAL);
let episodes = episodic_for_probe.replay(run_id_for_probe).await.unwrap();
let labels: Vec<&str> = episodes
.iter()
.filter_map(|e| match e {
klieo_core::Episode::RunAttributed { tenant_label } => {
Some(tenant_label.as_str())
}
_ => None,
})
.collect();
assert_eq!(
labels,
vec![expected.as_str()],
"exactly one RunAttributed carrying principal_hash; got {episodes:?}",
);
for ep in &episodes {
let payload = serde_json::to_string(ep).unwrap();
assert!(
!payload.contains(PRINCIPAL),
"raw principal leaked into recorded episode: {payload}",
);
}
let history = short_term_for_probe
.load(klieo_core::ids::ThreadId::new("echo-loop-thread"), 8192)
.await
.unwrap_or_default();
for msg in &history {
assert!(
!msg.content.contains(PRINCIPAL),
"principal leaked into short-term memory: {}",
msg.content
);
}
}
#[tokio::test]
async fn agent_as_tool_invoker_records_run_origin_from_parent_anchor() {
use klieo_core::test_utils::{noop_bus, FakeLlmClient, FakeLlmStep};
const PRINCIPAL: &str = "alice@example.com";
const ANCHOR: &str = "sha256:deadbeefcafe0123";
struct EchoLoopAgent;
#[async_trait]
impl Agent for EchoLoopAgent {
type Input = serde_json::Value;
type Output = serde_json::Value;
type Error = KlieoError;
fn name(&self) -> &str {
"echo-origin"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
ctx: AgentContext,
_input: serde_json::Value,
) -> Result<serde_json::Value, KlieoError> {
let out = klieo_core::runtime::run_steps(
&ctx,
"",
klieo_core::ids::ThreadId::new("echo-origin-thread"),
klieo_core::runtime::RunOptions::default(),
)
.await?;
Ok(serde_json::Value::String(out))
}
}
let mut ctx_seed = fake_context("echo-origin");
ctx_seed.llm = Arc::new(
FakeLlmClient::new("fake").with_steps(vec![FakeLlmStep::Text("done".into())]),
);
let episodic_for_probe = ctx_seed.episodic.clone();
let short_term_for_probe = ctx_seed.short_term.clone();
let run_id_for_probe = ctx_seed.run_id;
let slot = Arc::new(std::sync::Mutex::new(Some(ctx_seed)));
let ctx_factory: AgentContextFactory = Arc::new(move || {
slot.lock()
.unwrap()
.take()
.expect("ctx_factory called more than once")
});
let server = McpServer::builder()
.add_agent_with_schema(
EchoLoopAgent,
serde_json::json!({"type": "object"}),
ctx_factory,
)
.build()
.unwrap();
let (pubsub, _, kv, jobs) = noop_bus();
let tool_ctx = klieo_core::tool::ToolCtx::new(pubsub, kv, jobs)
.with_caller_principal(PRINCIPAL.into())
.with_parent_anchor(ANCHOR.into());
let _ = server
.invoker
.invoke("echo-origin", serde_json::json!({}), tool_ctx)
.await
.unwrap();
let episodes = episodic_for_probe.replay(run_id_for_probe).await.unwrap();
let anchors: Vec<&str> = episodes
.iter()
.filter_map(|e| match e {
klieo_core::Episode::RunOrigin { parent_anchor } => {
Some(parent_anchor.as_str())
}
_ => None,
})
.collect();
assert_eq!(
anchors,
vec![ANCHOR],
"exactly one RunOrigin carrying the verbatim anchor; got {episodes:?}",
);
let attributed = episodes
.iter()
.filter(|e| matches!(e, klieo_core::Episode::RunAttributed { .. }))
.count();
assert_eq!(
attributed, 1,
"RunOrigin co-emitted with exactly one RunAttributed; got {episodes:?}",
);
let history = short_term_for_probe
.load(klieo_core::ids::ThreadId::new("echo-origin-thread"), 8192)
.await
.unwrap_or_default();
for msg in &history {
assert!(
!msg.content.contains(ANCHOR),
"anchor leaked into short-term memory: {}",
msg.content
);
}
}
#[tokio::test]
async fn builder_with_no_invokers_returns_no_invokers_error() {
let result = McpServer::builder().build();
assert!(matches!(result, Err(McpBuildError::NoInvokers)));
}
#[tokio::test]
async fn with_client_sampling_sets_capability_flag() {
let server = McpServer::builder()
.with_client_sampling()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.build()
.unwrap();
assert!(
server.declare_sampling,
"with_client_sampling() must set declare_sampling=true on the built server"
);
}
#[tokio::test]
async fn default_builder_does_not_declare_sampling() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.build()
.unwrap();
assert!(
!server.declare_sampling,
"default builder must leave declare_sampling=false"
);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn default_session_idle_timeout_is_5min() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.build()
.unwrap();
assert_eq!(
server.session_idle_timeout,
std::time::Duration::from_secs(300),
"default session idle timeout must be 5 minutes"
);
assert!(
server.sessions.read().await.is_empty(),
"HTTP server must hold zero sessions before any initialize POST"
);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn with_session_idle_timeout_overrides_default() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.with_session_idle_timeout(std::time::Duration::from_secs(42))
.build()
.unwrap();
assert_eq!(
server.session_idle_timeout,
std::time::Duration::from_secs(42),
"with_session_idle_timeout must override the default"
);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn zero_duration_records_disabled_watchdog() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.with_session_idle_timeout(std::time::Duration::ZERO)
.build()
.unwrap();
assert_eq!(
server.session_idle_timeout,
std::time::Duration::ZERO,
"Duration::ZERO must thread through to record disabled-watchdog intent"
);
}
#[cfg(feature = "http")]
#[test]
fn default_max_sessions_is_1024() {
assert_eq!(DEFAULT_MAX_SESSIONS, 1024);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn with_max_sessions_overrides_default() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.with_max_sessions(64)
.build()
.unwrap();
assert_eq!(
server.max_sessions, 64,
"with_max_sessions must override the default cap"
);
}
#[cfg(feature = "http")]
#[test]
#[should_panic(expected = "max_sessions must be > 0")]
fn with_max_sessions_panics_on_zero() {
let _ = McpServer::builder().with_max_sessions(0);
}
#[cfg(feature = "http")]
#[test]
fn default_divisor_is_sixteen() {
assert_eq!(DEFAULT_MAX_SESSIONS_PER_PRINCIPAL_DIVISOR, 16);
}
#[cfg(feature = "http")]
#[test]
fn default_per_principal_derives_from_max_sessions() {
assert_eq!(default_max_sessions_per_principal(1024), 64);
assert_eq!(default_max_sessions_per_principal(32), 2);
}
#[cfg(feature = "http")]
#[test]
fn default_per_principal_floors_at_one() {
assert_eq!(default_max_sessions_per_principal(0), 1);
assert_eq!(default_max_sessions_per_principal(15), 1);
assert_eq!(default_max_sessions_per_principal(16), 1);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn with_max_sessions_per_principal_overrides_default() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.with_max_sessions(1024)
.with_max_sessions_per_principal(8)
.build()
.unwrap();
assert_eq!(
server.max_sessions_per_principal, 8,
"with_max_sessions_per_principal must override the default sub-cap"
);
}
#[cfg(feature = "http")]
#[test]
#[should_panic(expected = "max_sessions_per_principal must be > 0")]
fn with_max_sessions_per_principal_panics_on_zero() {
let _ = McpServer::builder().with_max_sessions_per_principal(0);
}
#[cfg(feature = "http")]
#[test]
fn default_sse_replay_capacity_is_256() {
assert_eq!(DEFAULT_SSE_REPLAY_CAPACITY, 256);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn with_sse_replay_capacity_overrides_default() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.with_sse_replay_capacity(8)
.build()
.unwrap();
assert_eq!(server.sse_replay_capacity, 8);
assert!(server.sse_replay_enabled());
let off = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.with_sse_replay_capacity(0)
.build()
.unwrap();
assert_eq!(off.sse_replay_capacity, 0);
assert!(!off.sse_replay_enabled());
}
#[tokio::test]
async fn builder_supports_multi_agent_dispatch() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.add_agent_with_schema(CancelObserver, serde_json::json!({}), Arc::new(fresh_ctx))
.build()
.unwrap();
let resp = server
.handle_line(r#"{"jsonrpc":"2.0","id":300,"method":"tools/list"}"#)
.await;
let tools = resp["result"]["tools"].as_array().unwrap();
let names: Vec<&str> = tools.iter().map(|t| t["name"].as_str().unwrap()).collect();
assert_eq!(tools.len(), 2);
assert!(names.contains(&"greeter"));
assert!(names.contains(&"cancel-observer"));
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":301,"method":"tools/call","params":{"name":"greeter","arguments":{"who":"multi"}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(text.contains(r#""greeting":"hello multi""#));
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":302,"method":"tools/call","params":{"name":"cancel-observer","arguments":{}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(text.contains(r#""state":"ran""#));
}
#[tokio::test]
async fn builder_parent_cancel_propagates_into_every_agent() {
let parent = CancellationToken::new();
let server = McpServer::builder()
.with_parent_cancel(parent.clone())
.add_agent_with_schema(CancelObserver, serde_json::json!({}), Arc::new(fresh_ctx))
.add_tools(Arc::new(super::OneToolInvoker))
.build()
.unwrap();
parent.cancel();
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":303,"method":"tools/call","params":{"name":"cancel-observer","arguments":{}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(
text.contains(r#""state":"cancelled""#),
"builder-level parent_cancel must reach every add_agent_* invoker; got: {text}"
);
}
#[tokio::test]
async fn builder_multi_agent_unknown_tool_returns_error() {
let server = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.add_agent_with_schema(CancelObserver, serde_json::json!({}), Arc::new(fresh_ctx))
.build()
.unwrap();
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":304,"method":"tools/call","params":{"name":"no-such-agent","arguments":{}}}"#,
)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_SERVER_ERROR);
assert!(
resp["error"]["message"]
.as_str()
.unwrap()
.contains("no-such-agent"),
"UnknownTool error must reference the requested name"
);
}
#[tokio::test]
async fn builder_build_returns_duplicate_tool_error() {
let result = McpServer::builder()
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.add_agent_with_schema(Greeter, one_object_schema(), Arc::new(fresh_ctx))
.build();
let Err(McpBuildError::DuplicateTool(ref name)) = result else {
panic!("expected DuplicateTool error");
};
assert!(
!name.is_empty(),
"DuplicateTool must carry the colliding tool name"
);
assert_eq!(
name, "greeter",
"DuplicateTool must name the colliding tool"
);
}
#[test]
fn merged_invoker_forwards_tool_redacts_audit_per_owner() {
use klieo_core::test_utils::FakeToolInvoker;
let pii_owner: Arc<dyn ToolInvoker> = Arc::new(
FakeToolInvoker::new().with_redacting_tool("claimant_lookup", "handles PII", Ok),
);
let plain_owner: Arc<dyn ToolInvoker> =
Arc::new(FakeToolInvoker::new().with_tool("echo", "plain", Ok));
let merged = MergedInvoker::new(vec![pii_owner, plain_owner])
.expect("distinct tool names must merge without DuplicateTool");
assert!(
merged.tool_redacts_audit("claimant_lookup"),
"a PII-flagged tool's redaction must survive the merge; \
default-false here would record raw PII (fail-open)"
);
assert!(
!merged.tool_redacts_audit("echo"),
"an unflagged tool must not be reported as redacting"
);
assert!(
!merged.tool_redacts_audit("no-such-tool"),
"an unrouted name hits the None arm and must default to false"
);
}
#[test]
fn mcp_builder_build_returns_cancel_requires_arc_error() {
let result = McpServer::builder()
.add_tools(Arc::new(OneToolInvoker))
.with_cancel_subscription()
.build();
assert!(
matches!(result, Err(McpBuildError::CancelRequiresArc)),
"build() with cancel subscription must return CancelRequiresArc",
);
}
#[tokio::test]
async fn expose_agent_sanitises_encode_error_on_wire() {
struct NonSerialisable;
impl serde::Serialize for NonSerialisable {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Err(serde::ser::Error::custom(
"internal: token=secret-encode-abc upstream=https://provider/encode",
))
}
}
struct EncodeFailing;
#[async_trait]
impl Agent for EncodeFailing {
type Input = serde_json::Value;
type Output = NonSerialisable;
type Error = KlieoError;
fn name(&self) -> &str {
"encode-failing"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
_ctx: AgentContext,
_input: serde_json::Value,
) -> Result<NonSerialisable, KlieoError> {
Ok(NonSerialisable)
}
}
let server = McpServer::expose_agent_with_schema(
EncodeFailing,
serde_json::json!({}),
Arc::new(fresh_ctx),
);
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":100,"method":"tools/call","params":{"name":"encode-failing","arguments":{}}}"#,
)
.await;
assert_eq!(resp["error"]["code"], JSONRPC_SERVER_ERROR);
let msg = resp["error"]["message"].as_str().unwrap();
assert!(
msg.contains("tool invocation failed"),
"wire message must contain the sanitised stable string; got: {msg}"
);
assert!(
!msg.contains("secret-encode-abc") && !msg.contains("https://"),
"internal encode-error detail must not leak: {msg}"
);
}
#[tokio::test]
async fn tool_ctx_factory_invoked_per_request() {
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = Arc::new(AtomicUsize::new(0));
let c2 = counter.clone();
let factory: ToolCtxFactory = Arc::new(move || {
c2.fetch_add(1, Ordering::SeqCst);
default_tool_ctx_factory()()
});
let server = McpServer::builder()
.with_tool_ctx_factory(factory)
.add_tools(Arc::new(super::OneToolInvoker))
.build()
.unwrap();
let req = serde_json::json!({
"jsonrpc": "2.0", "id": 1, "method": "tools/call",
"params": { "name": "echo", "arguments": {"x": 1} }
});
server.handle_jsonrpc(req.clone(), None).await;
server.handle_jsonrpc(req, None).await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn agent_as_tool_invoker_propagates_progress_to_agent_context() {
use klieo_core::AgentEvent;
use std::sync::Mutex;
use tokio::sync::broadcast;
struct CapturingAgent {
captured: Arc<Mutex<Option<Option<broadcast::Sender<AgentEvent>>>>>,
}
#[async_trait]
impl Agent for CapturingAgent {
type Input = serde_json::Value;
type Output = serde_json::Value;
type Error = KlieoError;
fn name(&self) -> &str {
"capturing"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
ctx: AgentContext,
_input: serde_json::Value,
) -> Result<serde_json::Value, KlieoError> {
*self.captured.lock().unwrap() = Some(ctx.progress.clone());
Ok(serde_json::json!({}))
}
}
let captured = Arc::new(Mutex::new(None::<Option<broadcast::Sender<AgentEvent>>>));
let agent = CapturingAgent {
captured: captured.clone(),
};
let (tx, _rx) = broadcast::channel::<AgentEvent>(16);
let tx_for_factory = tx.clone();
let factory: ToolCtxFactory = Arc::new(move || {
let bus = klieo_bus_memory::MemoryBus::new();
klieo_core::tool::ToolCtx::new(bus.pubsub, bus.kv, bus.jobs)
.with_progress(tx_for_factory.clone())
});
let server = McpServer::builder()
.with_tool_ctx_factory(factory)
.add_agent_with_schema(agent, serde_json::json!({}), Arc::new(fresh_ctx))
.build()
.unwrap();
let req = serde_json::json!({
"jsonrpc": "2.0", "id": 1, "method": "tools/call",
"params": { "name": "capturing", "arguments": {} }
});
let resp = server.handle_jsonrpc(req, None).await;
assert!(
resp["result"].is_object(),
"tools/call must succeed; got: {resp}"
);
let captured_progress = captured
.lock()
.unwrap()
.clone()
.expect("Agent::run was never invoked");
assert!(
captured_progress.is_some(),
"AgentContext.progress was None despite ToolCtx.progress=Some"
);
}
#[cfg(feature = "schemars")]
mod auto_derive {
use super::*;
use schemars::JsonSchema;
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
struct DerivedIn {
who: String,
}
#[derive(Debug, Clone, Serialize)]
struct DerivedOut {
greeting: String,
}
struct DerivedGreeter;
#[async_trait]
impl Agent for DerivedGreeter {
type Input = DerivedIn;
type Output = DerivedOut;
type Error = KlieoError;
fn name(&self) -> &str {
"derived-greeter"
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
_ctx: AgentContext,
input: DerivedIn,
) -> Result<DerivedOut, KlieoError> {
Ok(DerivedOut {
greeting: format!("hi {}", input.who),
})
}
}
#[tokio::test]
async fn expose_agent_auto_derives_schema_via_schemars() {
let server = McpServer::expose_agent(DerivedGreeter, Arc::new(fresh_ctx));
let resp = server
.handle_line(r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#)
.await;
let schema = &resp["result"]["tools"][0]["inputSchema"];
assert!(
schema["properties"]["who"].is_object(),
"derived schema must include the `who` field; got: {schema}"
);
}
}
}
#[cfg(not(feature = "governor"))]
mod expose_workflow_tests {
use super::*;
use async_trait::async_trait;
use chrono::Utc;
use klieo_core::agent::{Agent, AgentContext};
use klieo_core::error::Error as KlieoError;
use klieo_core::llm::Message;
use klieo_core::runtime::{ReviewPolicy, RunOptions};
use klieo_core::test_utils::{fake_context, fake_kv, FakeLlmClient, FakeLlmStep};
use klieo_core::ToolDef;
use klieo_hitl::HitlConfig;
use klieo_hitl_client::HitlClient;
use secrecy::SecretString;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::time::Duration;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const CHECKPOINT_BUCKET: &str = "klieo.run-checkpoints";
const WORKSPACE_ID: &str = "ws-test";
const PLANTED_SENTINEL: &str = "PLANT-SENTINEL-9F7C";
#[derive(Debug, Clone, Deserialize, Serialize)]
struct WorkflowIn {
#[allow(dead_code)]
payload: String,
}
#[derive(Debug, Clone, Serialize)]
struct UnusedOut;
struct WorkflowAgent {
name: &'static str,
}
#[async_trait]
impl Agent for WorkflowAgent {
type Input = WorkflowIn;
type Output = UnusedOut;
type Error = KlieoError;
fn name(&self) -> &str {
self.name
}
fn system_prompt(&self) -> &str {
""
}
fn tools(&self) -> &[ToolDef] {
&[]
}
async fn run(
&self,
_ctx: AgentContext,
_input: WorkflowIn,
) -> Result<UnusedOut, KlieoError> {
Err(KlieoError::BadResponse(
"workflow path must not call Agent::run".into(),
))
}
}
struct PauseOnce(std::sync::atomic::AtomicBool);
impl PauseOnce {
fn new() -> Self {
Self(std::sync::atomic::AtomicBool::new(false))
}
}
#[async_trait]
impl ReviewPolicy for PauseOnce {
async fn should_pause_for_approval(
&self,
_step: u32,
_message: &Message,
) -> Result<Option<String>, KlieoError> {
if self.0.swap(true, std::sync::atomic::Ordering::SeqCst) {
Ok(None)
} else {
Ok(Some("policy reason that MUST NOT leak to peer".into()))
}
}
}
fn workflow_ctx_with(steps: Vec<FakeLlmStep>) -> AgentContext {
let mut ctx = fake_context("workflow-test");
ctx.llm = Arc::new(FakeLlmClient::new("fake").with_steps(steps));
ctx.kv = fake_kv();
ctx
}
fn item_json(id: &str, state: &str) -> serde_json::Value {
json!({
"id": id, "workspace_id": WORKSPACE_ID, "state": state, "version": 1,
"escalation_count": 0,
"decision_context": {"subject_ref":"x","run_id":"r","payload_hash_hex":"h"},
"reviewer": null, "updated_at": "2026-06-18T00:00:00Z"
})
}
fn hitl_cfg(poll_timeout: Duration) -> HitlConfig {
HitlConfig::new(
WORKSPACE_ID,
CHECKPOINT_BUCKET,
Duration::from_millis(1),
poll_timeout,
)
}
fn gated_run_options() -> RunOptions {
RunOptions::default()
.with_review_policy(Arc::new(PauseOnce::new()))
.with_checkpoint_bucket(CHECKPOINT_BUCKET)
}
fn plain_run_options() -> RunOptions {
RunOptions::default()
}
fn one_shot_ctx_factory(ctx: AgentContext) -> AgentContextFactory {
let slot = Arc::new(std::sync::Mutex::new(Some(ctx)));
Arc::new(move || {
slot.lock()
.unwrap()
.take()
.expect("ctx_factory called more than once")
})
}
fn input_schema() -> serde_json::Value {
json!({
"type": "object",
"properties": {"payload": {"type": "string"}},
"required": ["payload"]
})
}
#[tokio::test]
async fn invoke_happy_path_returns_text_without_hitl_traffic() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/v1/hitl/items"))
.respond_with(ResponseTemplate::new(500))
.expect(0)
.mount(&mock)
.await;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text("workflow done".into())]);
let client = Arc::new(HitlClient::new(
mock.uri(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_secs(1)));
let server = McpServer::expose_workflow_with_schema(
WorkflowAgent { name: "wf-happy" },
"you are a workflow",
input_schema(),
plain_run_options(),
client,
cfg,
one_shot_ctx_factory(ctx),
)
.unwrap();
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"wf-happy","arguments":{"payload":"hi"}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(
text.contains("workflow done"),
"tools/call must return run_with_hitl text body; got: {text}"
);
}
#[tokio::test]
async fn invoke_suspend_path_redacts_reason_and_drops_checkpoint() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/v1/hitl/items"))
.respond_with(
ResponseTemplate::new(201)
.set_body_json(item_json("item-suspend", "awaiting")),
)
.mount(&mock)
.await;
Mock::given(method("GET"))
.and(path("/api/v1/hitl/items/item-suspend"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(item_json("item-suspend", "awaiting")),
)
.mount(&mock)
.await;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text(PLANTED_SENTINEL.into())]);
let client = Arc::new(HitlClient::new(
mock.uri(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_millis(5)));
let server = McpServer::expose_workflow_with_schema(
WorkflowAgent {
name: "wf-suspend",
},
"you are a suspending workflow",
input_schema(),
gated_run_options(),
client,
cfg,
one_shot_ctx_factory(ctx),
)
.unwrap();
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"wf-suspend","arguments":{"payload":"please-approve"}}}"#,
)
.await;
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(
text.contains(r#""status":"suspended""#),
"suspend response must carry status=suspended; got: {text}"
);
assert!(
text.contains("workflow suspended for human review"),
"suspend response must carry the safe wire reason; got: {text}"
);
assert!(
!text.contains("policy reason that MUST NOT leak"),
"raw ReviewPolicy reason leaked to peer: {text}"
);
assert!(
!text.contains(PLANTED_SENTINEL),
"checkpoint/conversation bytes leaked to peer: {text}"
);
}
#[tokio::test]
async fn invoke_hitl_submit_failure_maps_to_sanitised_tool_error() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/v1/hitl/items"))
.respond_with(ResponseTemplate::new(403).set_body_string("forbidden: token=xyz"))
.mount(&mock)
.await;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text("never reached".into())]);
let client = Arc::new(HitlClient::new(
mock.uri(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_secs(1)));
let server = McpServer::expose_workflow_with_schema(
WorkflowAgent { name: "wf-err" },
"",
input_schema(),
gated_run_options(),
client,
cfg,
one_shot_ctx_factory(ctx),
)
.unwrap();
let resp = server
.handle_line(
r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"wf-err","arguments":{"payload":"x"}}}"#,
)
.await;
assert!(
resp.get("error").is_some(),
"submit failure must surface as JSON-RPC error; got: {resp}"
);
let msg = resp["error"]["message"].as_str().unwrap();
assert!(
msg.contains("tool invocation failed"),
"wire message must be the sanitised stable string; got: {msg}"
);
assert!(
!msg.contains("token=xyz") && !msg.contains("forbidden"),
"internal HitlClientError detail leaked: {msg}"
);
}
#[tokio::test]
async fn caller_principal_does_not_enter_run_state() {
use klieo_core::test_utils::noop_bus;
const PRINCIPAL: &str = "alice-NEVER-IN-MEMORY@x";
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/v1/hitl/items"))
.respond_with(ResponseTemplate::new(500))
.expect(0)
.mount(&mock)
.await;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text("done".into())]);
let short_term_for_probe = ctx.short_term.clone();
let episodic_for_probe = ctx.episodic.clone();
let run_id_for_probe = ctx.run_id;
let client = Arc::new(HitlClient::new(
mock.uri(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_secs(1)));
let server = Arc::new(
McpServer::builder()
.with_hitl(client, cfg)
.add_workflow_with_schema(
WorkflowAgent { name: "wf-no-leak" },
"you are a workflow",
input_schema(),
plain_run_options(),
one_shot_ctx_factory(ctx),
)
.build()
.unwrap(),
);
let (pubsub, _, kv, jobs) = noop_bus();
let tool_ctx = klieo_core::tool::ToolCtx::new(pubsub, kv, jobs)
.with_caller_principal(PRINCIPAL.into());
let _result = server
.invoker
.invoke(
"wf-no-leak",
json!({"payload": "hi"}),
tool_ctx,
)
.await
.unwrap();
let history = short_term_for_probe
.load(klieo_core::ids::ThreadId::new("wf-no-leak:any"), 8192)
.await
.unwrap_or_default();
for msg in &history {
assert!(
!msg.content.contains(PRINCIPAL),
"principal leaked into short-term memory: {}",
msg.content
);
}
let episodes = episodic_for_probe
.replay(run_id_for_probe)
.await
.expect("episodic replay must succeed");
let expected_label = klieo_core::principal_hash(PRINCIPAL);
let attributed_labels: Vec<&str> = episodes
.iter()
.filter_map(|e| match e {
klieo_core::Episode::RunAttributed { tenant_label } => {
Some(tenant_label.as_str())
}
_ => None,
})
.collect();
assert_eq!(
attributed_labels,
vec![expected_label.as_str()],
"exactly one RunAttributed carrying principal_hash; got {episodes:?}",
);
for ep in &episodes {
let payload = serde_json::to_string(ep).expect("episode serialises");
assert!(
!payload.contains(PRINCIPAL),
"raw principal leaked into recorded episode: {payload}",
);
}
}
#[tokio::test]
async fn no_principal_yields_no_run_attributed_episode() {
use klieo_core::test_utils::noop_bus;
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/v1/hitl/items"))
.respond_with(ResponseTemplate::new(500))
.expect(0)
.mount(&mock)
.await;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text("done".into())]);
let episodic_for_probe = ctx.episodic.clone();
let run_id_for_probe = ctx.run_id;
let client = Arc::new(HitlClient::new(
mock.uri(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_secs(1)));
let server = Arc::new(
McpServer::builder()
.with_hitl(client, cfg)
.add_workflow_with_schema(
WorkflowAgent { name: "wf-anon" },
"you are a workflow",
input_schema(),
plain_run_options(),
one_shot_ctx_factory(ctx),
)
.build()
.unwrap(),
);
let (pubsub, _, kv, jobs) = noop_bus();
let tool_ctx = klieo_core::tool::ToolCtx::new(pubsub, kv, jobs);
let _ = server
.invoker
.invoke("wf-anon", json!({"payload": "hi"}), tool_ctx)
.await
.unwrap();
let episodes = episodic_for_probe
.replay(run_id_for_probe)
.await
.expect("episodic replay must succeed");
let attributed_count = episodes
.iter()
.filter(|e| matches!(e, klieo_core::Episode::RunAttributed { .. }))
.count();
assert_eq!(
attributed_count, 0,
"RunAttributed must not appear without a caller_principal; got {episodes:?}",
);
}
async fn run_suspend_with(
with_kv: bool,
with_principal: bool,
) -> serde_json::Value {
use klieo_core::test_utils::noop_bus;
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/v1/hitl/items"))
.respond_with(
ResponseTemplate::new(201)
.set_body_json(item_json("item-suspend", "awaiting")),
)
.mount(&mock)
.await;
Mock::given(method("GET"))
.and(path("/api/v1/hitl/items/item-suspend"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(item_json("item-suspend", "awaiting")),
)
.mount(&mock)
.await;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text(PLANTED_SENTINEL.into())]);
let client = Arc::new(HitlClient::new(
mock.uri(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_millis(5)));
let mut builder = McpServer::builder()
.with_hitl(client, cfg)
.add_workflow_with_schema(
WorkflowAgent { name: "wf-suspend" },
"",
input_schema(),
gated_run_options(),
one_shot_ctx_factory(ctx),
);
if with_kv {
builder = builder.with_checkpoint_kv(fake_kv());
}
let server = Arc::new(builder.build().unwrap());
let (pubsub, _, kv, jobs) = noop_bus();
let mut tool_ctx = klieo_core::tool::ToolCtx::new(pubsub, kv, jobs);
if with_principal {
tool_ctx = tool_ctx.with_caller_principal("alice@x".into());
}
server
.invoker
.invoke(
"wf-suspend",
json!({"payload": "please-approve"}),
tool_ctx,
)
.await
.unwrap()
}
#[tokio::test]
async fn suspend_with_kv_and_principal_issues_ticket() {
let envelope = run_suspend_with(true, true).await;
let body = envelope.to_string();
assert_eq!(envelope["status"], "suspended");
let ticket = envelope["ticket"].as_str().expect("ticket present");
assert!(!ticket.is_empty(), "issued ticket must be non-empty");
assert!(
!body.contains(PLANTED_SENTINEL),
"checkpoint bytes leaked: {body}"
);
assert!(
!body.contains("policy reason that MUST NOT leak"),
"raw policy reason leaked: {body}"
);
}
#[tokio::test]
async fn suspend_without_kv_falls_back_to_no_ticket_envelope() {
let envelope = run_suspend_with(false, true).await;
assert_eq!(envelope["status"], "suspended");
assert!(
envelope.get("ticket").is_none(),
"no checkpoint KV must yield slice-1 envelope (no ticket field)",
);
let body = envelope.to_string();
assert!(!body.contains(PLANTED_SENTINEL));
}
#[tokio::test]
async fn suspend_without_principal_falls_back_to_no_ticket_envelope() {
let envelope = run_suspend_with(true, false).await;
assert_eq!(envelope["status"], "suspended");
assert!(
envelope.get("ticket").is_none(),
"no caller principal must yield slice-1 envelope (no ticket field)",
);
let body = envelope.to_string();
assert!(!body.contains(PLANTED_SENTINEL));
}
#[tokio::test]
async fn principal_b_cannot_consume_principal_a_ticket() {
use crate::resume_ticket::{ResumeTicketRecord, ResumeTicketStore};
let store = ResumeTicketStore::new(fake_kv());
let token = ResumeTicketStore::mint_token();
let cp_json = serde_json::json!({
"run_id": klieo_core::ids::RunId::new(),
"step_index": 1,
"thread_id": "t-idor",
"messages": [],
"pending_tool_calls": null,
"created_at": "2026-06-18T00:00:00Z",
});
let checkpoint = serde_json::from_value(cp_json).unwrap();
let record = ResumeTicketRecord {
principal: "alice@x".into(),
workflow_name: "wf".into(),
checkpoint,
created_at: Utc::now(),
};
store.persist(&token, &record).await.unwrap();
let peeked = store.peek(&token).await.unwrap().expect("ticket present");
let principal_b = "mallory@x";
assert_ne!(
peeked.principal, principal_b,
"fixture must seed a distinct principal so the authz arm engages"
);
let consumed_by_alice = store.claim(&token).await.unwrap();
assert!(
consumed_by_alice.is_some(),
"after a foreign-principal denial the rightful owner can still resume"
);
let after = store.claim(&token).await.unwrap();
assert!(after.is_none(), "the now-consumed ticket cannot be reused");
}
#[tokio::test]
async fn concurrent_resume_runs_exactly_once() {
use crate::resume_ticket::{ResumeTicketRecord, ResumeTicketStore};
let store = Arc::new(ResumeTicketStore::new(fake_kv()));
let token = ResumeTicketStore::mint_token();
let cp_json = serde_json::json!({
"run_id": klieo_core::ids::RunId::new(),
"step_index": 1,
"thread_id": "t-conc",
"messages": [],
"pending_tool_calls": null,
"created_at": "2026-06-18T00:00:00Z",
});
let checkpoint = serde_json::from_value(cp_json).unwrap();
let record = ResumeTicketRecord {
principal: "alice@x".into(),
workflow_name: "wf".into(),
checkpoint,
created_at: Utc::now(),
};
store.persist(&token, &record).await.unwrap();
let racers: Vec<_> = (0..8)
.map(|_| {
let store = store.clone();
let token = token.clone();
tokio::spawn(async move { store.claim(&token).await })
})
.collect();
let mut winners = 0usize;
for handle in racers {
if handle.await.unwrap().unwrap().is_some() {
winners += 1;
}
}
assert_eq!(
winners, 1,
"concurrent ticket consumption must run exactly once; got {winners}"
);
}
#[tokio::test]
async fn approve_resume_drives_run_to_completion() {
use klieo_core::checkpoint::ApprovalDecision;
use std::sync::Mutex;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text("approved".into())]);
let cp_json = serde_json::json!({
"run_id": ctx.run_id,
"step_index": 1,
"thread_id": "t-resume-approve",
"messages": [],
"pending_tool_calls": null,
"created_at": "2026-06-18T00:00:00Z",
});
let checkpoint: klieo_core::checkpoint::RunCheckpoint =
serde_json::from_value(cp_json).unwrap();
let client = Arc::new(HitlClient::new(
"http://unused".to_string(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_secs(1)));
let ctx_holder = Arc::new(Mutex::new(Some(ctx)));
let ctx_factory: AgentContextFactory = Arc::new(move || {
ctx_holder
.lock()
.unwrap()
.take()
.expect("ctx_factory drained")
});
let invoker = Arc::new(crate::workflow::WorkflowAsToolInvoker::<WorkflowAgent>::new(
"wf-resume".into(),
"".into(),
input_schema(),
ctx_factory,
plain_run_options(),
crate::workflow::HitlBundle { client, cfg },
None,
#[cfg(feature = "governor")]
None,
));
let handle: Arc<dyn crate::workflow::WorkflowResumeHandle> = invoker;
let result = handle
.resume(checkpoint, ApprovalDecision::Approved, "hashed-tenant".into())
.await
.unwrap();
assert_eq!(result, serde_json::Value::String("approved".into()));
}
#[tokio::test]
async fn reject_resume_feeds_reason_back_to_model() {
use klieo_core::checkpoint::ApprovalDecision;
use klieo_core::llm::Role;
use std::sync::Mutex;
let ctx = workflow_ctx_with(vec![FakeLlmStep::Text("acknowledged".into())]);
let short_term_for_probe = ctx.short_term.clone();
let cp_json = serde_json::json!({
"run_id": ctx.run_id,
"step_index": 1,
"thread_id": "t-resume-reject",
"messages": [],
"pending_tool_calls": null,
"created_at": "2026-06-18T00:00:00Z",
});
let checkpoint: klieo_core::checkpoint::RunCheckpoint =
serde_json::from_value(cp_json).unwrap();
let client = Arc::new(HitlClient::new(
"http://unused".to_string(),
SecretString::from("tok".to_string()),
));
let cfg = Arc::new(hitl_cfg(Duration::from_secs(1)));
let ctx_holder = Arc::new(Mutex::new(Some(ctx)));
let ctx_factory: AgentContextFactory = Arc::new(move || {
ctx_holder
.lock()
.unwrap()
.take()
.expect("ctx_factory drained")
});
let invoker = Arc::new(crate::workflow::WorkflowAsToolInvoker::<WorkflowAgent>::new(
"wf-reject".into(),
"".into(),
input_schema(),
ctx_factory,
plain_run_options(),
crate::workflow::HitlBundle { client, cfg },
None,
#[cfg(feature = "governor")]
None,
));
let handle: Arc<dyn crate::workflow::WorkflowResumeHandle> = invoker;
let _ = handle
.resume(
checkpoint,
ApprovalDecision::Rejected {
reason: "BAD-IDEA-XYZ".into(),
},
"hashed-tenant".into(),
)
.await
.unwrap();
let history = short_term_for_probe
.load(
klieo_core::ids::ThreadId::new("t-resume-reject"),
8192,
)
.await
.unwrap();
let rejection_seen = history
.iter()
.any(|m| m.role == Role::Tool && m.content.contains("BAD-IDEA-XYZ"));
assert!(
rejection_seen,
"the model must see the operator's rejection reason on resume"
);
}
#[test]
fn builder_rejects_workflow_without_hitl() {
let ctx_factory: AgentContextFactory = Arc::new(|| fake_context("guard-test"));
let err = McpServer::builder()
.add_workflow_with_schema(
WorkflowAgent { name: "wf-guard" },
"",
input_schema(),
plain_run_options(),
ctx_factory,
)
.build()
.err()
.expect("workflow without with_hitl must fail build");
assert!(
matches!(err, McpBuildError::WorkflowWithoutHitl),
"expected WorkflowWithoutHitl, got: {err:?}"
);
}
}
#[test]
fn resume_errors_render_messages() {
let a = McpServerError::ResumeBufferExpired { since_id: 7 };
assert_eq!(a.to_string(), "resume window expired (since_id=7)");
let b = McpServerError::ResumeBufferNotFound("tok".into());
assert_eq!(b.to_string(), "no buffered stream for progressToken: tok");
}
#[test]
fn from_server_outbound_serialisation_maps_to_outbound_serialisation() {
use klieo_core::ServerOutboundError;
let serde_err = serde_json::from_str::<serde_json::Value>("{invalid}").unwrap_err();
let mcp_err = McpServerError::from(ServerOutboundError::Serialisation(serde_err));
assert!(
matches!(mcp_err, McpServerError::OutboundSerialisation(_)),
"ServerOutboundError::Serialisation must map to McpServerError::OutboundSerialisation; got {mcp_err:?}"
);
use std::error::Error;
assert!(
mcp_err.source().is_some(),
"McpServerError::OutboundSerialisation must expose source via #[source]"
);
}
struct NamedAuthn;
#[async_trait]
impl klieo_auth_common::Authenticator for NamedAuthn {
async fn authenticate(
&self,
_headers: &dyn klieo_auth_common::Headers,
_payload: &[u8],
) -> Result<klieo_auth_common::Identity, klieo_auth_common::AuthError> {
Ok(klieo_auth_common::Identity::new("alice"))
}
}
#[test]
fn regulated_without_tenant_kv_fails_closed() {
let err = McpServer::builder()
.add_tools(Arc::new(OneToolInvoker))
.with_authenticator(Arc::new(NamedAuthn))
.profile(klieo_core::DeploymentProfile::RegulatedMultiTenant)
.build()
.err()
.expect("must fail closed");
assert!(matches!(
err,
McpBuildError::RegulatedProfile(klieo_core::ProfileViolation::MissingTenantKv)
));
}
#[test]
fn regulated_without_authenticator_fails_closed() {
let kv = klieo_bus_memory::MemoryBus::new().kv;
let err = McpServer::builder()
.add_tools(Arc::new(OneToolInvoker))
.with_tenant_binding(kv)
.profile(klieo_core::DeploymentProfile::RegulatedMultiTenant)
.build()
.err()
.expect("must fail closed");
assert!(matches!(
err,
McpBuildError::RegulatedProfile(klieo_core::ProfileViolation::AnonymousAuth)
));
}
struct AnonAuthn;
#[async_trait]
impl klieo_auth_common::Authenticator for AnonAuthn {
async fn authenticate(
&self,
_headers: &dyn klieo_auth_common::Headers,
_payload: &[u8],
) -> Result<klieo_auth_common::Identity, klieo_auth_common::AuthError> {
Ok(klieo_auth_common::Identity::anonymous())
}
fn allows_anonymous(&self) -> bool {
true
}
}
#[test]
fn regulated_with_anonymous_authenticator_fails_closed() {
let kv = klieo_bus_memory::MemoryBus::new().kv;
let err = McpServer::builder()
.add_tools(Arc::new(OneToolInvoker))
.with_tenant_binding(kv)
.with_authenticator(Arc::new(AnonAuthn))
.profile(klieo_core::DeploymentProfile::RegulatedMultiTenant)
.build()
.err()
.expect("must fail closed");
assert!(matches!(
err,
McpBuildError::RegulatedProfile(klieo_core::ProfileViolation::AnonymousAuth)
));
}
#[test]
fn regulated_forces_strict_over_lenient_binding() {
let kv = klieo_bus_memory::MemoryBus::new().kv;
let server = McpServer::builder()
.add_tools(Arc::new(OneToolInvoker))
.with_tenant_binding(kv) .with_authenticator(Arc::new(NamedAuthn))
.profile(klieo_core::DeploymentProfile::RegulatedMultiTenant)
.build()
.expect("regulated build with named auth + kv must succeed");
assert_eq!(
server.ownership_registry.as_ref().map(|r| r.is_strict()),
Some(true)
);
}
#[test]
fn unprofiled_keeps_lenient_binding() {
let kv = klieo_bus_memory::MemoryBus::new().kv;
let server = McpServer::builder()
.add_tools(Arc::new(OneToolInvoker))
.with_tenant_binding(kv)
.with_authenticator(Arc::new(NamedAuthn))
.build()
.expect("unprofiled build ok");
assert_eq!(
server.ownership_registry.as_ref().map(|r| r.is_strict()),
Some(false)
);
}
}
#[cfg(test)]
#[cfg(feature = "http")]
mod jsonrpc_const_tests {
use super::*;
#[test]
fn jsonrpc_constants_are_i64_and_unique_at_runtime() {
let codes: [i64; 9] = [
JSONRPC_PARSE_ERROR,
JSONRPC_METHOD_NOT_FOUND,
JSONRPC_INVALID_PARAMS,
JSONRPC_SERVER_ERROR,
JSONRPC_UNAUTHENTICATED,
JSONRPC_RESUME_BUFFER_EXPIRED,
JSONRPC_RESUME_BUFFER_NOT_FOUND,
JSONRPC_LEADER_DIED,
JSONRPC_SESSION_CONFLICT,
];
let mut seen = std::collections::HashSet::new();
let mut duplicates: Vec<i64> = Vec::new();
for code in codes {
if !seen.insert(code) {
duplicates.push(code);
}
}
assert!(
duplicates.is_empty(),
"JSONRPC_* codes must be unique; found duplicates: {duplicates:?}"
);
let mut local = seen.clone();
assert!(
!local.insert(JSONRPC_PARSE_ERROR),
"duplicate detection logic broken"
);
}
}