mcp_session/lib.rs
1//! Bounded session management for MCP servers.
2//!
3//! This crate provides [`BoundedSessionManager`], a wrapper around rmcp's
4//! [`LocalSessionManager`] that enforces:
5//!
6//! - **Maximum concurrent sessions** with FIFO eviction of the oldest session
7//! when the limit is reached.
8//! - **Optional rate limiting** on session creation via a sliding-window
9//! counter.
10//! - **Idle timeout** via rmcp's `keep_alive` configuration (passed through).
11//!
12//! # Quick start
13//!
14//! ```rust,no_run
15//! use std::sync::Arc;
16//! use mcp_session::BoundedSessionManager;
17//! use mcp_session::SessionConfig;
18//!
19//! let mut session_config = SessionConfig::default();
20//! session_config.keep_alive = Some(std::time::Duration::from_secs(4 * 60 * 60));
21//!
22//! let manager = Arc::new(
23//! BoundedSessionManager::new(session_config, 100)
24//! .with_rate_limit(10, std::time::Duration::from_secs(60)),
25//! );
26//!
27//! // Pass `manager` to `StreamableHttpService::new(factory, manager, config)`
28//! ```
29
30#![warn(missing_docs)]
31
32use std::collections::VecDeque;
33use std::time::{Duration, Instant};
34
35use futures_core::Stream;
36use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
37use rmcp::transport::{
38 streamable_http_server::session::{
39 local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
40 ServerSseMessage, SessionManager,
41 },
42 WorkerTransport,
43};
44
45// Re-export types that consumers need so they don't have to depend on rmcp
46// directly for basic session configuration.
47pub use rmcp::transport::streamable_http_server::session::local::SessionConfig;
48pub use rmcp::transport::streamable_http_server::session::SessionId;
49
50// ---------------------------------------------------------------------------
51// Error type
52// ---------------------------------------------------------------------------
53
54/// Errors returned by [`BoundedSessionManager`].
55#[derive(Debug, thiserror::Error)]
56pub enum BoundedSessionError {
57 /// Propagated from the inner [`LocalSessionManager`].
58 #[error(transparent)]
59 Inner(#[from] LocalSessionManagerError),
60 /// Session creation was rejected because the rate limit was exceeded.
61 #[error("session creation rate limit exceeded")]
62 RateLimited,
63}
64
65// ---------------------------------------------------------------------------
66// RateLimiter
67// ---------------------------------------------------------------------------
68
69/// Sliding-window rate limiter for session creation.
70struct RateLimiter {
71 max_creates: usize,
72 window: Duration,
73 tracker: tokio::sync::Mutex<VecDeque<Instant>>,
74}
75
76impl RateLimiter {
77 fn new(max_creates: usize, window: Duration) -> Self {
78 Self {
79 max_creates,
80 window,
81 tracker: tokio::sync::Mutex::new(VecDeque::new()),
82 }
83 }
84
85 /// Reserve a slot. Returns `Err(BoundedSessionError::RateLimited)` if the
86 /// window is full. On success, the caller **must** eventually call
87 /// [`rollback`](Self::rollback) if session creation subsequently fails, to
88 /// return the slot.
89 async fn reserve(&self) -> Result<Instant, BoundedSessionError> {
90 let mut tracker = self.tracker.lock().await;
91 let now = Instant::now();
92 // Prune entries that have fallen outside the window.
93 while tracker
94 .front()
95 .is_some_and(|t| now.duration_since(*t) > self.window)
96 {
97 tracker.pop_front();
98 }
99 if tracker.len() >= self.max_creates {
100 return Err(BoundedSessionError::RateLimited);
101 }
102 tracker.push_back(now);
103 Ok(now)
104 }
105
106 /// Roll back a previously reserved slot (identified by its timestamp) when
107 /// session creation fails after [`reserve`](Self::reserve) succeeds.
108 async fn rollback(&self, reserved_at: Instant) {
109 let mut tracker = self.tracker.lock().await;
110 // Find and remove exactly one entry matching the reserved timestamp.
111 // Under concurrent interleaving, a later reservation may have been
112 // pushed after ours, so the entry is not necessarily at the back.
113 if let Some(pos) = tracker.iter().rposition(|t| *t == reserved_at) {
114 tracker.remove(pos);
115 }
116 }
117}
118
119// ---------------------------------------------------------------------------
120// BoundedSessionManager
121// ---------------------------------------------------------------------------
122
123/// Wraps [`LocalSessionManager`] and limits the number of concurrent sessions.
124///
125/// When the limit is reached, the oldest session (by creation order) is closed
126/// before the new one is created. This prevents unbounded memory growth when
127/// many clients connect without explicitly closing their sessions.
128///
129/// Optionally, a rate limit can be applied to session creation via
130/// [`BoundedSessionManager::with_rate_limit`].
131///
132/// # Concurrency note
133///
134/// Under concurrent session creation, the live count may transiently exceed
135/// `max_sessions` by at most the number of concurrent callers. The limit is
136/// best-effort under contention; use a semaphore if exact enforcement is
137/// required.
138pub struct BoundedSessionManager {
139 inner: LocalSessionManager,
140 max_sessions: usize,
141 /// Tracks session IDs in creation order for FIFO eviction.
142 creation_order: tokio::sync::Mutex<VecDeque<SessionId>>,
143 /// Optional sliding-window rate limiter for session creation.
144 rate_limiter: Option<RateLimiter>,
145}
146
147impl BoundedSessionManager {
148 /// Create a new `BoundedSessionManager`.
149 ///
150 /// * `session_config` — passed through to the inner [`LocalSessionManager`].
151 /// * `max_sessions` — maximum number of concurrent sessions. When this
152 /// limit is reached, the oldest session is evicted before creating a new
153 /// one. Must be at least 1.
154 ///
155 /// # Panics
156 ///
157 /// Panics if `max_sessions` is 0.
158 pub fn new(session_config: SessionConfig, max_sessions: usize) -> Self {
159 assert!(max_sessions >= 1, "max_sessions must be at least 1, got 0");
160 let mut inner = LocalSessionManager::default();
161 inner.session_config = session_config;
162 Self {
163 inner,
164 max_sessions,
165 creation_order: tokio::sync::Mutex::new(VecDeque::new()),
166 rate_limiter: None,
167 }
168 }
169
170 /// Configure a rate limit on session creation.
171 ///
172 /// At most `max_creates` sessions may be created within any rolling
173 /// `window` duration. If exceeded, [`BoundedSessionError::RateLimited`] is
174 /// returned and no eviction is performed.
175 ///
176 /// # Panics
177 ///
178 /// Panics if `max_creates` is 0. Pass no rate limit instead of 0 — a limit
179 /// of zero would silently block all session creation.
180 #[must_use]
181 pub fn with_rate_limit(mut self, max_creates: usize, window: Duration) -> Self {
182 assert!(
183 max_creates >= 1,
184 "max_creates must be at least 1; pass no rate limit instead of 0"
185 );
186 self.rate_limiter = Some(RateLimiter::new(max_creates, window));
187 self
188 }
189}
190
191impl SessionManager for BoundedSessionManager {
192 type Error = BoundedSessionError;
193 type Transport = WorkerTransport<LocalSessionWorker>;
194
195 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
196 // ----------------------------------------------------------------
197 // Critical section 1: rate-limit check.
198 // ----------------------------------------------------------------
199 let rate_reserved_at = if let Some(ref limiter) = self.rate_limiter {
200 Some(limiter.reserve().await?)
201 } else {
202 None
203 };
204
205 // ----------------------------------------------------------------
206 // Determine eviction candidate (short critical section).
207 // ----------------------------------------------------------------
208 let evict_candidate = {
209 let order = self.creation_order.lock().await;
210 // Use the inner sessions map for the authoritative live count so
211 // that expired sessions (which are removed from inner but remain
212 // in the deque) do not consume a capacity slot.
213 let live_count = self.inner.sessions.read().await.len();
214 if live_count >= self.max_sessions {
215 order.front().cloned()
216 } else {
217 None
218 }
219 };
220
221 // ----------------------------------------------------------------
222 // Evict oldest (no lock held across this await).
223 // ----------------------------------------------------------------
224 if let Some(ref oldest) = evict_candidate {
225 // Ignore errors: the session may have already expired.
226 let _ = self.inner.close_session(oldest).await;
227 }
228
229 // ----------------------------------------------------------------
230 // Create new session (no lock held across this await).
231 // ----------------------------------------------------------------
232 let result = self.inner.create_session().await;
233
234 // Roll back the rate-limit slot if creation failed.
235 if result.is_err() {
236 if let (Some(ref limiter), Some(reserved_at)) = (&self.rate_limiter, rate_reserved_at) {
237 limiter.rollback(reserved_at).await;
238 }
239 }
240
241 let (id, transport) = result?;
242
243 // ----------------------------------------------------------------
244 // Critical section 2: update the creation-order deque.
245 // ----------------------------------------------------------------
246 {
247 let mut order = self.creation_order.lock().await;
248 // Remove the evicted entry if it's still present.
249 if let Some(ref oldest) = evict_candidate {
250 order.retain(|s| s != oldest);
251 }
252 // Prune any deque entries for sessions that are no longer live
253 // (handles the drift caused by keep_alive expiry: finding #4).
254 let live_ids: std::collections::HashSet<_> = {
255 // Snapshot the live session IDs without holding two locks
256 // simultaneously (creation_order lock is already held here;
257 // sessions is a RwLock so a read lock is fine).
258 self.inner.sessions.read().await.keys().cloned().collect()
259 };
260 order.retain(|s| live_ids.contains(s));
261 order.push_back(id.clone());
262 }
263
264 Ok((id, transport))
265 }
266
267 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
268 self.inner.close_session(id).await?;
269 let mut order = self.creation_order.lock().await;
270 order.retain(|s| s != id);
271 Ok(())
272 }
273
274 async fn initialize_session(
275 &self,
276 id: &SessionId,
277 message: ClientJsonRpcMessage,
278 ) -> Result<ServerJsonRpcMessage, Self::Error> {
279 self.inner
280 .initialize_session(id, message)
281 .await
282 .map_err(Into::into)
283 }
284
285 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
286 self.inner.has_session(id).await.map_err(Into::into)
287 }
288
289 async fn create_stream(
290 &self,
291 id: &SessionId,
292 message: ClientJsonRpcMessage,
293 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
294 self.inner
295 .create_stream(id, message)
296 .await
297 .map_err(Into::into)
298 }
299
300 async fn accept_message(
301 &self,
302 id: &SessionId,
303 message: ClientJsonRpcMessage,
304 ) -> Result<(), Self::Error> {
305 self.inner
306 .accept_message(id, message)
307 .await
308 .map_err(Into::into)
309 }
310
311 async fn create_standalone_stream(
312 &self,
313 id: &SessionId,
314 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
315 self.inner
316 .create_standalone_stream(id)
317 .await
318 .map_err(Into::into)
319 }
320
321 async fn resume(
322 &self,
323 id: &SessionId,
324 last_event_id: String,
325 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
326 self.inner
327 .resume(id, last_event_id)
328 .await
329 .map_err(Into::into)
330 }
331}