use std::{
collections::HashMap,
pin::Pin,
result::Result as StdResult,
sync::{
self, Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
task::{Context, Poll},
time::{Duration, Instant},
};
use async_trait::async_trait;
use axum::{
Json, Router,
extract::State,
http::{HeaderMap, HeaderValue, StatusCode, header, uri::Authority},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
},
routing::{get, post},
};
use dashmap::DashMap;
use eventsource_stream::Eventsource;
use futures::{Sink, Stream, StreamExt, channel::mpsc};
use reqwest::Client as HttpClient;
use serde_json::Value;
use tokio::{
net::TcpListener,
sync::{Mutex, oneshot},
task::JoinHandle,
time::{interval, sleep, timeout},
};
use tokio_util::sync::CancellationToken;
use tower_http::cors::CorsLayer;
use tracing::{debug, error, info};
use url::Url;
use uuid::Uuid;
use crate::{
auth::OAuth2Client,
error::{Error, Result},
schema::{
JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse,
LATEST_PROTOCOL_VERSION, RequestId,
},
transport::{Transport, TransportStream},
};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120);
const SESSION_TIMEOUT: Duration = Duration::from_secs(3600);
#[derive(Debug, Clone)]
pub struct HttpSession {
pub last_activity: Arc<sync::Mutex<Instant>>,
pub sender: mpsc::UnboundedSender<JSONRPCMessage>,
pub receiver: Arc<Mutex<mpsc::UnboundedReceiver<JSONRPCMessage>>>,
event_counter: Arc<AtomicU64>,
streaming: Arc<AtomicBool>,
}
impl HttpSession {
fn next_event_id(&self) -> u64 {
self.event_counter.fetch_add(1, Ordering::Relaxed) + 1
}
fn bump_event_id(&self, last_event_id: u64) {
let mut current = self.event_counter.load(Ordering::Relaxed);
while last_event_id > current {
match self.event_counter.compare_exchange(
current,
last_event_id,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(next) => current = next,
}
}
}
}
struct StreamingGuard {
flag: Arc<AtomicBool>,
}
impl StreamingGuard {
fn new(flag: Arc<AtomicBool>) -> Self {
flag.store(true, Ordering::SeqCst);
Self { flag }
}
}
impl Drop for StreamingGuard {
fn drop(&mut self) {
self.flag.store(false, Ordering::SeqCst);
}
}
#[derive(Debug)]
struct HttpSseState {
running: AtomicBool,
supported: AtomicBool,
}
impl HttpSseState {
fn new() -> Self {
Self {
running: AtomicBool::new(false),
supported: AtomicBool::new(true),
}
}
}
#[derive(Clone)]
struct HttpServerState {
sessions: Arc<DashMap<String, HttpSession>>,
incoming_tx: mpsc::UnboundedSender<(JSONRPCMessage, String)>,
shutdown: CancellationToken,
}
#[doc(hidden)]
pub struct HttpClientTransport {
endpoint: String,
client: HttpClient,
session_id: Arc<Mutex<Option<String>>>,
last_event_id: Arc<Mutex<Option<String>>>,
sender: Option<mpsc::UnboundedSender<JSONRPCMessage>>,
receiver: Option<mpsc::UnboundedReceiver<JSONRPCMessage>>,
sse_state: Arc<HttpSseState>,
sse_shutdown: CancellationToken,
oauth_client: Option<Arc<OAuth2Client>>,
}
#[doc(hidden)]
pub struct HttpServerTransport {
pub bind_addr: String,
router: Option<Router>,
state: Option<HttpServerState>,
server_handle: Option<JoinHandle<Result<()>>>,
incoming_rx: Option<mpsc::UnboundedReceiver<(JSONRPCMessage, String)>>,
shutdown_token: Option<CancellationToken>,
}
struct HttpTransportStream {
sender: mpsc::UnboundedSender<JSONRPCMessage>,
receiver: mpsc::UnboundedReceiver<JSONRPCMessage>,
_http_task: JoinHandle<()>,
sse_shutdown: CancellationToken,
}
impl Drop for HttpTransportStream {
fn drop(&mut self) {
self._http_task.abort();
self.sse_shutdown.cancel();
}
}
async fn update_session_id(
is_initialize: bool,
headers: &HeaderMap,
session_id: &Arc<Mutex<Option<String>>>,
) -> Option<String> {
if !is_initialize {
return None;
}
let sid = headers.get("Mcp-Session-Id")?;
let Ok(sid_str) = sid.to_str() else {
return None;
};
let mut guard = session_id.lock().await;
let updated = sid_str.to_string();
*guard = Some(updated.clone());
debug!("Got session ID: {}", sid_str);
Some(updated)
}
fn expects_response(msg: &JSONRPCMessage) -> bool {
matches!(msg, JSONRPCMessage::Request(_))
}
fn validate_status(status: reqwest::StatusCode) -> bool {
if status.is_success() {
true
} else {
error!("HTTP request failed with status: {}", status);
false
}
}
async fn forward_response(
response: reqwest::Response,
sender: &mpsc::UnboundedSender<JSONRPCMessage>,
) {
match response.json::<JSONRPCMessage>().await {
Ok(response_msg) => {
debug!("HTTP client received response: {:?}", response_msg);
if let Err(e) = sender.unbounded_send(response_msg) {
error!("Failed to forward response: {}", e);
}
}
Err(e) => {
error!("Failed to parse response: {}", e);
}
}
}
fn response_is_sse(response: &reqwest::Response) -> bool {
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.starts_with("text/event-stream"))
}
async fn forward_sse_response(
response: reqwest::Response,
sender: &mpsc::UnboundedSender<JSONRPCMessage>,
last_event_id: &Arc<Mutex<Option<String>>>,
) {
let stream = response.bytes_stream().eventsource();
futures::pin_mut!(stream);
while let Some(event) = stream.next().await {
match event {
Ok(event) => {
let event_id = event.id;
if !event_id.is_empty() {
let mut guard = last_event_id.lock().await;
*guard = Some(event_id);
}
let data = event.data;
if let Ok(msg) = serde_json::from_str::<JSONRPCMessage>(&data)
&& sender.unbounded_send(msg).is_err()
{
break;
}
}
Err(e) => {
error!("SSE response error: {:?}", e);
break;
}
}
}
}
async fn handle_http_response(
msg: &JSONRPCMessage,
response: reqwest::Response,
sender: &mpsc::UnboundedSender<JSONRPCMessage>,
last_event_id: &Arc<Mutex<Option<String>>>,
) {
if !validate_status(response.status()) {
return;
}
if !expects_response(msg) {
return;
}
if matches!(
response.status(),
StatusCode::ACCEPTED | StatusCode::NO_CONTENT
) {
return;
}
if response_is_sse(&response) {
forward_sse_response(response, sender, last_event_id).await;
} else {
forward_response(response, sender).await;
}
}
type OriginResult = StdResult<(), Box<Response>>;
fn validate_origin(headers: &HeaderMap) -> OriginResult {
if headers.get(header::ORIGIN).is_none() {
return Ok(());
}
let origin = headers.get(header::ORIGIN).unwrap();
let origin_str = origin
.to_str()
.map_err(|_| Box::new((StatusCode::FORBIDDEN, "Invalid Origin").into_response()))?;
if origin_str == "null" {
return Err(Box::new(
(StatusCode::FORBIDDEN, "Invalid Origin").into_response(),
));
}
let origin_url = Url::parse(origin_str)
.map_err(|_| Box::new((StatusCode::FORBIDDEN, "Invalid Origin").into_response()))?;
if !matches!(origin_url.scheme(), "http" | "https") {
return Err(Box::new(
(StatusCode::FORBIDDEN, "Invalid Origin").into_response(),
));
}
let host_header = headers
.get(header::HOST)
.and_then(|value| value.to_str().ok())
.ok_or_else(|| Box::new((StatusCode::FORBIDDEN, "Invalid Origin").into_response()))?;
let authority = host_header
.parse::<Authority>()
.map_err(|_| Box::new((StatusCode::FORBIDDEN, "Invalid Origin").into_response()))?;
let origin_host = origin_url
.host_str()
.ok_or_else(|| Box::new((StatusCode::FORBIDDEN, "Invalid Origin").into_response()))?;
if !origin_host.eq_ignore_ascii_case(authority.host()) {
return Err(Box::new(
(StatusCode::FORBIDDEN, "Invalid Origin").into_response(),
));
}
if let Some(expected_port) = authority.port_u16() {
if origin_url.port_or_known_default() != Some(expected_port) {
return Err(Box::new(
(StatusCode::FORBIDDEN, "Invalid Origin").into_response(),
));
}
} else if origin_url.port().is_some() {
return Err(Box::new(
(StatusCode::FORBIDDEN, "Invalid Origin").into_response(),
));
}
Ok(())
}
fn parse_last_event_id(headers: &HeaderMap) -> Option<u64> {
headers
.get("Last-Event-ID")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u64>().ok())
}
fn apply_last_event_id(headers: &HeaderMap, session: &HttpSession) {
if let Some(last_event_id) = parse_last_event_id(headers) {
session.bump_event_id(last_event_id);
}
}
fn build_sse_event(session: &HttpSession, message: &JSONRPCMessage) -> Event {
Event::default()
.id(session.next_event_id().to_string())
.data(serde_json::to_string(message).unwrap())
}
impl HttpClientTransport {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
client: HttpClient::builder()
.timeout(DEFAULT_TIMEOUT)
.build()
.expect("Failed to create HTTP client"),
session_id: Arc::new(Mutex::new(None)),
last_event_id: Arc::new(Mutex::new(None)),
sender: None,
receiver: None,
sse_state: Arc::new(HttpSseState::new()),
sse_shutdown: CancellationToken::new(),
oauth_client: None,
}
}
pub fn with_oauth(mut self, oauth_client: Arc<OAuth2Client>) -> Self {
self.oauth_client = Some(oauth_client);
self
}
async fn connect_sse(
client: HttpClient,
endpoint: String,
session_id: Arc<Mutex<Option<String>>>,
last_event_id: Arc<Mutex<Option<String>>>,
sender: mpsc::UnboundedSender<JSONRPCMessage>,
oauth_client: Option<Arc<OAuth2Client>>,
shutdown: CancellationToken,
) -> Result<SseOutcome> {
let Some(session_id_value) = session_id.lock().await.clone() else {
return Ok(SseOutcome::NoSession);
};
let mut headers = HeaderMap::new();
headers.insert(
header::ACCEPT,
HeaderValue::from_static("text/event-stream"),
);
headers.insert(
"MCP-Protocol-Version",
HeaderValue::from_static(LATEST_PROTOCOL_VERSION),
);
headers.insert(
"Mcp-Session-Id",
HeaderValue::from_str(&session_id_value)
.map_err(|_| Error::Transport("Invalid session ID".into()))?,
);
if let Some(last_event_id_value) = last_event_id.lock().await.clone() {
headers.insert(
"Last-Event-ID",
HeaderValue::from_str(&last_event_id_value)
.map_err(|_| Error::Transport("Invalid Last-Event-ID".into()))?,
);
}
if let Some(oauth_client) = &oauth_client {
let token = oauth_client.get_valid_token().await?;
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|_| Error::Transport("Invalid authorization token".into()))?,
);
}
let response = client
.get(&endpoint)
.headers(headers)
.send()
.await
.map_err(|e| Error::Transport(format!("Failed to connect SSE: {e}")))?;
if response.status() == StatusCode::METHOD_NOT_ALLOWED {
return Ok(SseOutcome::NotSupported);
}
if !response.status().is_success() {
return Err(Error::Transport(format!(
"SSE connection failed with status: {}",
response.status()
)));
}
let stream = response.bytes_stream().eventsource();
futures::pin_mut!(stream);
loop {
tokio::select! {
_ = shutdown.cancelled() => break,
event = stream.next() => {
match event {
Some(Ok(event)) => {
let event_id = event.id;
if !event_id.is_empty() {
let mut guard = last_event_id.lock().await;
*guard = Some(event_id);
}
let data = event.data;
if let Ok(msg) = serde_json::from_str::<JSONRPCMessage>(&data)
&& sender.unbounded_send(msg).is_err()
{
break;
}
}
Some(Err(e)) => {
error!("SSE error: {:?}", e);
break;
}
None => break,
}
}
}
}
Ok(SseOutcome::Closed)
}
}
struct SseStartContext {
client: HttpClient,
endpoint: String,
session_id: Arc<Mutex<Option<String>>>,
last_event_id: Arc<Mutex<Option<String>>>,
sender: mpsc::UnboundedSender<JSONRPCMessage>,
oauth_client: Option<Arc<OAuth2Client>>,
sse_state: Arc<HttpSseState>,
shutdown: CancellationToken,
}
fn maybe_start_sse(context: SseStartContext) {
if !context.sse_state.supported.load(Ordering::SeqCst) {
return;
}
if context.sse_state.running.swap(true, Ordering::SeqCst) {
return;
}
tokio::spawn(async move {
let outcome = HttpClientTransport::connect_sse(
context.client,
context.endpoint,
context.session_id,
context.last_event_id,
context.sender,
context.oauth_client,
context.shutdown,
)
.await;
match outcome {
Ok(SseOutcome::NotSupported) => {
context.sse_state.supported.store(false, Ordering::SeqCst);
}
Ok(_) => {}
Err(err) => {
debug!("SSE connection failed (server may not support it): {}", err);
}
}
context.sse_state.running.store(false, Ordering::SeqCst);
});
}
#[derive(Debug, Clone, Copy)]
enum SseOutcome {
NoSession,
NotSupported,
Closed,
}
#[async_trait]
impl Transport for HttpClientTransport {
async fn connect(&mut self) -> Result<()> {
info!("Connecting to HTTP endpoint: {}", self.endpoint);
let (tx, rx) = mpsc::unbounded();
self.sender = Some(tx);
self.receiver = Some(rx);
Ok(())
}
fn framed(mut self: Box<Self>) -> Result<Box<dyn TransportStream>> {
let sender = self.sender.take().ok_or(Error::TransportDisconnected)?;
let receiver = self.receiver.take().ok_or(Error::TransportDisconnected)?;
let endpoint = self.endpoint.clone();
let client = self.client.clone();
let session_id = self.session_id.clone();
let last_event_id = self.last_event_id.clone();
let oauth_client = self.oauth_client.clone();
let sse_state = self.sse_state.clone();
let sse_shutdown = self.sse_shutdown.clone();
let (http_tx, mut http_rx) = mpsc::unbounded::<JSONRPCMessage>();
let sender_clone = sender;
let http_task = tokio::spawn(async move {
while let Some(msg) = http_rx.next().await {
debug!("HTTP client sending message: {:?}", msg);
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
headers.insert(
header::ACCEPT,
HeaderValue::from_static("application/json, text/event-stream"),
);
headers.insert(
"MCP-Protocol-Version",
HeaderValue::from_static(LATEST_PROTOCOL_VERSION),
);
if let Some(ref sid) = *session_id.lock().await {
headers.insert("Mcp-Session-Id", HeaderValue::from_str(sid).unwrap());
}
if let Some(oauth_client) = &oauth_client {
match oauth_client.get_valid_token().await {
Ok(token) => {
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
);
}
Err(e) => {
error!("Failed to get OAuth token: {}", e);
continue;
}
}
}
let is_initialize = matches!(&msg, JSONRPCMessage::Request(req) if req.request.method == "initialize");
match client
.post(&endpoint)
.headers(headers)
.json(&msg)
.send()
.await
{
Ok(response) => {
debug!("HTTP response status: {}", response.status());
update_session_id(is_initialize, response.headers(), &session_id).await;
handle_http_response(&msg, response, &sender_clone, &last_event_id).await;
maybe_start_sse(SseStartContext {
client: client.clone(),
endpoint: endpoint.clone(),
session_id: session_id.clone(),
last_event_id: last_event_id.clone(),
sender: sender_clone.clone(),
oauth_client: oauth_client.clone(),
sse_state: sse_state.clone(),
shutdown: sse_shutdown.clone(),
});
}
Err(e) => {
error!("Failed to send HTTP request to {}: {:?}", endpoint, e);
}
}
}
});
Ok(Box::new(HttpTransportStream {
sender: http_tx,
receiver,
_http_task: http_task,
sse_shutdown: self.sse_shutdown.clone(),
}))
}
fn remote_addr(&self) -> String {
self.endpoint.clone()
}
}
impl HttpServerTransport {
pub fn new(bind_addr: impl Into<String>) -> Self {
Self {
bind_addr: bind_addr.into(),
router: None,
state: None,
server_handle: None,
incoming_rx: None,
shutdown_token: None,
}
}
pub async fn start(&mut self) -> Result<()> {
if self.server_handle.is_some() {
return Ok(());
}
let (incoming_tx, incoming_rx) = mpsc::unbounded();
self.incoming_rx = Some(incoming_rx);
let state = HttpServerState {
sessions: Arc::new(DashMap::new()),
incoming_tx,
shutdown: CancellationToken::new(),
};
self.state = Some(state.clone());
self.shutdown_token = Some(state.shutdown.clone());
let router = Router::new()
.route("/", post(handle_post))
.route("/", get(handle_get))
.layer(CorsLayer::permissive())
.with_state(state.clone());
self.router = Some(router.clone());
let listener = TcpListener::bind(&self.bind_addr).await.map_err(|e| {
Error::Transport(format!("Failed to bind to {}: {}", self.bind_addr, e))
})?;
self.bind_addr = listener
.local_addr()
.map_err(|e| Error::Transport(format!("Failed to get local address: {e}")))?
.to_string();
let bind_addr = self.bind_addr.clone();
let shutdown = state.shutdown.clone();
let (ready_tx, ready_rx) = oneshot::channel();
let bind_addr_clone = bind_addr.clone();
let cleanup_sessions = state.sessions.clone();
let cleanup_shutdown = state.shutdown.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(60));
loop {
tokio::select! {
_ = cleanup_shutdown.cancelled() => break,
_ = interval.tick() => {
let now = Instant::now();
let expired: Vec<String> = cleanup_sessions
.iter()
.filter(|entry| {
let last_active = *entry.value().last_activity.lock().unwrap();
now.duration_since(last_active) > SESSION_TIMEOUT
})
.map(|entry| entry.key().clone())
.collect();
for id in expired {
debug!("Removing expired session: {}", id);
cleanup_sessions.remove(&id);
}
}
}
}
});
let server_handle = tokio::spawn(async move {
info!("HTTP server starting on {}", bind_addr);
ready_tx.send(()).ok();
axum::serve(listener, router)
.with_graceful_shutdown(async move {
shutdown.cancelled().await;
})
.await
.map_err(|e| Error::Transport(format!("Server error: {e}")))
});
self.server_handle = Some(server_handle);
ready_rx
.await
.map_err(|_| Error::Transport("Server failed to start".into()))?;
sleep(Duration::from_millis(100)).await;
info!("HTTP server ready on {}", bind_addr_clone);
Ok(())
}
}
#[async_trait]
impl Transport for HttpServerTransport {
async fn connect(&mut self) -> Result<()> {
self.start().await
}
fn framed(mut self: Box<Self>) -> Result<Box<dyn TransportStream>> {
let _shutdown_token = self.shutdown_token.take();
let incoming_rx = self
.incoming_rx
.take()
.ok_or(Error::TransportDisconnected)?;
let session_id = Uuid::new_v4().to_string();
let (tx, rx) = mpsc::unbounded();
if let Some(state) = &self.state {
let session = HttpSession {
last_activity: Arc::new(sync::Mutex::new(Instant::now())),
sender: tx,
receiver: Arc::new(Mutex::new(rx)),
event_counter: Arc::new(AtomicU64::new(0)),
streaming: Arc::new(AtomicBool::new(false)),
};
state.sessions.insert(session_id, session);
}
let stream = HttpServerStream {
incoming_rx,
state: self.state.clone(),
request_sessions: Arc::new(DashMap::new()),
};
Ok(Box::new(stream))
}
fn remote_addr(&self) -> String {
self.bind_addr.clone()
}
}
struct HttpServerStream {
incoming_rx: mpsc::UnboundedReceiver<(JSONRPCMessage, String)>,
state: Option<HttpServerState>,
request_sessions: Arc<DashMap<RequestId, String>>,
}
impl Stream for HttpServerStream {
type Item = Result<JSONRPCMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.incoming_rx.poll_next_unpin(cx) {
Poll::Ready(Some((msg, session_id))) => {
if let JSONRPCMessage::Request(ref req) = msg {
self.request_sessions.insert(req.id.clone(), session_id);
}
Poll::Ready(Some(Ok(msg)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl Sink<JSONRPCMessage> for HttpServerStream {
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: JSONRPCMessage) -> Result<()> {
if let Some(state) = &self.state {
match &item {
JSONRPCMessage::Response(resp) => {
let response_id = match resp {
JSONRPCResponse::Result(result) => Some(result.id.clone()),
JSONRPCResponse::Error(error) => error.id.clone(),
};
if let Some(response_id) = response_id
&& let Some((_, session_id)) = self.request_sessions.remove(&response_id)
&& let Some(session) = state.sessions.get(&session_id)
{
session.sender.unbounded_send(item).ok();
}
}
JSONRPCMessage::Notification(_) | JSONRPCMessage::Request(_) => {
if let Some(session_id) = resolve_session_id_for_message(&item, state) {
if let Some(session) = state.sessions.get(&session_id) {
session.sender.unbounded_send(item).ok();
}
} else {
debug!("Dropping HTTP message without session context");
}
}
}
}
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
}
impl TransportStream for HttpServerStream {}
fn resolve_session_id_for_message(
message: &JSONRPCMessage,
state: &HttpServerState,
) -> Option<String> {
let session_id = match message {
JSONRPCMessage::Request(request) => session_id_from_request(request),
JSONRPCMessage::Notification(notification) => session_id_from_notification(notification),
JSONRPCMessage::Response(_) => None,
};
session_id.or_else(|| single_session_id(state))
}
fn single_session_id(state: &HttpServerState) -> Option<String> {
if state.sessions.len() == 1 {
state
.sessions
.iter()
.next()
.map(|entry| entry.key().clone())
} else {
None
}
}
fn session_id_from_request(request: &JSONRPCRequest) -> Option<String> {
request.request.params.as_ref().and_then(|params| {
params
._meta
.as_ref()
.and_then(|meta| session_id_from_meta(Some(&meta.other)))
})
}
fn session_id_from_notification(notification: &JSONRPCNotification) -> Option<String> {
notification
.notification
.params
.as_ref()
.and_then(|params| session_id_from_meta(params._meta.as_ref()))
}
fn session_id_from_meta(meta: Option<&HashMap<String, Value>>) -> Option<String> {
meta.and_then(|map| map.get("sessionId"))
.and_then(Value::as_str)
.map(|value| value.to_string())
}
impl Drop for HttpServerTransport {
fn drop(&mut self) {
if let Some(token) = &self.shutdown_token {
token.cancel();
}
}
}
async fn handle_post(
State(state): State<HttpServerState>,
headers: HeaderMap,
Json(message): Json<JSONRPCMessage>,
) -> Response {
debug!("HTTP server received POST request: {:?}", message);
if let Err(response) = validate_origin(&headers) {
return *response;
}
if let Some(version) = headers.get("MCP-Protocol-Version")
&& version != LATEST_PROTOCOL_VERSION
{
return (StatusCode::BAD_REQUEST, "Unsupported protocol version").into_response();
}
let session_id = headers
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if matches!(&message, JSONRPCMessage::Request(req) if req.request.method == "initialize") {
let new_session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
let (tx, rx) = mpsc::unbounded();
let session = HttpSession {
last_activity: Arc::new(sync::Mutex::new(Instant::now())),
sender: tx,
receiver: Arc::new(Mutex::new(rx)),
event_counter: Arc::new(AtomicU64::new(0)),
streaming: Arc::new(AtomicBool::new(false)),
};
state
.sessions
.insert(new_session_id.clone(), session.clone());
state
.incoming_tx
.unbounded_send((message, new_session_id.clone()))
.ok();
let receiver = session.receiver.clone();
let response = timeout(Duration::from_secs(5), async move {
let mut receiver = receiver.lock().await;
receiver.next().await
})
.await;
match response {
Ok(Some(response)) => {
let mut http_response = Json::<JSONRPCMessage>(response).into_response();
http_response.headers_mut().insert(
"Mcp-Session-Id",
HeaderValue::from_str(&new_session_id).unwrap(),
);
return http_response;
}
Ok(None) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "No response from server")
.into_response();
}
Err(_) => {
return (StatusCode::REQUEST_TIMEOUT, "Initialization timeout").into_response();
}
}
}
let session_id = match session_id {
Some(id) => id,
None => return (StatusCode::BAD_REQUEST, "Missing session ID").into_response(),
};
let session = if let Some(session) = state.sessions.get(&session_id) {
*session.last_activity.lock().unwrap() = Instant::now();
session.clone()
} else {
return (StatusCode::NOT_FOUND, "Session not found").into_response();
};
match &message {
JSONRPCMessage::Request(_) => {
state
.incoming_tx
.unbounded_send((message, session_id.clone()))
.ok();
if session.streaming.load(Ordering::SeqCst) {
return StatusCode::ACCEPTED.into_response();
}
let accepts_sse = headers
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|v| v.contains("text/event-stream"))
.unwrap_or(false);
if accepts_sse {
apply_last_event_id(&headers, &session);
let receiver = session.receiver.clone();
let stream = async_stream::stream! {
let mut receiver = receiver.lock().await;
while let Some(msg) = receiver.next().await {
yield Ok::<_, Error>(build_sse_event(&session, &msg));
if matches!(&msg, JSONRPCMessage::Response(_)) {
break;
}
}
};
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else {
let receiver = session.receiver.clone();
let response = timeout(
Duration::from_secs(30), async move {
let mut receiver = receiver.lock().await;
receiver.next().await
},
)
.await;
match response {
Ok(Some(response)) => Json::<JSONRPCMessage>(response).into_response(),
Ok(None) => (StatusCode::INTERNAL_SERVER_ERROR, "No response").into_response(),
Err(_) => (StatusCode::REQUEST_TIMEOUT, "Request timeout").into_response(),
}
}
}
JSONRPCMessage::Response(_) | JSONRPCMessage::Notification(_) => {
state.incoming_tx.unbounded_send((message, session_id)).ok();
StatusCode::ACCEPTED.into_response()
}
}
}
async fn handle_get(State(state): State<HttpServerState>, headers: HeaderMap) -> Response {
if let Err(response) = validate_origin(&headers) {
return *response;
}
if let Some(version) = headers.get("MCP-Protocol-Version")
&& version != LATEST_PROTOCOL_VERSION
{
return (StatusCode::BAD_REQUEST, "Unsupported protocol version").into_response();
}
let session_id = headers
.get("Mcp-Session-Id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let session_id = match session_id {
Some(id) => id,
None => return (StatusCode::BAD_REQUEST, "Missing session ID").into_response(),
};
let session = if let Some(session) = state.sessions.get(&session_id) {
apply_last_event_id(&headers, &session);
session.clone()
} else {
return (StatusCode::NOT_FOUND, "Session not found").into_response();
};
let receiver = session.receiver.clone();
let streaming = session.streaming.clone();
let stream = async_stream::stream! {
let _guard = StreamingGuard::new(streaming);
let mut receiver = receiver.lock().await;
while let Some(msg) = receiver.next().await {
yield Ok::<_, Error>(build_sse_event(&session, &msg));
}
};
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
}
impl Stream for HttpTransportStream {
type Item = Result<JSONRPCMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.receiver.poll_next_unpin(cx) {
Poll::Ready(Some(msg)) => Poll::Ready(Some(Ok(msg))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl Sink<JSONRPCMessage> for HttpTransportStream {
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: JSONRPCMessage) -> Result<()> {
self.sender
.unbounded_send(item)
.map_err(|_| Error::ConnectionClosed)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
}
impl TransportStream for HttpTransportStream {}
#[cfg(test)]
mod tests {
use futures::SinkExt;
use super::*;
use crate::schema::{JSONRPCNotification, Notification};
#[tokio::test]
async fn test_http_client_transport_creation() {
let transport = HttpClientTransport::new("http://localhost:8080");
assert_eq!(transport.endpoint, "http://localhost:8080");
}
#[tokio::test]
async fn test_http_server_transport_creation() {
let transport = HttpServerTransport::new("127.0.0.1:8080");
assert_eq!(transport.bind_addr, "127.0.0.1:8080");
}
#[tokio::test]
async fn test_session_management() {
let (tx, rx) = mpsc::unbounded();
let session = HttpSession {
last_activity: Arc::new(sync::Mutex::new(Instant::now())),
sender: tx,
receiver: Arc::new(Mutex::new(rx)),
event_counter: Arc::new(AtomicU64::new(0)),
streaming: Arc::new(AtomicBool::new(false)),
};
let before = *session.last_activity.lock().unwrap();
sleep(Duration::from_millis(10)).await;
*session.last_activity.lock().unwrap() = Instant::now();
let after = *session.last_activity.lock().unwrap();
assert!(after > before);
}
#[tokio::test]
async fn test_http_transport_stream() {
let (tx1, rx1) = mpsc::unbounded();
let (tx2, rx2) = mpsc::unbounded();
let shutdown1 = CancellationToken::new();
let shutdown2 = CancellationToken::new();
let mut stream1 = HttpTransportStream {
sender: tx1,
receiver: rx2,
_http_task: tokio::spawn(async {}), sse_shutdown: shutdown1,
};
let mut stream2 = HttpTransportStream {
sender: tx2,
receiver: rx1,
_http_task: tokio::spawn(async {}), sse_shutdown: shutdown2,
};
let msg = JSONRPCMessage::Notification(JSONRPCNotification {
jsonrpc: "2.0".to_string(),
notification: Notification {
method: "test".to_string(),
params: None,
},
});
stream1.send(msg.clone()).await.unwrap();
let received = stream2.next().await.unwrap().unwrap();
match (msg, received) {
(JSONRPCMessage::Notification(n1), JSONRPCMessage::Notification(n2)) => {
assert_eq!(n1.notification.method, n2.notification.method);
}
_ => panic!("Message type mismatch"),
}
}
#[test]
fn test_validate_origin_allows_same_host() {
let mut headers = HeaderMap::new();
headers.insert(
header::ORIGIN,
HeaderValue::from_static("http://localhost:8080"),
);
headers.insert(header::HOST, HeaderValue::from_static("localhost:8080"));
assert!(validate_origin(&headers).is_ok());
}
#[test]
fn test_validate_origin_rejects_mismatch() {
let mut headers = HeaderMap::new();
headers.insert(
header::ORIGIN,
HeaderValue::from_static("http://example.com"),
);
headers.insert(header::HOST, HeaderValue::from_static("localhost:8080"));
let response = validate_origin(&headers).unwrap_err();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[test]
fn test_apply_last_event_id_advances_counter() {
let (tx, rx) = mpsc::unbounded();
let session = HttpSession {
last_activity: Arc::new(sync::Mutex::new(Instant::now())),
sender: tx,
receiver: Arc::new(Mutex::new(rx)),
event_counter: Arc::new(AtomicU64::new(0)),
streaming: Arc::new(AtomicBool::new(false)),
};
let mut headers = HeaderMap::new();
headers.insert("Last-Event-ID", HeaderValue::from_static("5"));
apply_last_event_id(&headers, &session);
assert_eq!(session.next_event_id(), 6);
}
}