1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6#[derive(Debug, Clone, Default)]
25pub struct SessionState {
26 inner: Arc<Mutex<HashMap<String, serde_json::Value>>>,
27 local: Option<Arc<Mutex<HashMap<String, serde_json::Value>>>>,
28}
29
30impl SessionState {
31 #[must_use]
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 #[must_use]
43 pub fn with_local_overrides(&self) -> Self {
44 Self {
45 inner: Arc::clone(&self.inner),
46 local: Some(
47 self.local
48 .as_ref()
49 .map_or_else(|| Arc::new(Mutex::new(HashMap::new())), Arc::clone),
50 ),
51 }
52 }
53
54 #[must_use]
60 pub fn snapshot(&self) -> Self {
61 let mut snapshot = self
62 .inner
63 .lock()
64 .unwrap_or_else(std::sync::PoisonError::into_inner)
65 .clone();
66 if let Some(local) = &self.local {
67 snapshot.extend(
68 local
69 .lock()
70 .unwrap_or_else(std::sync::PoisonError::into_inner)
71 .clone(),
72 );
73 }
74 Self {
75 inner: Arc::new(Mutex::new(snapshot)),
76 local: None,
77 }
78 }
79
80 #[must_use]
88 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
89 let value = self.get_raw(key)?;
90 serde_json::from_value(value).ok()
91 }
92
93 #[must_use]
97 pub fn get_raw(&self, key: &str) -> Option<serde_json::Value> {
98 if let Some(local) = &self.local {
99 let guard = local.lock().ok()?;
100 if let Some(value) = guard.get(key) {
101 return Some(value.clone());
102 }
103 }
104 let guard = self.inner.lock().ok()?;
105 guard.get(key).cloned()
106 }
107
108 pub fn set<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
117 let Ok(json_value) = serde_json::to_value(value) else {
118 return false;
119 };
120 let Ok(mut guard) = self.inner.lock() else {
121 return false;
122 };
123 guard.insert(key.into(), json_value);
124 true
125 }
126
127 pub fn set_raw(&self, key: impl Into<String>, value: serde_json::Value) -> bool {
131 let Ok(mut guard) = self.inner.lock() else {
132 return false;
133 };
134 guard.insert(key.into(), value);
135 true
136 }
137
138 pub fn set_local_raw(&self, key: impl Into<String>, value: serde_json::Value) -> bool {
142 let Some(local) = &self.local else {
143 return false;
144 };
145 let Ok(mut guard) = local.lock() else {
146 return false;
147 };
148 guard.insert(key.into(), value);
149 true
150 }
151
152 pub fn set_local<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
156 let Ok(json_value) = serde_json::to_value(value) else {
157 return false;
158 };
159 self.set_local_raw(key, json_value)
160 }
161
162 pub fn remove(&self, key: &str) -> Option<serde_json::Value> {
166 if let Some(local) = &self.local {
167 let mut guard = local.lock().ok()?;
168 if let Some(value) = guard.remove(key) {
169 return Some(value);
170 }
171 }
172 let mut guard = self.inner.lock().ok()?;
173 guard.remove(key)
174 }
175
176 #[must_use]
178 pub fn contains(&self, key: &str) -> bool {
179 self.get_raw(key).is_some()
180 }
181
182 #[must_use]
184 pub fn len(&self) -> usize {
185 let shared_keys = self
186 .inner
187 .lock()
188 .map(|g| g.keys().cloned().collect::<std::collections::HashSet<_>>())
189 .unwrap_or_default();
190 if let Some(local) = &self.local {
191 let local_keys = local
192 .lock()
193 .map(|g| g.keys().cloned().collect::<std::collections::HashSet<_>>())
194 .unwrap_or_default();
195 shared_keys.union(&local_keys).count()
196 } else {
197 shared_keys.len()
198 }
199 }
200
201 #[must_use]
203 pub fn is_empty(&self) -> bool {
204 self.len() == 0
205 }
206
207 pub fn clear(&self) {
209 if let Ok(mut guard) = self.inner.lock() {
210 guard.clear();
211 }
212 if let Some(local) = &self.local
213 && let Ok(mut guard) = local.lock()
214 {
215 guard.clear();
216 }
217 }
218}
219
220pub const DISABLED_TOOLS_KEY: &str = "fastmcp.disabled_tools";
226pub const DISABLED_RESOURCES_KEY: &str = "fastmcp.disabled_resources";
228pub const DISABLED_PROMPTS_KEY: &str = "fastmcp.disabled_prompts";
230
231impl SessionState {
232 #[must_use]
236 pub fn is_tool_enabled(&self, name: &str) -> bool {
237 !self.is_in_disabled_set(DISABLED_TOOLS_KEY, name)
238 }
239
240 #[must_use]
244 pub fn is_resource_enabled(&self, uri: &str) -> bool {
245 !self.is_in_disabled_set(DISABLED_RESOURCES_KEY, uri)
246 }
247
248 #[must_use]
252 pub fn is_prompt_enabled(&self, name: &str) -> bool {
253 !self.is_in_disabled_set(DISABLED_PROMPTS_KEY, name)
254 }
255
256 #[must_use]
258 pub fn disabled_tools(&self) -> std::collections::HashSet<String> {
259 self.get::<std::collections::HashSet<String>>(DISABLED_TOOLS_KEY)
260 .unwrap_or_default()
261 }
262
263 #[must_use]
265 pub fn disabled_resources(&self) -> std::collections::HashSet<String> {
266 self.get::<std::collections::HashSet<String>>(DISABLED_RESOURCES_KEY)
267 .unwrap_or_default()
268 }
269
270 #[must_use]
272 pub fn disabled_prompts(&self) -> std::collections::HashSet<String> {
273 self.get::<std::collections::HashSet<String>>(DISABLED_PROMPTS_KEY)
274 .unwrap_or_default()
275 }
276
277 fn is_in_disabled_set(&self, key: &str, name: &str) -> bool {
279 self.get::<std::collections::HashSet<String>>(key)
280 .map(|set| set.contains(name))
281 .unwrap_or(false)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_session_state_new() {
291 let state = SessionState::new();
292 assert!(state.is_empty());
293 assert_eq!(state.len(), 0);
294 }
295
296 #[test]
297 fn test_session_state_set_get() {
298 let state = SessionState::new();
299
300 assert!(state.set("name", "Alice"));
302 let name: Option<String> = state.get("name");
303 assert_eq!(name, Some("Alice".to_string()));
304
305 assert!(state.set("count", 42));
307 let count: Option<i32> = state.get("count");
308 assert_eq!(count, Some(42));
309 }
310
311 #[test]
312 fn test_session_state_get_nonexistent() {
313 let state = SessionState::new();
314 let value: Option<String> = state.get("nonexistent");
315 assert!(value.is_none());
316 }
317
318 #[test]
319 fn test_session_state_type_mismatch() {
320 let state = SessionState::new();
321 state.set("count", 42);
322
323 let value: Option<String> = state.get("count");
325 assert!(value.is_none());
326 }
327
328 #[test]
329 fn test_session_state_get_raw() {
330 let state = SessionState::new();
331 state.set("value", serde_json::json!({"nested": true}));
332
333 let raw = state.get_raw("value");
334 assert!(raw.is_some());
335 assert_eq!(raw.unwrap()["nested"], serde_json::json!(true));
336 }
337
338 #[test]
339 fn test_session_state_set_raw() {
340 let state = SessionState::new();
341 assert!(state.set_raw("key", serde_json::json!([1, 2, 3])));
342
343 let value: Option<Vec<i32>> = state.get("key");
344 assert_eq!(value, Some(vec![1, 2, 3]));
345 }
346
347 #[test]
348 fn test_session_state_remove() {
349 let state = SessionState::new();
350 state.set("key", "value");
351 assert!(state.contains("key"));
352
353 let removed = state.remove("key");
354 assert!(removed.is_some());
355 assert!(!state.contains("key"));
356 }
357
358 #[test]
359 fn test_session_state_contains() {
360 let state = SessionState::new();
361 assert!(!state.contains("key"));
362
363 state.set("key", "value");
364 assert!(state.contains("key"));
365 }
366
367 #[test]
368 fn test_session_state_len() {
369 let state = SessionState::new();
370 assert_eq!(state.len(), 0);
371
372 state.set("a", 1);
373 assert_eq!(state.len(), 1);
374
375 state.set("b", 2);
376 assert_eq!(state.len(), 2);
377
378 state.remove("a");
379 assert_eq!(state.len(), 1);
380 }
381
382 #[test]
383 fn test_session_state_clear() {
384 let state = SessionState::new();
385 state.set("a", 1);
386 state.set("b", 2);
387 assert_eq!(state.len(), 2);
388
389 state.clear();
390 assert!(state.is_empty());
391 }
392
393 #[test]
394 fn test_session_state_clone() {
395 let state = SessionState::new();
396 state.set("key", "value");
397
398 let cloned = state.clone();
400 cloned.set("key2", "value2");
401
402 assert!(state.contains("key2"));
403 }
404
405 #[test]
406 fn test_session_state_snapshot_is_isolated() {
407 let state = SessionState::new();
408 state.set("counter", 1);
409
410 let snapshot = state.snapshot();
411 state.set("counter", 2);
412 snapshot.set("only_in_snapshot", true);
413
414 let live_counter: Option<i32> = state.get("counter");
415 let snap_counter: Option<i32> = snapshot.get("counter");
416 let live_only: Option<bool> = state.get("only_in_snapshot");
417 let snap_only: Option<bool> = snapshot.get("only_in_snapshot");
418
419 assert_eq!(live_counter, Some(2));
420 assert_eq!(snap_counter, Some(1));
421 assert_eq!(live_only, None);
422 assert_eq!(snap_only, Some(true));
423 }
424
425 #[test]
430 fn test_is_tool_enabled_default() {
431 let state = SessionState::new();
432
433 assert!(state.is_tool_enabled("any_tool"));
435 assert!(state.is_tool_enabled("another_tool"));
436 }
437
438 #[test]
439 fn test_is_tool_enabled_disabled() {
440 let state = SessionState::new();
441
442 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
444 disabled.insert("my_tool".to_string());
445 state.set(super::DISABLED_TOOLS_KEY, disabled);
446
447 assert!(!state.is_tool_enabled("my_tool"));
448 assert!(state.is_tool_enabled("other_tool"));
449 }
450
451 #[test]
452 fn test_is_resource_enabled_default() {
453 let state = SessionState::new();
454
455 assert!(state.is_resource_enabled("file://path"));
457 assert!(state.is_resource_enabled("http://example.com"));
458 }
459
460 #[test]
461 fn test_is_resource_enabled_disabled() {
462 let state = SessionState::new();
463
464 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
466 disabled.insert("file://secret".to_string());
467 state.set(super::DISABLED_RESOURCES_KEY, disabled);
468
469 assert!(!state.is_resource_enabled("file://secret"));
470 assert!(state.is_resource_enabled("file://public"));
471 }
472
473 #[test]
474 fn test_is_prompt_enabled_default() {
475 let state = SessionState::new();
476
477 assert!(state.is_prompt_enabled("any_prompt"));
479 }
480
481 #[test]
482 fn test_is_prompt_enabled_disabled() {
483 let state = SessionState::new();
484
485 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
487 disabled.insert("admin_prompt".to_string());
488 state.set(super::DISABLED_PROMPTS_KEY, disabled);
489
490 assert!(!state.is_prompt_enabled("admin_prompt"));
491 assert!(state.is_prompt_enabled("user_prompt"));
492 }
493
494 #[test]
495 fn test_disabled_sets_return_empty_by_default() {
496 let state = SessionState::new();
497
498 assert!(state.disabled_tools().is_empty());
499 assert!(state.disabled_resources().is_empty());
500 assert!(state.disabled_prompts().is_empty());
501 }
502
503 #[test]
504 fn test_disabled_tools_returns_set() {
505 let state = SessionState::new();
506
507 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
508 disabled.insert("tool1".to_string());
509 disabled.insert("tool2".to_string());
510 state.set(super::DISABLED_TOOLS_KEY, disabled);
511
512 let result = state.disabled_tools();
513 assert_eq!(result.len(), 2);
514 assert!(result.contains("tool1"));
515 assert!(result.contains("tool2"));
516 }
517}