use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use axum::extract::{Request, State};
use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use solo_core::TenantId;
use tokio::sync::broadcast;
use uuid::Uuid;
use crate::auth::AuthenticatedPrincipal;
pub const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
pub const MCP_SESSION_INACTIVITY_TTL_MS: u64 = 30 * 60 * 1000;
pub const MCP_SESSION_ABSOLUTE_TTL_MS: u64 = 4 * 60 * 60 * 1000;
pub const MCP_SESSION_SWEEP_INTERVAL_SECS: u64 = 60;
pub const MCP_LAST_EVENT_ID_HEADER: &str = "last-event-id";
pub const MCP_SESSION_EVENT_BUFFER_CAPACITY: usize = 256;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SessionId(String);
impl SessionId {
pub fn new() -> Self {
Self(Uuid::now_v7().to_string())
}
pub fn parse(raw: &str) -> Option<Self> {
let s = raw.trim();
if s.is_empty() {
return None;
}
Some(Self(s.to_string()))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for SessionId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum McpEventKind {
Init,
Message,
Progress,
Lagged,
Heartbeat,
}
pub const MCP_STREAM_EVENT_INIT_NAME: &str = "init";
pub const MCP_STREAM_EVENT_MESSAGE_NAME: &str = "message";
pub const MCP_STREAM_EVENT_PROGRESS_NAME: &str = "progress";
pub const MCP_STREAM_EVENT_LAGGED_NAME: &str = "lagged";
pub const MCP_STREAM_EVENT_HEARTBEAT_NAME: &str = "heartbeat";
impl McpEventKind {
pub fn as_str(&self) -> &'static str {
match self {
McpEventKind::Init => MCP_STREAM_EVENT_INIT_NAME,
McpEventKind::Message => MCP_STREAM_EVENT_MESSAGE_NAME,
McpEventKind::Progress => MCP_STREAM_EVENT_PROGRESS_NAME,
McpEventKind::Lagged => MCP_STREAM_EVENT_LAGGED_NAME,
McpEventKind::Heartbeat => MCP_STREAM_EVENT_HEARTBEAT_NAME,
}
}
}
#[derive(Debug, Clone)]
pub struct McpStreamEvent {
pub id: u64,
pub event: McpEventKind,
pub data: serde_json::Value,
}
#[derive(Debug)]
pub struct SessionState {
pub tenant_id: TenantId,
pub principal: Option<AuthenticatedPrincipal>,
pub created_at_ms: i64,
pub last_accessed_at_ms: std::sync::atomic::AtomicI64,
pub event_tx: broadcast::Sender<McpStreamEvent>,
pub next_event_id: AtomicU64,
pub event_replay_buffer: Arc<std::sync::Mutex<VecDeque<McpStreamEvent>>>,
}
impl SessionState {
pub fn new(tenant_id: TenantId, principal: Option<AuthenticatedPrincipal>) -> Self {
let now_ms = now_ms();
let (event_tx, _) = broadcast::channel(MCP_SESSION_EVENT_BUFFER_CAPACITY);
let event_replay_buffer = Arc::new(std::sync::Mutex::new(VecDeque::with_capacity(
MCP_SESSION_EVENT_BUFFER_CAPACITY,
)));
Self {
tenant_id,
principal,
created_at_ms: now_ms,
last_accessed_at_ms: std::sync::atomic::AtomicI64::new(now_ms),
event_tx,
next_event_id: AtomicU64::new(1),
event_replay_buffer,
}
}
fn is_expired(&self, now_ms: i64) -> bool {
let absolute_deadline = self
.created_at_ms
.saturating_add(MCP_SESSION_ABSOLUTE_TTL_MS as i64);
if now_ms >= absolute_deadline {
return true;
}
let last = self
.last_accessed_at_ms
.load(std::sync::atomic::Ordering::Relaxed);
let inactivity_deadline = last.saturating_add(MCP_SESSION_INACTIVITY_TTL_MS as i64);
now_ms >= inactivity_deadline
}
fn touch(&self) {
self.last_accessed_at_ms
.store(now_ms(), std::sync::atomic::Ordering::Relaxed);
}
pub fn publish_event(&self, kind: McpEventKind, data: serde_json::Value) -> u64 {
let id = self
.next_event_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let event = McpStreamEvent {
id,
event: kind,
data,
};
if let Ok(mut buf) = self.event_replay_buffer.lock() {
if buf.len() >= MCP_SESSION_EVENT_BUFFER_CAPACITY {
buf.pop_front();
}
buf.push_back(event.clone());
}
let _ = self.event_tx.send(event);
id
}
pub fn subscribe_events(&self) -> broadcast::Receiver<McpStreamEvent> {
self.event_tx.subscribe()
}
pub fn snapshot_replay_buffer(&self) -> Vec<McpStreamEvent> {
match self.event_replay_buffer.lock() {
Ok(buf) => buf.iter().cloned().collect(),
Err(poisoned) => poisoned.into_inner().iter().cloned().collect(),
}
}
}
#[derive(Clone)]
pub struct SessionStore {
inner: Arc<SessionStoreInner>,
}
struct SessionStoreInner {
sessions: DashMap<SessionId, Arc<SessionState>>,
sweep_task: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl SessionStore {
pub fn new() -> Self {
let inner = Arc::new(SessionStoreInner {
sessions: DashMap::new(),
sweep_task: std::sync::Mutex::new(None),
});
let weak = Arc::downgrade(&inner);
let sweep = tokio::spawn(async move {
let mut tick =
tokio::time::interval(Duration::from_secs(MCP_SESSION_SWEEP_INTERVAL_SECS));
tick.tick().await;
loop {
tick.tick().await;
let Some(inner) = weak.upgrade() else {
return;
};
sweep_once(&inner.sessions);
}
});
*inner.sweep_task.lock().expect("sweep_task mutex poisoned") = Some(sweep);
Self { inner }
}
#[cfg(test)]
pub(crate) fn new_for_tests_no_sweep() -> Self {
let inner = Arc::new(SessionStoreInner {
sessions: DashMap::new(),
sweep_task: std::sync::Mutex::new(None),
});
Self { inner }
}
pub fn insert(&self, state: SessionState) -> SessionId {
let id = SessionId::new();
self.inner.sessions.insert(id.clone(), Arc::new(state));
id
}
pub fn get(&self, id: &SessionId) -> Option<Arc<SessionState>> {
let now = now_ms();
let cloned = self.inner.sessions.get(id).map(|r| r.clone());
let state = cloned?;
if state.is_expired(now) {
self.inner.sessions.remove(id);
return None;
}
state.touch();
Some(state)
}
pub fn delete(&self, id: &SessionId) -> bool {
self.inner.sessions.remove(id).is_some()
}
pub fn len(&self) -> usize {
self.inner.sessions.len()
}
pub fn is_empty(&self) -> bool {
self.inner.sessions.is_empty()
}
pub fn sweep_now(&self) {
sweep_once(&self.inner.sessions);
}
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
impl Drop for SessionStoreInner {
fn drop(&mut self) {
if let Ok(mut guard) = self.sweep_task.lock()
&& let Some(handle) = guard.take()
{
handle.abort();
}
}
}
fn sweep_once(sessions: &DashMap<SessionId, Arc<SessionState>>) {
let now = now_ms();
let expired: Vec<SessionId> = sessions
.iter()
.filter(|entry| entry.value().is_expired(now))
.map(|entry| entry.key().clone())
.collect();
for id in expired {
sessions.remove(&id);
}
}
fn now_ms() -> i64 {
chrono::Utc::now().timestamp_millis()
}
pub const MCP_SESSION_EXPIRED_ERROR: &str = "session_expired";
pub async fn mcp_session_middleware(
State(store): State<SessionStore>,
mut req: Request,
next: Next,
) -> Response {
let header_value = req
.headers()
.get(MCP_SESSION_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
if let Some(raw) = header_value {
let id = match SessionId::parse(&raw) {
Some(id) => id,
None => return session_expired_response(&raw),
};
match store.get(&id) {
Some(state) => {
req.extensions_mut().insert(id);
req.extensions_mut().insert(state);
}
None => return session_expired_response(&raw),
}
}
next.run(req).await
}
fn session_expired_response(presented_id: &str) -> Response {
let body = axum::Json(serde_json::json!({
"error": MCP_SESSION_EXPIRED_ERROR,
"status": 404,
"message": format!(
"Mcp-Session-Id `{presented_id}` is unknown or expired; \
re-initialize via POST /mcp without Mcp-Session-Id"
),
"retry": "re-initialize",
}));
(StatusCode::NOT_FOUND, body).into_response()
}
pub fn set_session_id_header(headers: &mut HeaderMap, id: &SessionId) {
let value =
HeaderValue::from_str(id.as_str()).expect("SessionId is ASCII-safe (UUID) for HeaderValue");
headers.insert(HeaderName::from_static(MCP_SESSION_ID_HEADER), value);
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::Ordering;
fn fake_tenant() -> TenantId {
TenantId::default_tenant()
}
fn fresh_state() -> SessionState {
SessionState::new(fake_tenant(), None)
}
#[test]
fn session_store_insert_returns_unique_id() {
let store = SessionStore::new_for_tests_no_sweep();
let id_a = store.insert(fresh_state());
let id_b = store.insert(fresh_state());
assert_ne!(id_a, id_b, "two inserts must produce distinct ids");
assert_eq!(store.len(), 2);
}
#[test]
fn session_store_get_returns_state_when_present() {
let store = SessionStore::new_for_tests_no_sweep();
let id = store.insert(fresh_state());
let got = store.get(&id);
assert!(got.is_some(), "get must return Some for a just-inserted id");
assert_eq!(got.unwrap().tenant_id, fake_tenant());
}
fn aged_state(
tenant_id: TenantId,
principal: Option<AuthenticatedPrincipal>,
delta_ms: i64,
) -> SessionState {
let now = now_ms();
let shifted = now.saturating_sub(delta_ms);
let mut state = SessionState::new(tenant_id, principal);
state.created_at_ms = shifted;
state.last_accessed_at_ms.store(shifted, Ordering::Relaxed);
state
}
#[test]
fn session_store_get_returns_none_when_expired_by_inactivity() {
let store = SessionStore::new_for_tests_no_sweep();
let stale_delta = MCP_SESSION_INACTIVITY_TTL_MS as i64 + 1;
let stale = Arc::new(aged_state(fake_tenant(), None, stale_delta));
let id = SessionId::new();
store.inner.sessions.insert(id.clone(), stale);
assert!(
store.get(&id).is_none(),
"session inactive past TTL must read as expired"
);
assert!(
store.inner.sessions.get(&id).is_none(),
"expired entry must be removed from the underlying map"
);
}
#[test]
fn session_store_get_returns_none_when_expired_by_absolute_ttl() {
let store = SessionStore::new_for_tests_no_sweep();
let absolute_delta = MCP_SESSION_ABSOLUTE_TTL_MS as i64 + 1;
let state = aged_state(fake_tenant(), None, absolute_delta);
state.last_accessed_at_ms.store(now_ms(), Ordering::Relaxed);
let aged = Arc::new(state);
let id = SessionId::new();
store.inner.sessions.insert(id.clone(), aged);
assert!(
store.get(&id).is_none(),
"session past absolute TTL must read as expired even when recently touched"
);
}
#[test]
fn session_store_get_refreshes_last_accessed_on_hit() {
let store = SessionStore::new_for_tests_no_sweep();
let id = store.insert(fresh_state());
let before = store
.inner
.sessions
.get(&id)
.unwrap()
.last_accessed_at_ms
.load(Ordering::Relaxed);
std::thread::sleep(std::time::Duration::from_millis(5));
let _ = store.get(&id).expect("session must still be present");
let after = store
.inner
.sessions
.get(&id)
.unwrap()
.last_accessed_at_ms
.load(Ordering::Relaxed);
assert!(
after > before,
"get must bump last_accessed_at_ms (before={before}, after={after})"
);
}
#[test]
fn session_store_delete_returns_true_when_present() {
let store = SessionStore::new_for_tests_no_sweep();
let id = store.insert(fresh_state());
assert!(store.delete(&id));
assert!(store.get(&id).is_none(), "deleted session must not read");
}
#[test]
fn session_store_delete_returns_false_when_absent() {
let store = SessionStore::new_for_tests_no_sweep();
assert!(!store.delete(&SessionId::new()));
}
#[test]
fn session_store_sweep_now_removes_expired() {
let store = SessionStore::new_for_tests_no_sweep();
let healthy_id = store.insert(fresh_state());
let stale_delta = MCP_SESSION_INACTIVITY_TTL_MS as i64 + 1;
let stale = Arc::new(aged_state(fake_tenant(), None, stale_delta));
let stale_id = SessionId::new();
store.inner.sessions.insert(stale_id.clone(), stale);
assert_eq!(store.len(), 2);
store.sweep_now();
assert_eq!(store.len(), 1, "sweep must drop the expired session");
assert!(
store.get(&healthy_id).is_some(),
"sweep must preserve the healthy session"
);
assert!(
store.inner.sessions.get(&stale_id).is_none(),
"stale id must be gone from the map after sweep"
);
}
#[tokio::test]
async fn session_store_background_sweep_removes_expired() {
let store = SessionStore::new();
let stale_delta = MCP_SESSION_INACTIVITY_TTL_MS as i64 + 1;
let stale = Arc::new(aged_state(fake_tenant(), None, stale_delta));
let stale_id = SessionId::new();
store.inner.sessions.insert(stale_id.clone(), stale);
store.sweep_now();
assert!(store.inner.sessions.get(&stale_id).is_none());
}
#[test]
fn session_id_round_trips_through_string() {
let id = SessionId::new();
let s = id.as_str().to_string();
let parsed = SessionId::parse(&s).expect("ASCII round-trip");
assert_eq!(id, parsed);
}
#[test]
fn session_id_parse_rejects_empty_string() {
assert!(SessionId::parse("").is_none());
assert!(SessionId::parse(" ").is_none());
}
#[test]
fn session_state_publish_event_returns_monotonic_ids() {
let state = fresh_state();
let id1 = state.publish_event(McpEventKind::Init, serde_json::json!({"connected": true}));
let id2 = state.publish_event(McpEventKind::Message, serde_json::json!({"hello": 1}));
let id3 = state.publish_event(McpEventKind::Progress, serde_json::json!({"progress": 5}));
assert_eq!(
id1, 1,
"first event must allocate id 1 (id 0 reserved for client sentinel)"
);
assert_eq!(id2, 2);
assert_eq!(id3, 3);
}
#[tokio::test]
async fn session_state_publish_event_broadcasts_to_subscribers() {
let state = fresh_state();
let mut rx = state.subscribe_events();
let id = state.publish_event(
McpEventKind::Message,
serde_json::json!({"jsonrpc": "2.0", "method": "notifications/message"}),
);
let received = rx
.recv()
.await
.expect("subscriber must observe the broadcast event");
assert_eq!(received.id, id);
assert_eq!(received.event, McpEventKind::Message);
assert_eq!(received.data["method"], "notifications/message");
}
#[test]
fn session_state_event_buffer_capacity_256() {
let state = fresh_state();
let total = (MCP_SESSION_EVENT_BUFFER_CAPACITY + 50) as u64; for _ in 0..total {
state.publish_event(McpEventKind::Message, serde_json::json!({}));
}
let snapshot = state.snapshot_replay_buffer();
assert_eq!(
snapshot.len(),
MCP_SESSION_EVENT_BUFFER_CAPACITY,
"replay buffer must retain exactly {} entries after overflow",
MCP_SESSION_EVENT_BUFFER_CAPACITY,
);
let expected_first_id = total - MCP_SESSION_EVENT_BUFFER_CAPACITY as u64 + 1;
let expected_last_id = total; assert_eq!(
snapshot.first().unwrap().id,
expected_first_id,
"oldest retained event id must be {expected_first_id}",
);
assert_eq!(
snapshot.last().unwrap().id,
expected_last_id,
"newest retained event id must be {expected_last_id}",
);
for win in snapshot.windows(2) {
assert_eq!(
win[1].id,
win[0].id + 1,
"replay buffer must be contiguous (no gaps)",
);
}
}
#[test]
fn session_state_publish_event_no_subscribers_is_lossless_to_buffer() {
let state = fresh_state();
let id = state.publish_event(McpEventKind::Init, serde_json::json!({"hi": true}));
let snapshot = state.snapshot_replay_buffer();
assert_eq!(snapshot.len(), 1);
assert_eq!(snapshot[0].id, id);
}
}