1use std::sync::Arc;
10use std::sync::RwLock;
11use std::sync::atomic::{AtomicU8, Ordering};
12
13use crate::router::Extensions;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18#[non_exhaustive]
19pub enum SessionPhase {
20 Uninitialized = 0,
22 Initializing = 1,
24 Initialized = 2,
26}
27
28impl From<u8> for SessionPhase {
29 fn from(value: u8) -> Self {
30 match value {
31 0 => SessionPhase::Uninitialized,
32 1 => SessionPhase::Initializing,
33 2 => SessionPhase::Initialized,
34 _ => SessionPhase::Uninitialized,
35 }
36 }
37}
38
39#[derive(Clone)]
69pub struct SessionState {
70 phase: Arc<AtomicU8>,
71 extensions: Arc<RwLock<Extensions>>,
72}
73
74impl Default for SessionState {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl SessionState {
81 pub fn new() -> Self {
83 Self {
84 phase: Arc::new(AtomicU8::new(SessionPhase::Uninitialized as u8)),
85 extensions: Arc::new(RwLock::new(Extensions::new())),
86 }
87 }
88
89 pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) {
104 if let Ok(mut ext) = self.extensions.write() {
105 ext.insert(val);
106 }
107 }
108
109 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
125 self.extensions
126 .read()
127 .ok()
128 .and_then(|ext| ext.get::<T>().cloned())
129 }
130
131 pub fn phase(&self) -> SessionPhase {
133 SessionPhase::from(self.phase.load(Ordering::Acquire))
134 }
135
136 pub fn is_initialized(&self) -> bool {
138 self.phase() == SessionPhase::Initialized
139 }
140
141 pub fn mark_initializing(&self) -> bool {
145 self.phase
146 .compare_exchange(
147 SessionPhase::Uninitialized as u8,
148 SessionPhase::Initializing as u8,
149 Ordering::AcqRel,
150 Ordering::Acquire,
151 )
152 .is_ok()
153 }
154
155 pub fn mark_initialized(&self) -> bool {
166 if self
168 .phase
169 .compare_exchange(
170 SessionPhase::Initializing as u8,
171 SessionPhase::Initialized as u8,
172 Ordering::AcqRel,
173 Ordering::Acquire,
174 )
175 .is_ok()
176 {
177 return true;
178 }
179
180 self.phase
184 .compare_exchange(
185 SessionPhase::Uninitialized as u8,
186 SessionPhase::Initialized as u8,
187 Ordering::AcqRel,
188 Ordering::Acquire,
189 )
190 .is_ok()
191 }
192
193 pub fn is_request_allowed(&self, method: &str) -> bool {
198 match self.phase() {
199 SessionPhase::Uninitialized => {
200 matches!(method, "initialize" | "ping" | "server/discover")
202 }
203 SessionPhase::Initializing | SessionPhase::Initialized => true,
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_session_lifecycle() {
214 let session = SessionState::new();
215
216 assert_eq!(session.phase(), SessionPhase::Uninitialized);
218 assert!(!session.is_initialized());
219
220 assert!(session.is_request_allowed("initialize"));
222 assert!(session.is_request_allowed("ping"));
223 assert!(!session.is_request_allowed("tools/list"));
224
225 assert!(session.mark_initializing());
227 assert_eq!(session.phase(), SessionPhase::Initializing);
228 assert!(!session.is_initialized());
229
230 assert!(!session.mark_initializing());
232
233 assert!(session.is_request_allowed("tools/list"));
235
236 assert!(session.mark_initialized());
238 assert_eq!(session.phase(), SessionPhase::Initialized);
239 assert!(session.is_initialized());
240
241 assert!(!session.mark_initialized());
243 }
244
245 #[test]
246 fn test_session_clone_shares_state() {
247 let session1 = SessionState::new();
248 let session2 = session1.clone();
249
250 session1.mark_initializing();
251 assert_eq!(session2.phase(), SessionPhase::Initializing);
252
253 session2.mark_initialized();
254 assert_eq!(session1.phase(), SessionPhase::Initialized);
255 }
256
257 #[test]
258 fn test_session_extensions_insert_and_get() {
259 let session = SessionState::new();
260
261 session.insert(42u32);
263 assert_eq!(session.get::<u32>(), Some(42));
264
265 assert_eq!(session.get::<String>(), None);
267 }
268
269 #[test]
270 fn test_session_extensions_overwrite() {
271 let session = SessionState::new();
272
273 session.insert(42u32);
274 assert_eq!(session.get::<u32>(), Some(42));
275
276 session.insert(100u32);
278 assert_eq!(session.get::<u32>(), Some(100));
279 }
280
281 #[test]
282 fn test_session_extensions_multiple_types() {
283 let session = SessionState::new();
284
285 session.insert(42u32);
286 session.insert("hello".to_string());
287 session.insert(true);
288
289 assert_eq!(session.get::<u32>(), Some(42));
290 assert_eq!(session.get::<String>(), Some("hello".to_string()));
291 assert_eq!(session.get::<bool>(), Some(true));
292 }
293
294 #[test]
295 fn test_session_extensions_shared_across_clones() {
296 let session1 = SessionState::new();
297 let session2 = session1.clone();
298
299 session1.insert(42u32);
301
302 assert_eq!(session2.get::<u32>(), Some(42));
304
305 session2.insert("world".to_string());
307
308 assert_eq!(session1.get::<String>(), Some("world".to_string()));
310 }
311
312 #[test]
313 fn test_mark_initialized_from_uninitialized() {
314 let session = SessionState::new();
315
316 assert_eq!(session.phase(), SessionPhase::Uninitialized);
320 assert!(session.mark_initialized());
321 assert_eq!(session.phase(), SessionPhase::Initialized);
322 assert!(session.is_initialized());
323
324 assert!(session.is_request_allowed("tools/list"));
326 assert!(session.is_request_allowed("ping"));
327 }
328
329 #[test]
330 fn test_mark_initialized_idempotent_when_already_initialized() {
331 let session = SessionState::new();
332
333 session.mark_initializing();
335 session.mark_initialized();
336 assert_eq!(session.phase(), SessionPhase::Initialized);
337
338 assert!(!session.mark_initialized());
340 assert_eq!(session.phase(), SessionPhase::Initialized);
341 }
342
343 #[test]
344 fn test_session_extensions_custom_type() {
345 #[derive(Debug, Clone, PartialEq)]
346 struct UserClaims {
347 user_id: String,
348 role: String,
349 }
350
351 let session = SessionState::new();
352
353 session.insert(UserClaims {
354 user_id: "user123".to_string(),
355 role: "admin".to_string(),
356 });
357
358 let claims = session.get::<UserClaims>();
359 assert!(claims.is_some());
360 let claims = claims.unwrap();
361 assert_eq!(claims.user_id, "user123");
362 assert_eq!(claims.role, "admin");
363 }
364}