Skip to main content

plexus_substrate/
mcp_session.rs

1//! SQLite-backed MCP session manager for persistent sessions across restarts
2//!
3//! This module provides a SessionManager implementation that persists session
4//! state to SQLite, allowing clients to reconnect after server restarts.
5//!
6//! Sessions older than 30 days (configurable) are automatically cleaned up on startup.
7
8use std::{
9    collections::HashMap,
10    path::PathBuf,
11    time::Duration,
12};
13
14use futures::Stream;
15use sqlx::{
16    sqlite::{SqliteConnectOptions, SqlitePool},
17    ConnectOptions,
18};
19use thiserror::Error;
20use tokio::sync::RwLock;
21use tokio_stream::wrappers::ReceiverStream;
22
23use rmcp::{
24    model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
25    transport::{
26        WorkerTransport,
27        common::server_side_http::{SessionId, ServerSseMessage, session_id},
28        streamable_http_server::session::{
29            SessionManager,
30            local::{
31                LocalSessionWorker, LocalSessionHandle, SessionConfig,
32                SessionError, create_local_session, EventIdParseError,
33            },
34        },
35    },
36};
37
38/// Default session cleanup age: 30 days
39pub const DEFAULT_SESSION_MAX_AGE: Duration = Duration::from_secs(30 * 24 * 60 * 60);
40
41/// Configuration for SQLite session storage
42#[derive(Debug, Clone)]
43pub struct SqliteSessionConfig {
44    /// Path to SQLite database
45    pub db_path: PathBuf,
46    /// Session worker configuration
47    pub session_config: SessionConfig,
48    /// Maximum age for sessions before cleanup (default: 30 days)
49    pub max_session_age: Duration,
50}
51
52impl Default for SqliteSessionConfig {
53    fn default() -> Self {
54        Self {
55            db_path: PathBuf::from("mcp_sessions.db"),
56            session_config: SessionConfig::default(),
57            max_session_age: DEFAULT_SESSION_MAX_AGE,
58        }
59    }
60}
61
62/// Error types for SQLite session manager
63#[derive(Debug, Error)]
64pub enum SqliteSessionError {
65    #[error("Session not found: {0}")]
66    SessionNotFound(SessionId),
67    #[error("Session error: {0}")]
68    SessionError(#[from] SessionError),
69    #[error("Invalid event id: {0}")]
70    InvalidEventId(#[from] EventIdParseError),
71    #[error("Database error: {0}")]
72    DatabaseError(String),
73}
74
75/// SQLite-backed session manager
76///
77/// Persists session IDs to SQLite so clients can reconnect after server restart.
78/// The actual session workers are created on-demand, but session identity persists.
79pub struct SqliteSessionManager {
80    pool: SqlitePool,
81    /// In-memory session handles (runtime state)
82    sessions: RwLock<HashMap<SessionId, LocalSessionHandle>>,
83    session_config: SessionConfig,
84    /// Maximum age for sessions before cleanup
85    max_session_age: Duration,
86}
87
88impl SqliteSessionManager {
89    /// Create a new SQLite session manager
90    pub async fn new(config: SqliteSessionConfig) -> Result<Self, SqliteSessionError> {
91        let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
92        let connect_options: SqliteConnectOptions = db_url
93            .parse()
94            .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to parse DB URL: {}", e)))?;
95        let connect_options = connect_options.disable_statement_logging();
96
97        let pool = SqlitePool::connect_with(connect_options.clone())
98            .await
99            .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to connect: {}", e)))?;
100
101        let manager = Self {
102            pool,
103            sessions: RwLock::new(HashMap::new()),
104            session_config: config.session_config,
105            max_session_age: config.max_session_age,
106        };
107
108        manager.run_migrations().await?;
109
110        // Clean up old sessions on startup
111        let cleaned = manager.cleanup_old_sessions().await?;
112        if cleaned > 0 {
113            tracing::info!(count = cleaned, "Cleaned up old MCP sessions");
114        }
115
116        // Log persisted sessions (for debugging)
117        let persisted = manager.count_persisted_sessions().await?;
118        if persisted > 0 {
119            tracing::info!(
120                count = persisted,
121                "Found persisted MCP sessions (clients will need to reconnect)"
122            );
123        }
124
125        Ok(manager)
126    }
127
128    /// Count persisted sessions in database
129    async fn count_persisted_sessions(&self) -> Result<usize, SqliteSessionError> {
130        let row = sqlx::query("SELECT COUNT(*) as count FROM mcp_sessions")
131            .fetch_one(&self.pool)
132            .await
133            .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to count sessions: {}", e)))?;
134
135        let count: i64 = sqlx::Row::get(&row, "count");
136        Ok(count as usize)
137    }
138
139
140    /// Clean up sessions older than max_session_age
141    ///
142    /// Returns the number of sessions cleaned up
143    pub async fn cleanup_old_sessions(&self) -> Result<usize, SqliteSessionError> {
144        let cutoff = std::time::SystemTime::now()
145            .duration_since(std::time::UNIX_EPOCH)
146            .unwrap()
147            .as_secs() as i64
148            - self.max_session_age.as_secs() as i64;
149
150        let result = sqlx::query("DELETE FROM mcp_sessions WHERE last_seen_at < ?")
151            .bind(cutoff)
152            .execute(&self.pool)
153            .await
154            .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to cleanup sessions: {}", e)))?;
155
156        Ok(result.rows_affected() as usize)
157    }
158
159    /// Run database migrations
160    async fn run_migrations(&self) -> Result<(), SqliteSessionError> {
161        sqlx::query(
162            r#"
163            CREATE TABLE IF NOT EXISTS mcp_sessions (
164                id TEXT PRIMARY KEY,
165                created_at INTEGER NOT NULL,
166                last_seen_at INTEGER NOT NULL
167            );
168
169            CREATE TABLE IF NOT EXISTS mcp_session_cache (
170                id INTEGER PRIMARY KEY AUTOINCREMENT,
171                session_id TEXT NOT NULL,
172                event_id TEXT NOT NULL,
173                message TEXT NOT NULL,
174                created_at INTEGER NOT NULL,
175                FOREIGN KEY (session_id) REFERENCES mcp_sessions(id) ON DELETE CASCADE
176            );
177
178            CREATE INDEX IF NOT EXISTS idx_session_cache_session ON mcp_session_cache(session_id);
179            CREATE INDEX IF NOT EXISTS idx_session_cache_event ON mcp_session_cache(session_id, event_id);
180            "#,
181        )
182        .execute(&self.pool)
183        .await
184        .map_err(|e| SqliteSessionError::DatabaseError(format!("Migration failed: {}", e)))?;
185
186        Ok(())
187    }
188
189    /// Record a session in the database
190    async fn persist_session(&self, id: &SessionId) -> Result<(), SqliteSessionError> {
191        let now = std::time::SystemTime::now()
192            .duration_since(std::time::UNIX_EPOCH)
193            .unwrap()
194            .as_secs() as i64;
195
196        sqlx::query(
197            "INSERT OR REPLACE INTO mcp_sessions (id, created_at, last_seen_at) VALUES (?, ?, ?)",
198        )
199        .bind(id.as_ref())
200        .bind(now)
201        .bind(now)
202        .execute(&self.pool)
203        .await
204        .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to persist session: {}", e)))?;
205
206        Ok(())
207    }
208
209    /// Update last seen timestamp
210    async fn touch_session(&self, id: &SessionId) -> Result<(), SqliteSessionError> {
211        let now = std::time::SystemTime::now()
212            .duration_since(std::time::UNIX_EPOCH)
213            .unwrap()
214            .as_secs() as i64;
215
216        sqlx::query("UPDATE mcp_sessions SET last_seen_at = ? WHERE id = ?")
217            .bind(now)
218            .bind(id.as_ref())
219            .execute(&self.pool)
220            .await
221            .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to touch session: {}", e)))?;
222
223        Ok(())
224    }
225
226    /// Check if a session exists in the database
227    async fn session_exists_in_db(&self, id: &SessionId) -> Result<bool, SqliteSessionError> {
228        let row = sqlx::query("SELECT 1 FROM mcp_sessions WHERE id = ?")
229            .bind(id.as_ref())
230            .fetch_optional(&self.pool)
231            .await
232            .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to check session: {}", e)))?;
233
234        Ok(row.is_some())
235    }
236
237    /// Remove a session from the database
238    async fn remove_session_from_db(&self, id: &SessionId) -> Result<(), SqliteSessionError> {
239        sqlx::query("DELETE FROM mcp_sessions WHERE id = ?")
240            .bind(id.as_ref())
241            .execute(&self.pool)
242            .await
243            .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to remove session: {}", e)))?;
244
245        Ok(())
246    }
247
248    /// Recreate a session worker for a known session ID (for reconnection after restart)
249    async fn recreate_session(
250        &self,
251        id: SessionId,
252    ) -> Result<WorkerTransport<LocalSessionWorker>, SqliteSessionError> {
253        let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
254        self.sessions.write().await.insert(id.clone(), handle);
255        self.touch_session(&id).await?;
256        Ok(WorkerTransport::spawn(worker))
257    }
258}
259
260impl SessionManager for SqliteSessionManager {
261    type Error = SqliteSessionError;
262    type Transport = WorkerTransport<LocalSessionWorker>;
263
264    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
265        let id = session_id();
266        let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
267
268        // Persist to database
269        self.persist_session(&id).await?;
270
271        // Store in memory
272        self.sessions.write().await.insert(id.clone(), handle);
273
274        tracing::info!(session_id = ?id, "Created new persistent MCP session");
275        Ok((id, WorkerTransport::spawn(worker)))
276    }
277
278    async fn initialize_session(
279        &self,
280        id: &SessionId,
281        message: ClientJsonRpcMessage,
282    ) -> Result<ServerJsonRpcMessage, Self::Error> {
283        // Check if session exists in memory
284        let sessions = self.sessions.read().await;
285        if let Some(handle) = sessions.get(id) {
286            let response = handle.initialize(message).await?;
287            return Ok(response);
288        }
289        drop(sessions);
290
291        // Check if session exists in database (reconnection case)
292        if self.session_exists_in_db(id).await? {
293            tracing::info!(session_id = ?id, "Reconnecting to persisted MCP session");
294            // Note: For reconnection, the transport would need to be re-established
295            // This is a limitation - we can track the session but can't fully restore state
296            return Err(SqliteSessionError::SessionNotFound(id.clone()));
297        }
298
299        Err(SqliteSessionError::SessionNotFound(id.clone()))
300    }
301
302    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
303        // Only return true if the session worker is active in memory
304        // Workers can't be restored without handler connection (rmcp limitation)
305        if self.sessions.read().await.contains_key(id) {
306            return Ok(true);
307        }
308
309        // Session in DB but no active worker - remove stale entry and return false
310        // Client will get 401 and should reconnect with fresh session
311        if self.session_exists_in_db(id).await? {
312            tracing::info!(session_id = ?id, "Removing stale session from DB (no active worker)");
313            self.remove_session_from_db(id).await.ok();
314        }
315
316        Ok(false)
317    }
318
319    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
320        // Remove from memory
321        let mut sessions = self.sessions.write().await;
322        if let Some(handle) = sessions.remove(id) {
323            handle.close().await?;
324        }
325
326        // Remove from database
327        self.remove_session_from_db(id).await?;
328
329        tracing::info!(session_id = ?id, "Closed MCP session");
330        Ok(())
331    }
332
333    async fn create_stream(
334        &self,
335        id: &SessionId,
336        message: ClientJsonRpcMessage,
337    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
338        let sessions = self.sessions.read().await;
339        let handle = sessions
340            .get(id)
341            .ok_or(SqliteSessionError::SessionNotFound(id.clone()))?;
342
343        let receiver = handle.establish_request_wise_channel().await?;
344        handle
345            .push_message(message, receiver.http_request_id)
346            .await?;
347
348        self.touch_session(id).await.ok(); // Best effort
349        Ok(ReceiverStream::new(receiver.inner))
350    }
351
352    async fn create_standalone_stream(
353        &self,
354        id: &SessionId,
355    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
356        let sessions = self.sessions.read().await;
357        let handle = sessions
358            .get(id)
359            .ok_or(SqliteSessionError::SessionNotFound(id.clone()))?;
360
361        let receiver = handle.establish_common_channel().await?;
362        self.touch_session(id).await.ok(); // Best effort
363        Ok(ReceiverStream::new(receiver.inner))
364    }
365
366    async fn resume(
367        &self,
368        id: &SessionId,
369        last_event_id: String,
370    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
371        // Check memory first
372        {
373            let sessions = self.sessions.read().await;
374            if let Some(handle) = sessions.get(id) {
375                let receiver = handle.resume(last_event_id.parse()?).await?;
376                self.touch_session(id).await.ok();
377                return Ok(ReceiverStream::new(receiver.inner));
378            }
379        }
380
381        // Check if this is a reconnection after restart
382        if self.session_exists_in_db(id).await? {
383            tracing::info!(session_id = ?id, last_event_id, "Session reconnection attempt - recreating worker");
384            // Recreate the session worker
385            let _transport = self.recreate_session(id.clone()).await?;
386
387            // Now try to get the handle and resume
388            let sessions = self.sessions.read().await;
389            if let Some(handle) = sessions.get(id) {
390                let receiver = handle.resume(last_event_id.parse()?).await?;
391                return Ok(ReceiverStream::new(receiver.inner));
392            }
393        }
394
395        Err(SqliteSessionError::SessionNotFound(id.clone()))
396    }
397
398    async fn accept_message(
399        &self,
400        id: &SessionId,
401        message: ClientJsonRpcMessage,
402    ) -> Result<(), Self::Error> {
403        let sessions = self.sessions.read().await;
404        let handle = sessions
405            .get(id)
406            .ok_or(SqliteSessionError::SessionNotFound(id.clone()))?;
407
408        handle.push_message(message, None).await?;
409        self.touch_session(id).await.ok(); // Best effort
410        Ok(())
411    }
412}