Skip to main content

mcp_session/
lib.rs

1//! Bounded session management for MCP servers.
2//!
3//! This crate provides [`BoundedSessionManager`], a wrapper around rmcp's
4//! [`LocalSessionManager`] that enforces:
5//!
6//! - **Maximum concurrent sessions** with FIFO eviction of the oldest session
7//!   when the limit is reached.
8//! - **Optional rate limiting** on session creation via a sliding-window
9//!   counter.
10//! - **Idle timeout** via rmcp's `keep_alive` configuration (passed through).
11//!
12//! # Quick start
13//!
14//! ```rust,ignore
15//! use std::sync::Arc;
16//! use mcp_session::BoundedSessionManager;
17//! use mcp_session::SessionConfig;
18//!
19//! let manager = Arc::new(
20//!     BoundedSessionManager::new(
21//!         SessionConfig {
22//!             keep_alive: Some(std::time::Duration::from_secs(4 * 60 * 60)),
23//!             ..Default::default()
24//!         },
25//!         100, // max concurrent sessions
26//!     )
27//!     .with_rate_limit(10, std::time::Duration::from_secs(60)),
28//! );
29//!
30//! // Pass `manager` to `StreamableHttpService::new(factory, manager, config)`
31//! ```
32
33#![warn(missing_docs)]
34
35use std::collections::VecDeque;
36use std::time::{Duration, Instant};
37
38use futures_core::Stream;
39use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
40use rmcp::transport::{
41    streamable_http_server::session::{
42        local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
43        ServerSseMessage, SessionManager,
44    },
45    WorkerTransport,
46};
47
48// Re-export types that consumers need so they don't have to depend on rmcp
49// directly for basic session configuration.
50pub use rmcp::transport::streamable_http_server::session::local::SessionConfig;
51pub use rmcp::transport::streamable_http_server::session::SessionId;
52
53// ---------------------------------------------------------------------------
54// Error type
55// ---------------------------------------------------------------------------
56
57/// Errors returned by [`BoundedSessionManager`].
58#[derive(Debug, thiserror::Error)]
59pub enum BoundedSessionError {
60    /// Propagated from the inner [`LocalSessionManager`].
61    #[error(transparent)]
62    Inner(#[from] LocalSessionManagerError),
63    /// Session creation was rejected because the rate limit was exceeded.
64    #[error("session creation rate limit exceeded")]
65    RateLimited,
66}
67
68// ---------------------------------------------------------------------------
69// RateLimiter
70// ---------------------------------------------------------------------------
71
72/// Sliding-window rate limiter for session creation.
73struct RateLimiter {
74    max_creates: usize,
75    window: Duration,
76    tracker: tokio::sync::Mutex<VecDeque<Instant>>,
77}
78
79impl RateLimiter {
80    fn new(max_creates: usize, window: Duration) -> Self {
81        Self {
82            max_creates,
83            window,
84            tracker: tokio::sync::Mutex::new(VecDeque::new()),
85        }
86    }
87
88    /// Reserve a slot. Returns `Err(BoundedSessionError::RateLimited)` if the
89    /// window is full. On success, the caller **must** eventually call
90    /// [`rollback`](Self::rollback) if session creation subsequently fails, to
91    /// return the slot.
92    async fn reserve(&self) -> Result<Instant, BoundedSessionError> {
93        let mut tracker = self.tracker.lock().await;
94        let now = Instant::now();
95        // Prune entries that have fallen outside the window.
96        while tracker
97            .front()
98            .is_some_and(|t| now.duration_since(*t) > self.window)
99        {
100            tracker.pop_front();
101        }
102        if tracker.len() >= self.max_creates {
103            return Err(BoundedSessionError::RateLimited);
104        }
105        tracker.push_back(now);
106        Ok(now)
107    }
108
109    /// Roll back a previously reserved slot (identified by its timestamp) when
110    /// session creation fails after [`reserve`](Self::reserve) succeeds.
111    async fn rollback(&self, reserved_at: Instant) {
112        let mut tracker = self.tracker.lock().await;
113        // Find and remove exactly one entry matching the reserved timestamp.
114        // Under concurrent interleaving, a later reservation may have been
115        // pushed after ours, so the entry is not necessarily at the back.
116        if let Some(pos) = tracker.iter().rposition(|t| *t == reserved_at) {
117            tracker.remove(pos);
118        }
119    }
120}
121
122// ---------------------------------------------------------------------------
123// BoundedSessionManager
124// ---------------------------------------------------------------------------
125
126/// Wraps [`LocalSessionManager`] and limits the number of concurrent sessions.
127///
128/// When the limit is reached, the oldest session (by creation order) is closed
129/// before the new one is created. This prevents unbounded memory growth when
130/// many clients connect without explicitly closing their sessions.
131///
132/// Optionally, a rate limit can be applied to session creation via
133/// [`BoundedSessionManager::with_rate_limit`].
134///
135/// # Concurrency note
136///
137/// Under concurrent session creation, the live count may transiently exceed
138/// `max_sessions` by at most the number of concurrent callers. The limit is
139/// best-effort under contention; use a semaphore if exact enforcement is
140/// required.
141pub struct BoundedSessionManager {
142    inner: LocalSessionManager,
143    max_sessions: usize,
144    /// Tracks session IDs in creation order for FIFO eviction.
145    creation_order: tokio::sync::Mutex<VecDeque<SessionId>>,
146    /// Optional sliding-window rate limiter for session creation.
147    rate_limiter: Option<RateLimiter>,
148}
149
150impl BoundedSessionManager {
151    /// Create a new `BoundedSessionManager`.
152    ///
153    /// * `session_config` — passed through to the inner [`LocalSessionManager`].
154    /// * `max_sessions`   — maximum number of concurrent sessions. When this
155    ///   limit is reached, the oldest session is evicted before creating a new
156    ///   one. Must be at least 1.
157    ///
158    /// # Panics
159    ///
160    /// Panics if `max_sessions` is 0.
161    pub fn new(session_config: SessionConfig, max_sessions: usize) -> Self {
162        assert!(max_sessions >= 1, "max_sessions must be at least 1, got 0");
163        Self {
164            inner: LocalSessionManager {
165                session_config,
166                ..Default::default()
167            },
168            max_sessions,
169            creation_order: tokio::sync::Mutex::new(VecDeque::new()),
170            rate_limiter: None,
171        }
172    }
173
174    /// Configure a rate limit on session creation.
175    ///
176    /// At most `max_creates` sessions may be created within any rolling
177    /// `window` duration. If exceeded, [`BoundedSessionError::RateLimited`] is
178    /// returned and no eviction is performed.
179    ///
180    /// # Panics
181    ///
182    /// Panics if `max_creates` is 0. Pass no rate limit instead of 0 — a limit
183    /// of zero would silently block all session creation.
184    #[must_use]
185    pub fn with_rate_limit(mut self, max_creates: usize, window: Duration) -> Self {
186        assert!(
187            max_creates >= 1,
188            "max_creates must be at least 1; pass no rate limit instead of 0"
189        );
190        self.rate_limiter = Some(RateLimiter::new(max_creates, window));
191        self
192    }
193}
194
195impl SessionManager for BoundedSessionManager {
196    type Error = BoundedSessionError;
197    type Transport = WorkerTransport<LocalSessionWorker>;
198
199    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
200        // ----------------------------------------------------------------
201        // Critical section 1: rate-limit check.
202        // ----------------------------------------------------------------
203        let rate_reserved_at = if let Some(ref limiter) = self.rate_limiter {
204            Some(limiter.reserve().await?)
205        } else {
206            None
207        };
208
209        // ----------------------------------------------------------------
210        // Determine eviction candidate (short critical section).
211        // ----------------------------------------------------------------
212        let evict_candidate = {
213            let order = self.creation_order.lock().await;
214            // Use the inner sessions map for the authoritative live count so
215            // that expired sessions (which are removed from inner but remain
216            // in the deque) do not consume a capacity slot.
217            let live_count = self.inner.sessions.read().await.len();
218            if live_count >= self.max_sessions {
219                order.front().cloned()
220            } else {
221                None
222            }
223        };
224
225        // ----------------------------------------------------------------
226        // Evict oldest (no lock held across this await).
227        // ----------------------------------------------------------------
228        if let Some(ref oldest) = evict_candidate {
229            // Ignore errors: the session may have already expired.
230            let _ = self.inner.close_session(oldest).await;
231        }
232
233        // ----------------------------------------------------------------
234        // Create new session (no lock held across this await).
235        // ----------------------------------------------------------------
236        let result = self.inner.create_session().await;
237
238        // Roll back the rate-limit slot if creation failed.
239        if result.is_err() {
240            if let (Some(ref limiter), Some(reserved_at)) = (&self.rate_limiter, rate_reserved_at) {
241                limiter.rollback(reserved_at).await;
242            }
243        }
244
245        let (id, transport) = result?;
246
247        // ----------------------------------------------------------------
248        // Critical section 2: update the creation-order deque.
249        // ----------------------------------------------------------------
250        {
251            let mut order = self.creation_order.lock().await;
252            // Remove the evicted entry if it's still present.
253            if let Some(ref oldest) = evict_candidate {
254                order.retain(|s| s != oldest);
255            }
256            // Prune any deque entries for sessions that are no longer live
257            // (handles the drift caused by keep_alive expiry: finding #4).
258            let live_ids: std::collections::HashSet<_> = {
259                // Snapshot the live session IDs without holding two locks
260                // simultaneously (creation_order lock is already held here;
261                // sessions is a RwLock so a read lock is fine).
262                self.inner.sessions.read().await.keys().cloned().collect()
263            };
264            order.retain(|s| live_ids.contains(s));
265            order.push_back(id.clone());
266        }
267
268        Ok((id, transport))
269    }
270
271    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
272        self.inner.close_session(id).await?;
273        let mut order = self.creation_order.lock().await;
274        order.retain(|s| s != id);
275        Ok(())
276    }
277
278    async fn initialize_session(
279        &self,
280        id: &SessionId,
281        message: ClientJsonRpcMessage,
282    ) -> Result<ServerJsonRpcMessage, Self::Error> {
283        self.inner
284            .initialize_session(id, message)
285            .await
286            .map_err(Into::into)
287    }
288
289    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
290        self.inner.has_session(id).await.map_err(Into::into)
291    }
292
293    async fn create_stream(
294        &self,
295        id: &SessionId,
296        message: ClientJsonRpcMessage,
297    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
298        self.inner
299            .create_stream(id, message)
300            .await
301            .map_err(Into::into)
302    }
303
304    async fn accept_message(
305        &self,
306        id: &SessionId,
307        message: ClientJsonRpcMessage,
308    ) -> Result<(), Self::Error> {
309        self.inner
310            .accept_message(id, message)
311            .await
312            .map_err(Into::into)
313    }
314
315    async fn create_standalone_stream(
316        &self,
317        id: &SessionId,
318    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
319        self.inner
320            .create_standalone_stream(id)
321            .await
322            .map_err(Into::into)
323    }
324
325    async fn resume(
326        &self,
327        id: &SessionId,
328        last_event_id: String,
329    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + Sync + 'static, Self::Error> {
330        self.inner
331            .resume(id, last_event_id)
332            .await
333            .map_err(Into::into)
334    }
335}