use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct McpSession {
pub session_id: Option<String>,
pub last_activity: Instant,
pub server_url: String,
pub initialized: bool,
}
impl McpSession {
pub fn new(server_url: impl Into<String>) -> Self {
Self {
session_id: None,
last_activity: Instant::now(),
server_url: server_url.into(),
initialized: false,
}
}
pub fn update_session_id(&mut self, session_id: Option<String>) {
if session_id.is_some() {
self.session_id = session_id;
}
self.last_activity = Instant::now();
}
pub fn mark_initialized(&mut self) {
self.initialized = true;
self.last_activity = Instant::now();
}
pub fn is_stale(&self, max_idle_secs: u64) -> bool {
self.last_activity.elapsed().as_secs() > max_idle_secs
}
pub fn touch(&mut self) {
self.last_activity = Instant::now();
}
}
pub struct McpSessionManager {
sessions: RwLock<HashMap<String, McpSession>>,
max_idle_secs: u64,
}
impl McpSessionManager {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
max_idle_secs: 1800, }
}
pub fn with_idle_timeout(max_idle_secs: u64) -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
max_idle_secs,
}
}
pub async fn get_or_create(&self, server_name: &str, server_url: &str) -> McpSession {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get(server_name) {
if session.is_stale(self.max_idle_secs) {
let new_session = McpSession::new(server_url);
sessions.insert(server_name.to_string(), new_session.clone());
return new_session;
}
return session.clone();
}
let session = McpSession::new(server_url);
sessions.insert(server_name.to_string(), session.clone());
session
}
pub async fn get_session_id(&self, server_name: &str) -> Option<String> {
let sessions = self.sessions.read().await;
sessions.get(server_name).and_then(|s| s.session_id.clone())
}
pub async fn update_session_id(&self, server_name: &str, session_id: Option<String>) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(server_name) {
session.update_session_id(session_id);
}
}
pub async fn mark_initialized(&self, server_name: &str) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(server_name) {
session.mark_initialized();
}
}
pub async fn is_initialized(&self, server_name: &str) -> bool {
let sessions = self.sessions.read().await;
sessions
.get(server_name)
.map(|s| s.initialized)
.unwrap_or(false)
}
pub async fn touch(&self, server_name: &str) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(server_name) {
session.touch();
}
}
pub async fn terminate(&self, server_name: &str) {
let mut sessions = self.sessions.write().await;
sessions.remove(server_name);
}
pub async fn active_servers(&self) -> Vec<String> {
let sessions = self.sessions.read().await;
sessions.keys().cloned().collect()
}
pub async fn cleanup_stale(&self) -> usize {
let mut sessions = self.sessions.write().await;
let before_len = sessions.len();
sessions.retain(|_, session| !session.is_stale(self.max_idle_secs));
before_len - sessions.len()
}
}
impl Default for McpSessionManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_creation() {
let session = McpSession::new("https://mcp.example.com");
assert!(session.session_id.is_none());
assert!(!session.initialized);
assert_eq!(session.server_url, "https://mcp.example.com");
}
#[test]
fn test_session_update() {
let mut session = McpSession::new("https://mcp.example.com");
session.update_session_id(Some("session-123".to_string()));
assert_eq!(session.session_id, Some("session-123".to_string()));
session.mark_initialized();
assert!(session.initialized);
}
#[test]
fn test_session_staleness() {
let mut session = McpSession::new("https://mcp.example.com");
assert!(!session.is_stale(1800));
session.last_activity = std::time::Instant::now() - std::time::Duration::from_secs(10);
assert!(session.is_stale(5));
assert!(!session.is_stale(15));
}
#[tokio::test]
async fn test_session_manager_get_or_create() {
let manager = McpSessionManager::new();
let session1 = manager
.get_or_create("notion", "https://mcp.notion.com")
.await;
assert!(session1.session_id.is_none());
manager
.update_session_id("notion", Some("session-abc".to_string()))
.await;
let session2 = manager
.get_or_create("notion", "https://mcp.notion.com")
.await;
assert_eq!(session2.session_id, Some("session-abc".to_string()));
}
#[tokio::test]
async fn test_session_manager_terminate() {
let manager = McpSessionManager::new();
manager
.get_or_create("notion", "https://mcp.notion.com")
.await;
manager
.update_session_id("notion", Some("session-123".to_string()))
.await;
manager.terminate("notion").await;
let session = manager
.get_or_create("notion", "https://mcp.notion.com")
.await;
assert!(session.session_id.is_none());
}
#[tokio::test]
async fn test_session_manager_initialization() {
let manager = McpSessionManager::new();
manager
.get_or_create("notion", "https://mcp.notion.com")
.await;
assert!(!manager.is_initialized("notion").await);
manager.mark_initialized("notion").await;
assert!(manager.is_initialized("notion").await);
}
#[tokio::test]
async fn test_active_servers() {
let manager = McpSessionManager::new();
manager
.get_or_create("notion", "https://mcp.notion.com")
.await;
manager
.get_or_create("github", "https://mcp.github.com")
.await;
let servers = manager.active_servers().await;
assert_eq!(servers.len(), 2);
assert!(servers.contains(&"notion".to_string()));
assert!(servers.contains(&"github".to_string()));
}
}