use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
pub const MCP_PROTOCOL_VERSION_HEADER: &str = "MCP-Protocol-Version";
pub const MCP_SESSION_ID_HEADER: &str = "MCP-Session-Id";
pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["2024-11-05", "2025-03-26", "2025-11-25"];
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct McpSession {
pub session_id: String,
pub protocol_version: String,
#[serde(skip)]
pub created_at: Option<Instant>,
#[serde(skip)]
pub last_active: Option<Instant>,
}
pub struct SessionStore {
sessions: RwLock<HashMap<String, McpSession>>,
timeout: Duration,
}
impl SessionStore {
#[must_use]
pub fn new(timeout: Duration) -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
timeout,
}
}
#[must_use]
pub fn create(&self, protocol_version: String) -> String {
let session_id = uuid::Uuid::new_v4().to_string();
let now = Instant::now();
let session = McpSession {
session_id: session_id.clone(),
protocol_version,
created_at: Some(now),
last_active: Some(now),
};
info!(session_id = %session_id, "MCP session created");
self.sessions
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(session_id.clone(), session);
session_id
}
pub fn validate(&self, session_id: &str) -> Option<McpSession> {
let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
if let Some(session) = sessions.get_mut(session_id) {
session.last_active = Some(Instant::now());
Some(session.clone())
} else {
debug!(session_id = %session_id, "Unknown session ID");
None
}
}
pub fn remove(&self, session_id: &str) -> bool {
let removed = self
.sessions
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(session_id)
.is_some();
if removed {
info!(session_id = %session_id, "MCP session removed");
}
removed
}
pub fn prune_expired(&self) -> usize {
let now = Instant::now();
let timeout = self.timeout;
let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
let before = sessions.len();
sessions.retain(|id, s| {
let alive = s
.last_active
.map(|la| now.duration_since(la) < timeout)
.unwrap_or(false);
if !alive {
warn!(session_id = %id, "Pruning expired MCP session");
}
alive
});
before - sessions.len()
}
#[must_use]
pub fn active_count(&self) -> usize {
self.sessions
.read()
.unwrap_or_else(|e| e.into_inner())
.len()
}
}
impl Default for SessionStore {
fn default() -> Self {
Self::new(Duration::from_secs(3600)) }
}
pub fn validate_protocol_version(version: &str) -> Result<&str, String> {
if SUPPORTED_PROTOCOL_VERSIONS.contains(&version) {
Ok(version)
} else {
Err(format!(
"unsupported MCP protocol version: {version}. Supported: {}",
SUPPORTED_PROTOCOL_VERSIONS.join(", ")
))
}
}
pub fn validate_origin(origin: &str, allowed_origins: &[String]) -> Result<(), String> {
if origin.is_empty() {
return Ok(());
}
if allowed_origins.iter().any(|o| o == "*") {
return Ok(());
}
if allowed_origins.is_empty() {
return Err("no origins allowed (strict mode)".into());
}
if allowed_origins.iter().any(|o| o == origin) {
Ok(())
} else {
Err(format!("origin not allowed: {origin}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_and_validate_session() {
let store = SessionStore::default();
let id = store.create("2025-11-25".into());
assert_eq!(store.active_count(), 1);
let session = store.validate(&id);
assert!(session.is_some());
assert_eq!(session.unwrap().protocol_version, "2025-11-25");
}
#[test]
fn validate_unknown_session_returns_none() {
let store = SessionStore::default();
assert!(store.validate("nonexistent").is_none());
}
#[test]
fn remove_session() {
let store = SessionStore::default();
let id = store.create("2025-11-25".into());
assert!(store.remove(&id));
assert_eq!(store.active_count(), 0);
assert!(!store.remove(&id)); }
#[test]
fn prune_expired() {
let store = SessionStore::new(Duration::from_millis(1));
let _ = store.create("2025-11-25".into());
std::thread::sleep(Duration::from_millis(10));
let pruned = store.prune_expired();
assert_eq!(pruned, 1);
assert_eq!(store.active_count(), 0);
}
#[test]
fn validate_protocol_version_supported() {
assert!(validate_protocol_version("2025-11-25").is_ok());
assert!(validate_protocol_version("2025-03-26").is_ok());
assert!(validate_protocol_version("2024-11-05").is_ok());
}
#[test]
fn validate_protocol_version_unsupported() {
assert!(validate_protocol_version("1999-01-01").is_err());
assert!(validate_protocol_version("").is_err());
}
#[test]
fn origin_empty_allowed() {
assert!(validate_origin("", &[]).is_ok());
}
#[test]
fn origin_wildcard_allows_any() {
assert!(validate_origin("http://evil.com", &["*".into()]).is_ok());
}
#[test]
fn origin_strict_rejects_all() {
assert!(validate_origin("http://localhost", &[]).is_err());
}
#[test]
fn origin_matched() {
let allowed = vec!["http://localhost:8090".into()];
assert!(validate_origin("http://localhost:8090", &allowed).is_ok());
}
#[test]
fn origin_not_matched() {
let allowed = vec!["http://localhost:8090".into()];
assert!(validate_origin("http://evil.com", &allowed).is_err());
}
#[test]
fn origin_multiple_allowed() {
let allowed = vec![
"http://localhost:8090".into(),
"http://localhost:3000".into(),
];
assert!(validate_origin("http://localhost:3000", &allowed).is_ok());
assert!(validate_origin("http://other.com", &allowed).is_err());
}
#[test]
fn session_store_survives_poisoned_lock() {
use std::sync::Arc;
let store = Arc::new(SessionStore::default());
let store2 = Arc::clone(&store);
let _ = std::thread::spawn(move || {
let _guard = store2.sessions.write().unwrap();
panic!("intentional panic to poison lock");
})
.join();
let id = store.create("2025-11-25".into());
assert_eq!(store.active_count(), 1);
assert!(store.validate(&id).is_some());
assert!(store.remove(&id));
assert_eq!(store.active_count(), 0);
}
}