mcp_session/lib.rs
1//! Bounded session management for MCP servers.
2//!
3//! This crate provides [`BoundedSessionManager`], a wrapper around rmcp's
4//! [`LocalSessionManager`] that enforces:
5//!
6//! - **Maximum concurrent sessions** with FIFO eviction of the oldest session
7//! when the limit is reached.
8//! - **Optional rate limiting** on session creation via a sliding-window
9//! counter.
10//! - **Idle timeout** via rmcp's `keep_alive` configuration (passed through).
11//!
12//! # Quick start
13//!
14//! ```rust,ignore
15//! use std::sync::Arc;
16//! use mcp_session::BoundedSessionManager;
17//! use mcp_session::SessionConfig;
18//!
19//! let manager = Arc::new(
20//! BoundedSessionManager::new(
21//! SessionConfig {
22//! keep_alive: Some(std::time::Duration::from_secs(4 * 60 * 60)),
23//! ..Default::default()
24//! },
25//! 100, // max concurrent sessions
26//! )
27//! .with_rate_limit(10, std::time::Duration::from_secs(60)),
28//! );
29//!
30//! // Pass `manager` to `StreamableHttpService::new(factory, manager, config)`
31//! ```
32
33#![warn(missing_docs)]
34
35use std::collections::VecDeque;
36use std::time::{Duration, Instant};
37
38use futures_core::Stream;
39use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
40use rmcp::transport::{
41 streamable_http_server::session::{
42 local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
43 ServerSseMessage, SessionManager,
44 },
45 WorkerTransport,
46};
47
48// Re-export types that consumers need so they don't have to depend on rmcp
49// directly for basic session configuration.
50pub use rmcp::transport::streamable_http_server::session::local::SessionConfig;
51pub use rmcp::transport::streamable_http_server::session::SessionId;
52
53// ---------------------------------------------------------------------------
54// Error type
55// ---------------------------------------------------------------------------
56
57/// Errors returned by [`BoundedSessionManager`].
58#[derive(Debug, thiserror::Error)]
59pub enum BoundedSessionError {
60 /// Propagated from the inner [`LocalSessionManager`].
61 #[error(transparent)]
62 Inner(#[from] LocalSessionManagerError),
63 /// Session creation was rejected because the rate limit was exceeded.
64 #[error("session creation rate limit exceeded")]
65 RateLimited,
66}
67
68// ---------------------------------------------------------------------------
69// RateLimiter
70// ---------------------------------------------------------------------------
71
72/// Sliding-window rate limiter for session creation.
73struct RateLimiter {
74 max_creates: usize,
75 window: Duration,
76 tracker: tokio::sync::Mutex<VecDeque<Instant>>,
77}
78
79impl RateLimiter {
80 fn new(max_creates: usize, window: Duration) -> Self {
81 Self {
82 max_creates,
83 window,
84 tracker: tokio::sync::Mutex::new(VecDeque::new()),
85 }
86 }
87
88 /// Reserve a slot. Returns `Err(BoundedSessionError::RateLimited)` if the
89 /// window is full. On success, the caller **must** eventually call
90 /// [`rollback`](Self::rollback) if session creation subsequently fails, to
91 /// return the slot.
92 async fn reserve(&self) -> Result<Instant, BoundedSessionError> {
93 let mut tracker = self.tracker.lock().await;
94 let now = Instant::now();
95 // Prune entries that have fallen outside the window.
96 while tracker
97 .front()
98 .is_some_and(|t| now.duration_since(*t) > self.window)
99 {
100 tracker.pop_front();
101 }
102 if tracker.len() >= self.max_creates {
103 return Err(BoundedSessionError::RateLimited);
104 }
105 tracker.push_back(now);
106 Ok(now)
107 }
108
109 /// Roll back a previously reserved slot (identified by its timestamp) when
110 /// session creation fails after [`reserve`](Self::reserve) succeeds.
111 async fn rollback(&self, reserved_at: Instant) {
112 let mut tracker = self.tracker.lock().await;
113 // Find and remove exactly one entry matching the reserved timestamp.
114 // Under concurrent interleaving, a later reservation may have been
115 // pushed after ours, so the entry is not necessarily at the back.
116 if let Some(pos) = tracker.iter().rposition(|t| *t == reserved_at) {
117 tracker.remove(pos);
118 }
119 }
120}
121
122// ---------------------------------------------------------------------------
123// BoundedSessionManager
124// ---------------------------------------------------------------------------
125
126/// Wraps [`LocalSessionManager`] and limits the number of concurrent sessions.
127///
128/// When the limit is reached, the oldest session (by creation order) is closed
129/// before the new one is created. This prevents unbounded memory growth when
130/// many clients connect without explicitly closing their sessions.
131///
132/// Optionally, a rate limit can be applied to session creation via
133/// [`BoundedSessionManager::with_rate_limit`].
134///
135/// # Concurrency note
136///
137/// Under concurrent session creation, the live count may transiently exceed
138/// `max_sessions` by at most the number of concurrent callers. The limit is
139/// best-effort under contention; use a semaphore if exact enforcement is
140/// required.
141pub struct BoundedSessionManager {
142 inner: LocalSessionManager,
143 max_sessions: usize,
144 /// Tracks session IDs in creation order for FIFO eviction.
145 creation_order: tokio::sync::Mutex<VecDeque<SessionId>>,
146 /// Optional sliding-window rate limiter for session creation.
147 rate_limiter: Option<RateLimiter>,
148}
149
150impl BoundedSessionManager {
151 /// Create a new `BoundedSessionManager`.
152 ///
153 /// * `session_config` — passed through to the inner [`LocalSessionManager`].
154 /// * `max_sessions` — maximum number of concurrent sessions. When this
155 /// limit is reached, the oldest session is evicted before creating a new
156 /// one. Must be at least 1.
157 ///
158 /// # Panics
159 ///
160 /// Panics if `max_sessions` is 0.
161 pub fn new(session_config: SessionConfig, max_sessions: usize) -> Self {
162 assert!(max_sessions >= 1, "max_sessions must be at least 1, got 0");
163 Self {
164 inner: LocalSessionManager {
165 session_config,
166 ..Default::default()
167 },
168 max_sessions,
169 creation_order: tokio::sync::Mutex::new(VecDeque::new()),
170 rate_limiter: None,
171 }
172 }
173
174 /// Configure a rate limit on session creation.
175 ///
176 /// At most `max_creates` sessions may be created within any rolling
177 /// `window` duration. If exceeded, [`BoundedSessionError::RateLimited`] is
178 /// returned and no eviction is performed.
179 ///
180 /// # Panics
181 ///
182 /// Panics if `max_creates` is 0. Pass no rate limit instead of 0 — a limit
183 /// of zero would silently block all session creation.
184 #[must_use]
185 pub fn with_rate_limit(mut self, max_creates: usize, window: Duration) -> Self {
186 assert!(
187 max_creates >= 1,
188 "max_creates must be at least 1; pass no rate limit instead of 0"
189 );
190 self.rate_limiter = Some(RateLimiter::new(max_creates, window));
191 self
192 }
193}
194
195impl SessionManager for BoundedSessionManager {
196 type Error = BoundedSessionError;
197 type Transport = WorkerTransport<LocalSessionWorker>;
198
199 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
200 // ----------------------------------------------------------------
201 // Critical section 1: rate-limit check.
202 // ----------------------------------------------------------------
203 let rate_reserved_at = if let Some(ref limiter) = self.rate_limiter {
204 Some(limiter.reserve().await?)
205 } else {
206 None
207 };
208
209 // ----------------------------------------------------------------
210 // Determine eviction candidate (short critical section).
211 // ----------------------------------------------------------------
212 let evict_candidate = {
213 let order = self.creation_order.lock().await;
214 // Use the inner sessions map for the authoritative live count so
215 // that expired sessions (which are removed from inner but remain
216 // in the deque) do not consume a capacity slot.
217 let live_count = self.inner.sessions.read().await.len();
218 if live_count >= self.max_sessions {
219 order.front().cloned()
220 } else {
221 None
222 }
223 };
224
225 // ----------------------------------------------------------------
226 // Evict oldest (no lock held across this await).
227 // ----------------------------------------------------------------
228 if let Some(ref oldest) = evict_candidate {
229 // Ignore errors: the session may have already expired.
230 let _ = self.inner.close_session(oldest).await;
231 }
232
233 // ----------------------------------------------------------------
234 // Create new session (no lock held across this await).
235 // ----------------------------------------------------------------
236 let result = self.inner.create_session().await;
237
238 // Roll back the rate-limit slot if creation failed.
239 if result.is_err() {
240 if let (Some(ref limiter), Some(reserved_at)) = (&self.rate_limiter, rate_reserved_at) {
241 limiter.rollback(reserved_at).await;
242 }
243 }
244
245 let (id, transport) = result?;
246
247 // ----------------------------------------------------------------
248 // Critical section 2: update the creation-order deque.
249 // ----------------------------------------------------------------
250 {
251 let mut order = self.creation_order.lock().await;
252 // Remove the evicted entry if it's still present.
253 if let Some(ref oldest) = evict_candidate {
254 order.retain(|s| s != oldest);
255 }
256 // Prune any deque entries for sessions that are no longer live
257 // (handles the drift caused by keep_alive expiry: finding #4).
258 let live_ids: std::collections::HashSet<_> = {
259 // Snapshot the live session IDs without holding two locks
260 // simultaneously (creation_order lock is already held here;
261 // sessions is a RwLock so a read lock is fine).
262 self.inner.sessions.read().await.keys().cloned().collect()
263 };
264 order.retain(|s| live_ids.contains(s));
265 order.push_back(id.clone());
266 }
267
268 Ok((id, transport))
269 }
270
271 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
272 self.inner.close_session(id).await?;
273 let mut order = self.creation_order.lock().await;
274 order.retain(|s| s != id);
275 Ok(())
276 }
277
278 async fn initialize_session(
279 &self,
280 id: &SessionId,
281 message: ClientJsonRpcMessage,
282 ) -> Result<ServerJsonRpcMessage, Self::Error> {
283 self.inner
284 .initialize_session(id, message)
285 .await
286 .map_err(Into::into)
287 }
288
289 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
290 self.inner.has_session(id).await.map_err(Into::into)
291 }
292
293 async fn create_stream(
294 &self,
295 id: &SessionId,
296 message: ClientJsonRpcMessage,
297 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
298 self.inner
299 .create_stream(id, message)
300 .await
301 .map_err(Into::into)
302 }
303
304 async fn accept_message(
305 &self,
306 id: &SessionId,
307 message: ClientJsonRpcMessage,
308 ) -> Result<(), Self::Error> {
309 self.inner
310 .accept_message(id, message)
311 .await
312 .map_err(Into::into)
313 }
314
315 async fn create_standalone_stream(
316 &self,
317 id: &SessionId,
318 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
319 self.inner
320 .create_standalone_stream(id)
321 .await
322 .map_err(Into::into)
323 }
324
325 async fn resume(
326 &self,
327 id: &SessionId,
328 last_event_id: String,
329 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
330 self.inner
331 .resume(id, last_event_id)
332 .await
333 .map_err(Into::into)
334 }
335}