Skip to main content

mcp_session/
lib.rs

1//! Bounded session management for MCP servers.
2//!
3//! This crate provides [`BoundedSessionManager`], a wrapper around rmcp's
4//! [`LocalSessionManager`] that enforces:
5//!
6//! - **Maximum concurrent sessions** with FIFO eviction of the oldest session
7//!   when the limit is reached.
8//! - **Optional rate limiting** on session creation via a sliding-window
9//!   counter.
10//! - **Idle timeout** via rmcp's `keep_alive` configuration (passed through).
11//!
12//! # Quick start
13//!
14//! ```rust,no_run
15//! use std::sync::Arc;
16//! use mcp_session::BoundedSessionManager;
17//! use mcp_session::SessionConfig;
18//!
19//! let mut session_config = SessionConfig::default();
20//! session_config.keep_alive = Some(std::time::Duration::from_secs(4 * 60 * 60));
21//!
22//! let manager = Arc::new(
23//!     BoundedSessionManager::new(session_config, 100)
24//!         .with_rate_limit(10, std::time::Duration::from_secs(60)),
25//! );
26//!
27//! // Pass `manager` to `StreamableHttpService::new(factory, manager, config)`
28//! ```
29
30#![warn(missing_docs)]
31
32use std::collections::VecDeque;
33use std::time::{Duration, Instant};
34
35use futures_core::Stream;
36use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
37use rmcp::transport::{
38    streamable_http_server::session::{
39        local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
40        ServerSseMessage, SessionManager,
41    },
42    WorkerTransport,
43};
44
45// Re-export types that consumers need so they don't have to depend on rmcp
46// directly for basic session configuration.
47pub use rmcp::transport::streamable_http_server::session::local::SessionConfig;
48pub use rmcp::transport::streamable_http_server::session::SessionId;
49
50// ---------------------------------------------------------------------------
51// Error type
52// ---------------------------------------------------------------------------
53
54/// Errors returned by [`BoundedSessionManager`].
55#[derive(Debug, thiserror::Error)]
56pub enum BoundedSessionError {
57    /// Propagated from the inner [`LocalSessionManager`].
58    #[error(transparent)]
59    Inner(#[from] LocalSessionManagerError),
60    /// Session creation was rejected because the rate limit was exceeded.
61    #[error("session creation rate limit exceeded")]
62    RateLimited,
63}
64
65// ---------------------------------------------------------------------------
66// RateLimiter
67// ---------------------------------------------------------------------------
68
69/// Sliding-window rate limiter for session creation.
70struct RateLimiter {
71    max_creates: usize,
72    window: Duration,
73    tracker: tokio::sync::Mutex<VecDeque<Instant>>,
74}
75
76impl RateLimiter {
77    fn new(max_creates: usize, window: Duration) -> Self {
78        Self {
79            max_creates,
80            window,
81            tracker: tokio::sync::Mutex::new(VecDeque::new()),
82        }
83    }
84
85    /// Reserve a slot. Returns `Err(BoundedSessionError::RateLimited)` if the
86    /// window is full. On success, the caller **must** eventually call
87    /// [`rollback`](Self::rollback) if session creation subsequently fails, to
88    /// return the slot.
89    async fn reserve(&self) -> Result<Instant, BoundedSessionError> {
90        let mut tracker = self.tracker.lock().await;
91        let now = Instant::now();
92        // Prune entries that have fallen outside the window.
93        while tracker
94            .front()
95            .is_some_and(|t| now.duration_since(*t) > self.window)
96        {
97            tracker.pop_front();
98        }
99        if tracker.len() >= self.max_creates {
100            return Err(BoundedSessionError::RateLimited);
101        }
102        tracker.push_back(now);
103        Ok(now)
104    }
105
106    /// Roll back a previously reserved slot (identified by its timestamp) when
107    /// session creation fails after [`reserve`](Self::reserve) succeeds.
108    async fn rollback(&self, reserved_at: Instant) {
109        let mut tracker = self.tracker.lock().await;
110        // Find and remove exactly one entry matching the reserved timestamp.
111        // Under concurrent interleaving, a later reservation may have been
112        // pushed after ours, so the entry is not necessarily at the back.
113        if let Some(pos) = tracker.iter().rposition(|t| *t == reserved_at) {
114            tracker.remove(pos);
115        }
116    }
117}
118
119// ---------------------------------------------------------------------------
120// BoundedSessionManager
121// ---------------------------------------------------------------------------
122
123/// Wraps [`LocalSessionManager`] and limits the number of concurrent sessions.
124///
125/// When the limit is reached, the oldest session (by creation order) is closed
126/// before the new one is created. This prevents unbounded memory growth when
127/// many clients connect without explicitly closing their sessions.
128///
129/// Optionally, a rate limit can be applied to session creation via
130/// [`BoundedSessionManager::with_rate_limit`].
131///
132/// # Concurrency note
133///
134/// Under concurrent session creation, the live count may transiently exceed
135/// `max_sessions` by at most the number of concurrent callers. The limit is
136/// best-effort under contention; use a semaphore if exact enforcement is
137/// required.
138pub struct BoundedSessionManager {
139    inner: LocalSessionManager,
140    max_sessions: usize,
141    /// Tracks session IDs in creation order for FIFO eviction.
142    creation_order: tokio::sync::Mutex<VecDeque<SessionId>>,
143    /// Optional sliding-window rate limiter for session creation.
144    rate_limiter: Option<RateLimiter>,
145}
146
147impl BoundedSessionManager {
148    /// Create a new `BoundedSessionManager`.
149    ///
150    /// * `session_config` — passed through to the inner [`LocalSessionManager`].
151    /// * `max_sessions`   — maximum number of concurrent sessions. When this
152    ///   limit is reached, the oldest session is evicted before creating a new
153    ///   one. Must be at least 1.
154    ///
155    /// # Panics
156    ///
157    /// Panics if `max_sessions` is 0.
158    pub fn new(session_config: SessionConfig, max_sessions: usize) -> Self {
159        assert!(max_sessions >= 1, "max_sessions must be at least 1, got 0");
160        let mut inner = LocalSessionManager::default();
161        inner.session_config = session_config;
162        Self {
163            inner,
164            max_sessions,
165            creation_order: tokio::sync::Mutex::new(VecDeque::new()),
166            rate_limiter: None,
167        }
168    }
169
170    /// Configure a rate limit on session creation.
171    ///
172    /// At most `max_creates` sessions may be created within any rolling
173    /// `window` duration. If exceeded, [`BoundedSessionError::RateLimited`] is
174    /// returned and no eviction is performed.
175    ///
176    /// # Panics
177    ///
178    /// Panics if `max_creates` is 0. Pass no rate limit instead of 0 — a limit
179    /// of zero would silently block all session creation.
180    #[must_use]
181    pub fn with_rate_limit(mut self, max_creates: usize, window: Duration) -> Self {
182        assert!(
183            max_creates >= 1,
184            "max_creates must be at least 1; pass no rate limit instead of 0"
185        );
186        self.rate_limiter = Some(RateLimiter::new(max_creates, window));
187        self
188    }
189}
190
191impl SessionManager for BoundedSessionManager {
192    type Error = BoundedSessionError;
193    type Transport = WorkerTransport<LocalSessionWorker>;
194
195    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
196        // ----------------------------------------------------------------
197        // Critical section 1: rate-limit check.
198        // ----------------------------------------------------------------
199        let rate_reserved_at = if let Some(ref limiter) = self.rate_limiter {
200            Some(limiter.reserve().await?)
201        } else {
202            None
203        };
204
205        // ----------------------------------------------------------------
206        // Determine eviction candidate (short critical section).
207        // ----------------------------------------------------------------
208        let evict_candidate = {
209            let order = self.creation_order.lock().await;
210            // Use the inner sessions map for the authoritative live count so
211            // that expired sessions (which are removed from inner but remain
212            // in the deque) do not consume a capacity slot.
213            let live_count = self.inner.sessions.read().await.len();
214            if live_count >= self.max_sessions {
215                order.front().cloned()
216            } else {
217                None
218            }
219        };
220
221        // ----------------------------------------------------------------
222        // Evict oldest (no lock held across this await).
223        // ----------------------------------------------------------------
224        if let Some(ref oldest) = evict_candidate {
225            // Ignore errors: the session may have already expired.
226            let _ = self.inner.close_session(oldest).await;
227        }
228
229        // ----------------------------------------------------------------
230        // Create new session (no lock held across this await).
231        // ----------------------------------------------------------------
232        let result = self.inner.create_session().await;
233
234        // Roll back the rate-limit slot if creation failed.
235        if result.is_err() {
236            if let (Some(ref limiter), Some(reserved_at)) = (&self.rate_limiter, rate_reserved_at) {
237                limiter.rollback(reserved_at).await;
238            }
239        }
240
241        let (id, transport) = result?;
242
243        // ----------------------------------------------------------------
244        // Critical section 2: update the creation-order deque.
245        // ----------------------------------------------------------------
246        {
247            let mut order = self.creation_order.lock().await;
248            // Remove the evicted entry if it's still present.
249            if let Some(ref oldest) = evict_candidate {
250                order.retain(|s| s != oldest);
251            }
252            // Prune any deque entries for sessions that are no longer live
253            // (handles the drift caused by keep_alive expiry: finding #4).
254            let live_ids: std::collections::HashSet<_> = {
255                // Snapshot the live session IDs without holding two locks
256                // simultaneously (creation_order lock is already held here;
257                // sessions is a RwLock so a read lock is fine).
258                self.inner.sessions.read().await.keys().cloned().collect()
259            };
260            order.retain(|s| live_ids.contains(s));
261            order.push_back(id.clone());
262        }
263
264        Ok((id, transport))
265    }
266
267    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
268        self.inner.close_session(id).await?;
269        let mut order = self.creation_order.lock().await;
270        order.retain(|s| s != id);
271        Ok(())
272    }
273
274    async fn initialize_session(
275        &self,
276        id: &SessionId,
277        message: ClientJsonRpcMessage,
278    ) -> Result<ServerJsonRpcMessage, Self::Error> {
279        self.inner
280            .initialize_session(id, message)
281            .await
282            .map_err(Into::into)
283    }
284
285    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
286        self.inner.has_session(id).await.map_err(Into::into)
287    }
288
289    async fn create_stream(
290        &self,
291        id: &SessionId,
292        message: ClientJsonRpcMessage,
293    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
294        self.inner
295            .create_stream(id, message)
296            .await
297            .map_err(Into::into)
298    }
299
300    async fn accept_message(
301        &self,
302        id: &SessionId,
303        message: ClientJsonRpcMessage,
304    ) -> Result<(), Self::Error> {
305        self.inner
306            .accept_message(id, message)
307            .await
308            .map_err(Into::into)
309    }
310
311    async fn create_standalone_stream(
312        &self,
313        id: &SessionId,
314    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
315        self.inner
316            .create_standalone_stream(id)
317            .await
318            .map_err(Into::into)
319    }
320
321    async fn resume(
322        &self,
323        id: &SessionId,
324        last_event_id: String,
325    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
326        self.inner
327            .resume(id, last_event_id)
328            .await
329            .map_err(Into::into)
330    }
331}