1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6#[derive(Debug, Clone, Default)]
25pub struct SessionState {
26 inner: Arc<Mutex<HashMap<String, serde_json::Value>>>,
27}
28
29impl SessionState {
30 #[must_use]
32 pub fn new() -> Self {
33 Self::default()
34 }
35
36 #[must_use]
44 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
45 let guard = self.inner.lock().ok()?;
46 let value = guard.get(key)?;
47 serde_json::from_value(value.clone()).ok()
48 }
49
50 #[must_use]
54 pub fn get_raw(&self, key: &str) -> Option<serde_json::Value> {
55 let guard = self.inner.lock().ok()?;
56 guard.get(key).cloned()
57 }
58
59 pub fn set<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
68 let Ok(json_value) = serde_json::to_value(value) else {
69 return false;
70 };
71 let Ok(mut guard) = self.inner.lock() else {
72 return false;
73 };
74 guard.insert(key.into(), json_value);
75 true
76 }
77
78 pub fn set_raw(&self, key: impl Into<String>, value: serde_json::Value) -> bool {
82 let Ok(mut guard) = self.inner.lock() else {
83 return false;
84 };
85 guard.insert(key.into(), value);
86 true
87 }
88
89 pub fn remove(&self, key: &str) -> Option<serde_json::Value> {
93 let mut guard = self.inner.lock().ok()?;
94 guard.remove(key)
95 }
96
97 #[must_use]
99 pub fn contains(&self, key: &str) -> bool {
100 self.inner
101 .lock()
102 .map(|g| g.contains_key(key))
103 .unwrap_or(false)
104 }
105
106 #[must_use]
108 pub fn len(&self) -> usize {
109 self.inner.lock().map(|g| g.len()).unwrap_or(0)
110 }
111
112 #[must_use]
114 pub fn is_empty(&self) -> bool {
115 self.len() == 0
116 }
117
118 pub fn clear(&self) {
120 if let Ok(mut guard) = self.inner.lock() {
121 guard.clear();
122 }
123 }
124}
125
126pub const DISABLED_TOOLS_KEY: &str = "fastmcp.disabled_tools";
132pub const DISABLED_RESOURCES_KEY: &str = "fastmcp.disabled_resources";
134pub const DISABLED_PROMPTS_KEY: &str = "fastmcp.disabled_prompts";
136
137impl SessionState {
138 #[must_use]
142 pub fn is_tool_enabled(&self, name: &str) -> bool {
143 !self.is_in_disabled_set(DISABLED_TOOLS_KEY, name)
144 }
145
146 #[must_use]
150 pub fn is_resource_enabled(&self, uri: &str) -> bool {
151 !self.is_in_disabled_set(DISABLED_RESOURCES_KEY, uri)
152 }
153
154 #[must_use]
158 pub fn is_prompt_enabled(&self, name: &str) -> bool {
159 !self.is_in_disabled_set(DISABLED_PROMPTS_KEY, name)
160 }
161
162 #[must_use]
164 pub fn disabled_tools(&self) -> std::collections::HashSet<String> {
165 self.get::<std::collections::HashSet<String>>(DISABLED_TOOLS_KEY)
166 .unwrap_or_default()
167 }
168
169 #[must_use]
171 pub fn disabled_resources(&self) -> std::collections::HashSet<String> {
172 self.get::<std::collections::HashSet<String>>(DISABLED_RESOURCES_KEY)
173 .unwrap_or_default()
174 }
175
176 #[must_use]
178 pub fn disabled_prompts(&self) -> std::collections::HashSet<String> {
179 self.get::<std::collections::HashSet<String>>(DISABLED_PROMPTS_KEY)
180 .unwrap_or_default()
181 }
182
183 fn is_in_disabled_set(&self, key: &str, name: &str) -> bool {
185 self.get::<std::collections::HashSet<String>>(key)
186 .map(|set| set.contains(name))
187 .unwrap_or(false)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_session_state_new() {
197 let state = SessionState::new();
198 assert!(state.is_empty());
199 assert_eq!(state.len(), 0);
200 }
201
202 #[test]
203 fn test_session_state_set_get() {
204 let state = SessionState::new();
205
206 assert!(state.set("name", "Alice"));
208 let name: Option<String> = state.get("name");
209 assert_eq!(name, Some("Alice".to_string()));
210
211 assert!(state.set("count", 42));
213 let count: Option<i32> = state.get("count");
214 assert_eq!(count, Some(42));
215 }
216
217 #[test]
218 fn test_session_state_get_nonexistent() {
219 let state = SessionState::new();
220 let value: Option<String> = state.get("nonexistent");
221 assert!(value.is_none());
222 }
223
224 #[test]
225 fn test_session_state_type_mismatch() {
226 let state = SessionState::new();
227 state.set("count", 42);
228
229 let value: Option<String> = state.get("count");
231 assert!(value.is_none());
232 }
233
234 #[test]
235 fn test_session_state_get_raw() {
236 let state = SessionState::new();
237 state.set("value", serde_json::json!({"nested": true}));
238
239 let raw = state.get_raw("value");
240 assert!(raw.is_some());
241 assert_eq!(raw.unwrap()["nested"], serde_json::json!(true));
242 }
243
244 #[test]
245 fn test_session_state_set_raw() {
246 let state = SessionState::new();
247 assert!(state.set_raw("key", serde_json::json!([1, 2, 3])));
248
249 let value: Option<Vec<i32>> = state.get("key");
250 assert_eq!(value, Some(vec![1, 2, 3]));
251 }
252
253 #[test]
254 fn test_session_state_remove() {
255 let state = SessionState::new();
256 state.set("key", "value");
257 assert!(state.contains("key"));
258
259 let removed = state.remove("key");
260 assert!(removed.is_some());
261 assert!(!state.contains("key"));
262 }
263
264 #[test]
265 fn test_session_state_contains() {
266 let state = SessionState::new();
267 assert!(!state.contains("key"));
268
269 state.set("key", "value");
270 assert!(state.contains("key"));
271 }
272
273 #[test]
274 fn test_session_state_len() {
275 let state = SessionState::new();
276 assert_eq!(state.len(), 0);
277
278 state.set("a", 1);
279 assert_eq!(state.len(), 1);
280
281 state.set("b", 2);
282 assert_eq!(state.len(), 2);
283
284 state.remove("a");
285 assert_eq!(state.len(), 1);
286 }
287
288 #[test]
289 fn test_session_state_clear() {
290 let state = SessionState::new();
291 state.set("a", 1);
292 state.set("b", 2);
293 assert_eq!(state.len(), 2);
294
295 state.clear();
296 assert!(state.is_empty());
297 }
298
299 #[test]
300 fn test_session_state_clone() {
301 let state = SessionState::new();
302 state.set("key", "value");
303
304 let cloned = state.clone();
306 cloned.set("key2", "value2");
307
308 assert!(state.contains("key2"));
309 }
310
311 #[test]
316 fn test_is_tool_enabled_default() {
317 let state = SessionState::new();
318
319 assert!(state.is_tool_enabled("any_tool"));
321 assert!(state.is_tool_enabled("another_tool"));
322 }
323
324 #[test]
325 fn test_is_tool_enabled_disabled() {
326 let state = SessionState::new();
327
328 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
330 disabled.insert("my_tool".to_string());
331 state.set(super::DISABLED_TOOLS_KEY, disabled);
332
333 assert!(!state.is_tool_enabled("my_tool"));
334 assert!(state.is_tool_enabled("other_tool"));
335 }
336
337 #[test]
338 fn test_is_resource_enabled_default() {
339 let state = SessionState::new();
340
341 assert!(state.is_resource_enabled("file://path"));
343 assert!(state.is_resource_enabled("http://example.com"));
344 }
345
346 #[test]
347 fn test_is_resource_enabled_disabled() {
348 let state = SessionState::new();
349
350 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
352 disabled.insert("file://secret".to_string());
353 state.set(super::DISABLED_RESOURCES_KEY, disabled);
354
355 assert!(!state.is_resource_enabled("file://secret"));
356 assert!(state.is_resource_enabled("file://public"));
357 }
358
359 #[test]
360 fn test_is_prompt_enabled_default() {
361 let state = SessionState::new();
362
363 assert!(state.is_prompt_enabled("any_prompt"));
365 }
366
367 #[test]
368 fn test_is_prompt_enabled_disabled() {
369 let state = SessionState::new();
370
371 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
373 disabled.insert("admin_prompt".to_string());
374 state.set(super::DISABLED_PROMPTS_KEY, disabled);
375
376 assert!(!state.is_prompt_enabled("admin_prompt"));
377 assert!(state.is_prompt_enabled("user_prompt"));
378 }
379
380 #[test]
381 fn test_disabled_sets_return_empty_by_default() {
382 let state = SessionState::new();
383
384 assert!(state.disabled_tools().is_empty());
385 assert!(state.disabled_resources().is_empty());
386 assert!(state.disabled_prompts().is_empty());
387 }
388
389 #[test]
390 fn test_disabled_tools_returns_set() {
391 let state = SessionState::new();
392
393 let mut disabled: std::collections::HashSet<String> = std::collections::HashSet::new();
394 disabled.insert("tool1".to_string());
395 disabled.insert("tool2".to_string());
396 state.set(super::DISABLED_TOOLS_KEY, disabled);
397
398 let result = state.disabled_tools();
399 assert_eq!(result.len(), 2);
400 assert!(result.contains("tool1"));
401 assert!(result.contains("tool2"));
402 }
403}