use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use axum::{
body::Bytes,
extract::{Json, State},
http::{HeaderMap, HeaderValue, StatusCode},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
},
};
use futures::{StreamExt, stream};
use objectiveai_sdk::mcp::{
JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
initialize_result::{
Implementation, InitializeResult, ResourcesCapability, ServerCapabilities,
ToolsCapability,
},
resource::ReadResourceRequestParams,
tool::{CallToolRequestParams, ContentBlock, TextContent},
};
use tokio::sync::broadcast;
use crate::AppState;
use crate::session::{CallToolError, ReadResourceError, Session};
use crate::session_manager::SessionManager;
use crate::upstream::BadInit;
const PROTOCOL_VERSION: &str = "2025-06-18";
const ACCEPTED_PROTOCOL_VERSIONS: &[&str] = &["2025-06-18", "2025-11-25"];
const PARSE_ERROR: i64 = -32700;
const INVALID_REQUEST: i64 = -32600;
const METHOD_NOT_FOUND: i64 = -32601;
const INVALID_PARAMS: i64 = -32602;
const INTERNAL_ERROR: i64 = -32603;
const REQUEST_CANCELLED: i64 = -32800;
const SESSION_ID_HEADER: &str = "Mcp-Session-Id";
const SSE_KEEP_ALIVE: Duration = Duration::from_secs(15);
pub async fn handle_post(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> Response {
if let Err(resp) = require_streamable_http_accept(&headers) {
return resp;
}
let body: serde_json::Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => return parse_error_response(format!("invalid JSON: {e}")),
};
if body.get("id").is_none() {
if let Ok(notification) =
serde_json::from_value::<JsonRpcNotification>(body)
{
if notification.method == "notifications/cancelled" {
handle_cancelled_notification(&state, &headers, ¬ification);
}
}
return StatusCode::ACCEPTED.into_response();
}
let request: JsonRpcRequest = match serde_json::from_value(body) {
Ok(r) => r,
Err(e) => return parse_error_response(format!("invalid JSON-RPC envelope: {e}")),
};
match request.method.as_str() {
"initialize" => handle_initialize(&state, &headers, request).await,
"ping" => handle_ping(request),
"tools/list" => handle_tools_list(&state.sessions, &headers, request).await,
"tools/call" => handle_tools_call(&state.sessions, &headers, request).await,
"resources/list" => handle_resources_list(&state.sessions, &headers, request).await,
"resources/read" => handle_resources_read(&state.sessions, &headers, request).await,
other => method_not_found_response(request.id, other),
}
}
fn handle_cancelled_notification(
state: &AppState,
headers: &HeaderMap,
notification: &JsonRpcNotification,
) {
let session_id = match headers
.get(SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
{
Some(s) => s,
None => return,
};
let session = match state.sessions.get(session_id) {
Some(s) => s,
None => return,
};
let request_id = match notification
.params
.as_ref()
.and_then(|p| p.get("requestId"))
{
Some(id) => id,
None => return,
};
let cancelled = session.cancel_in_flight(request_id);
tracing::debug!(
session = %session_id,
request_id = %request_id,
cancelled,
"notifications/cancelled received",
);
}
pub async fn handle_delete(
State(state): State<AppState>,
headers: HeaderMap,
) -> Response {
let session_id = match extract_session_id(&headers) {
Ok(id) => id,
Err(resp) => return resp,
};
match state.sessions.remove(&session_id) {
Some(_) => StatusCode::OK.into_response(),
None => (StatusCode::NOT_FOUND, "unknown session").into_response(),
}
}
pub async fn handle_get(
State(state): State<AppState>,
headers: HeaderMap,
) -> Response {
let session_id = match extract_session_id(&headers) {
Ok(id) => id,
Err(resp) => return resp,
};
let session = match state.sessions.get(&session_id) {
Some(s) => s,
None => return (StatusCode::NOT_FOUND, "unknown session").into_response(),
};
let rx = session.outbound.subscribe();
let stream = stream::unfold(rx, |mut rx| async move {
loop {
match rx.recv().await {
Ok(notification) => {
let event = match Event::default().json_data(¬ification) {
Ok(e) => e,
Err(e) => {
tracing::warn!(error = %e, "failed to encode SSE event");
continue;
}
};
return Some((Ok::<_, Infallible>(event), rx));
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(skipped = n, "SSE consumer lagged; dropped notifications");
continue;
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
});
Sse::new(stream)
.keep_alive(KeepAlive::new().interval(SSE_KEEP_ALIVE))
.into_response()
}
async fn handle_initialize(
state: &AppState,
headers: &HeaderMap,
request: JsonRpcRequest,
) -> Response {
match request.params.as_ref().and_then(|p| p.get("protocolVersion")) {
Some(v) => match v.as_str() {
Some(version) if ACCEPTED_PROTOCOL_VERSIONS.contains(&version) => {
}
Some(other) => {
return invalid_request_response(
request.id,
format!(
"unsupported protocolVersion {other:?}; this proxy accepts {ACCEPTED_PROTOCOL_VERSIONS:?}",
),
);
}
None => {
return invalid_params_response(
request.id,
"params.protocolVersion must be a string".into(),
);
}
},
None => {
return invalid_params_response(
request.id,
"params.protocolVersion is required".into(),
);
}
}
let provided_session_id = headers
.get(SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
if let Some(sid) = &provided_session_id {
if let Some(session) = state.sessions.get(sid) {
let new_id = state.sessions.mint_id(&session.payload);
let _ = (new_id, session);
return ok_response_resume_sse(request.id);
}
let connections_with_headers = match state.sessions.decode_session_id(sid) {
Some(payload) => {
match crate::upstream::reconnect_from_payload(&state.client, &payload).await {
Ok(pairs) => pairs,
Err(e @ BadInit::UpstreamConnectFailed { .. }) => {
return internal_error_response(request.id, e.to_string());
}
Err(e) => {
return internal_error_response(request.id, e.to_string());
}
}
}
None => {
return (
StatusCode::UNAUTHORIZED,
format!("Unauthorized: Session not found ({sid:?})"),
)
.into_response();
}
};
let _ = state.sessions.add(connections_with_headers);
ok_response_resume_sse(request.id)
} else {
let connections_with_headers = match crate::upstream::connect_all_fresh(&state.client, headers).await {
Ok(pairs) => pairs,
Err(e @ (BadInit::NotUtf8 { .. } | BadInit::NotJson { .. })) => {
return invalid_request_response(request.id, e.to_string());
}
Err(e @ BadInit::UpstreamConnectFailed { .. }) => {
return internal_error_response(request.id, e.to_string());
}
};
let session_id = state.sessions.add(connections_with_headers);
ok_response_fresh_sse(request.id, session_id)
}
}
fn ok_response_fresh_sse(
request_id: serde_json::Value,
session_id: String,
) -> Response {
let result = InitializeResult {
protocol_version: PROTOCOL_VERSION.into(),
capabilities: server_capabilities(),
server_info: server_info(),
instructions: None,
_meta: None,
};
let body: JsonRpcResponse<InitializeResult> = JsonRpcResponse::Success {
jsonrpc: "2.0".into(),
id: request_id,
result,
};
let payload = match serde_json::to_string(&body) {
Ok(s) => s,
Err(e) => {
return internal_error_response(
serde_json::Value::Null,
format!("failed to serialize InitializeResult: {e}"),
);
}
};
let header_value = match HeaderValue::from_str(&session_id) {
Ok(v) => v,
Err(_) => {
return internal_error_response(
serde_json::Value::Null,
format!("session id is not a valid header value: {session_id}"),
);
}
};
let stream = stream::once(async move {
Ok::<sse_stream::Sse, Infallible>(sse_stream::Sse::default().data(payload))
});
let body_stream = sse_stream::SseBody::new(stream);
let mut response = Response::new(axum::body::Body::new(body_stream));
*response.status_mut() = StatusCode::OK;
let h = response.headers_mut();
h.insert(SESSION_ID_HEADER, header_value);
h.insert(
axum::http::header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream"),
);
h.insert(
axum::http::header::CACHE_CONTROL,
HeaderValue::from_static("no-cache"),
);
response
}
fn ok_response_resume_sse(request_id: serde_json::Value) -> Response {
let priming = sse_stream::Sse::default()
.data("")
.id("0")
.retry_duration(Duration::from_millis(3000));
let result = InitializeResult {
protocol_version: PROTOCOL_VERSION.into(),
capabilities: server_capabilities(),
server_info: server_info(),
instructions: None,
_meta: None,
};
let body: JsonRpcResponse<InitializeResult> = JsonRpcResponse::Success {
jsonrpc: "2.0".into(),
id: request_id,
result,
};
let payload = match serde_json::to_string(&body) {
Ok(s) => s,
Err(e) => {
return internal_error_response(
serde_json::Value::Null,
format!("failed to serialize InitializeResult: {e}"),
);
}
};
let result_event = sse_stream::Sse::default().data(payload);
let stream = stream::iter(vec![
Ok::<sse_stream::Sse, Infallible>(priming),
Ok(result_event),
])
.chain(stream::pending::<Result<sse_stream::Sse, Infallible>>());
let body_stream = sse_stream::SseBody::new(stream);
let mut response = Response::new(axum::body::Body::new(body_stream));
*response.status_mut() = StatusCode::OK;
let h = response.headers_mut();
h.insert(
axum::http::header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream"),
);
h.insert(
axum::http::header::CACHE_CONTROL,
HeaderValue::from_static("no-cache"),
);
response
}
fn handle_ping(request: JsonRpcRequest) -> Response {
let body: JsonRpcResponse<serde_json::Value> = JsonRpcResponse::Success {
jsonrpc: "2.0".into(),
id: request.id,
result: serde_json::json!({}),
};
(StatusCode::OK, Json(body)).into_response()
}
async fn handle_tools_list(
sessions: &SessionManager,
headers: &HeaderMap,
request: JsonRpcRequest,
) -> Response {
let session_id = match extract_session_id(headers) {
Ok(id) => id,
Err(resp) => return resp,
};
let session = match sessions.get(&session_id) {
Some(s) => s,
None => return unknown_session_response(),
};
match session.list_tools().await {
Ok(result) => {
let body = JsonRpcResponse::Success {
jsonrpc: "2.0".into(),
id: request.id,
result,
};
(StatusCode::OK, Json(body)).into_response()
}
Err(e) => internal_error_response(request.id, format!("list_tools: {e}")),
}
}
async fn handle_tools_call(
sessions: &SessionManager,
headers: &HeaderMap,
request: JsonRpcRequest,
) -> Response {
let session_id = match extract_session_id(headers) {
Ok(id) => id,
Err(resp) => return resp,
};
let session = match sessions.get(&session_id) {
Some(s) => s,
None => return unknown_session_response(),
};
let params: CallToolRequestParams = match request.params.clone() {
Some(v) => match serde_json::from_value(v) {
Ok(p) => p,
Err(e) => {
return invalid_params_response(
request.id,
format!("tools/call params: {e}"),
);
}
},
None => return invalid_params_response(request.id, "missing params".into()),
};
let token = session.register_in_flight(&request.id);
let _guard = InFlightGuard {
session: Arc::clone(&session),
id: request.id.clone(),
};
let result = tokio::select! {
biased;
_ = token.cancelled() => {
return cancelled_response(request.id);
}
result = session.call_tool(¶ms) => result,
};
match result {
Ok(mut result) => {
let pending = session.drain_notifications().await;
if !pending.is_empty() {
let mut prefixed = Vec::with_capacity(2 + pending.len() + result.content.len());
prefixed.push(ContentBlock::Text(TextContent {
text: SYSTEM_REMINDER_PREFIX.to_string(),
annotations: None,
_meta: None,
}));
prefixed.extend(pending);
prefixed.push(ContentBlock::Text(TextContent {
text: SYSTEM_REMINDER_SUFFIX.to_string(),
annotations: None,
_meta: None,
}));
prefixed.append(&mut result.content);
result.content = prefixed;
}
let body = JsonRpcResponse::Success {
jsonrpc: "2.0".into(),
id: request.id,
result,
};
(StatusCode::OK, Json(body)).into_response()
}
Err(CallToolError::ToolNotFound(name)) => {
method_not_found_response(request.id, &format!("tool: {name}"))
}
Err(CallToolError::Upstream(e)) => {
internal_error_response(request.id, format!("upstream call_tool: {e}"))
}
}
}
const SYSTEM_REMINDER_PREFIX: &str =
"<system-reminder>\nThe user sent a new message while you were working:\n";
const SYSTEM_REMINDER_SUFFIX: &str = "\n\n</system-reminder>";
pub async fn handle_notify(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let session_id = match extract_session_id(&headers) {
Ok(id) => id,
Err(resp) => return resp,
};
let session = match state.sessions.get(&session_id) {
Some(s) => s,
None => return unknown_session_response(),
};
let blocks: Vec<ContentBlock> = match serde_json::from_slice(&body) {
Ok(b) => b,
Err(e) => {
return parse_error_response(format!("invalid /notify body: {e}"));
}
};
session.enqueue_notifications(blocks).await;
StatusCode::ACCEPTED.into_response()
}
async fn handle_resources_list(
sessions: &SessionManager,
headers: &HeaderMap,
request: JsonRpcRequest,
) -> Response {
let session_id = match extract_session_id(headers) {
Ok(id) => id,
Err(resp) => return resp,
};
let session = match sessions.get(&session_id) {
Some(s) => s,
None => return unknown_session_response(),
};
match session.list_resources().await {
Ok(result) => {
let body = JsonRpcResponse::Success {
jsonrpc: "2.0".into(),
id: request.id,
result,
};
(StatusCode::OK, Json(body)).into_response()
}
Err(e) => internal_error_response(request.id, format!("list_resources: {e}")),
}
}
async fn handle_resources_read(
sessions: &SessionManager,
headers: &HeaderMap,
request: JsonRpcRequest,
) -> Response {
let session_id = match extract_session_id(headers) {
Ok(id) => id,
Err(resp) => return resp,
};
let session = match sessions.get(&session_id) {
Some(s) => s,
None => return unknown_session_response(),
};
let params: ReadResourceRequestParams = match request.params.clone() {
Some(v) => match serde_json::from_value(v) {
Ok(p) => p,
Err(e) => {
return invalid_params_response(
request.id,
format!("resources/read params: {e}"),
);
}
},
None => return invalid_params_response(request.id, "missing params".into()),
};
let token = session.register_in_flight(&request.id);
let _guard = InFlightGuard {
session: Arc::clone(&session),
id: request.id.clone(),
};
let result = tokio::select! {
biased;
_ = token.cancelled() => {
return cancelled_response(request.id);
}
result = session.read_resource(¶ms.uri) => result,
};
match result {
Ok(result) => {
let body = JsonRpcResponse::Success {
jsonrpc: "2.0".into(),
id: request.id,
result,
};
(StatusCode::OK, Json(body)).into_response()
}
Err(ReadResourceError::ResourceNotFound(uri)) => {
invalid_params_response(request.id, format!("resource not found: {uri}"))
}
Err(ReadResourceError::Upstream(e)) => {
internal_error_response(request.id, format!("upstream read_resource: {e}"))
}
}
}
fn extract_session_id(headers: &HeaderMap) -> Result<String, Response> {
match headers.get(SESSION_ID_HEADER) {
Some(v) => match v.to_str() {
Ok(s) => Ok(s.to_string()),
Err(_) => Err((
StatusCode::NOT_FOUND,
format!("{SESSION_ID_HEADER} is not valid UTF-8"),
)
.into_response()),
},
None => Err((
StatusCode::NOT_FOUND,
format!("missing {SESSION_ID_HEADER} header"),
)
.into_response()),
}
}
fn unknown_session_response() -> Response {
(StatusCode::NOT_FOUND, "unknown session").into_response()
}
fn cancelled_response(id: serde_json::Value) -> Response {
json_rpc_error_response(StatusCode::OK, id, REQUEST_CANCELLED, "request cancelled".into())
}
struct InFlightGuard {
session: Arc<Session>,
id: serde_json::Value,
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
self.session.deregister_in_flight(&self.id);
}
}
fn require_streamable_http_accept(headers: &HeaderMap) -> Result<(), Response> {
let raw = headers
.get(axum::http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let mut json = false;
let mut sse = false;
let mut wildcard = false;
for part in raw.split(',') {
let media = part.split(';').next().unwrap_or("").trim().to_ascii_lowercase();
match media.as_str() {
"application/json" => json = true,
"text/event-stream" => sse = true,
"*/*" | "application/*" | "text/*" => wildcard = true,
_ => {}
}
}
if (json && sse) || wildcard {
Ok(())
} else {
Err((
StatusCode::NOT_ACCEPTABLE,
"Accept header must list both application/json and text/event-stream",
)
.into_response())
}
}
fn server_capabilities() -> ServerCapabilities {
ServerCapabilities {
experimental: None,
logging: None,
completions: None,
prompts: None,
tools: Some(ToolsCapability {
list_changed: Some(true),
}),
resources: Some(ResourcesCapability {
subscribe: None,
list_changed: Some(true),
}),
tasks: None,
}
}
fn server_info() -> Implementation {
Implementation {
name: "oaip".into(),
title: None,
version: env!("CARGO_PKG_VERSION").into(),
website_url: None,
description: None,
icons: None,
}
}
fn json_rpc_error_response(
status: StatusCode,
id: serde_json::Value,
code: i64,
message: String,
) -> Response {
let body: JsonRpcResponse<()> = JsonRpcResponse::Error {
jsonrpc: "2.0".into(),
id,
error: JsonRpcError {
code,
message,
data: None,
},
};
(status, Json(body)).into_response()
}
fn parse_error_response(message: String) -> Response {
json_rpc_error_response(
StatusCode::BAD_REQUEST,
serde_json::Value::Null,
PARSE_ERROR,
message,
)
}
fn invalid_request_response(id: serde_json::Value, message: String) -> Response {
json_rpc_error_response(StatusCode::OK, id, INVALID_REQUEST, message)
}
fn invalid_params_response(id: serde_json::Value, message: String) -> Response {
json_rpc_error_response(StatusCode::OK, id, INVALID_PARAMS, message)
}
fn internal_error_response(id: serde_json::Value, message: String) -> Response {
json_rpc_error_response(StatusCode::OK, id, INTERNAL_ERROR, message)
}
fn method_not_found_response(id: serde_json::Value, method: &str) -> Response {
json_rpc_error_response(
StatusCode::OK,
id,
METHOD_NOT_FOUND,
format!("method not found: {method}"),
)
}