Skip to main content

fastmcp_core/
state.rs

1//! Session state storage for per-session key-value data.
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6/// Thread-safe session state container for per-session key-value storage.
7///
8/// This allows handlers to store and retrieve state that persists across
9/// requests within a single MCP session. The state is typed as JSON values
10/// to support flexible data storage.
11///
12/// # Thread Safety
13///
14/// SessionState is designed for concurrent access from multiple handlers.
15/// Operations are synchronized via an internal mutex.
16///
17/// # Example
18///
19/// ```ignore
20/// // In a tool handler:
21/// ctx.set_state("counter", 42);
22/// let count: Option<i32> = ctx.get_state("counter");
23/// ```
24#[derive(Debug, Clone, Default)]
25pub struct SessionState {
26    inner: Arc<Mutex<HashMap<String, serde_json::Value>>>,
27    local: Option<Arc<Mutex<HashMap<String, serde_json::Value>>>>,
28}
29
30impl SessionState {
31    /// Creates a new empty session state.
32    #[must_use]
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Returns a view with request-local overrides layered on top of the
38    /// shared session state.
39    ///
40    /// Cloning the returned value preserves the same local override map,
41    /// while ordinary writes continue to target the shared session state.
42    #[must_use]
43    pub fn with_local_overrides(&self) -> Self {
44        Self {
45            inner: Arc::clone(&self.inner),
46            local: Some(
47                self.local
48                    .as_ref()
49                    .map_or_else(|| Arc::new(Mutex::new(HashMap::new())), Arc::clone),
50            ),
51        }
52    }
53
54    /// Creates an isolated snapshot of the current session state.
55    ///
56    /// Unlike [`Clone`], which shares the same underlying storage, this copies
57    /// the current key-value map into a fresh container so later writes do not
58    /// bleed across requests.
59    #[must_use]
60    pub fn snapshot(&self) -> Self {
61        let mut snapshot = self
62            .inner
63            .lock()
64            .unwrap_or_else(std::sync::PoisonError::into_inner)
65            .clone();
66        if let Some(local) = &self.local {
67            snapshot.extend(
68                local
69                    .lock()
70                    .unwrap_or_else(std::sync::PoisonError::into_inner)
71                    .clone(),
72            );
73        }
74        Self {
75            inner: Arc::new(Mutex::new(snapshot)),
76            local: None,
77        }
78    }
79
80    /// Gets a value from session state by key.
81    ///
82    /// Returns `None` if the key doesn't exist or if deserialization fails.
83    ///
84    /// # Type Parameters
85    ///
86    /// * `T` - The expected type of the value (must implement Deserialize)
87    #[must_use]
88    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
89        let value = self.get_raw(key)?;
90        serde_json::from_value(value).ok()
91    }
92
93    /// Gets a raw JSON value from session state by key.
94    ///
95    /// Returns `None` if the key doesn't exist.
96    #[must_use]
97    pub fn get_raw(&self, key: &str) -> Option<serde_json::Value> {
98        if let Some(local) = &self.local {
99            let guard = local.lock().ok()?;
100            if let Some(value) = guard.get(key) {
101                return Some(value.clone());
102            }
103        }
104        let guard = self.inner.lock().ok()?;
105        guard.get(key).cloned()
106    }
107
108    /// Sets a value in session state.
109    ///
110    /// The value is serialized to JSON for storage. Returns `true` if
111    /// the value was successfully stored.
112    ///
113    /// # Type Parameters
114    ///
115    /// * `T` - The type of the value (must implement Serialize)
116    pub fn set<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
117        let Ok(json_value) = serde_json::to_value(value) else {
118            return false;
119        };
120        let Ok(mut guard) = self.inner.lock() else {
121            return false;
122        };
123        guard.insert(key.into(), json_value);
124        true
125    }
126
127    /// Sets a raw JSON value in session state.
128    ///
129    /// Returns `true` if the value was successfully stored.
130    pub fn set_raw(&self, key: impl Into<String>, value: serde_json::Value) -> bool {
131        let Ok(mut guard) = self.inner.lock() else {
132            return false;
133        };
134        guard.insert(key.into(), value);
135        true
136    }
137
138    /// Sets a request-local raw JSON value layered over the shared session state.
139    ///
140    /// Returns `false` if this state does not have local overrides enabled.
141    pub fn set_local_raw(&self, key: impl Into<String>, value: serde_json::Value) -> bool {
142        let Some(local) = &self.local else {
143            return false;
144        };
145        let Ok(mut guard) = local.lock() else {
146            return false;
147        };
148        guard.insert(key.into(), value);
149        true
150    }
151
152    /// Sets a request-local value layered over the shared session state.
153    ///
154    /// Returns `false` if serialization fails or local overrides are unavailable.
155    pub fn set_local<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
156        let Ok(json_value) = serde_json::to_value(value) else {
157            return false;
158        };
159        self.set_local_raw(key, json_value)
160    }
161
162    /// Removes a value from session state.
163    ///
164    /// Returns the previous value if it existed.
165    pub fn remove(&self, key: &str) -> Option<serde_json::Value> {
166        if let Some(local) = &self.local {
167            let mut guard = local.lock().ok()?;
168            if let Some(value) = guard.remove(key) {
169                return Some(value);
170            }
171        }
172        let mut guard = self.inner.lock().ok()?;
173        guard.remove(key)
174    }
175
176    /// Checks if a key exists in session state.
177    #[must_use]
178    pub fn contains(&self, key: &str) -> bool {
179        self.get_raw(key).is_some()
180    }
181
182    /// Returns the number of entries in session state.
183    #[must_use]
184    pub fn len(&self) -> usize {
185        let shared_keys = self
186            .inner
187            .lock()
188            .map(|g| g.keys().cloned().collect::<std::collections::HashSet<_>>())
189            .unwrap_or_default();
190        if let Some(local) = &self.local {
191            let local_keys = local
192                .lock()
193                .map(|g| g.keys().cloned().collect::<std::collections::HashSet<_>>())
194                .unwrap_or_default();
195            shared_keys.union(&local_keys).count()
196        } else {
197            shared_keys.len()
198        }
199    }
200
201    /// Returns true if session state is empty.
202    #[must_use]
203    pub fn is_empty(&self) -> bool {
204        self.len() == 0
205    }
206
207    /// Clears all session state.
208    pub fn clear(&self) {
209        if let Ok(mut guard) = self.inner.lock() {
210            guard.clear();
211        }
212        if let Some(local) = &self.local
213            && let Ok(mut guard) = local.lock()
214        {
215            guard.clear();
216        }
217    }
218}
219
220// ============================================================================
221// Dynamic Component Enable/Disable Helpers
222// ============================================================================
223
224/// Session state key for disabled tools.
225pub const DISABLED_TOOLS_KEY: &str = "fastmcp.disabled_tools";
226/// Session state key for disabled resources.
227pub const DISABLED_RESOURCES_KEY: &str = "fastmcp.disabled_resources";
228/// Session state key for disabled prompts.
229pub const DISABLED_PROMPTS_KEY: &str = "fastmcp.disabled_prompts";
230
231impl SessionState {
232    /// Returns whether a tool is enabled (not disabled) for this session.
233    ///
234    /// Tools are enabled by default unless explicitly disabled.
235    #[must_use]
236    pub fn is_tool_enabled(&self, name: &str) -> bool {
237        !self.is_in_disabled_set(DISABLED_TOOLS_KEY, name)
238    }
239
240    /// Returns whether a resource is enabled (not disabled) for this session.
241    ///
242    /// Resources are enabled by default unless explicitly disabled.
243    #[must_use]
244    pub fn is_resource_enabled(&self, uri: &str) -> bool {
245        !self.is_in_disabled_set(DISABLED_RESOURCES_KEY, uri)
246    }
247
248    /// Returns whether a prompt is enabled (not disabled) for this session.
249    ///
250    /// Prompts are enabled by default unless explicitly disabled.
251    #[must_use]
252    pub fn is_prompt_enabled(&self, name: &str) -> bool {
253        !self.is_in_disabled_set(DISABLED_PROMPTS_KEY, name)
254    }
255
256    /// Returns the set of disabled tools.
257    #[must_use]
258    pub fn disabled_tools(&self) -> std::collections::HashSet<String> {
259        self.get::<std::collections::HashSet<String>>(DISABLED_TOOLS_KEY)
260            .unwrap_or_default()
261    }
262
263    /// Returns the set of disabled resources.
264    #[must_use]
265    pub fn disabled_resources(&self) -> std::collections::HashSet<String> {
266        self.get::<std::collections::HashSet<String>>(DISABLED_RESOURCES_KEY)
267            .unwrap_or_default()
268    }
269
270    /// Returns the set of disabled prompts.
271    #[must_use]
272    pub fn disabled_prompts(&self) -> std::collections::HashSet<String> {
273        self.get::<std::collections::HashSet<String>>(DISABLED_PROMPTS_KEY)
274            .unwrap_or_default()
275    }
276
277    // Helper: Check if a name is in a disabled set
278    fn is_in_disabled_set(&self, key: &str, name: &str) -> bool {
279        self.get::<std::collections::HashSet<String>>(key)
280            .map(|set| set.contains(name))
281            .unwrap_or(false)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_session_state_new() {
291        let state = SessionState::new();
292        assert!(state.is_empty());
293        assert_eq!(state.len(), 0);
294    }
295
296    #[test]
297    fn test_session_state_set_get() {
298        let state = SessionState::new();
299
300        // Set a string value
301        assert!(state.set("name", "Alice"));
302        let name: Option<String> = state.get("name");
303        assert_eq!(name, Some("Alice".to_string()));
304
305        // Set a number value
306        assert!(state.set("count", 42));
307        let count: Option<i32> = state.get("count");
308        assert_eq!(count, Some(42));
309    }
310
311    #[test]
312    fn test_session_state_get_nonexistent() {
313        let state = SessionState::new();
314        let value: Option<String> = state.get("nonexistent");
315        assert!(value.is_none());
316    }
317
318    #[test]
319    fn test_session_state_type_mismatch() {
320        let state = SessionState::new();
321        state.set("count", 42);
322
323        // Try to get as wrong type - should return None
324        let value: Option<String> = state.get("count");
325        assert!(value.is_none());
326    }
327
328    #[test]
329    fn test_session_state_get_raw() {
330        let state = SessionState::new();
331        state.set("value", serde_json::json!({"nested": true}));
332
333        let raw = state.get_raw("value");
334        assert!(raw.is_some());
335        assert_eq!(raw.unwrap()["nested"], serde_json::json!(true));
336    }
337
338    #[test]
339    fn test_session_state_set_raw() {
340        let state = SessionState::new();
341        assert!(state.set_raw("key", serde_json::json!([1, 2, 3])));
342
343        let value: Option<Vec<i32>> = state.get("key");
344        assert_eq!(value, Some(vec![1, 2, 3]));
345    }
346
347    #[test]
348    fn test_session_state_remove() {
349        let state = SessionState::new();
350        state.set("key", "value");
351        assert!(state.contains("key"));
352
353        let removed = state.remove("key");
354        assert!(removed.is_some());
355        assert!(!state.contains("key"));
356    }
357
358    #[test]
359    fn test_session_state_contains() {
360        let state = SessionState::new();
361        assert!(!state.contains("key"));
362
363        state.set("key", "value");
364        assert!(state.contains("key"));
365    }
366
367    #[test]
368    fn test_session_state_len() {
369        let state = SessionState::new();
370        assert_eq!(state.len(), 0);
371
372        state.set("a", 1);
373        assert_eq!(state.len(), 1);
374
375        state.set("b", 2);
376        assert_eq!(state.len(), 2);
377
378        state.remove("a");
379        assert_eq!(state.len(), 1);
380    }
381
382    #[test]
383    fn test_session_state_clear() {
384        let state = SessionState::new();
385        state.set("a", 1);
386        state.set("b", 2);
387        assert_eq!(state.len(), 2);
388
389        state.clear();
390        assert!(state.is_empty());
391    }
392
393    #[test]
394    fn test_session_state_clone() {
395        let state = SessionState::new();
396        state.set("key", "value");
397
398        // Clone should share the same underlying state
399        let cloned = state.clone();
400        cloned.set("key2", "value2");
401
402        assert!(state.contains("key2"));
403    }
404
405    #[test]
406    fn test_session_state_snapshot_is_isolated() {
407        let state = SessionState::new();
408        state.set("counter", 1);
409
410        let snapshot = state.snapshot();
411        state.set("counter", 2);
412        snapshot.set("only_in_snapshot", true);
413
414        let live_counter: Option<i32> = state.get("counter");
415        let snap_counter: Option<i32> = snapshot.get("counter");
416        let live_only: Option<bool> = state.get("only_in_snapshot");
417        let snap_only: Option<bool> = snapshot.get("only_in_snapshot");
418
419        assert_eq!(live_counter, Some(2));
420        assert_eq!(snap_counter, Some(1));
421        assert_eq!(live_only, None);
422        assert_eq!(snap_only, Some(true));
423    }
424
425    // ========================================================================
426    // Dynamic Enable/Disable Tests
427    // ========================================================================
428
429    #[test]
430    fn test_is_tool_enabled_default() {
431        let state = SessionState::new();
432
433        // Tools are enabled by default
434        assert!(state.is_tool_enabled("any_tool"));
435        assert!(state.is_tool_enabled("another_tool"));
436    }
437
438    #[test]
439    fn test_is_tool_enabled_disabled() {
440        let state = SessionState::new();
441
442        // Manually disable a tool by setting the disabled set
443        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
444        disabled.insert("my_tool".to_string());
445        state.set(super::DISABLED_TOOLS_KEY, disabled);
446
447        assert!(!state.is_tool_enabled("my_tool"));
448        assert!(state.is_tool_enabled("other_tool"));
449    }
450
451    #[test]
452    fn test_is_resource_enabled_default() {
453        let state = SessionState::new();
454
455        // Resources are enabled by default
456        assert!(state.is_resource_enabled("file://path"));
457        assert!(state.is_resource_enabled("http://example.com"));
458    }
459
460    #[test]
461    fn test_is_resource_enabled_disabled() {
462        let state = SessionState::new();
463
464        // Manually disable a resource
465        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
466        disabled.insert("file://secret".to_string());
467        state.set(super::DISABLED_RESOURCES_KEY, disabled);
468
469        assert!(!state.is_resource_enabled("file://secret"));
470        assert!(state.is_resource_enabled("file://public"));
471    }
472
473    #[test]
474    fn test_is_prompt_enabled_default() {
475        let state = SessionState::new();
476
477        // Prompts are enabled by default
478        assert!(state.is_prompt_enabled("any_prompt"));
479    }
480
481    #[test]
482    fn test_is_prompt_enabled_disabled() {
483        let state = SessionState::new();
484
485        // Manually disable a prompt
486        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
487        disabled.insert("admin_prompt".to_string());
488        state.set(super::DISABLED_PROMPTS_KEY, disabled);
489
490        assert!(!state.is_prompt_enabled("admin_prompt"));
491        assert!(state.is_prompt_enabled("user_prompt"));
492    }
493
494    #[test]
495    fn test_disabled_sets_return_empty_by_default() {
496        let state = SessionState::new();
497
498        assert!(state.disabled_tools().is_empty());
499        assert!(state.disabled_resources().is_empty());
500        assert!(state.disabled_prompts().is_empty());
501    }
502
503    #[test]
504    fn test_disabled_tools_returns_set() {
505        let state = SessionState::new();
506
507        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
508        disabled.insert("tool1".to_string());
509        disabled.insert("tool2".to_string());
510        state.set(super::DISABLED_TOOLS_KEY, disabled);
511
512        let result = state.disabled_tools();
513        assert_eq!(result.len(), 2);
514        assert!(result.contains("tool1"));
515        assert!(result.contains("tool2"));
516    }
517}