use crate::{
rpc_error, tool_error_to_envelope, McpServer, McpServerError, JSONRPC_INVALID_PARAMS,
JSONRPC_LEADER_DIED, JSONRPC_PARSE_ERROR, JSONRPC_SERVER_ERROR, JSONRPC_UNAUTHENTICATED,
MCP_LEADER_KEY_PREFIX,
};
use axum::{
body::Bytes,
extract::{DefaultBodyLimit, State},
http::{header, HeaderMap, StatusCode},
response::{IntoResponse, Response},
routing::post,
Json, Router,
};
use futures::{Stream, StreamExt as _};
use klieo_auth_common::Identity;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tracing::{error, instrument, warn};
use tracing_opentelemetry::OpenTelemetrySpanExt as _;
const MAX_BODY_BYTES: usize = 1 << 20; const CANCEL_SUBJECT_PREFIX: &str = "klieo.mcp.cancel.";
const MCP_OWNERSHIP_KEY_PREFIX: &str = "mcp.";
const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id";
const LAST_EVENT_ID_HEADER: &str = "Last-Event-Id";
pub(crate) const SSE_FRAME_HINT_BYTES: usize = 512;
const MAX_BATCH_ITEMS: usize = 100;
fn mint_session_race_500(
raw_id: Option<&serde_json::Value>,
field: &'static str,
attempted_session_id: Option<uuid::Uuid>,
) -> Response {
error!(
target: "klieo::mcp::session",
field,
attempted_session_id = ?attempted_session_id,
"session-mint race (concurrent initialize past guard)"
);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(rpc_error(
raw_id.cloned(),
JSONRPC_SERVER_ERROR,
"internal: session mint race",
)),
)
.into_response()
}
impl McpServer {
pub fn router(self: &Arc<Self>) -> Router {
router_impl(self.clone())
}
pub async fn serve_http(self: Arc<Self>, addr: SocketAddr) -> Result<(), McpServerError> {
let cancel = self.parent_cancel.clone();
let listener = tokio::net::TcpListener::bind(addr).await?;
let router = self.router();
axum::serve(listener, router)
.with_graceful_shutdown(async move { cancel.cancelled().await })
.await?;
Ok(())
}
}
fn router_impl(server: Arc<McpServer>) -> Router {
Router::new()
.route("/mcp", post(post_mcp).get(get_mcp).delete(delete_mcp))
.layer(DefaultBodyLimit::max(MAX_BODY_BYTES))
.with_state(server)
}
#[instrument(
skip_all,
fields(
rpc.system = "klieo-mcp",
rpc.method = tracing::field::Empty,
http.request.method = "POST",
),
)]
async fn post_mcp(
State(server): State<Arc<McpServer>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let parent_cx = klieo_core::extract_traceparent(&klieo_headers_from_axum(&headers));
tracing::Span::current().set_parent(parent_cx);
if server.parent_cancel.is_cancelled() {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(rpc_error(
None,
JSONRPC_SERVER_ERROR,
"server shutting down",
)),
)
.into_response();
}
if !content_type_is_json(&headers) {
return StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response();
}
let raw: serde_json::Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, "rejected malformed JSON-RPC body");
return (
StatusCode::BAD_REQUEST,
Json(rpc_error(
None,
JSONRPC_PARSE_ERROR,
"malformed JSON-RPC body",
)),
)
.into_response();
}
};
if let Some(method) = raw.get("method").and_then(|m| m.as_str()) {
tracing::Span::current().record("rpc.method", method);
}
let identity = match enforce_authenticator(&server, &headers, &body, &raw).await {
Ok(identity) => identity,
Err(rejection) => return rejection,
};
let method = raw.get("method").and_then(|m| m.as_str()).unwrap_or("");
if method == "initialize" {
return handle_initialize_post(&server, raw, identity).await;
}
let session = if server.sessions.read().await.is_empty() {
None
} else {
match require_session(&headers, &server, identity.as_ref()).await {
Ok(s) => Some(s),
Err(rejection) => return rejection,
}
};
if raw.is_object() && raw.get("method").is_none() {
let Some(session) = session.as_ref() else {
warn!(
target: "mcp.http",
"POST /mcp body has no method and no session header; rejecting",
);
return (
StatusCode::BAD_REQUEST,
Json(rpc_error(None, JSONRPC_PARSE_ERROR, "missing method")),
)
.into_response();
};
return route_outbound_response(&server, session, raw).await;
}
if let Some(session) = session.as_ref() {
touch_last_activity(&server, session);
}
if let Some(sse) = try_sse_upgrade(server.clone(), raw.clone(), identity).await {
return sse;
}
let resp = dispatch(&server, raw, session.as_ref()).await;
(StatusCode::OK, Json(resp)).into_response()
}
fn touch_last_activity(server: &Arc<McpServer>, session: &Arc<crate::session::Session>) {
session.mark_active(server.server_start);
}
async fn route_outbound_response(
server: &Arc<McpServer>,
session: &Arc<crate::session::Session>,
raw: serde_json::Value,
) -> Response {
let id = raw.get("id").and_then(|v| v.as_i64());
let outbound = session.outbound.get();
if let (Some(id), Some(outbound)) = (id, outbound) {
outbound.complete_pending(id, raw).await;
touch_last_activity(server, session);
return StatusCode::ACCEPTED.into_response();
}
warn!(
target: "mcp.http",
"POST /mcp body has no method and no routable outbound; rejecting",
);
(
StatusCode::BAD_REQUEST,
Json(rpc_error(None, JSONRPC_PARSE_ERROR, "missing method")),
)
.into_response()
}
#[instrument(
skip_all,
fields(
rpc.system = "klieo-mcp",
http.request.method = "GET",
),
)]
async fn get_mcp(State(server): State<Arc<McpServer>>, headers: HeaderMap) -> Response {
let parent_cx = klieo_core::extract_traceparent(&klieo_headers_from_axum(&headers));
tracing::Span::current().set_parent(parent_cx);
if server.parent_cancel.is_cancelled() {
return StatusCode::SERVICE_UNAVAILABLE.into_response();
}
let identity = match enforce_authenticator_for_get(&server, &headers).await {
Ok(identity) => identity,
Err(rejection) => return rejection,
};
let session = match require_session(&headers, &server, identity.as_ref()).await {
Ok(s) => s,
Err(rejection) => return rejection,
};
let session_id = session.id.expect("HTTP session always carries Some(uuid)");
let replay = match compute_replay_window(&headers, &session, &server) {
Ok(slice) => slice,
Err(rejection) => return rejection,
};
let (tx, rx) = crate::outbound_ring::bounded_ring::<(u64, std::sync::Arc<serde_json::Value>)>(
crate::outbound_sink::OUTBOUND_QUEUE_CAPACITY,
);
if session.outbound_tx.set(tx.clone()).is_err() {
return mint_session_race_500(None, "session.outbound_tx", session.id);
}
if let Err(rejection) = wire_session_outbound(&server, &session, tx) {
return rejection;
}
let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<()>();
spawn_session_cleanup(server.clone(), session_id, cleanup_rx);
build_outbound_sse_response(session_id, replay, rx, cleanup_tx)
}
#[allow(clippy::result_large_err)]
fn compute_replay_window(
headers: &HeaderMap,
session: &crate::session::Session,
server: &McpServer,
) -> Result<Vec<(u64, std::sync::Arc<serde_json::Value>)>, Response> {
let Some(raw) = headers.get(LAST_EVENT_ID_HEADER) else {
return Ok(Vec::new());
};
let Ok(header_str) = raw.to_str() else {
return Err((StatusCode::BAD_REQUEST, "Last-Event-Id must be ASCII u64").into_response());
};
let Ok(last_id) = header_str.parse::<u64>() else {
return Err((StatusCode::BAD_REQUEST, "Last-Event-Id is not a valid u64").into_response());
};
if !server.sse_replay_enabled() {
return Err((
StatusCode::NOT_IMPLEMENTED,
"resume buffer disabled (with_sse_replay_capacity(0))",
)
.into_response());
}
let current_head = session
.next_event_id
.load(std::sync::atomic::Ordering::Relaxed);
if last_id >= current_head {
return Err((
StatusCode::BAD_REQUEST,
"Last-Event-Id is ahead of the server's event sequence",
)
.into_response());
}
let buffer = session.sse_replay_buffer.lock();
let oldest_retained = buffer.front().map(|(id, _)| *id).unwrap_or(current_head);
if last_id < oldest_retained.saturating_sub(1) {
return Err((
StatusCode::GONE,
Json(rpc_error(
None,
JSONRPC_SERVER_ERROR,
"resume gap; reconnect with fresh initialize",
)),
)
.into_response());
}
Ok(buffer
.iter()
.filter(|(id, _)| *id > last_id)
.cloned()
.collect())
}
async fn delete_mcp(State(server): State<Arc<McpServer>>, headers: HeaderMap) -> Response {
let identity = match enforce_authenticator_for_delete(&server, &headers).await {
Ok(identity) => identity,
Err(rejection) => return rejection,
};
let id = match extract_session_id(&headers) {
Some(id) => id,
None => {
return (StatusCode::BAD_REQUEST, "missing or invalid Mcp-Session-Id").into_response();
}
};
{
let sessions = server.sessions.read().await;
let Some(session) = sessions.get(&id) else {
return StatusCode::NOT_FOUND.into_response();
};
if !principal_matches(identity.as_ref(), session.principal.as_deref()) {
return StatusCode::NOT_FOUND.into_response();
}
}
let session = {
let mut sessions = server.sessions.write().await;
sessions.remove(&id)
};
let Some(session) = session else {
return StatusCode::NOT_FOUND.into_response();
};
let principal = session.principal.clone();
session.close_and_drain().await;
server.decrement_principal_count(principal.as_deref()).await;
metrics::counter!(
"klieo_mcp_session_deleted_total",
"reason" => "client_delete"
)
.increment(1);
tracing::info!(
target: "klieo::mcp::session",
session_id = %id,
"session deleted by client"
);
StatusCode::NO_CONTENT.into_response()
}
fn spawn_session_cleanup(
server: Arc<McpServer>,
session_id: uuid::Uuid,
cleanup_rx: tokio::sync::oneshot::Receiver<()>,
) {
tokio::spawn(async move {
let _ = cleanup_rx.await;
let session = {
let mut sessions = server.sessions.write().await;
sessions.remove(&session_id)
};
let Some(session) = session else {
return;
};
let principal = session.principal.clone();
session.close_and_drain().await;
server.decrement_principal_count(principal.as_deref()).await;
});
}
#[allow(clippy::result_large_err)]
fn wire_session_outbound(
server: &Arc<McpServer>,
session: &Arc<crate::session::Session>,
tx: crate::outbound_ring::RingSender<(u64, std::sync::Arc<serde_json::Value>)>,
) -> Result<(), Response> {
use crate::outbound::OutboundRequests;
use crate::outbound_sink::HttpFrameSink;
use klieo_core::ServerOutbound;
let sink: Arc<dyn crate::OutboundFrameSink> = Arc::new(HttpFrameSink::new(
Arc::downgrade(session),
tx,
server.sse_replay_capacity,
));
let outbound = Arc::new(OutboundRequests::new(sink));
if session.outbound.set(outbound.clone()).is_err() {
return Err(mint_session_race_500(None, "session.outbound", session.id));
}
if server.declare_sampling {
let as_trait: Arc<dyn ServerOutbound> = outbound.clone();
let _ = session
.roots_cache
.set(Arc::new(crate::roots::RootsCache::new(as_trait)));
}
Ok(())
}
pub fn encode_sse_frame(
event_id: u64,
frame: &serde_json::Value,
session_id: uuid::Uuid,
) -> Option<Bytes> {
use bytes::BufMut as _;
use std::fmt::Write as _;
let mut buf = bytes::BytesMut::with_capacity(SSE_FRAME_HINT_BYTES);
write!(&mut buf, "id: {event_id}\ndata: ").expect("BytesMut write is infallible");
match serde_json::to_writer((&mut buf).writer(), frame) {
Ok(()) => {
buf.extend_from_slice(b"\n\n");
Some(buf.freeze())
}
Err(err) => {
tracing::error!(
target: "klieo::mcp::sse",
event_id,
session_id = %session_id,
error = %err,
"outbound frame serialisation failed; skipping id",
);
None
}
}
}
fn build_outbound_sse_response(
session_id: uuid::Uuid,
replay: Vec<(u64, std::sync::Arc<serde_json::Value>)>,
mut rx: crate::outbound_ring::RingReceiver<(u64, std::sync::Arc<serde_json::Value>)>,
cleanup_tx: tokio::sync::oneshot::Sender<()>,
) -> Response {
let stream = async_stream::stream! {
for (event_id, value) in replay {
if let Some(frame) = encode_sse_frame(event_id, &value, session_id) {
yield Ok::<Bytes, std::io::Error>(frame);
}
}
while let Some((event_id, value)) = rx.recv().await {
if let Some(frame) = encode_sse_frame(event_id, &value, session_id) {
yield Ok::<Bytes, std::io::Error>(frame);
}
}
};
let guarded = GuardedSseStream::new(Box::pin(stream), cleanup_tx);
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(MCP_SESSION_ID_HEADER, session_id.to_string())
.body(axum::body::Body::from_stream(guarded))
.unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
}
struct GuardedSseStream<S> {
inner: S,
cleanup: Option<tokio::sync::oneshot::Sender<()>>,
}
impl<S> GuardedSseStream<S> {
fn new(inner: S, cleanup: tokio::sync::oneshot::Sender<()>) -> Self {
Self {
inner,
cleanup: Some(cleanup),
}
}
}
impl<S: Stream + Unpin> Stream for GuardedSseStream<S> {
type Item = S::Item;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
impl<S> Drop for GuardedSseStream<S> {
fn drop(&mut self) {
if let Some(sender) = self.cleanup.take() {
let _ = sender.send(());
}
}
}
async fn enforce_authenticator(
server: &Arc<McpServer>,
headers: &HeaderMap,
body: &Bytes,
raw: &serde_json::Value,
) -> Result<Option<Identity>, Response> {
let Some(auth) = server.authenticator() else {
return Ok(None);
};
let adapter = AxumHeaders(headers);
let identity = match auth.authenticate(&adapter, body).await {
Ok(identity) => identity,
Err(e) => {
warn!(target: "mcp.auth", error = ?e, "authenticate failed");
return Err(unauthenticated_response(raw, "authentication required"));
}
};
let method = raw.get("method").and_then(|m| m.as_str()).unwrap_or("");
if let Err(e) = auth.authorize_method(&identity, method).await {
warn!(target: "mcp.auth", method = %method, error = ?e, "authorize_method failed");
return Err(unauthenticated_response(
raw,
"method not authorized for principal",
));
}
Ok(Some(identity))
}
fn unauthenticated_response(raw: &serde_json::Value, message: &str) -> Response {
(
StatusCode::OK,
Json(rpc_error(
raw.get("id").cloned(),
JSONRPC_UNAUTHENTICATED,
message,
)),
)
.into_response()
}
#[allow(clippy::result_large_err)]
async fn enforce_authenticator_for_get(
server: &Arc<McpServer>,
headers: &HeaderMap,
) -> Result<Option<Identity>, Response> {
let Some(auth) = server.authenticator() else {
return Ok(None);
};
let adapter = AxumHeaders(headers);
match auth.authenticate(&adapter, &Bytes::new()).await {
Ok(identity) => Ok(Some(identity)),
Err(e) => {
warn!(target: "mcp.auth", error = ?e, "authenticate failed on GET /mcp");
Err(unauthenticated_response_for_get())
}
}
}
fn unauthenticated_response_for_get() -> Response {
(StatusCode::UNAUTHORIZED, "authentication required").into_response()
}
#[allow(clippy::result_large_err)]
async fn enforce_authenticator_for_delete(
server: &Arc<McpServer>,
headers: &HeaderMap,
) -> Result<Option<Identity>, Response> {
enforce_authenticator_for_get(server, headers).await
}
struct AxumHeaders<'a>(&'a HeaderMap);
impl<'a> klieo_auth_common::Headers for AxumHeaders<'a> {
fn get(&self, name: &str) -> Option<&str> {
self.0.get(name).and_then(|v| v.to_str().ok())
}
}
fn extract_session_id(headers: &HeaderMap) -> Option<uuid::Uuid> {
let raw = headers.get(MCP_SESSION_ID_HEADER)?.to_str().ok()?;
uuid::Uuid::parse_str(raw).ok()
}
async fn require_session(
headers: &HeaderMap,
server: &McpServer,
caller: Option<&Identity>,
) -> Result<std::sync::Arc<crate::session::Session>, Response> {
let id = extract_session_id(headers).ok_or_else(|| {
(StatusCode::BAD_REQUEST, "missing or invalid Mcp-Session-Id").into_response()
})?;
let unknown_session = || -> Response {
(
StatusCode::NOT_FOUND,
Json(rpc_error(None, JSONRPC_SERVER_ERROR, "unknown session id")),
)
.into_response()
};
let sessions = server.sessions.read().await;
let session = sessions.get(&id).cloned().ok_or_else(unknown_session)?;
if !principal_matches(caller, session.principal.as_deref()) {
return Err(unknown_session());
}
Ok(session)
}
fn principal_matches(caller: Option<&Identity>, session_principal: Option<&str>) -> bool {
match (caller, session_principal) {
(None, None) => true,
(Some(identity), Some(stored)) => identity.as_str() == stored,
_ => false,
}
}
async fn handle_initialize_post(
server: &Arc<McpServer>,
raw: serde_json::Value,
caller: Option<Identity>,
) -> Response {
let session_id = uuid::Uuid::new_v4();
let principal = caller.as_ref().map(|id| id.as_str().to_string());
let session = std::sync::Arc::new(crate::session::Session::new_http(
session_id,
principal,
server.server_start,
));
{
let mut sessions = server.sessions.write().await;
let mut principal_counts = server.principal_counts.write().await;
if sessions.len() >= server.max_sessions {
drop(principal_counts);
drop(sessions);
tracing::warn!(
target: "klieo::mcp::session",
scope = "global",
cap = server.max_sessions,
"session cap reached; rejecting initialize"
);
metrics::counter!(
"klieo_mcp_session_cap_rejected_total",
"scope" => "global"
)
.increment(1);
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(rpc_error(
raw.get("id").cloned(),
JSONRPC_SERVER_ERROR,
"session cap reached",
)),
)
.into_response();
}
if let Some(p) = session.principal.as_deref() {
let current = principal_counts.get(p).copied().unwrap_or(0);
if current >= server.max_sessions_per_principal {
drop(principal_counts);
drop(sessions);
tracing::warn!(
target: "klieo::mcp::session",
scope = "per_principal",
principal = p,
cap = server.max_sessions_per_principal,
"per-principal session cap reached; rejecting initialize"
);
metrics::counter!(
"klieo_mcp_session_cap_rejected_total",
"scope" => "per_principal"
)
.increment(1);
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(rpc_error(
raw.get("id").cloned(),
JSONRPC_SERVER_ERROR,
"per-principal session cap reached",
)),
)
.into_response();
}
*principal_counts.entry(p.to_string()).or_insert(0) += 1;
}
sessions.insert(session_id, session.clone());
}
ensure_idle_reaper(server).await;
let resp = dispatch(server, raw, Some(&session)).await;
touch_last_activity(server, &session);
(
StatusCode::OK,
[(MCP_SESSION_ID_HEADER, session_id.to_string())],
Json(resp),
)
.into_response()
}
async fn ensure_idle_reaper(server: &Arc<McpServer>) {
let server_for_task = server.clone();
let _ = server
.idle_reaper_started
.get_or_init(|| async move {
tokio::spawn(idle_reaper_loop(server_for_task));
})
.await;
}
async fn idle_reaper_loop(server: Arc<McpServer>) {
let tick = server.idle_reaper_tick;
let mut interval = tokio::time::interval(tick);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
let timeout = server.session_idle_timeout;
if timeout.is_zero() {
continue;
}
let snapshot: Vec<(uuid::Uuid, Arc<crate::session::Session>)> = {
let sessions = server.sessions.read().await;
sessions.iter().map(|(id, s)| (*id, s.clone())).collect()
};
let now_millis = server.server_start.elapsed().as_millis() as u64;
let timeout_millis = timeout.as_millis() as u64;
let to_evict: Vec<(uuid::Uuid, Arc<crate::session::Session>)> = snapshot
.into_iter()
.filter(|(_, session)| {
let last = session
.last_activity_millis
.load(std::sync::atomic::Ordering::Relaxed);
now_millis.saturating_sub(last) > timeout_millis
})
.collect();
if to_evict.is_empty() {
continue;
}
{
let mut sessions = server.sessions.write().await;
for (id, _) in &to_evict {
sessions.remove(id);
}
}
for (id, session) in to_evict {
let principal = session.principal.clone();
session.close_and_drain().await;
server.decrement_principal_count(principal.as_deref()).await;
metrics::counter!(
"klieo_mcp_session_deleted_total",
"reason" => "idle_timeout"
)
.increment(1);
tracing::info!(
target: "klieo::mcp::session",
session_id = %id,
"session evicted by idle reaper"
);
}
}
}
#[instrument(skip_all, level = "debug")]
async fn try_sse_upgrade(
server: Arc<McpServer>,
raw: serde_json::Value,
identity: Option<Identity>,
) -> Option<Response> {
let method = raw.get("method").and_then(|m| m.as_str())?;
match method {
"tools/call" => {
let token_val = raw.pointer("/params/_meta/progressToken")?;
match token_val {
serde_json::Value::String(_) | serde_json::Value::Number(_) => {
Some(stream_tools_call(server, raw.clone(), token_val.clone(), identity).await)
}
_ => {
let id = raw.get("id").cloned();
Some(
(
StatusCode::BAD_REQUEST,
Json(rpc_error(
id,
JSONRPC_INVALID_PARAMS,
"_meta.progressToken must be a string or number",
)),
)
.into_response(),
)
}
}
}
"klieo/run/resume" => Some(handle_run_resume(server, raw, identity).await),
"klieo/tools/resume" => {
let params = match raw.get("params").cloned() {
Some(p) => p,
None => {
let id = raw.get("id").cloned();
return Some(
(
StatusCode::BAD_REQUEST,
Json(rpc_error(
id,
JSONRPC_INVALID_PARAMS,
"klieo/tools/resume: missing params",
)),
)
.into_response(),
);
}
};
let parsed: ResumeParams = match serde_json::from_value(params) {
Ok(p) => p,
Err(e) => {
let id = raw.get("id").cloned();
return Some(
(
StatusCode::BAD_REQUEST,
Json(rpc_error(
id,
JSONRPC_INVALID_PARAMS,
&format!("klieo/tools/resume: invalid params: {e}"),
)),
)
.into_response(),
);
}
};
Some(stream_resume(server, raw, parsed, identity).await)
}
_ => None,
}
}
#[derive(serde::Deserialize)]
struct ResumeParams {
#[serde(rename = "progressToken")]
progress_token: serde_json::Value,
#[serde(rename = "lastEventId")]
last_event_id: u64,
}
#[derive(serde::Deserialize)]
struct RunResumeParams {
ticket: String,
decision: RunResumeDecision,
}
#[derive(serde::Deserialize)]
struct RunResumeDecision {
approved: bool,
#[serde(default)]
reason: Option<String>,
}
const RUN_RESUME_DENY_MESSAGE: &str = "resume ticket invalid";
const RUN_RESUME_UNAVAILABLE_MESSAGE: &str = "resume unavailable";
async fn handle_run_resume(
server: Arc<McpServer>,
raw: serde_json::Value,
identity: Option<Identity>,
) -> Response {
let req_id = raw.get("id").cloned();
let parsed = match parse_run_resume_params(&raw, req_id.clone()) {
Ok(p) => p,
Err(resp) => return resp,
};
let claimed = match claim_resume_record(&server, &parsed.ticket, identity, req_id.clone()).await
{
Ok(rec) => rec,
Err(resp) => return resp,
};
drive_resume(&server, parsed.decision, claimed, req_id).await
}
#[allow(clippy::result_large_err)]
fn parse_run_resume_params(
raw: &serde_json::Value,
req_id: Option<serde_json::Value>,
) -> Result<RunResumeParams, Response> {
let Some(params_value) = raw.get("params").cloned() else {
return Err(run_resume_invalid_params(req_id, "missing params"));
};
serde_json::from_value::<RunResumeParams>(params_value)
.map_err(|_| run_resume_invalid_params(req_id, "invalid params"))
}
#[allow(clippy::result_large_err)]
async fn claim_resume_record(
server: &McpServer,
ticket: &str,
identity: Option<Identity>,
req_id: Option<serde_json::Value>,
) -> Result<crate::resume_ticket::ResumeTicketRecord, Response> {
let Some(store) = server.resume_ticket_store.as_ref() else {
return Err(run_resume_unavailable(req_id));
};
let caller = identity
.as_ref()
.filter(|id| !id.is_anonymous())
.ok_or_else(|| run_resume_denied(req_id.clone(), "anonymous caller"))?;
let peeked = store
.peek(ticket)
.await
.map_err(|err| log_and_deny(err, req_id.clone(), "peek failure"))?;
let record = peeked.ok_or_else(|| run_resume_denied(req_id.clone(), "unknown ticket"))?;
if caller.as_str() != record.principal {
return Err(run_resume_denied(req_id, "principal mismatch"));
}
let claimed = store
.claim(ticket)
.await
.map_err(|err| log_and_deny(err, req_id.clone(), "claim failure"))?;
claimed.ok_or_else(|| run_resume_denied(req_id, "claim lost race"))
}
async fn drive_resume(
server: &McpServer,
decision: RunResumeDecision,
record: crate::resume_ticket::ResumeTicketRecord,
req_id: Option<serde_json::Value>,
) -> Response {
let Some(handle) = server.workflow_resume_handles.get(&record.workflow_name) else {
tracing::warn!(
target: "klieo.mcp.resume",
workflow = %record.workflow_name,
"claimed ticket references an unregistered workflow",
);
return run_resume_unavailable(req_id);
};
let approval = if decision.approved {
klieo_core::checkpoint::ApprovalDecision::Approved
} else {
klieo_core::checkpoint::ApprovalDecision::Rejected {
reason: decision.reason.unwrap_or_default(),
}
};
let tenant_label = klieo_core::principal_hash(&record.principal);
match handle
.resume(record.checkpoint, approval, tenant_label)
.await
{
Ok(result) => (StatusCode::OK, Json(crate::rpc_ok(req_id, result))).into_response(),
Err(err) => {
tracing::warn!(
target: "klieo.mcp.resume",
workflow = %record.workflow_name,
error = %err,
"workflow resume failed",
);
run_resume_server_error(req_id)
}
}
}
fn log_and_deny(
err: crate::resume_ticket::TicketStoreError,
req_id: Option<serde_json::Value>,
log_reason: &str,
) -> Response {
tracing::warn!(
target: "klieo.mcp.resume",
rpc_id = ?req_id,
error = %err,
reason = log_reason,
"resume ticket-store op failed; denying fail-closed",
);
run_resume_denied(req_id, log_reason)
}
fn run_resume_denied(id: Option<serde_json::Value>, log_reason: &str) -> Response {
tracing::info!(
target: "klieo.mcp.resume",
rpc_id = ?id,
reason = log_reason,
"klieo/run/resume denied",
);
(
StatusCode::OK,
Json(rpc_error(id, JSONRPC_INVALID_PARAMS, RUN_RESUME_DENY_MESSAGE)),
)
.into_response()
}
fn run_resume_unavailable(id: Option<serde_json::Value>) -> Response {
(
StatusCode::OK,
Json(rpc_error(
id,
JSONRPC_SERVER_ERROR,
RUN_RESUME_UNAVAILABLE_MESSAGE,
)),
)
.into_response()
}
fn run_resume_invalid_params(id: Option<serde_json::Value>, log_reason: &str) -> Response {
tracing::warn!(
target: "klieo.mcp.resume",
rpc_id = ?id,
reason = log_reason,
"klieo/run/resume params invalid",
);
(
StatusCode::BAD_REQUEST,
Json(rpc_error(id, JSONRPC_INVALID_PARAMS, RUN_RESUME_DENY_MESSAGE)),
)
.into_response()
}
fn run_resume_server_error(id: Option<serde_json::Value>) -> Response {
(
StatusCode::OK,
Json(rpc_error(id, JSONRPC_SERVER_ERROR, "resume execution failed")),
)
.into_response()
}
fn klieo_headers_from_axum(headers: &HeaderMap) -> klieo_core::Headers {
let mut out = klieo_core::Headers::default();
if let Some(value) = headers.get("traceparent").and_then(|v| v.to_str().ok()) {
out.insert("traceparent".into(), value.to_string());
}
if let Some(value) = headers.get("tracestate").and_then(|v| v.to_str().ok()) {
out.insert("tracestate".into(), value.to_string());
}
out
}
fn content_type_is_json(headers: &HeaderMap) -> bool {
headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| {
s.split(';')
.next()
.unwrap_or("")
.trim()
.eq_ignore_ascii_case("application/json")
})
.unwrap_or(false)
}
const PROGRESS_CHANNEL_CAP: usize = 64;
fn progress_token_to_string(token: &serde_json::Value) -> String {
match token {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
other => other.to_string(),
}
}
fn id_prefixed_frame(id: u64, frame_bytes: Bytes) -> Bytes {
let mut prefixed = format!("id: {id}\n").into_bytes();
prefixed.extend_from_slice(&frame_bytes);
Bytes::from(prefixed)
}
fn parse_id_prefix(bytes: &Bytes) -> Option<u64> {
let text = std::str::from_utf8(bytes).ok()?;
let id_line = text.lines().next()?;
id_line.strip_prefix("id: ")?.parse::<u64>().ok()
}
fn spawn_record(
buffer: Arc<dyn klieo_core::resume::ResumeBuffer>,
stream_id: String,
id: u64,
payload: Bytes,
close_after: bool,
) {
tokio::spawn(async move {
if let Err(e) = buffer.record(&stream_id, id, payload.clone()).await {
tracing::warn!(
target: "mcp.resume",
stream_id,
id,
error = %e,
"resume buffer record failed"
);
}
if close_after {
if let Err(e) = buffer.close(&stream_id).await {
tracing::warn!(
target: "mcp.resume",
stream_id,
error = %e,
"resume buffer close failed"
);
}
}
});
}
fn spawn_publish(
pubsub: Arc<dyn klieo_core::Pubsub>,
subject: String,
payload: Bytes,
permits: Arc<tokio::sync::Semaphore>,
) {
let mut trace_headers = klieo_core::Headers::default();
klieo_core::inject_traceparent(&mut trace_headers, &tracing::Span::current().context());
klieo_core::cancel::spawn_publish_bounded(
pubsub,
subject,
payload,
"mcp.fanout",
permits,
trace_headers,
);
}
fn wrap_with_cancel_fanout<S>(
server: &Arc<McpServer>,
inner: S,
stream_id: String,
request_cancel: tokio_util::sync::CancellationToken,
leader_handle: Option<klieo_core::LeaderHandle>,
ownership_handle: Option<klieo_core::OwnershipHandle>,
) -> CancelOnDrop<klieo_core::cancel::RegistryDeregisterOnDrop<S>>
where
S: Stream<Item = Result<Bytes, std::convert::Infallible>> + Send + Unpin + 'static,
{
let cancel_subject = CANCEL_SUBJECT_PREFIX.to_string() + &stream_id;
let deregistered = klieo_core::cancel::RegistryDeregisterOnDrop::new(
inner,
server.cancel_registry().clone(),
stream_id,
);
CancelOnDrop {
inner: deregistered,
_guard: request_cancel.drop_guard(),
pubsub: server.pubsub.clone(),
cancel_subject,
permits: server.publish_permits.clone(),
_leader: leader_handle,
_ownership: ownership_handle,
}
}
async fn try_claim_leader(
server: &Arc<McpServer>,
stream_id: &str,
payload: Option<Bytes>,
principal: Option<String>,
) -> Option<klieo_core::LeaderHandle> {
let registry = server.leader_registry()?;
let key = format!("{MCP_LEADER_KEY_PREFIX}{stream_id}");
match registry
.claim_with_heartbeat(
key,
server.leader_ttl(),
server.leader_heartbeat_interval(),
payload,
principal,
)
.await
{
Ok(handle) => Some(handle),
Err(e) => {
tracing::warn!(
target: "mcp.leader",
stream_id = %stream_id,
error = %e,
"leader claim failed; degrading to no-claim (orphan detection \
disabled for this stream)",
);
None
}
}
}
async fn try_claim_ownership(
server: &Arc<McpServer>,
stream_id: &str,
identity: &Option<Identity>,
) -> Result<Option<klieo_core::OwnershipHandle>, Response> {
let Some(registry) = server.ownership_registry() else {
return Ok(None);
};
let Some(identity) = identity.as_ref() else {
return Ok(None);
};
if identity.is_anonymous() {
return Ok(None);
}
let key = format!("{MCP_OWNERSHIP_KEY_PREFIX}{stream_id}");
let principal = identity.as_str().to_string();
match registry.claim_guarded(key, principal).await {
klieo_core::OwnershipClaim::Claimed(handle) => Ok(Some(handle)),
klieo_core::OwnershipClaim::Proceed => Ok(None),
klieo_core::OwnershipClaim::Unavailable => {
Err(stream_unavailable_response(serde_json::Value::Null))
}
_ => Err(stream_unavailable_response(serde_json::Value::Null)),
}
}
async fn enforce_owner(
server: &Arc<McpServer>,
stream_id: &str,
identity: &Option<Identity>,
req_id: &serde_json::Value,
) -> Result<(), Response> {
let Some(registry) = server.ownership_registry() else {
return Ok(());
};
let Some(identity) = identity.as_ref() else {
return Ok(());
};
if identity.is_anonymous() {
return Ok(());
}
let key = format!("{MCP_OWNERSHIP_KEY_PREFIX}{stream_id}");
match registry.check_owner(&key, identity.as_str()).await {
klieo_core::OwnershipCheck::Allowed => Ok(()),
klieo_core::OwnershipCheck::Denied => {
tracing::warn!(
target: "mcp.tenants",
stream_id = %stream_id,
principal = %identity.as_str(),
"ownership mismatch on klieo/tools/resume; denying as stream-not-found",
);
Err(stream_not_found_response(req_id.clone()))
}
klieo_core::OwnershipCheck::Unavailable => Err(stream_unavailable_response(req_id.clone())),
_ => Err(stream_unavailable_response(req_id.clone())),
}
}
fn stream_not_found_response(req_id: serde_json::Value) -> Response {
(
StatusCode::OK,
Json(rpc_error(
Some(req_id),
crate::JSONRPC_RESUME_BUFFER_NOT_FOUND,
"no buffered stream for progressToken",
)),
)
.into_response()
}
fn stream_unavailable_response(req_id: serde_json::Value) -> Response {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(rpc_error(
Some(req_id),
crate::JSONRPC_SERVER_ERROR,
"ownership store unavailable; denied (strict tenant binding)",
)),
)
.into_response()
}
enum LeaderProbe {
NoRegistry,
Alive,
Dead,
}
async fn probe_leader(server: &Arc<McpServer>, stream_id: &str) -> LeaderProbe {
let Some(registry) = server.leader_registry() else {
return LeaderProbe::NoRegistry;
};
let key = format!("{MCP_LEADER_KEY_PREFIX}{stream_id}");
match registry.is_alive(&key).await {
Ok(true) => LeaderProbe::Alive,
Ok(false) => LeaderProbe::Dead,
Err(e) => {
tracing::warn!(
target: "mcp.leader",
stream_id = %stream_id,
error = %e,
"is_alive probe failed; treating as alive (fail-open per ADR-020)",
);
LeaderProbe::Alive
}
}
}
async fn max_event_id(
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
stream_id: &str,
) -> Option<u64> {
let mut replay = match buffer.replay(stream_id, 0).await {
Ok(stream) => stream,
Err(klieo_core::resume::ResumeError::NotFound(_)) => return None,
Err(e) => {
tracing::warn!(
target: "mcp.leader",
stream_id = %stream_id,
error = %e,
"max_event_id replay failed; skipping orphan terminal write",
);
return None;
}
};
let mut highest: Option<u64> = None;
while let Some((id, _)) = tokio_stream::StreamExt::next(&mut replay).await {
highest = Some(match highest {
Some(current) => current.max(id),
None => id,
});
}
highest
}
fn leader_died_envelope(id: serde_json::Value, stream_id: &str) -> serde_json::Value {
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": JSONRPC_LEADER_DIED,
"message": "stream leader died",
"data": { "stream_id": stream_id },
}
})
}
fn leader_died_sse_frame_bytes(stream_id: &str) -> Bytes {
let envelope = serde_json::json!({
"jsonrpc": "2.0",
"id": serde_json::Value::Null,
"error": {
"code": JSONRPC_LEADER_DIED,
"message": "stream leader died",
"data": { "stream_id": stream_id },
}
});
Bytes::from(format!("event: error\ndata: {}\n\n", envelope))
}
async fn write_orphan_terminal_frame(
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
stream_id: &str,
) -> bool {
let Some(max_id) = max_event_id(buffer, stream_id).await else {
return false;
};
let next_id = max_id + 1;
let frame = leader_died_sse_frame_bytes(stream_id);
if let Err(e) = buffer.record(stream_id, next_id, frame).await {
tracing::warn!(
target: "mcp.leader",
stream_id = %stream_id,
next_id,
error = %e,
"orphan terminal record failed; skipping orphan write",
);
return false;
}
if let Err(e) = buffer.close(stream_id).await {
tracing::warn!(
target: "mcp.leader",
stream_id = %stream_id,
error = %e,
"orphan terminal close failed; resume buffer may retain stale stream",
);
}
tracing::error!(
target: "mcp.leader",
stream_id = %stream_id,
next_id,
"stream leader died; emitted LEADER_DIED terminal frame at max+1",
);
true
}
async fn terminate_orphan_mcp(
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
req_id: &serde_json::Value,
stream_id: &str,
) -> OrphanOutcome {
if write_orphan_terminal_frame(buffer, stream_id).await {
let resp = (
StatusCode::OK,
Json(leader_died_envelope(req_id.clone(), stream_id)),
)
.into_response();
return OrphanOutcome::Terminated(resp);
}
OrphanOutcome::Passthrough
}
#[non_exhaustive]
pub enum OrphanOutcome {
Reinvoked(Response),
Terminated(Response),
Passthrough,
}
#[cfg(feature = "test-fixtures")]
pub async fn handle_dead_leader_orphan_mcp(
server: &Arc<McpServer>,
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
req_id: &serde_json::Value,
stream_id: &str,
) -> OrphanOutcome {
handle_dead_leader_orphan_mcp_impl(server, buffer, req_id, stream_id).await
}
#[cfg(not(feature = "test-fixtures"))]
async fn handle_dead_leader_orphan_mcp(
server: &Arc<McpServer>,
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
req_id: &serde_json::Value,
stream_id: &str,
) -> OrphanOutcome {
handle_dead_leader_orphan_mcp_impl(server, buffer, req_id, stream_id).await
}
async fn handle_dead_leader_orphan_mcp_impl(
server: &Arc<McpServer>,
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
req_id: &serde_json::Value,
stream_id: &str,
) -> OrphanOutcome {
let Some(registry) = server.leader_registry() else {
return OrphanOutcome::Passthrough;
};
let key = format!("{MCP_LEADER_KEY_PREFIX}{stream_id}");
let lookup = registry.lookup_entry_with_revision(&key).await;
let Some((entry, prior_rev)) = lookup_ok_or_log_mcp(&key, lookup) else {
return terminate_orphan_mcp(buffer, req_id, stream_id).await;
};
let Some(payload_bytes) = entry.payload.clone() else {
tracing::debug!(
target: "mcp.failover",
stream_id = %stream_id,
"no cached payload on leader entry; emitting terminate frame",
);
return terminate_orphan_mcp(buffer, req_id, stream_id).await;
};
let parsed_body: serde_json::Value = match serde_json::from_slice(&payload_bytes) {
Ok(v) => v,
Err(e) => {
tracing::error!(
target: "mcp.failover",
stream_id = %stream_id,
error = %e,
"cached payload parse failed; emitting terminate frame",
);
return terminate_orphan_mcp(buffer, req_id, stream_id).await;
}
};
let tool_name = parsed_body
.pointer("/params/name")
.and_then(|v| v.as_str())
.unwrap_or("");
if !server.invoker.is_tool_idempotent(tool_name) {
tracing::debug!(
target: "mcp.failover",
stream_id = %stream_id,
tool = %tool_name,
"tool not idempotent; emitting terminate frame",
);
return terminate_orphan_mcp(buffer, req_id, stream_id).await;
}
if entry.attempt >= server.max_failover_attempts() {
tracing::warn!(
target: "mcp.failover",
stream_id = %stream_id,
attempt = entry.attempt,
cap = server.max_failover_attempts(),
"failover attempt cap reached; emitting terminate frame",
);
return terminate_orphan_mcp(buffer, req_id, stream_id).await;
}
let new_handle = match registry
.claim_with_attempt_cas_and_heartbeat(
key.clone(),
server.leader_ttl(),
server.leader_heartbeat_interval(),
prior_rev,
&entry,
)
.await
{
Ok(h) => h,
Err(klieo_core::BusError::CasConflict { .. }) => {
tracing::info!(
target: "mcp.failover",
stream_id = %stream_id,
"another follower won the CAS race; emitting terminate frame",
);
return terminate_orphan_mcp(buffer, req_id, stream_id).await;
}
Err(e) => {
tracing::warn!(
target: "mcp.failover",
stream_id = %stream_id,
error = %e,
"CAS claim failed; emitting terminate frame",
);
return terminate_orphan_mcp(buffer, req_id, stream_id).await;
}
};
record_failover_marker_mcp(buffer, stream_id, &entry, &new_handle).await;
let progress_token = parsed_body
.pointer("/params/_meta/progressToken")
.cloned()
.unwrap_or(serde_json::Value::Null);
let reinvoke_identity = entry.principal.as_ref().map(|p| Identity::new(p.clone()));
let resp = run_tools_call(
server.clone(),
parsed_body,
progress_token,
reinvoke_identity,
Some(new_handle),
)
.await;
OrphanOutcome::Reinvoked(resp)
}
fn lookup_ok_or_log_mcp(
key: &str,
lookup: Result<Option<(klieo_core::LeaderEntry, klieo_core::Revision)>, klieo_core::BusError>,
) -> Option<(klieo_core::LeaderEntry, klieo_core::Revision)> {
match lookup {
Ok(Some(pair)) => Some(pair),
Ok(None) => None,
Err(e) => {
tracing::warn!(
target: "mcp.failover",
key,
error = %e,
"leader entry lookup_with_revision failed; falling back to terminate",
);
None
}
}
}
async fn record_failover_marker_mcp(
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
stream_id: &str,
prior: &klieo_core::LeaderEntry,
new_handle: &klieo_core::LeaderHandle,
) {
let Some(max) = max_event_id(buffer, stream_id).await else {
return;
};
let marker_id = max + 1;
let frame = failover_reinvoke_sse_frame_bytes_mcp(
stream_id,
marker_id,
prior.attempt + 1,
new_handle.replica_id(),
);
if let Err(e) = buffer.record(stream_id, marker_id, frame).await {
tracing::warn!(
target: "mcp.failover",
stream_id = %stream_id,
marker_id,
error = %e,
"failover-reinvoke marker record failed; continuing without marker",
);
}
}
fn failover_reinvoke_sse_frame_bytes_mcp(
stream_id: &str,
event_id: u64,
attempt: u32,
new_replica_id: &str,
) -> Bytes {
let payload = serde_json::json!({
"jsonrpc": "2.0",
"id": serde_json::Value::Null,
"event": "failover-reinvoke",
"data": {
"stream_id": format!("{MCP_LEADER_KEY_PREFIX}{stream_id}"),
"attempt": attempt,
"by_replica": new_replica_id,
},
"event_id": event_id,
});
Bytes::from(serde_json::to_vec(&payload).unwrap_or_default())
}
fn combine_replay_and_tail(
replay: Pin<Box<dyn Stream<Item = (u64, Bytes)> + Send>>,
live_tail: Option<klieo_core::MsgStream>,
max_replayed: Arc<AtomicU64>,
stream_id: String,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, std::convert::Infallible>> + Send>> {
let max_for_replay = max_replayed.clone();
let replay = replay.map(move |(id, payload)| {
max_for_replay.fetch_max(id, std::sync::atomic::Ordering::SeqCst);
Ok::<_, std::convert::Infallible>(id_prefixed_frame(id, payload))
});
let Some(tail_msgs) = live_tail else {
return replay.boxed();
};
let max_for_tail = max_replayed;
let stream_id_for_log = stream_id;
let tail = tail_msgs.filter_map(move |msg_result| {
let stream_id = stream_id_for_log.clone();
let max_for_tail = max_for_tail.clone();
async move {
let msg = match msg_result {
Ok(m) => m,
Err(e) => {
tracing::warn!(
target: "mcp.fanout",
stream_id = %stream_id,
error = %e,
"live tail subscription stream error; skipping"
);
return None;
}
};
let parent_cx = klieo_core::extract_traceparent(&msg.headers);
let decode_span = tracing::info_span!(
"tail_frame_decode",
messaging.system = "klieo-bus",
messaging.destination = %format!("klieo.mcp.progress.{stream_id}"),
messaging.operation = "receive",
klieo.stream_id = %stream_id,
);
decode_span.set_parent(parent_cx);
let _enter = decode_span.enter();
let payload = msg.payload.clone();
if let Err(e) = msg.ack.ack().await {
tracing::warn!(
target: "mcp.fanout",
error = %e,
"ack failed; ephemeral consumer continues"
);
}
match parse_id_prefix(&payload) {
Some(id) if id > max_for_tail.load(std::sync::atomic::Ordering::SeqCst) => {
Some(Ok::<_, std::convert::Infallible>(payload))
}
Some(_) => None,
None => {
tracing::warn!(
target: "mcp.fanout",
payload_len = payload.len(),
"tail frame missing or unparseable id prefix; dropping",
);
None
}
}
}
});
replay.chain(tail).boxed()
}
#[allow(clippy::too_many_arguments)]
fn emit_progress_frame(
buffer: &Arc<dyn klieo_core::resume::ResumeBuffer>,
pubsub: &Arc<dyn klieo_core::Pubsub>,
permits: &Arc<tokio::sync::Semaphore>,
publish_subject: &str,
stream_id: &str,
next_id: &AtomicU64,
frame: Bytes,
terminal: bool,
) -> Bytes {
let id = next_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
spawn_record(
buffer.clone(),
stream_id.to_string(),
id,
frame.clone(),
terminal,
);
let stamped = id_prefixed_frame(id, frame);
spawn_publish(
pubsub.clone(),
publish_subject.to_string(),
stamped.clone(),
permits.clone(),
);
stamped
}
const MIN_PARENT_ANCHOR_BYTES: usize = 8;
const MAX_PARENT_ANCHOR_BYTES: usize = 256;
fn is_parent_anchor_byte(byte: u8) -> bool {
byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'=' | b':' | b'-')
}
fn parse_parent_anchor(raw: &serde_json::Value) -> Result<Option<String>, &'static str> {
let Some(value) = raw.pointer("/params/_meta/parentAnchor") else {
return Ok(None);
};
let anchor = value
.as_str()
.ok_or("_meta.parentAnchor must be a string")?;
if anchor.len() < MIN_PARENT_ANCHOR_BYTES || anchor.len() > MAX_PARENT_ANCHOR_BYTES {
return Err("_meta.parentAnchor length out of bounds");
}
if !anchor.bytes().all(is_parent_anchor_byte) {
return Err("_meta.parentAnchor must be a hash token");
}
Ok(Some(anchor.to_string()))
}
#[cfg(test)]
mod parent_anchor_tests {
use super::parse_parent_anchor;
use serde_json::json;
fn req_with_anchor(anchor: serde_json::Value) -> serde_json::Value {
json!({ "params": { "name": "t", "_meta": { "parentAnchor": anchor } } })
}
#[test]
fn absent_anchor_is_ok_none() {
let req = json!({ "params": { "name": "t" } });
assert_eq!(parse_parent_anchor(&req), Ok(None));
let req = json!({ "params": { "name": "t", "_meta": {} } });
assert_eq!(parse_parent_anchor(&req), Ok(None));
}
#[test]
fn valid_hash_token_is_accepted_verbatim() {
let req = req_with_anchor(json!("sha256:0123abcd_ef-ABCD="));
assert_eq!(
parse_parent_anchor(&req),
Ok(Some("sha256:0123abcd_ef-ABCD=".to_string()))
);
}
#[test]
fn non_string_anchor_is_rejected() {
assert!(parse_parent_anchor(&req_with_anchor(json!(42))).is_err());
assert!(parse_parent_anchor(&req_with_anchor(json!(["a"]))).is_err());
}
#[test]
fn empty_or_short_anchor_is_rejected() {
assert!(parse_parent_anchor(&req_with_anchor(json!(""))).is_err());
assert!(parse_parent_anchor(&req_with_anchor(json!("abc"))).is_err());
}
#[test]
fn oversize_anchor_is_rejected() {
let huge = "a".repeat(257);
assert!(parse_parent_anchor(&req_with_anchor(json!(huge))).is_err());
}
#[test]
fn freeform_or_pii_shaped_anchor_is_rejected() {
assert!(parse_parent_anchor(&req_with_anchor(json!("alice@example.com"))).is_err());
assert!(parse_parent_anchor(&req_with_anchor(json!("hello world"))).is_err());
}
}
fn spawn_progress_stream_task(
server: &Arc<McpServer>,
raw: &serde_json::Value,
cancel: tokio_util::sync::CancellationToken,
identity: Option<&Identity>,
parent_anchor: Option<String>,
) -> (
serde_json::Value,
tokio::sync::broadcast::Receiver<klieo_core::AgentEvent>,
tokio::task::JoinHandle<Result<serde_json::Value, klieo_core::error::ToolError>>,
) {
let req_id = raw.get("id").cloned().unwrap_or(serde_json::Value::Null);
let params = raw
.get("params")
.cloned()
.unwrap_or(serde_json::Value::Null);
let name = params
.get("name")
.and_then(|n| n.as_str())
.unwrap_or("")
.to_string();
let args = params
.get("arguments")
.cloned()
.unwrap_or(serde_json::Value::Null);
let (tx, rx) = tokio::sync::broadcast::channel::<klieo_core::AgentEvent>(PROGRESS_CHANNEL_CAP);
let principal = identity
.filter(|id| !id.is_anonymous())
.map(|id| id.as_str().to_string());
let parent_anchor = if principal.is_some() {
parent_anchor
} else {
None
};
let tool_ctx = server.tool_ctx_with_progress(tx, cancel, principal, parent_anchor);
let invoker = server.invoker.clone();
let invoke_handle = tokio::spawn(async move { invoker.invoke(&name, args, tool_ctx).await });
(req_id, rx, invoke_handle)
}
#[instrument(
skip_all,
fields(
rpc.system = "klieo-mcp",
rpc.method = "klieo/tools/resume",
klieo.stream_id = tracing::field::Empty,
),
)]
async fn stream_resume(
server: Arc<McpServer>,
raw: serde_json::Value,
params: ResumeParams,
identity: Option<Identity>,
) -> Response {
let req_id = raw.get("id").cloned().unwrap_or(serde_json::Value::Null);
let stream_id = progress_token_to_string(¶ms.progress_token);
if let Err(e) = klieo_core::validate_subject_token(&stream_id) {
tracing::warn!(
target: "mcp.resume",
error = %e,
"rejected progressToken: invalid subject segment",
);
return (
StatusCode::OK,
Json(rpc_error(
Some(req_id),
JSONRPC_INVALID_PARAMS,
"progressToken contains reserved bus-subject metacharacters",
)),
)
.into_response();
}
tracing::Span::current().record("klieo.stream_id", stream_id.as_str());
let request_cancel = server.parent_cancel.child_token();
let buffer = server.resume_buffer.clone();
let pubsub = server.pubsub.clone();
let since = params.last_event_id;
if let LeaderProbe::Dead = probe_leader(&server, &stream_id).await {
match handle_dead_leader_orphan_mcp(&server, &buffer, &req_id, &stream_id).await {
OrphanOutcome::Reinvoked(resp) | OrphanOutcome::Terminated(resp) => return resp,
OrphanOutcome::Passthrough => {}
}
}
if let Err(resp) = enforce_owner(&server, &stream_id, &identity, &req_id).await {
return resp;
}
let replay_stream = match buffer.replay(&stream_id, since).await {
Ok(s) => s,
Err(klieo_core::resume::ResumeError::Expired { since_id }) => {
return (
StatusCode::OK,
Json(rpc_error(
Some(req_id),
crate::JSONRPC_RESUME_BUFFER_EXPIRED,
&format!("resume window expired (since_id={since_id})"),
)),
)
.into_response();
}
Err(klieo_core::resume::ResumeError::NotFound(_)) => {
return (
StatusCode::OK,
Json(rpc_error(
Some(req_id),
crate::JSONRPC_RESUME_BUFFER_NOT_FOUND,
"no buffered stream for progressToken",
)),
)
.into_response();
}
Err(klieo_core::resume::ResumeError::Backend(e)) => {
tracing::warn!(
target: "mcp.resume",
stream_id = %stream_id,
error = %e,
"resume backend error"
);
return (
StatusCode::OK,
Json(rpc_error(
Some(req_id),
JSONRPC_SERVER_ERROR,
"resume backend unavailable",
)),
)
.into_response();
}
Err(_) => {
return (
StatusCode::OK,
Json(rpc_error(
Some(req_id),
JSONRPC_SERVER_ERROR,
"resume backend error",
)),
)
.into_response();
}
};
let already_closed = match buffer.is_terminal(&stream_id).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(
target: "mcp.resume",
stream_id = %stream_id,
error = %e,
"is_terminal check failed; assuming live"
);
false
}
};
let subject = format!("klieo.mcp.progress.{}", stream_id);
let live_tail = if already_closed {
None
} else {
let durable = klieo_core::DurableName::new(format!("klieo-eph-{}", uuid::Uuid::new_v4()));
match pubsub.subscribe(&subject, durable).await {
Ok(s) => Some(s),
Err(e) => {
tracing::warn!(
target: "mcp.resume",
subject = %subject,
error = %e,
"live tail subscribe failed; returning replay-only response"
);
None
}
}
};
let max_replayed = Arc::new(AtomicU64::new(since));
let combined =
combine_replay_and_tail(replay_stream, live_tail, max_replayed, stream_id.clone());
server
.cancel_registry()
.register(stream_id.clone(), request_cancel.clone());
let guarded = wrap_with_cancel_fanout(&server, combined, stream_id, request_cancel, None, None);
axum::response::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header("Cache-Control", "no-cache")
.body(axum::body::Body::from_stream(guarded))
.unwrap()
}
#[allow(clippy::too_many_lines)]
#[instrument(
skip_all,
fields(
rpc.system = "klieo-mcp",
rpc.method = "tools/call",
klieo.stream_id = tracing::field::Empty,
),
)]
async fn stream_tools_call(
server: Arc<McpServer>,
raw: serde_json::Value,
progress_token: serde_json::Value,
identity: Option<Identity>,
) -> Response {
run_tools_call(server, raw, progress_token, identity, None).await
}
#[allow(clippy::too_many_lines)]
async fn run_tools_call(
server: Arc<McpServer>,
raw: serde_json::Value,
progress_token: serde_json::Value,
identity: Option<Identity>,
existing_leader: Option<klieo_core::LeaderHandle>,
) -> Response {
let stream_id = progress_token_to_string(&progress_token);
if let Err(e) = klieo_core::validate_subject_token(&stream_id) {
tracing::warn!(
target: "mcp.tools_call",
error = %e,
"rejected progressToken: invalid subject segment",
);
let req_id = raw.get("id").cloned().unwrap_or(serde_json::Value::Null);
return (
StatusCode::OK,
Json(rpc_error(
Some(req_id),
JSONRPC_INVALID_PARAMS,
"progressToken contains reserved bus-subject metacharacters",
)),
)
.into_response();
}
let parent_anchor = match parse_parent_anchor(&raw) {
Ok(anchor) => anchor,
Err(message) => {
tracing::warn!(
target: "mcp.tools_call",
reason = message,
"rejected tools/call: malformed _meta.parentAnchor",
);
let req_id = raw.get("id").cloned().unwrap_or(serde_json::Value::Null);
return (
StatusCode::OK,
Json(rpc_error(Some(req_id), JSONRPC_INVALID_PARAMS, message)),
)
.into_response();
}
};
tracing::Span::current().record("klieo.stream_id", stream_id.as_str());
let request_cancel = server.parent_cancel.child_token();
server
.cancel_registry()
.register(stream_id.clone(), request_cancel.clone());
let leader_handle = match existing_leader {
Some(handle) => Some(handle),
None => {
let payload_bytes_for_failover = match serde_json::to_vec(&raw) {
Ok(v) => Some(Bytes::from(v)),
Err(e) => {
tracing::warn!(
target: "mcp.failover",
stream_id = %stream_id,
error = %e,
"tools/call body serialise for failover failed; \
proceeding without cached payload",
);
None
}
};
let principal_for_failover = identity
.as_ref()
.filter(|id| !id.is_anonymous())
.map(|id| id.as_str().to_string());
try_claim_leader(
&server,
&stream_id,
payload_bytes_for_failover,
principal_for_failover,
)
.await
}
};
let ownership_handle = match try_claim_ownership(&server, &stream_id, &identity).await {
Ok(handle) => handle,
Err(resp) => return resp,
};
let (req_id, rx, invoke_handle) = spawn_progress_stream_task(
&server,
&raw,
request_cancel.clone(),
identity.as_ref(),
parent_anchor,
);
let next_id = Arc::new(AtomicU64::new(0));
let buffer = server.resume_buffer.clone();
let pubsub = server.pubsub.clone();
let publish_subject = format!("klieo.mcp.progress.{stream_id}");
let publish_permits = server.publish_permits.clone();
let request_cancel_for_stream = request_cancel.clone();
let stream_id_for_stream = stream_id.clone();
let body_stream = async_stream::stream! {
let mut rx = rx;
let invoke_handle = invoke_handle;
tokio::pin!(invoke_handle);
tokio::task::yield_now().await;
loop {
tokio::select! {
biased;
_ = request_cancel_for_stream.cancelled() => {
break;
}
recv_result = rx.recv() => {
match recv_result {
Ok(event) => {
let frame = match progress_event(&progress_token, &event) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, false,
));
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
let frame = match lagged_event(&progress_token, n) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, false,
));
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
result = &mut invoke_handle => {
loop {
match rx.try_recv() {
Ok(event) => {
let frame = match progress_event(&progress_token, &event) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, false,
));
}
Err(tokio::sync::broadcast::error::TryRecvError::Lagged(n)) => {
let frame = match lagged_event(&progress_token, n) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, false,
));
}
Err(_) => break,
}
}
let outcome = result
.unwrap_or_else(|_| Err(klieo_core::error::ToolError::Permanent(
"invoke task panicked".into()
)));
let req_id_opt = if req_id == serde_json::Value::Null {
None
} else {
Some(req_id.clone())
};
let frame = match result_event(req_id_opt, outcome) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, true,
));
return;
}
}
}
let result = invoke_handle.await
.unwrap_or_else(|_| Err(klieo_core::error::ToolError::Permanent(
"invoke task panicked".into()
)));
loop {
match rx.try_recv() {
Ok(event) => {
let frame = match progress_event(&progress_token, &event) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, false,
));
}
Err(tokio::sync::broadcast::error::TryRecvError::Lagged(n)) => {
let frame = match lagged_event(&progress_token, n) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, false,
));
}
Err(_) => break,
}
}
let req_id_opt = if req_id == serde_json::Value::Null {
None
} else {
Some(req_id)
};
let frame = match result_event(req_id_opt, result) {
Ok(f) => f,
Err(e) => match e {},
};
yield Ok::<_, std::convert::Infallible>(emit_progress_frame(
&buffer, &pubsub, &publish_permits,
&publish_subject, &stream_id_for_stream,
&next_id, frame, true,
));
};
let guarded = wrap_with_cancel_fanout(
&server,
body_stream.boxed(),
stream_id,
request_cancel,
leader_handle,
ownership_handle,
);
axum::response::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header("Cache-Control", "no-cache")
.body(axum::body::Body::from_stream(guarded))
.unwrap()
}
fn progress_event(
token: &serde_json::Value,
event: &klieo_core::AgentEvent,
) -> Result<axum::body::Bytes, std::convert::Infallible> {
let payload = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": {
"progressToken": token,
"data": event,
}
});
Ok(axum::body::Bytes::from(format!(
"event: progress\ndata: {}\n\n",
payload
)))
}
fn lagged_event(
token: &serde_json::Value,
skipped: u64,
) -> Result<axum::body::Bytes, std::convert::Infallible> {
let payload = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": {
"progressToken": token,
"data": { "kind": "lagged", "message": format!("lagged: skipped {} events", skipped) }
}
});
Ok(axum::body::Bytes::from(format!(
"event: progress\ndata: {}\n\n",
payload
)))
}
fn result_event(
id: Option<serde_json::Value>,
outcome: Result<serde_json::Value, klieo_core::error::ToolError>,
) -> Result<axum::body::Bytes, std::convert::Infallible> {
let envelope = match outcome {
Ok(v) => {
let result = serde_json::json!({
"content": [{ "type": "text", "text": v.to_string() }]
});
serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": result })
}
Err(e) => tool_error_to_envelope(id, e),
};
Ok(axum::body::Bytes::from(format!(
"event: result\ndata: {}\n\n",
envelope
)))
}
struct CancelOnDrop<S> {
inner: S,
_guard: tokio_util::sync::DropGuard,
pubsub: Arc<dyn klieo_core::Pubsub>,
cancel_subject: String,
permits: Arc<tokio::sync::Semaphore>,
_leader: Option<klieo_core::LeaderHandle>,
_ownership: Option<klieo_core::OwnershipHandle>,
}
impl<S: futures::Stream + Unpin> futures::Stream for CancelOnDrop<S> {
type Item = S::Item;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<S::Item>> {
std::pin::Pin::new(&mut self.inner).poll_next(cx)
}
}
impl<S> Drop for CancelOnDrop<S> {
fn drop(&mut self) {
let mut trace_headers = klieo_core::Headers::default();
klieo_core::inject_traceparent(&mut trace_headers, &opentelemetry::Context::current());
klieo_core::cancel::spawn_drop_publish(
self.pubsub.clone(),
std::mem::take(&mut self.cancel_subject),
"mcp.cancel",
Some(self.permits.clone()),
trace_headers,
);
}
}
#[instrument(
skip_all,
fields(
rpc.system = "klieo-mcp",
rpc.method = tracing::field::Empty,
),
)]
async fn dispatch(
server: &McpServer,
raw: serde_json::Value,
session: Option<&std::sync::Arc<crate::session::Session>>,
) -> serde_json::Value {
if let Some(method) = raw.get("method").and_then(|m| m.as_str()) {
tracing::Span::current().record("rpc.method", method);
}
match raw {
serde_json::Value::Array(items) => {
if items.len() > MAX_BATCH_ITEMS {
warn!(
items = items.len(),
max = MAX_BATCH_ITEMS,
"rejected oversized JSON-RPC batch"
);
return rpc_error(None, JSONRPC_SERVER_ERROR, "batch size exceeds limit");
}
let mut out = Vec::with_capacity(items.len());
for item in items {
if server.parent_cancel.is_cancelled() {
out.push(rpc_error(
item.get("id").cloned(),
JSONRPC_SERVER_ERROR,
"server shutting down",
));
continue;
}
out.push(server.handle_jsonrpc(item, session).await);
}
serde_json::Value::Array(out)
}
single => server.handle_jsonrpc(single, session).await,
}
}
#[cfg(test)]
mod post_body_classification_tests {
use super::*;
use crate::outbound::OutboundRequests;
use crate::{OutboundFrameSink, OutboundSinkError};
use async_trait::async_trait;
use axum::body::{to_bytes, Body};
use axum::http::{header, Method, Request, StatusCode};
use klieo_core::error::ToolError;
use klieo_core::llm::ToolDef;
use klieo_core::tool::{ToolCtx, ToolInvoker};
use serde_json::{json, Value};
use tokio::sync::Mutex as AsyncMutex;
use tower::ServiceExt;
struct NoopInvoker;
#[async_trait]
impl ToolInvoker for NoopInvoker {
fn catalogue(&self) -> Vec<ToolDef> {
Vec::new()
}
async fn invoke(
&self,
name: &str,
_args: Value,
_ctx: ToolCtx,
) -> Result<Value, ToolError> {
Err(ToolError::UnknownTool(name.into()))
}
}
fn server() -> Arc<McpServer> {
McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.build_arc()
.unwrap()
}
struct LocalCapturingSink {
frames: AsyncMutex<Vec<Value>>,
}
#[async_trait]
impl OutboundFrameSink for LocalCapturingSink {
async fn send_frame(&self, frame: std::sync::Arc<Value>) -> Result<(), OutboundSinkError> {
self.frames.lock().await.push((*frame).clone());
Ok(())
}
}
async fn post_init(server: &Arc<McpServer>) -> String {
let body = json!({
"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}
});
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = server.router().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
resp.headers()
.get(MCP_SESSION_ID_HEADER)
.expect("initialize must echo Mcp-Session-Id")
.to_str()
.unwrap()
.to_string()
}
fn outbound_response_request(session_id: &str, id: i64) -> Request<Body> {
let body = json!({"jsonrpc": "2.0", "id": id, "result": {"x": 1}});
Request::builder()
.method(Method::POST)
.uri("/mcp")
.header(header::CONTENT_TYPE, "application/json")
.header(MCP_SESSION_ID_HEADER, session_id)
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap()
}
#[tokio::test]
async fn post_outbound_response_with_no_outbound_wired_returns_400() {
let server = server();
let session_id = post_init(&server).await;
let session_uuid = uuid::Uuid::parse_str(&session_id).expect("session id is a uuid");
{
let sessions = server.sessions.read().await;
let session = sessions
.get(&session_uuid)
.expect("post_init inserts the session into the registry");
assert!(
session.outbound.get().is_none(),
"outbound must be unset on a plain HTTP server"
);
}
let resp = server
.router()
.oneshot(outbound_response_request(&session_id, 42))
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let bytes = to_bytes(resp.into_body(), 1 << 16).await.unwrap();
let envelope: Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(envelope["error"]["code"], JSONRPC_PARSE_ERROR);
}
#[tokio::test]
async fn post_outbound_response_routes_when_outbound_wired() {
let server = server();
let session_id = post_init(&server).await;
let sink: Arc<dyn OutboundFrameSink> = Arc::new(LocalCapturingSink {
frames: AsyncMutex::new(Vec::new()),
});
let outbound = Arc::new(OutboundRequests::new(sink));
let session_uuid = uuid::Uuid::parse_str(&session_id).expect("session id is a uuid");
let session = {
let sessions = server.sessions.read().await;
sessions
.get(&session_uuid)
.expect("post_init inserts the session into the registry")
.clone()
};
if session.outbound.set(outbound.clone()).is_err() {
panic!("session.outbound OnceCell must be empty for this test");
}
let call_handle = {
let outbound = outbound.clone();
tokio::spawn(async move {
use klieo_core::ServerOutbound;
outbound
.outbound_request(
"custom/method",
Value::Null,
std::time::Duration::from_secs(2),
)
.await
})
};
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
let resp = server
.router()
.oneshot(outbound_response_request(&session_id, 1))
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::ACCEPTED);
let result = tokio::time::timeout(std::time::Duration::from_secs(2), call_handle)
.await
.expect("outbound_request did not resolve")
.expect("task panicked")
.expect("outbound_request returned error");
assert_eq!(result["x"], 1);
}
#[tokio::test]
async fn post_updates_last_activity() {
let server = server();
let session_id = post_init(&server).await;
let session_uuid = uuid::Uuid::parse_str(&session_id).expect("session id is a uuid");
let session = {
let sessions = server.sessions.read().await;
sessions
.get(&session_uuid)
.expect("post_init inserts the session into the registry")
.clone()
};
let before = session
.last_activity_millis
.load(std::sync::atomic::Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let body = json!({"jsonrpc": "2.0", "id": 2, "method": "tools/list"});
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header(header::CONTENT_TYPE, "application/json")
.header(MCP_SESSION_ID_HEADER, &session_id)
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = server.router().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let after = session
.last_activity_millis
.load(std::sync::atomic::Ordering::Relaxed);
assert!(
after > before,
"session last_activity must advance on successful POST"
);
}
}
#[cfg(test)]
mod get_outbound_wiring_tests {
use super::*;
use axum::body::Body;
use axum::http::{header, Method, Request, StatusCode};
use klieo_core::error::ToolError;
use klieo_core::llm::ToolDef;
use klieo_core::tool::{ToolCtx, ToolInvoker};
use serde_json::{json, Value};
use tower::ServiceExt;
struct NoopInvoker;
#[async_trait::async_trait]
impl ToolInvoker for NoopInvoker {
fn catalogue(&self) -> Vec<ToolDef> {
Vec::new()
}
async fn invoke(
&self,
name: &str,
_args: Value,
_ctx: ToolCtx,
) -> Result<Value, ToolError> {
Err(ToolError::UnknownTool(name.into()))
}
}
fn server_with_sampling() -> Arc<McpServer> {
McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.with_client_sampling()
.build_arc()
.unwrap()
}
fn server_without_sampling() -> Arc<McpServer> {
McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.build_arc()
.unwrap()
}
async fn post_init(server: &Arc<McpServer>) -> String {
let body = json!({
"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}
});
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = server.router().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
resp.headers()
.get(MCP_SESSION_ID_HEADER)
.expect("initialize echoes session id")
.to_str()
.unwrap()
.to_string()
}
async fn get_mcp_with_session(server: &Arc<McpServer>, session_id: &str) -> StatusCode {
let req = Request::builder()
.method(Method::GET)
.uri("/mcp")
.header(MCP_SESSION_ID_HEADER, session_id)
.body(Body::empty())
.unwrap();
let resp = server.router().oneshot(req).await.unwrap();
resp.status()
}
#[tokio::test]
async fn http_get_wires_outbound_primitive() {
let server = server_with_sampling();
let session_id = post_init(&server).await;
let session_uuid = uuid::Uuid::parse_str(&session_id).expect("session id is a uuid");
let session = {
let sessions = server.sessions.read().await;
sessions
.get(&session_uuid)
.expect("post_init inserts the session into the registry")
.clone()
};
assert!(
session.outbound.get().is_none(),
"outbound must be unset before GET"
);
let status = get_mcp_with_session(&server, &session_id).await;
assert_eq!(status, StatusCode::OK);
assert!(
session.outbound.get().is_some(),
"GET must populate the outbound primitive"
);
}
#[tokio::test]
async fn http_get_wires_roots_cache_when_sampling_declared() {
let server = server_with_sampling();
let session_id = post_init(&server).await;
let session_uuid = uuid::Uuid::parse_str(&session_id).expect("session id is a uuid");
let session = {
let sessions = server.sessions.read().await;
sessions
.get(&session_uuid)
.expect("post_init inserts the session into the registry")
.clone()
};
assert!(session.roots_cache.get().is_none());
let status = get_mcp_with_session(&server, &session_id).await;
assert_eq!(status, StatusCode::OK);
assert!(
session.roots_cache.get().is_some(),
"with_client_sampling + GET must populate roots_cache"
);
}
#[tokio::test]
async fn http_get_skips_roots_cache_when_sampling_absent() {
let server = server_without_sampling();
let session_id = post_init(&server).await;
let session_uuid = uuid::Uuid::parse_str(&session_id).expect("session id is a uuid");
let session = {
let sessions = server.sessions.read().await;
sessions
.get(&session_uuid)
.expect("post_init inserts the session into the registry")
.clone()
};
let status = get_mcp_with_session(&server, &session_id).await;
assert_eq!(status, StatusCode::OK);
assert!(
session.outbound.get().is_some(),
"outbound is wired regardless of sampling"
);
assert!(
session.roots_cache.get().is_none(),
"roots_cache is gated on declare_sampling"
);
}
#[tokio::test]
async fn disconnect_drains_pending_outbound() {
use klieo_core::{ServerOutbound, ServerOutboundError};
let server = server_with_sampling();
let session_id = post_init(&server).await;
let req = Request::builder()
.method(Method::GET)
.uri("/mcp")
.header(MCP_SESSION_ID_HEADER, &session_id)
.body(Body::empty())
.unwrap();
let resp = server.router().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let session_uuid = uuid::Uuid::parse_str(&session_id).expect("session id is a uuid");
let outbound = {
let sessions = server.sessions.read().await;
sessions
.get(&session_uuid)
.and_then(|s| s.outbound.get().cloned())
.expect("GET must populate outbound primitive")
};
let call_handle = {
let outbound = outbound.clone();
tokio::spawn(async move {
outbound
.outbound_request(
"custom/method",
Value::Null,
std::time::Duration::from_secs(5),
)
.await
})
};
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
drop(resp);
let outcome = tokio::time::timeout(std::time::Duration::from_secs(2), call_handle)
.await
.expect("outbound_request did not resolve within 2s of disconnect")
.expect("task panicked");
assert!(
matches!(outcome, Err(ServerOutboundError::TransportClosed)),
"disconnect must surface as TransportClosed; got {outcome:?}"
);
}
}
#[cfg(test)]
mod idle_watchdog_tests {
use super::*;
use axum::body::Body;
use axum::http::{header, Method, Request, StatusCode};
use klieo_core::error::ToolError;
use klieo_core::llm::ToolDef;
use klieo_core::tool::{ToolCtx, ToolInvoker};
use serde_json::{json, Value};
use tower::ServiceExt;
struct NoopInvoker;
#[async_trait::async_trait]
impl ToolInvoker for NoopInvoker {
fn catalogue(&self) -> Vec<ToolDef> {
Vec::new()
}
async fn invoke(
&self,
name: &str,
_args: Value,
_ctx: ToolCtx,
) -> Result<Value, ToolError> {
Err(ToolError::UnknownTool(name.into()))
}
}
async fn post_init(server: &Arc<McpServer>) -> uuid::Uuid {
let body = json!({
"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}
});
let req = Request::builder()
.method(Method::POST)
.uri("/mcp")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let resp = server.router().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let raw = resp
.headers()
.get(MCP_SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.expect("Mcp-Session-Id header on initialize");
uuid::Uuid::parse_str(raw).expect("Mcp-Session-Id parses as UUID")
}
async fn wait_for_session_evicted(server: &Arc<McpServer>, id: uuid::Uuid) {
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
if server.is_session_closed_by_id(id).await.is_none() {
return;
}
if std::time::Instant::now() >= deadline {
panic!("session {id} never evicted by idle reaper");
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
}
#[tokio::test]
async fn idle_timeout_fires_after_inactivity() {
let server = McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.with_session_idle_timeout(std::time::Duration::from_millis(150))
.with_idle_reaper_tick(std::time::Duration::from_millis(50))
.build_arc()
.unwrap();
let id = post_init(&server).await;
assert_eq!(server.is_session_closed_by_id(id).await, Some(false));
wait_for_session_evicted(&server, id).await;
}
#[tokio::test]
async fn zero_timeout_disables_watchdog() {
let server = McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.with_session_idle_timeout(std::time::Duration::ZERO)
.with_idle_reaper_tick(std::time::Duration::from_millis(50))
.build_arc()
.unwrap();
let id = post_init(&server).await;
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
assert_eq!(
server.is_session_closed_by_id(id).await,
Some(false),
"Duration::ZERO must skip eviction"
);
}
}
#[cfg(test)]
mod mint_race_tests {
use super::*;
use axum::body::to_bytes;
use axum::http::StatusCode;
async fn extract_envelope(resp: Response) -> (StatusCode, serde_json::Value) {
let status = resp.status();
let body_bytes = to_bytes(resp.into_body(), usize::MAX)
.await
.expect("response body collects");
let envelope: serde_json::Value =
serde_json::from_slice(&body_bytes).expect("response body is JSON-RPC envelope");
(status, envelope)
}
#[tokio::test]
async fn mint_session_race_500_returns_500_with_stable_wire_envelope() {
let raw_id = serde_json::json!(7);
let attempted = uuid::Uuid::from_u128(0x0123_4567_89ab_cdef_0123_4567_89ab_cdef);
let resp = mint_session_race_500(Some(&raw_id), "active_session", Some(attempted));
let (status, envelope) = extract_envelope(resp).await;
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(envelope["jsonrpc"], "2.0");
assert_eq!(envelope["id"], 7);
assert_eq!(envelope["error"]["code"], JSONRPC_SERVER_ERROR);
assert_eq!(
envelope["error"]["message"], "internal: session mint race",
"stable wire message must not drift — operators key alerting on this string"
);
}
#[tokio::test]
async fn mint_session_race_500_handles_none_attempted_session_id() {
let resp = mint_session_race_500(None, "outbound", None);
let (status, envelope) = extract_envelope(resp).await;
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(envelope["error"]["code"], JSONRPC_SERVER_ERROR);
assert_eq!(envelope["error"]["message"], "internal: session mint race");
assert_eq!(envelope["id"], serde_json::Value::Null);
}
#[tokio::test]
async fn mint_session_race_500_distinct_message_from_session_already_active() {
let resp = mint_session_race_500(None, "active_session", None);
let (_, envelope) = extract_envelope(resp).await;
assert_ne!(
envelope["error"]["message"], "session already active",
"race-500 path must remain distinct from the .get().is_some() 409 path"
);
}
}
#[cfg(test)]
mod principal_matches_tests {
use super::*;
#[test]
fn both_none_passes() {
assert!(principal_matches(None, None));
}
#[test]
fn matching_some_passes() {
let caller = Identity::new("alice");
assert!(principal_matches(Some(&caller), Some("alice")));
}
#[test]
fn mismatched_some_fails() {
let caller = Identity::new("alice");
assert!(!principal_matches(Some(&caller), Some("bob")));
}
#[test]
fn caller_authenticated_session_anonymous_fails() {
let caller = Identity::new("alice");
assert!(!principal_matches(Some(&caller), None));
}
#[test]
fn caller_anonymous_session_authenticated_fails() {
assert!(!principal_matches(None, Some("alice")));
}
}
#[cfg(test)]
mod run_resume_authz_tests {
use super::*;
use crate::resume_ticket::{ResumeTicketRecord, ResumeTicketStore};
use async_trait::async_trait;
use axum::body::to_bytes;
use klieo_bus_memory::MemoryKv;
use klieo_core::error::ToolError;
use klieo_core::llm::ToolDef;
use klieo_core::tool::{ToolCtx, ToolInvoker};
use serde_json::Value;
struct NoopInvoker;
#[async_trait]
impl ToolInvoker for NoopInvoker {
fn catalogue(&self) -> Vec<ToolDef> {
Vec::new()
}
async fn invoke(
&self,
name: &str,
_args: Value,
_ctx: ToolCtx,
) -> Result<Value, ToolError> {
Err(ToolError::UnknownTool(name.into()))
}
}
fn seeded_record(principal: &str) -> ResumeTicketRecord {
let cp = serde_json::json!({
"run_id": klieo_core::ids::RunId::new(),
"step_index": 1,
"thread_id": "t-authz",
"messages": [],
"pending_tool_calls": null,
"created_at": "2026-06-18T00:00:00Z",
});
ResumeTicketRecord {
principal: principal.into(),
workflow_name: "wf".into(),
checkpoint: serde_json::from_value(cp).unwrap(),
created_at: chrono::Utc::now(),
}
}
fn server_over(kv: Arc<MemoryKv>) -> Arc<McpServer> {
McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.with_checkpoint_kv(kv)
.build_arc()
.unwrap()
}
async fn deny_message(resp: Response) -> String {
let body = to_bytes(resp.into_body(), usize::MAX)
.await
.expect("deny body collects");
let json: serde_json::Value =
serde_json::from_slice(&body).expect("deny body is JSON-RPC");
json["error"]["message"]
.as_str()
.expect("deny envelope carries an error message")
.to_string()
}
#[tokio::test]
async fn anonymous_caller_is_denied_and_ticket_survives() {
let kv = Arc::new(MemoryKv::new());
let server = server_over(kv.clone());
let store = ResumeTicketStore::new(kv);
let token = ResumeTicketStore::mint_token();
store
.persist(&token, &seeded_record("alice@x"))
.await
.unwrap();
let resp = claim_resume_record(&server, &token, Some(Identity::anonymous()), None)
.await
.expect_err("anonymous caller must be denied");
assert_eq!(deny_message(resp).await, RUN_RESUME_DENY_MESSAGE);
let owner =
claim_resume_record(&server, &token, Some(Identity::new("alice@x")), None).await;
assert!(
owner.is_ok(),
"an anonymous denial must not consume the ticket"
);
}
#[tokio::test]
async fn foreign_principal_is_denied_and_owner_still_claims() {
let kv = Arc::new(MemoryKv::new());
let server = server_over(kv.clone());
let store = ResumeTicketStore::new(kv);
let token = ResumeTicketStore::mint_token();
store
.persist(&token, &seeded_record("alice@x"))
.await
.unwrap();
let resp = claim_resume_record(&server, &token, Some(Identity::new("mallory@x")), None)
.await
.expect_err("foreign principal must be denied (IDOR)");
assert_eq!(deny_message(resp).await, RUN_RESUME_DENY_MESSAGE);
let owner = claim_resume_record(&server, &token, Some(Identity::new("alice@x")), None)
.await
.expect("rightful owner claims after the foreign denial");
assert_eq!(owner.principal, "alice@x");
}
#[tokio::test]
async fn unknown_ticket_is_denied_with_opaque_message() {
let kv = Arc::new(MemoryKv::new());
let server = server_over(kv);
let resp =
claim_resume_record(&server, "no-such-token", Some(Identity::new("alice@x")), None)
.await
.expect_err("unknown ticket must be denied");
assert_eq!(deny_message(resp).await, RUN_RESUME_DENY_MESSAGE);
}
#[tokio::test]
async fn owner_claims_exactly_once_replay_loses_race() {
let kv = Arc::new(MemoryKv::new());
let server = server_over(kv.clone());
let store = ResumeTicketStore::new(kv);
let token = ResumeTicketStore::mint_token();
store
.persist(&token, &seeded_record("alice@x"))
.await
.unwrap();
let first =
claim_resume_record(&server, &token, Some(Identity::new("alice@x")), None).await;
assert!(first.is_ok(), "the first claim by the owner succeeds");
let replay = claim_resume_record(&server, &token, Some(Identity::new("alice@x")), None)
.await
.expect_err("a replayed claim must lose the race");
assert_eq!(deny_message(replay).await, RUN_RESUME_DENY_MESSAGE);
}
struct StubResumeHandle {
fail: bool,
}
#[async_trait]
impl crate::workflow::WorkflowResumeHandle for StubResumeHandle {
async fn resume(
&self,
_checkpoint: klieo_core::checkpoint::RunCheckpoint,
_decision: klieo_core::checkpoint::ApprovalDecision,
_tenant_label: String,
) -> Result<Value, ToolError> {
if self.fail {
Err(ToolError::Permanent("resume blew up".into()))
} else {
Ok(serde_json::json!({ "resumed": true }))
}
}
}
fn server_with_handle(fail: bool) -> McpServer {
let mut server = McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.build()
.unwrap();
server.workflow_resume_handles.insert(
"wf".to_string(),
Arc::new(StubResumeHandle { fail })
as Arc<dyn crate::workflow::WorkflowResumeHandle>,
);
server
}
async fn result_value(resp: Response) -> serde_json::Value {
let body = to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body collects");
serde_json::from_slice(&body).expect("body is JSON-RPC")
}
#[tokio::test]
async fn drive_resume_unregistered_workflow_is_unavailable() {
let server = McpServer::builder()
.add_tools(Arc::new(NoopInvoker))
.build()
.unwrap();
let decision = RunResumeDecision {
approved: true,
reason: None,
};
let resp = drive_resume(&server, decision, seeded_record("alice@x"), None).await;
assert_eq!(deny_message(resp).await, RUN_RESUME_UNAVAILABLE_MESSAGE);
}
#[tokio::test]
async fn drive_resume_dispatches_to_registered_handle() {
let server = server_with_handle(false);
let decision = RunResumeDecision {
approved: true,
reason: None,
};
let resp = drive_resume(&server, decision, seeded_record("alice@x"), None).await;
let body = result_value(resp).await;
assert_eq!(
body["result"]["resumed"],
serde_json::Value::Bool(true),
"approved resume returns the handle's result envelope; got {body}"
);
}
#[tokio::test]
async fn drive_resume_handle_error_yields_server_error() {
let server = server_with_handle(true);
let decision = RunResumeDecision {
approved: false,
reason: Some("operator rejected".into()),
};
let resp = drive_resume(&server, decision, seeded_record("alice@x"), None).await;
assert_eq!(deny_message(resp).await, "resume execution failed");
}
}