use std::convert::Infallible;
use std::sync::Arc;
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse,
},
routing::{get, post},
Json, Router,
};
use serde::Deserialize;
use serde_json::json;
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::{Stream, StreamExt};
use tower_http::cors::{Any, CorsLayer};
use super::protocol::{McpHandler, McpRequest, McpResponse};
use crate::realtime::{EventType, RealtimeEvent, RealtimeManager};
#[derive(Clone)]
struct AppState {
handler: Arc<dyn McpHandler>,
api_key: Option<String>,
realtime: Option<RealtimeManager>,
}
async fn handle_mcp(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<McpRequest>,
) -> impl IntoResponse {
if let Some(ref expected) = state.api_key {
if !check_bearer(&headers, expected) {
let err = McpResponse::error(request.id, -32000, "Unauthorized".to_string());
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::to_value(err).unwrap_or_default()),
);
}
}
let is_notification = request.id.is_none();
let response = state.handler.handle_request(request);
if is_notification {
return (StatusCode::ACCEPTED, Json(serde_json::Value::Null));
}
(
StatusCode::OK,
Json(serde_json::to_value(response).unwrap_or_default()),
)
}
async fn handle_health() -> impl IntoResponse {
Json(json!({
"status": "ok",
"version": env!("CARGO_PKG_VERSION"),
"protocol": "2025-11-25"
}))
}
#[derive(Debug, Clone, Deserialize)]
struct EventsQuery {
event_types: Option<String>,
workspace: Option<String>,
}
impl EventsQuery {
fn parsed_event_types(&self) -> Option<Vec<EventType>> {
let raw = self.event_types.as_deref()?;
let types: Vec<EventType> = raw
.split(',')
.filter_map(|s| parse_event_type(s.trim()))
.collect();
if types.is_empty() {
None
} else {
Some(types)
}
}
}
fn parse_event_type(s: &str) -> Option<EventType> {
match s {
"memory_created" => Some(EventType::MemoryCreated),
"memory_updated" => Some(EventType::MemoryUpdated),
"memory_deleted" => Some(EventType::MemoryDeleted),
"crossref_created" => Some(EventType::CrossrefCreated),
"crossref_deleted" => Some(EventType::CrossrefDeleted),
"sync_started" => Some(EventType::SyncStarted),
"sync_completed" => Some(EventType::SyncCompleted),
"sync_failed" => Some(EventType::SyncFailed),
_ => None,
}
}
fn event_type_to_str(et: EventType) -> &'static str {
match et {
EventType::MemoryCreated => "memory_created",
EventType::MemoryUpdated => "memory_updated",
EventType::MemoryDeleted => "memory_deleted",
EventType::CrossrefCreated => "crossref_created",
EventType::CrossrefDeleted => "crossref_deleted",
EventType::SyncStarted => "sync_started",
EventType::SyncCompleted => "sync_completed",
EventType::SyncFailed => "sync_failed",
}
}
const SSE_RETRY_MS: u64 = 3000;
fn realtime_event_to_sse(event: &RealtimeEvent) -> Event {
let event_type_str = event_type_to_str(event.event_type);
let data = serde_json::to_string(event).unwrap_or_else(|_| "{}".to_string());
let mut sse = Event::default().event(event_type_str).data(data);
if let Some(id) = event.seq_id {
sse = sse.id(format!("{id}"));
}
sse
}
async fn handle_events(
State(state): State<AppState>,
headers: HeaderMap,
Query(query): Query<EventsQuery>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
if let Some(ref expected) = state.api_key {
if !check_bearer(&headers, expected) {
return Err(StatusCode::UNAUTHORIZED);
}
}
let manager = match state.realtime {
Some(m) => m,
None => return Err(StatusCode::SERVICE_UNAVAILABLE),
};
let last_event_id: Option<u64> = headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let event_type_filter = query.parsed_event_types();
let workspace_filter = query.workspace.clone();
let apply_filters = {
let et_filter = event_type_filter.clone();
let ws_filter = workspace_filter.clone();
move |event: &RealtimeEvent| -> bool {
if let Some(ref types) = et_filter {
if !types.contains(&event.event_type) {
return false;
}
}
if let Some(ref ws) = ws_filter {
let event_ws = event
.data
.as_ref()
.and_then(|d: &serde_json::Value| d.get("workspace"))
.and_then(|v: &serde_json::Value| v.as_str());
match event_ws {
Some(ews) if ews == ws => {}
_ => return false,
}
}
true
}
};
let rx = manager.subscribe();
let broadcast_stream = BroadcastStream::new(rx);
let replay_events: Vec<Result<Event, Infallible>> = if let Some(last_id) = last_event_id {
manager
.get_events_after(last_id)
.into_iter()
.filter(|e| apply_filters(e))
.map(|e| Ok::<Event, Infallible>(realtime_event_to_sse(&e)))
.collect()
} else {
vec![]
};
let replay_stream = tokio_stream::iter(replay_events);
let live_stream = broadcast_stream.filter_map(move |result| {
match result {
Err(_lagged) => None,
Ok(event) => {
if !apply_filters(&event) {
return None;
}
Some(Ok::<Event, Infallible>(realtime_event_to_sse(&event)))
}
}
});
let combined = replay_stream.chain(live_stream);
let retry_event = std::iter::once(Ok::<Event, Infallible>(
Event::default().retry(std::time::Duration::from_millis(SSE_RETRY_MS)),
));
let full_stream = tokio_stream::iter(retry_event).chain(combined);
Ok(Sse::new(full_stream)
.keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(30))))
}
fn check_bearer(headers: &HeaderMap, expected: &str) -> bool {
headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.map(|v| {
v.strip_prefix("Bearer ")
.map(|token| token == expected)
.unwrap_or(false)
})
.unwrap_or(false)
}
pub async fn serve_http(
handler: Arc<dyn McpHandler>,
port: u16,
api_key: Option<String>,
realtime: Option<RealtimeManager>,
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
let state = AppState {
handler,
api_key,
realtime,
};
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/mcp", post(handle_mcp))
.route("/health", get(handle_health))
.route("/v1/events", get(handle_events))
.layer(cors)
.with_state(state);
let addr = format!("0.0.0.0:{port}");
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("HTTP transport listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::realtime::RealtimeEvent;
#[test]
fn test_check_bearer_valid() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer my-secret".parse().unwrap());
assert!(check_bearer(&headers, "my-secret"));
}
#[test]
fn test_check_bearer_invalid_token() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer wrong".parse().unwrap());
assert!(!check_bearer(&headers, "my-secret"));
}
#[test]
fn test_check_bearer_missing_header() {
let headers = HeaderMap::new();
assert!(!check_bearer(&headers, "my-secret"));
}
#[test]
fn test_check_bearer_bad_scheme() {
let mut headers = HeaderMap::new();
headers.insert("authorization", "Basic abc123".parse().unwrap());
assert!(!check_bearer(&headers, "abc123"));
}
#[test]
fn test_sse_event_serialization() {
let event = RealtimeEvent::memory_created(42, "hello world".to_string());
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"type\":\"memory_created\""));
assert!(json.contains("\"memory_id\":42"));
assert_eq!(event_type_to_str(event.event_type), "memory_created");
}
#[test]
fn test_sse_event_type_to_str_all_variants() {
assert_eq!(
event_type_to_str(EventType::MemoryCreated),
"memory_created"
);
assert_eq!(
event_type_to_str(EventType::MemoryUpdated),
"memory_updated"
);
assert_eq!(
event_type_to_str(EventType::MemoryDeleted),
"memory_deleted"
);
assert_eq!(
event_type_to_str(EventType::CrossrefCreated),
"crossref_created"
);
assert_eq!(
event_type_to_str(EventType::CrossrefDeleted),
"crossref_deleted"
);
assert_eq!(event_type_to_str(EventType::SyncStarted), "sync_started");
assert_eq!(
event_type_to_str(EventType::SyncCompleted),
"sync_completed"
);
assert_eq!(event_type_to_str(EventType::SyncFailed), "sync_failed");
}
#[test]
fn test_parse_event_type_known() {
assert_eq!(
parse_event_type("memory_created"),
Some(EventType::MemoryCreated)
);
assert_eq!(parse_event_type("sync_failed"), Some(EventType::SyncFailed));
}
#[test]
fn test_parse_event_type_unknown_is_none() {
assert_eq!(parse_event_type("unknown_event"), None);
assert_eq!(parse_event_type(""), None);
}
#[test]
fn test_events_query_parsed_event_types_none() {
let q = EventsQuery {
event_types: None,
workspace: None,
};
assert!(q.parsed_event_types().is_none());
}
#[test]
fn test_events_query_parsed_event_types_single() {
let q = EventsQuery {
event_types: Some("memory_created".to_string()),
workspace: None,
};
let types = q.parsed_event_types().unwrap();
assert_eq!(types, vec![EventType::MemoryCreated]);
}
#[test]
fn test_events_query_parsed_event_types_multiple() {
let q = EventsQuery {
event_types: Some("memory_created,memory_deleted,sync_failed".to_string()),
workspace: None,
};
let types = q.parsed_event_types().unwrap();
assert_eq!(
types,
vec![
EventType::MemoryCreated,
EventType::MemoryDeleted,
EventType::SyncFailed
]
);
}
#[test]
fn test_events_query_parsed_event_types_with_spaces() {
let q = EventsQuery {
event_types: Some("memory_created, memory_updated".to_string()),
workspace: None,
};
let types = q.parsed_event_types().unwrap();
assert_eq!(
types,
vec![EventType::MemoryCreated, EventType::MemoryUpdated]
);
}
#[test]
fn test_events_query_parsed_event_types_all_unknown_returns_none() {
let q = EventsQuery {
event_types: Some("bogus,fake".to_string()),
workspace: None,
};
assert!(q.parsed_event_types().is_none());
}
#[test]
fn test_event_type_filter_matches() {
use crate::realtime::SubscriptionFilter;
let filter = SubscriptionFilter {
event_types: Some(vec![EventType::MemoryCreated]),
memory_ids: None,
tags: None,
};
let created = RealtimeEvent::memory_created(1, "test".to_string());
let deleted = RealtimeEvent::memory_deleted(1);
assert!(filter.matches(&created));
assert!(!filter.matches(&deleted));
}
#[test]
fn test_auth_rejection_no_header() {
let headers = HeaderMap::new();
assert!(!check_bearer(&headers, "secret-key"));
}
#[test]
fn test_auth_no_key_configured_always_passes() {
let has_key: Option<String> = None;
assert!(has_key.is_none());
}
#[test]
fn test_keep_alive_interval_is_30s() {
let interval = std::time::Duration::from_secs(30);
assert_eq!(interval.as_secs(), 30);
}
#[test]
fn test_last_event_id_header_valid() {
let mut headers = HeaderMap::new();
headers.insert("last-event-id", "42".parse().unwrap());
let parsed: Option<u64> = headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
assert_eq!(parsed, Some(42));
}
#[test]
fn test_last_event_id_header_missing_is_none() {
let headers = HeaderMap::new();
let parsed: Option<u64> = headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
assert!(parsed.is_none());
}
#[test]
fn test_last_event_id_header_non_numeric_is_none() {
let mut headers = HeaderMap::new();
headers.insert("last-event-id", "not-a-number".parse().unwrap());
let parsed: Option<u64> = headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
assert!(parsed.is_none());
}
#[test]
fn test_last_event_id_header_zero() {
let mut headers = HeaderMap::new();
headers.insert("last-event-id", "0".parse().unwrap());
let parsed: Option<u64> = headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
assert_eq!(parsed, Some(0));
}
#[test]
fn test_realtime_event_to_sse_with_seq_id() {
use crate::realtime::RealtimeManager;
let manager = RealtimeManager::new();
let _rx = manager.subscribe();
manager.broadcast(RealtimeEvent::memory_created(1, "hello".to_string()));
let buffered = manager.get_events_after(0);
assert_eq!(buffered.len(), 1);
let event = &buffered[0];
assert_eq!(event.seq_id, Some(1));
let sse = realtime_event_to_sse(event);
let _ = sse; }
#[test]
fn test_realtime_event_to_sse_without_seq_id_no_id_field() {
let event = RealtimeEvent::memory_created(5, "no id".to_string());
assert!(event.seq_id.is_none());
let sse = realtime_event_to_sse(&event);
let _ = sse; }
#[test]
fn test_replay_events_after_last_id() {
use crate::realtime::RealtimeManager;
let manager = RealtimeManager::new();
let _rx = manager.subscribe();
for i in 1..=5i64 {
manager.broadcast(RealtimeEvent::memory_created(i, format!("ev{i}")));
}
let last_id: u64 = 3;
let replayed = manager.get_events_after(last_id);
assert_eq!(replayed.len(), 2);
let ids: Vec<u64> = replayed.iter().filter_map(|e| e.seq_id).collect();
assert_eq!(ids, vec![4, 5]);
}
#[test]
fn test_retry_constant_is_3000ms() {
assert_eq!(SSE_RETRY_MS, 3000);
}
}