Skip to main content

fastmcp_core/
state.rs

1//! Session state storage for per-session key-value data.
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6/// Thread-safe session state container for per-session key-value storage.
7///
8/// This allows handlers to store and retrieve state that persists across
9/// requests within a single MCP session. The state is typed as JSON values
10/// to support flexible data storage.
11///
12/// # Thread Safety
13///
14/// SessionState is designed for concurrent access from multiple handlers.
15/// Operations are synchronized via an internal mutex.
16///
17/// # Example
18///
19/// ```ignore
20/// // In a tool handler:
21/// ctx.set_state("counter", 42);
22/// let count: Option<i32> = ctx.get_state("counter");
23/// ```
24#[derive(Debug, Clone, Default)]
25pub struct SessionState {
26    inner: Arc<Mutex<HashMap<String, serde_json::Value>>>,
27}
28
29impl SessionState {
30    /// Creates a new empty session state.
31    #[must_use]
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Gets a value from session state by key.
37    ///
38    /// Returns `None` if the key doesn't exist or if deserialization fails.
39    ///
40    /// # Type Parameters
41    ///
42    /// * `T` - The expected type of the value (must implement Deserialize)
43    #[must_use]
44    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
45        let guard = self.inner.lock().ok()?;
46        let value = guard.get(key)?;
47        serde_json::from_value(value.clone()).ok()
48    }
49
50    /// Gets a raw JSON value from session state by key.
51    ///
52    /// Returns `None` if the key doesn't exist.
53    #[must_use]
54    pub fn get_raw(&self, key: &str) -> Option<serde_json::Value> {
55        let guard = self.inner.lock().ok()?;
56        guard.get(key).cloned()
57    }
58
59    /// Sets a value in session state.
60    ///
61    /// The value is serialized to JSON for storage. Returns `true` if
62    /// the value was successfully stored.
63    ///
64    /// # Type Parameters
65    ///
66    /// * `T` - The type of the value (must implement Serialize)
67    pub fn set<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
68        let Ok(json_value) = serde_json::to_value(value) else {
69            return false;
70        };
71        let Ok(mut guard) = self.inner.lock() else {
72            return false;
73        };
74        guard.insert(key.into(), json_value);
75        true
76    }
77
78    /// Sets a raw JSON value in session state.
79    ///
80    /// Returns `true` if the value was successfully stored.
81    pub fn set_raw(&self, key: impl Into<String>, value: serde_json::Value) -> bool {
82        let Ok(mut guard) = self.inner.lock() else {
83            return false;
84        };
85        guard.insert(key.into(), value);
86        true
87    }
88
89    /// Removes a value from session state.
90    ///
91    /// Returns the previous value if it existed.
92    pub fn remove(&self, key: &str) -> Option<serde_json::Value> {
93        let mut guard = self.inner.lock().ok()?;
94        guard.remove(key)
95    }
96
97    /// Checks if a key exists in session state.
98    #[must_use]
99    pub fn contains(&self, key: &str) -> bool {
100        self.inner
101            .lock()
102            .map(|g| g.contains_key(key))
103            .unwrap_or(false)
104    }
105
106    /// Returns the number of entries in session state.
107    #[must_use]
108    pub fn len(&self) -> usize {
109        self.inner.lock().map(|g| g.len()).unwrap_or(0)
110    }
111
112    /// Returns true if session state is empty.
113    #[must_use]
114    pub fn is_empty(&self) -> bool {
115        self.len() == 0
116    }
117
118    /// Clears all session state.
119    pub fn clear(&self) {
120        if let Ok(mut guard) = self.inner.lock() {
121            guard.clear();
122        }
123    }
124}
125
126// ============================================================================
127// Dynamic Component Enable/Disable Helpers
128// ============================================================================
129
130/// Session state key for disabled tools.
131pub const DISABLED_TOOLS_KEY: &str = "fastmcp.disabled_tools";
132/// Session state key for disabled resources.
133pub const DISABLED_RESOURCES_KEY: &str = "fastmcp.disabled_resources";
134/// Session state key for disabled prompts.
135pub const DISABLED_PROMPTS_KEY: &str = "fastmcp.disabled_prompts";
136
137impl SessionState {
138    /// Returns whether a tool is enabled (not disabled) for this session.
139    ///
140    /// Tools are enabled by default unless explicitly disabled.
141    #[must_use]
142    pub fn is_tool_enabled(&self, name: &str) -> bool {
143        !self.is_in_disabled_set(DISABLED_TOOLS_KEY, name)
144    }
145
146    /// Returns whether a resource is enabled (not disabled) for this session.
147    ///
148    /// Resources are enabled by default unless explicitly disabled.
149    #[must_use]
150    pub fn is_resource_enabled(&self, uri: &str) -> bool {
151        !self.is_in_disabled_set(DISABLED_RESOURCES_KEY, uri)
152    }
153
154    /// Returns whether a prompt is enabled (not disabled) for this session.
155    ///
156    /// Prompts are enabled by default unless explicitly disabled.
157    #[must_use]
158    pub fn is_prompt_enabled(&self, name: &str) -> bool {
159        !self.is_in_disabled_set(DISABLED_PROMPTS_KEY, name)
160    }
161
162    /// Returns the set of disabled tools.
163    #[must_use]
164    pub fn disabled_tools(&self) -> std::collections::HashSet<String> {
165        self.get::<std::collections::HashSet<String>>(DISABLED_TOOLS_KEY)
166            .unwrap_or_default()
167    }
168
169    /// Returns the set of disabled resources.
170    #[must_use]
171    pub fn disabled_resources(&self) -> std::collections::HashSet<String> {
172        self.get::<std::collections::HashSet<String>>(DISABLED_RESOURCES_KEY)
173            .unwrap_or_default()
174    }
175
176    /// Returns the set of disabled prompts.
177    #[must_use]
178    pub fn disabled_prompts(&self) -> std::collections::HashSet<String> {
179        self.get::<std::collections::HashSet<String>>(DISABLED_PROMPTS_KEY)
180            .unwrap_or_default()
181    }
182
183    // Helper: Check if a name is in a disabled set
184    fn is_in_disabled_set(&self, key: &str, name: &str) -> bool {
185        self.get::<std::collections::HashSet<String>>(key)
186            .map(|set| set.contains(name))
187            .unwrap_or(false)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_session_state_new() {
197        let state = SessionState::new();
198        assert!(state.is_empty());
199        assert_eq!(state.len(), 0);
200    }
201
202    #[test]
203    fn test_session_state_set_get() {
204        let state = SessionState::new();
205
206        // Set a string value
207        assert!(state.set("name", "Alice"));
208        let name: Option<String> = state.get("name");
209        assert_eq!(name, Some("Alice".to_string()));
210
211        // Set a number value
212        assert!(state.set("count", 42));
213        let count: Option<i32> = state.get("count");
214        assert_eq!(count, Some(42));
215    }
216
217    #[test]
218    fn test_session_state_get_nonexistent() {
219        let state = SessionState::new();
220        let value: Option<String> = state.get("nonexistent");
221        assert!(value.is_none());
222    }
223
224    #[test]
225    fn test_session_state_type_mismatch() {
226        let state = SessionState::new();
227        state.set("count", 42);
228
229        // Try to get as wrong type - should return None
230        let value: Option<String> = state.get("count");
231        assert!(value.is_none());
232    }
233
234    #[test]
235    fn test_session_state_get_raw() {
236        let state = SessionState::new();
237        state.set("value", serde_json::json!({"nested": true}));
238
239        let raw = state.get_raw("value");
240        assert!(raw.is_some());
241        assert_eq!(raw.unwrap()["nested"], serde_json::json!(true));
242    }
243
244    #[test]
245    fn test_session_state_set_raw() {
246        let state = SessionState::new();
247        assert!(state.set_raw("key", serde_json::json!([1, 2, 3])));
248
249        let value: Option<Vec<i32>> = state.get("key");
250        assert_eq!(value, Some(vec![1, 2, 3]));
251    }
252
253    #[test]
254    fn test_session_state_remove() {
255        let state = SessionState::new();
256        state.set("key", "value");
257        assert!(state.contains("key"));
258
259        let removed = state.remove("key");
260        assert!(removed.is_some());
261        assert!(!state.contains("key"));
262    }
263
264    #[test]
265    fn test_session_state_contains() {
266        let state = SessionState::new();
267        assert!(!state.contains("key"));
268
269        state.set("key", "value");
270        assert!(state.contains("key"));
271    }
272
273    #[test]
274    fn test_session_state_len() {
275        let state = SessionState::new();
276        assert_eq!(state.len(), 0);
277
278        state.set("a", 1);
279        assert_eq!(state.len(), 1);
280
281        state.set("b", 2);
282        assert_eq!(state.len(), 2);
283
284        state.remove("a");
285        assert_eq!(state.len(), 1);
286    }
287
288    #[test]
289    fn test_session_state_clear() {
290        let state = SessionState::new();
291        state.set("a", 1);
292        state.set("b", 2);
293        assert_eq!(state.len(), 2);
294
295        state.clear();
296        assert!(state.is_empty());
297    }
298
299    #[test]
300    fn test_session_state_clone() {
301        let state = SessionState::new();
302        state.set("key", "value");
303
304        // Clone should share the same underlying state
305        let cloned = state.clone();
306        cloned.set("key2", "value2");
307
308        assert!(state.contains("key2"));
309    }
310
311    // ========================================================================
312    // Dynamic Enable/Disable Tests
313    // ========================================================================
314
315    #[test]
316    fn test_is_tool_enabled_default() {
317        let state = SessionState::new();
318
319        // Tools are enabled by default
320        assert!(state.is_tool_enabled("any_tool"));
321        assert!(state.is_tool_enabled("another_tool"));
322    }
323
324    #[test]
325    fn test_is_tool_enabled_disabled() {
326        let state = SessionState::new();
327
328        // Manually disable a tool by setting the disabled set
329        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
330        disabled.insert("my_tool".to_string());
331        state.set(super::DISABLED_TOOLS_KEY, disabled);
332
333        assert!(!state.is_tool_enabled("my_tool"));
334        assert!(state.is_tool_enabled("other_tool"));
335    }
336
337    #[test]
338    fn test_is_resource_enabled_default() {
339        let state = SessionState::new();
340
341        // Resources are enabled by default
342        assert!(state.is_resource_enabled("file://path"));
343        assert!(state.is_resource_enabled("http://example.com"));
344    }
345
346    #[test]
347    fn test_is_resource_enabled_disabled() {
348        let state = SessionState::new();
349
350        // Manually disable a resource
351        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
352        disabled.insert("file://secret".to_string());
353        state.set(super::DISABLED_RESOURCES_KEY, disabled);
354
355        assert!(!state.is_resource_enabled("file://secret"));
356        assert!(state.is_resource_enabled("file://public"));
357    }
358
359    #[test]
360    fn test_is_prompt_enabled_default() {
361        let state = SessionState::new();
362
363        // Prompts are enabled by default
364        assert!(state.is_prompt_enabled("any_prompt"));
365    }
366
367    #[test]
368    fn test_is_prompt_enabled_disabled() {
369        let state = SessionState::new();
370
371        // Manually disable a prompt
372        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
373        disabled.insert("admin_prompt".to_string());
374        state.set(super::DISABLED_PROMPTS_KEY, disabled);
375
376        assert!(!state.is_prompt_enabled("admin_prompt"));
377        assert!(state.is_prompt_enabled("user_prompt"));
378    }
379
380    #[test]
381    fn test_disabled_sets_return_empty_by_default() {
382        let state = SessionState::new();
383
384        assert!(state.disabled_tools().is_empty());
385        assert!(state.disabled_resources().is_empty());
386        assert!(state.disabled_prompts().is_empty());
387    }
388
389    #[test]
390    fn test_disabled_tools_returns_set() {
391        let state = SessionState::new();
392
393        let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
394        disabled.insert("tool1".to_string());
395        disabled.insert("tool2".to_string());
396        state.set(super::DISABLED_TOOLS_KEY, disabled);
397
398        let result = state.disabled_tools();
399        assert_eq!(result.len(), 2);
400        assert!(result.contains("tool1"));
401        assert!(result.contains("tool2"));
402    }
403}