1use std::{
9 collections::HashMap,
10 path::PathBuf,
11 time::Duration,
12};
13
14use futures::Stream;
15use sqlx::{
16 sqlite::{SqliteConnectOptions, SqlitePool},
17 ConnectOptions,
18};
19use thiserror::Error;
20use tokio::sync::RwLock;
21use tokio_stream::wrappers::ReceiverStream;
22
23use rmcp::{
24 model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
25 transport::{
26 WorkerTransport,
27 common::server_side_http::{SessionId, ServerSseMessage, session_id},
28 streamable_http_server::session::{
29 SessionManager,
30 local::{
31 LocalSessionWorker, LocalSessionHandle, SessionConfig,
32 SessionError, create_local_session, EventIdParseError,
33 },
34 },
35 },
36};
37
38pub const DEFAULT_SESSION_MAX_AGE: Duration = Duration::from_secs(30 * 24 * 60 * 60);
40
41#[derive(Debug, Clone)]
43pub struct SqliteSessionConfig {
44 pub db_path: PathBuf,
46 pub session_config: SessionConfig,
48 pub max_session_age: Duration,
50}
51
52impl Default for SqliteSessionConfig {
53 fn default() -> Self {
54 Self {
55 db_path: PathBuf::from("mcp_sessions.db"),
56 session_config: SessionConfig::default(),
57 max_session_age: DEFAULT_SESSION_MAX_AGE,
58 }
59 }
60}
61
62#[derive(Debug, Error)]
64pub enum SqliteSessionError {
65 #[error("Session not found: {0}")]
66 SessionNotFound(SessionId),
67 #[error("Session error: {0}")]
68 SessionError(#[from] SessionError),
69 #[error("Invalid event id: {0}")]
70 InvalidEventId(#[from] EventIdParseError),
71 #[error("Database error: {0}")]
72 DatabaseError(String),
73}
74
75pub struct SqliteSessionManager {
80 pool: SqlitePool,
81 sessions: RwLock<HashMap<SessionId, LocalSessionHandle>>,
83 session_config: SessionConfig,
84 max_session_age: Duration,
86}
87
88impl SqliteSessionManager {
89 pub async fn new(config: SqliteSessionConfig) -> Result<Self, SqliteSessionError> {
91 let db_url = format!("sqlite:{}?mode=rwc", config.db_path.display());
92 let connect_options: SqliteConnectOptions = db_url
93 .parse()
94 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to parse DB URL: {}", e)))?;
95 let connect_options = connect_options.disable_statement_logging();
96
97 let pool = SqlitePool::connect_with(connect_options.clone())
98 .await
99 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to connect: {}", e)))?;
100
101 let manager = Self {
102 pool,
103 sessions: RwLock::new(HashMap::new()),
104 session_config: config.session_config,
105 max_session_age: config.max_session_age,
106 };
107
108 manager.run_migrations().await?;
109
110 let cleaned = manager.cleanup_old_sessions().await?;
112 if cleaned > 0 {
113 tracing::info!(count = cleaned, "Cleaned up old MCP sessions");
114 }
115
116 let persisted = manager.count_persisted_sessions().await?;
118 if persisted > 0 {
119 tracing::info!(
120 count = persisted,
121 "Found persisted MCP sessions (clients will need to reconnect)"
122 );
123 }
124
125 Ok(manager)
126 }
127
128 async fn count_persisted_sessions(&self) -> Result<usize, SqliteSessionError> {
130 let row = sqlx::query("SELECT COUNT(*) as count FROM mcp_sessions")
131 .fetch_one(&self.pool)
132 .await
133 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to count sessions: {}", e)))?;
134
135 let count: i64 = sqlx::Row::get(&row, "count");
136 Ok(count as usize)
137 }
138
139
140 pub async fn cleanup_old_sessions(&self) -> Result<usize, SqliteSessionError> {
144 let cutoff = std::time::SystemTime::now()
145 .duration_since(std::time::UNIX_EPOCH)
146 .unwrap()
147 .as_secs() as i64
148 - self.max_session_age.as_secs() as i64;
149
150 let result = sqlx::query("DELETE FROM mcp_sessions WHERE last_seen_at < ?")
151 .bind(cutoff)
152 .execute(&self.pool)
153 .await
154 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to cleanup sessions: {}", e)))?;
155
156 Ok(result.rows_affected() as usize)
157 }
158
159 async fn run_migrations(&self) -> Result<(), SqliteSessionError> {
161 sqlx::query(
162 r#"
163 CREATE TABLE IF NOT EXISTS mcp_sessions (
164 id TEXT PRIMARY KEY,
165 created_at INTEGER NOT NULL,
166 last_seen_at INTEGER NOT NULL
167 );
168
169 CREATE TABLE IF NOT EXISTS mcp_session_cache (
170 id INTEGER PRIMARY KEY AUTOINCREMENT,
171 session_id TEXT NOT NULL,
172 event_id TEXT NOT NULL,
173 message TEXT NOT NULL,
174 created_at INTEGER NOT NULL,
175 FOREIGN KEY (session_id) REFERENCES mcp_sessions(id) ON DELETE CASCADE
176 );
177
178 CREATE INDEX IF NOT EXISTS idx_session_cache_session ON mcp_session_cache(session_id);
179 CREATE INDEX IF NOT EXISTS idx_session_cache_event ON mcp_session_cache(session_id, event_id);
180 "#,
181 )
182 .execute(&self.pool)
183 .await
184 .map_err(|e| SqliteSessionError::DatabaseError(format!("Migration failed: {}", e)))?;
185
186 Ok(())
187 }
188
189 async fn persist_session(&self, id: &SessionId) -> Result<(), SqliteSessionError> {
191 let now = std::time::SystemTime::now()
192 .duration_since(std::time::UNIX_EPOCH)
193 .unwrap()
194 .as_secs() as i64;
195
196 sqlx::query(
197 "INSERT OR REPLACE INTO mcp_sessions (id, created_at, last_seen_at) VALUES (?, ?, ?)",
198 )
199 .bind(id.as_ref())
200 .bind(now)
201 .bind(now)
202 .execute(&self.pool)
203 .await
204 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to persist session: {}", e)))?;
205
206 Ok(())
207 }
208
209 async fn touch_session(&self, id: &SessionId) -> Result<(), SqliteSessionError> {
211 let now = std::time::SystemTime::now()
212 .duration_since(std::time::UNIX_EPOCH)
213 .unwrap()
214 .as_secs() as i64;
215
216 sqlx::query("UPDATE mcp_sessions SET last_seen_at = ? WHERE id = ?")
217 .bind(now)
218 .bind(id.as_ref())
219 .execute(&self.pool)
220 .await
221 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to touch session: {}", e)))?;
222
223 Ok(())
224 }
225
226 async fn session_exists_in_db(&self, id: &SessionId) -> Result<bool, SqliteSessionError> {
228 let row = sqlx::query("SELECT 1 FROM mcp_sessions WHERE id = ?")
229 .bind(id.as_ref())
230 .fetch_optional(&self.pool)
231 .await
232 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to check session: {}", e)))?;
233
234 Ok(row.is_some())
235 }
236
237 async fn remove_session_from_db(&self, id: &SessionId) -> Result<(), SqliteSessionError> {
239 sqlx::query("DELETE FROM mcp_sessions WHERE id = ?")
240 .bind(id.as_ref())
241 .execute(&self.pool)
242 .await
243 .map_err(|e| SqliteSessionError::DatabaseError(format!("Failed to remove session: {}", e)))?;
244
245 Ok(())
246 }
247
248 async fn recreate_session(
250 &self,
251 id: SessionId,
252 ) -> Result<WorkerTransport<LocalSessionWorker>, SqliteSessionError> {
253 let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
254 self.sessions.write().await.insert(id.clone(), handle);
255 self.touch_session(&id).await?;
256 Ok(WorkerTransport::spawn(worker))
257 }
258}
259
260impl SessionManager for SqliteSessionManager {
261 type Error = SqliteSessionError;
262 type Transport = WorkerTransport<LocalSessionWorker>;
263
264 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
265 let id = session_id();
266 let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
267
268 self.persist_session(&id).await?;
270
271 self.sessions.write().await.insert(id.clone(), handle);
273
274 tracing::info!(session_id = ?id, "Created new persistent MCP session");
275 Ok((id, WorkerTransport::spawn(worker)))
276 }
277
278 async fn initialize_session(
279 &self,
280 id: &SessionId,
281 message: ClientJsonRpcMessage,
282 ) -> Result<ServerJsonRpcMessage, Self::Error> {
283 let sessions = self.sessions.read().await;
285 if let Some(handle) = sessions.get(id) {
286 let response = handle.initialize(message).await?;
287 return Ok(response);
288 }
289 drop(sessions);
290
291 if self.session_exists_in_db(id).await? {
293 tracing::info!(session_id = ?id, "Reconnecting to persisted MCP session");
294 return Err(SqliteSessionError::SessionNotFound(id.clone()));
297 }
298
299 Err(SqliteSessionError::SessionNotFound(id.clone()))
300 }
301
302 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
303 if self.sessions.read().await.contains_key(id) {
306 return Ok(true);
307 }
308
309 if self.session_exists_in_db(id).await? {
312 tracing::info!(session_id = ?id, "Removing stale session from DB (no active worker)");
313 self.remove_session_from_db(id).await.ok();
314 }
315
316 Ok(false)
317 }
318
319 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
320 let mut sessions = self.sessions.write().await;
322 if let Some(handle) = sessions.remove(id) {
323 handle.close().await?;
324 }
325
326 self.remove_session_from_db(id).await?;
328
329 tracing::info!(session_id = ?id, "Closed MCP session");
330 Ok(())
331 }
332
333 async fn create_stream(
334 &self,
335 id: &SessionId,
336 message: ClientJsonRpcMessage,
337 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
338 let sessions = self.sessions.read().await;
339 let handle = sessions
340 .get(id)
341 .ok_or(SqliteSessionError::SessionNotFound(id.clone()))?;
342
343 let receiver = handle.establish_request_wise_channel().await?;
344 handle
345 .push_message(message, receiver.http_request_id)
346 .await?;
347
348 self.touch_session(id).await.ok(); Ok(ReceiverStream::new(receiver.inner))
350 }
351
352 async fn create_standalone_stream(
353 &self,
354 id: &SessionId,
355 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
356 let sessions = self.sessions.read().await;
357 let handle = sessions
358 .get(id)
359 .ok_or(SqliteSessionError::SessionNotFound(id.clone()))?;
360
361 let receiver = handle.establish_common_channel().await?;
362 self.touch_session(id).await.ok(); Ok(ReceiverStream::new(receiver.inner))
364 }
365
366 async fn resume(
367 &self,
368 id: &SessionId,
369 last_event_id: String,
370 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
371 {
373 let sessions = self.sessions.read().await;
374 if let Some(handle) = sessions.get(id) {
375 let receiver = handle.resume(last_event_id.parse()?).await?;
376 self.touch_session(id).await.ok();
377 return Ok(ReceiverStream::new(receiver.inner));
378 }
379 }
380
381 if self.session_exists_in_db(id).await? {
383 tracing::info!(session_id = ?id, last_event_id, "Session reconnection attempt - recreating worker");
384 let _transport = self.recreate_session(id.clone()).await?;
386
387 let sessions = self.sessions.read().await;
389 if let Some(handle) = sessions.get(id) {
390 let receiver = handle.resume(last_event_id.parse()?).await?;
391 return Ok(ReceiverStream::new(receiver.inner));
392 }
393 }
394
395 Err(SqliteSessionError::SessionNotFound(id.clone()))
396 }
397
398 async fn accept_message(
399 &self,
400 id: &SessionId,
401 message: ClientJsonRpcMessage,
402 ) -> Result<(), Self::Error> {
403 let sessions = self.sessions.read().await;
404 let handle = sessions
405 .get(id)
406 .ok_or(SqliteSessionError::SessionNotFound(id.clone()))?;
407
408 handle.push_message(message, None).await?;
409 self.touch_session(id).await.ok(); Ok(())
411 }
412}