Skip to main content

tower_mcp/
session.rs

1//! MCP session state management
2//!
3//! Tracks the lifecycle state of an MCP connection as per the specification.
4//! The session progresses through phases: Uninitialized -> Initializing -> Initialized.
5//!
6//! Sessions also support type-safe extensions for storing arbitrary data like
7//! authentication claims, user roles, or other session-scoped state.
8
9use std::sync::Arc;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicU8, Ordering};
12
13use crate::router::Extensions;
14
15/// Session lifecycle phase
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18pub enum SessionPhase {
19    /// Initial state - only `initialize` and `ping` requests are valid
20    Uninitialized = 0,
21    /// Server has responded to `initialize`, waiting for `initialized` notification
22    Initializing = 1,
23    /// `initialized` notification received, normal operation
24    Initialized = 2,
25}
26
27impl From<u8> for SessionPhase {
28    fn from(value: u8) -> Self {
29        match value {
30            0 => SessionPhase::Uninitialized,
31            1 => SessionPhase::Initializing,
32            2 => SessionPhase::Initialized,
33            _ => SessionPhase::Uninitialized,
34        }
35    }
36}
37
38/// Shared session state that can be cloned across requests.
39///
40/// Uses atomic operations for thread-safe state transitions. Includes a type-safe
41/// extensions map for storing session-scoped data like authentication claims.
42///
43/// # Example
44///
45/// ```rust
46/// use tower_mcp::SessionState;
47///
48/// #[derive(Debug, Clone)]
49/// struct UserClaims {
50///     user_id: String,
51///     role: String,
52/// }
53///
54/// let session = SessionState::new();
55///
56/// // Store auth claims in the session
57/// session.insert(UserClaims {
58///     user_id: "user123".to_string(),
59///     role: "admin".to_string(),
60/// });
61///
62/// // Retrieve claims later
63/// if let Some(claims) = session.get::<UserClaims>() {
64///     assert_eq!(claims.role, "admin");
65/// }
66/// ```
67#[derive(Clone)]
68pub struct SessionState {
69    phase: Arc<AtomicU8>,
70    extensions: Arc<RwLock<Extensions>>,
71}
72
73impl Default for SessionState {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl SessionState {
80    /// Create a new session in the Uninitialized phase
81    pub fn new() -> Self {
82        Self {
83            phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
84            extensions: Arc::new(RwLock::new(Extensions::new())),
85        }
86    }
87
88    /// Insert a value into the session extensions.
89    ///
90    /// This is typically used by auth middleware to store claims that can
91    /// be checked by capability filters.
92    ///
93    /// # Example
94    ///
95    /// ```rust
96    /// use tower_mcp::SessionState;
97    ///
98    /// let session = SessionState::new();
99    /// session.insert(42u32);
100    /// assert_eq!(session.get::<u32>(), Some(42));
101    /// ```
102    pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) {
103        if let Ok(mut ext) = self.extensions.write() {
104            ext.insert(val);
105        }
106    }
107
108    /// Get a cloned value from the session extensions.
109    ///
110    /// Returns `None` if no value of the given type has been inserted or if
111    /// the lock cannot be acquired.
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use tower_mcp::SessionState;
117    ///
118    /// let session = SessionState::new();
119    /// session.insert("hello".to_string());
120    /// assert_eq!(session.get::<String>(), Some("hello".to_string()));
121    /// assert_eq!(session.get::<u32>(), None);
122    /// ```
123    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
124        self.extensions
125            .read()
126            .ok()
127            .and_then(|ext| ext.get::<T>().cloned())
128    }
129
130    /// Get the current session phase
131    pub fn phase(&self) -> SessionPhase {
132        SessionPhase::from(self.phase.load(Ordering::Acquire))
133    }
134
135    /// Check if the session is initialized (operation phase)
136    pub fn is_initialized(&self) -> bool {
137        self.phase() == SessionPhase::Initialized
138    }
139
140    /// Transition from Uninitialized to Initializing.
141    /// Called after responding to an `initialize` request.
142    /// Returns true if the transition was successful.
143    pub fn mark_initializing(&self) -> bool {
144        self.phase
145            .compare_exchange(
146                SessionPhase::Uninitialized as u8,
147                SessionPhase::Initializing as u8,
148                Ordering::AcqRel,
149                Ordering::Acquire,
150            )
151            .is_ok()
152    }
153
154    /// Transition from Initializing to Initialized.
155    /// Called when receiving an `initialized` notification.
156    /// Returns true if the transition was successful.
157    pub fn mark_initialized(&self) -> bool {
158        self.phase
159            .compare_exchange(
160                SessionPhase::Initializing as u8,
161                SessionPhase::Initialized as u8,
162                Ordering::AcqRel,
163                Ordering::Acquire,
164            )
165            .is_ok()
166    }
167
168    /// Check if a request method is allowed in the current phase.
169    /// Per spec:
170    /// - Before initialization: only `initialize` and `ping` are valid
171    /// - During all phases: `ping` is always valid
172    pub fn is_request_allowed(&self, method: &str) -> bool {
173        match self.phase() {
174            SessionPhase::Uninitialized => {
175                matches!(method, "initialize" | "ping")
176            }
177            SessionPhase::Initializing | SessionPhase::Initialized => true,
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_session_lifecycle() {
188        let session = SessionState::new();
189
190        // Initial state
191        assert_eq!(session.phase(), SessionPhase::Uninitialized);
192        assert!(!session.is_initialized());
193
194        // Only initialize and ping allowed
195        assert!(session.is_request_allowed("initialize"));
196        assert!(session.is_request_allowed("ping"));
197        assert!(!session.is_request_allowed("tools/list"));
198
199        // Transition to initializing
200        assert!(session.mark_initializing());
201        assert_eq!(session.phase(), SessionPhase::Initializing);
202        assert!(!session.is_initialized());
203
204        // Can't mark initializing again
205        assert!(!session.mark_initializing());
206
207        // All requests allowed during initializing
208        assert!(session.is_request_allowed("tools/list"));
209
210        // Transition to initialized
211        assert!(session.mark_initialized());
212        assert_eq!(session.phase(), SessionPhase::Initialized);
213        assert!(session.is_initialized());
214
215        // Can't mark initialized again
216        assert!(!session.mark_initialized());
217    }
218
219    #[test]
220    fn test_session_clone_shares_state() {
221        let session1 = SessionState::new();
222        let session2 = session1.clone();
223
224        session1.mark_initializing();
225        assert_eq!(session2.phase(), SessionPhase::Initializing);
226
227        session2.mark_initialized();
228        assert_eq!(session1.phase(), SessionPhase::Initialized);
229    }
230
231    #[test]
232    fn test_session_extensions_insert_and_get() {
233        let session = SessionState::new();
234
235        // Insert and retrieve a value
236        session.insert(42u32);
237        assert_eq!(session.get::<u32>(), Some(42));
238
239        // Different type returns None
240        assert_eq!(session.get::<String>(), None);
241    }
242
243    #[test]
244    fn test_session_extensions_overwrite() {
245        let session = SessionState::new();
246
247        session.insert(42u32);
248        assert_eq!(session.get::<u32>(), Some(42));
249
250        // Overwrite with new value
251        session.insert(100u32);
252        assert_eq!(session.get::<u32>(), Some(100));
253    }
254
255    #[test]
256    fn test_session_extensions_multiple_types() {
257        let session = SessionState::new();
258
259        session.insert(42u32);
260        session.insert("hello".to_string());
261        session.insert(true);
262
263        assert_eq!(session.get::<u32>(), Some(42));
264        assert_eq!(session.get::<String>(), Some("hello".to_string()));
265        assert_eq!(session.get::<bool>(), Some(true));
266    }
267
268    #[test]
269    fn test_session_extensions_shared_across_clones() {
270        let session1 = SessionState::new();
271        let session2 = session1.clone();
272
273        // Insert in one clone
274        session1.insert(42u32);
275
276        // Should be visible in the other
277        assert_eq!(session2.get::<u32>(), Some(42));
278
279        // Insert in the second clone
280        session2.insert("world".to_string());
281
282        // Should be visible in the first
283        assert_eq!(session1.get::<String>(), Some("world".to_string()));
284    }
285
286    #[test]
287    fn test_session_extensions_custom_type() {
288        #[derive(Debug, Clone, PartialEq)]
289        struct UserClaims {
290            user_id: String,
291            role: String,
292        }
293
294        let session = SessionState::new();
295
296        session.insert(UserClaims {
297            user_id: "user123".to_string(),
298            role: "admin".to_string(),
299        });
300
301        let claims = session.get::<UserClaims>();
302        assert!(claims.is_some());
303        let claims = claims.unwrap();
304        assert_eq!(claims.user_id, "user123");
305        assert_eq!(claims.role, "admin");
306    }
307}