use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::{Arc, Mutex};
use anyhow::Result;
use axum::extract::{Query, State};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use futures_util::stream::Stream;
use serde_json::{json, Value};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tower_http::cors::CorsLayer;
use uuid::Uuid;
use crate::server::stdio::process_message;
use crate::server::NexusServer;
type SenderMap = Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>;
#[derive(Clone)]
struct AppState {
server: Arc<NexusServer>,
sessions: SenderMap,
}
fn sse_stream(
rx: mpsc::UnboundedReceiver<String>,
) -> impl Stream<Item = Result<Event, Infallible>> {
UnboundedReceiverStream::new(rx)
.map(|msg| Ok::<_, Infallible>(Event::default().event("message").data(msg)))
}
async fn sse_handler(State(state): State<AppState>) -> impl IntoResponse {
let session_id = Uuid::new_v4().to_string();
let (tx, rx) = mpsc::unbounded_channel::<String>();
state
.sessions
.lock()
.unwrap()
.insert(session_id.clone(), tx);
let endpoint_event = Ok::<_, Infallible>(
Event::default()
.event("endpoint")
.data(format!("/messages?sessionId={session_id}")),
);
let stream = futures_util::stream::once(async move { endpoint_event }).chain(sse_stream(rx));
Sse::new(stream).keep_alive(KeepAlive::default())
}
async fn messages_handler(
State(state): State<AppState>,
Query(params): Query<HashMap<String, String>>,
body: String,
) -> impl IntoResponse {
let session_id = match params.get("sessionId") {
Some(id) => id.clone(),
None => {
return Json(json!({ "error": "missing sessionId query parameter" })).into_response()
}
};
let response = match process_message(&state.server, body.trim()) {
Some(r) => r,
None => {
return axum::http::StatusCode::ACCEPTED.into_response();
}
};
{
let sessions = state.sessions.lock().unwrap();
if let Some(tx) = sessions.get(&session_id) {
let _ = tx.send(response);
}
}
axum::http::StatusCode::ACCEPTED.into_response()
}
async fn health_handler() -> Json<Value> {
Json(json!({
"ok": true,
"transport": "http-sse",
"version": env!("CARGO_PKG_VERSION")
}))
}
pub async fn run_http(server: NexusServer, port: u16) -> Result<()> {
let state = AppState {
server: Arc::new(server),
sessions: Arc::new(Mutex::new(HashMap::new())),
};
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/messages", post(messages_handler))
.route("/health", get(health_handler))
.layer(CorsLayer::permissive())
.with_state(state);
let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?;
eprintln!(
"bctx mcp: HTTP/SSE on http://0.0.0.0:{port} \
(SSE: GET /sse requests: POST /messages health: GET /health)"
);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
atlas::vault_store::flush();
Ok(())
}
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl-C handler");
eprintln!("\nbctx mcp: shutting down");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn app_state_is_clone() {
let state = AppState {
server: Arc::new(NexusServer::default()),
sessions: Arc::new(Mutex::new(HashMap::new())),
};
let _clone = state.clone();
}
#[test]
fn session_registry_insert_and_lookup() {
let sessions: SenderMap = Arc::new(Mutex::new(HashMap::new()));
let (tx, _rx) = mpsc::unbounded_channel::<String>();
sessions
.lock()
.unwrap()
.insert("session-abc".to_string(), tx);
assert!(sessions.lock().unwrap().contains_key("session-abc"));
assert!(!sessions.lock().unwrap().contains_key("session-xyz"));
}
#[test]
fn unknown_session_is_a_noop() {
let sessions: SenderMap = Arc::new(Mutex::new(HashMap::new()));
let guard = sessions.lock().unwrap();
let result = guard
.get("nonexistent")
.map(|tx| tx.send("msg".to_string()));
assert!(result.is_none());
}
#[tokio::test]
async fn ping_roundtrip_via_process_message() {
let server = NexusServer::default();
let req = r#"{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}"#;
let resp = process_message(&server, req).unwrap();
let v: serde_json::Value = serde_json::from_str(&resp).unwrap();
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["id"], 1);
assert!(v["result"].is_object());
}
#[tokio::test]
async fn notification_returns_none() {
let server = NexusServer::default();
let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
assert!(process_message(&server, notif).is_none());
}
#[tokio::test]
async fn session_channel_delivers_message() {
let sessions: SenderMap = Arc::new(Mutex::new(HashMap::new()));
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
sessions.lock().unwrap().insert("s1".to_string(), tx);
{
let guard = sessions.lock().unwrap();
let _ = guard.get("s1").unwrap().send("hello".to_string());
}
let msg = rx.recv().await.unwrap();
assert_eq!(msg, "hello");
}
}
#[cfg(test)]
mod integration {
use super::*;
use reqwest::Client;
use std::time::Duration;
use tokio::time::timeout;
async fn spawn_server() -> u16 {
let state = AppState {
server: Arc::new(NexusServer::default()),
sessions: Arc::new(Mutex::new(HashMap::new())),
};
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/messages", post(messages_handler))
.route("/health", get(health_handler))
.layer(CorsLayer::permissive())
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
port
}
#[tokio::test]
async fn health_endpoint_returns_ok() {
let port = spawn_server().await;
let resp = Client::new()
.get(format!("http://127.0.0.1:{port}/health"))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
let body: serde_json::Value = resp.json().await.unwrap();
assert_eq!(body["ok"], true);
assert_eq!(body["transport"], "http-sse");
assert!(body["version"].is_string());
}
#[tokio::test]
async fn messages_without_session_id_returns_error_body() {
let port = spawn_server().await;
let resp = Client::new()
.post(format!("http://127.0.0.1:{port}/messages"))
.header("Content-Type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#)
.send()
.await
.unwrap();
let body: serde_json::Value = resp.json().await.unwrap();
assert!(
body["error"].is_string(),
"expected error field, got: {body}"
);
}
#[tokio::test]
async fn messages_unknown_session_accepted_silently() {
let port = spawn_server().await;
let resp = Client::new()
.post(format!(
"http://127.0.0.1:{port}/messages?sessionId=does-not-exist"
))
.header("Content-Type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#)
.send()
.await
.unwrap();
assert_eq!(resp.status(), 202);
}
#[tokio::test]
async fn tools_list_returned_via_sse() {
let port = spawn_server().await;
let client = Arc::new(Client::new());
let (endpoint_tx, endpoint_rx) = tokio::sync::oneshot::channel::<String>();
let (response_tx, response_rx) = tokio::sync::oneshot::channel::<serde_json::Value>();
let client2 = client.clone();
tokio::spawn(async move {
let mut resp = client2
.get(format!("http://127.0.0.1:{port}/sse"))
.send()
.await
.unwrap();
let mut buf = String::new();
let mut endpoint_tx = Some(endpoint_tx);
let mut response_tx = Some(response_tx);
'outer: while let Some(chunk) = resp.chunk().await.unwrap() {
buf.push_str(&String::from_utf8_lossy(&chunk));
while let Some(end) = buf.find("\n\n") {
let event = buf[..end].to_string();
buf = buf[end + 2..].to_string();
if endpoint_tx.is_some() && event.contains("event: endpoint") {
if let Some(data) = event.lines().find(|l| l.starts_with("data: ")) {
let path = &data["data: ".len()..];
if let Some(sid) = path.split("sessionId=").nth(1) {
if let Some(tx) = endpoint_tx.take() {
let _ = tx.send(sid.to_string());
}
}
}
}
if endpoint_tx.is_none()
&& response_tx.is_some()
&& event.contains("event: message")
{
if let Some(data) = event.lines().find(|l| l.starts_with("data: ")) {
let payload = &data["data: ".len()..];
if let Ok(v) = serde_json::from_str::<serde_json::Value>(payload) {
if let Some(tx) = response_tx.take() {
let _ = tx.send(v);
}
break 'outer;
}
}
}
}
}
});
let session_id = timeout(Duration::from_secs(2), endpoint_rx)
.await
.expect("timed out waiting for SSE endpoint event")
.unwrap();
client
.post(format!(
"http://127.0.0.1:{port}/messages?sessionId={session_id}"
))
.header("Content-Type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":99,"method":"tools/list"}"#)
.send()
.await
.unwrap();
let response = timeout(Duration::from_secs(2), response_rx)
.await
.expect("timed out waiting for SSE response")
.unwrap();
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 99);
let tools = response["result"]["tools"].as_array().unwrap();
assert!(tools.len() >= 36, "expected ≥36 tools, got {}", tools.len());
}
#[tokio::test]
async fn ping_end_to_end_via_sse() {
let port = spawn_server().await;
let client = Arc::new(Client::new());
let (endpoint_tx, endpoint_rx) = tokio::sync::oneshot::channel::<String>();
let (response_tx, response_rx) = tokio::sync::oneshot::channel::<serde_json::Value>();
let client2 = client.clone();
tokio::spawn(async move {
let mut resp = client2
.get(format!("http://127.0.0.1:{port}/sse"))
.send()
.await
.unwrap();
let mut buf = String::new();
let mut endpoint_tx = Some(endpoint_tx);
let mut response_tx = Some(response_tx);
'outer: while let Some(chunk) = resp.chunk().await.unwrap() {
buf.push_str(&String::from_utf8_lossy(&chunk));
while let Some(end) = buf.find("\n\n") {
let event = buf[..end].to_string();
buf = buf[end + 2..].to_string();
if endpoint_tx.is_some() && event.contains("event: endpoint") {
if let Some(data) = event.lines().find(|l| l.starts_with("data: ")) {
if let Some(sid) = data["data: ".len()..].split("sessionId=").nth(1) {
if let Some(tx) = endpoint_tx.take() {
let _ = tx.send(sid.to_string());
}
}
}
}
if endpoint_tx.is_none()
&& response_tx.is_some()
&& event.contains("event: message")
{
if let Some(data) = event.lines().find(|l| l.starts_with("data: ")) {
if let Ok(v) =
serde_json::from_str::<serde_json::Value>(&data["data: ".len()..])
{
if let Some(tx) = response_tx.take() {
let _ = tx.send(v);
}
break 'outer;
}
}
}
}
}
});
let session_id = timeout(Duration::from_secs(2), endpoint_rx)
.await
.expect("timed out waiting for SSE endpoint event")
.unwrap();
client
.post(format!(
"http://127.0.0.1:{port}/messages?sessionId={session_id}"
))
.header("Content-Type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":7,"method":"ping"}"#)
.send()
.await
.unwrap();
let response = timeout(Duration::from_secs(2), response_rx)
.await
.expect("timed out waiting for ping response via SSE")
.unwrap();
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 7);
assert!(response["result"].is_object());
assert!(response.get("error").is_none());
}
}