mcp_host/server/
session.rs

1//! Session management
2//!
3//! Tracks individual MCP client connections with state storage, lifecycle state machine,
4//! and per-session tool/resource/prompt customization.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, RwLock};
8
9use dashmap::DashMap;
10use serde_json::Value;
11use uuid::Uuid;
12
13use crate::protocol::capabilities::ClientCapabilities;
14use crate::protocol::types::Implementation;
15use crate::registry::prompts::Prompt;
16use crate::registry::resources::Resource;
17use crate::registry::tools::Tool;
18use crate::server::profile::SessionProfile;
19
20/// Session state storage (thread-safe)
21pub type SessionState = Arc<RwLock<HashMap<String, Value>>>;
22
23/// Session lifecycle states
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SessionLifecycle {
26    /// Session created but not initialized
27    Created,
28    /// Session initialized and ready for requests
29    Ready,
30    /// Session encountered errors but still operational
31    Degraded,
32    /// Session closed
33    Closed,
34}
35
36impl SessionLifecycle {
37    /// Check if session can accept requests
38    pub fn can_accept_requests(&self) -> bool {
39        matches!(self, Self::Ready | Self::Degraded)
40    }
41
42    /// Check if session is healthy
43    pub fn is_healthy(&self) -> bool {
44        matches!(self, Self::Ready)
45    }
46}
47
48/// Session represents a single MCP client connection
49#[derive(Clone)]
50pub struct Session {
51    /// Unique session ID
52    pub id: String,
53
54    /// Client information (set during initialization)
55    pub client_info: Option<Implementation>,
56
57    /// Client capabilities (set during initialization)
58    pub capabilities: Option<ClientCapabilities>,
59
60    /// Negotiated protocol version (set during initialization)
61    pub protocol_version: Option<String>,
62
63    /// Session lifecycle state (Arc-wrapped for shared state across clones)
64    lifecycle: Arc<RwLock<SessionLifecycle>>,
65
66    /// Error count for degraded state transitions
67    error_count: Arc<RwLock<u32>>,
68
69    /// Per-session state storage
70    state: SessionState,
71
72    // Tool customization
73    /// Tool overrides (same name, different implementation)
74    tool_overrides: Arc<DashMap<String, Arc<dyn Tool>>>,
75    /// Extra tools added to this session
76    tool_extras: Arc<DashMap<String, Arc<dyn Tool>>>,
77    /// Tools hidden from this session
78    tool_hidden: Arc<RwLock<HashSet<String>>>,
79    /// Tool aliases (alias -> actual name)
80    tool_aliases: Arc<RwLock<HashMap<String, String>>>,
81
82    // Resource customization
83    /// Resource overrides (same URI, different implementation)
84    resource_overrides: Arc<DashMap<String, Arc<dyn Resource>>>,
85    /// Extra resources added to this session
86    resource_extras: Arc<DashMap<String, Arc<dyn Resource>>>,
87    /// Resources hidden from this session
88    resource_hidden: Arc<RwLock<HashSet<String>>>,
89
90    // Prompt customization
91    /// Prompt overrides (same name, different implementation)
92    prompt_overrides: Arc<DashMap<String, Arc<dyn Prompt>>>,
93    /// Extra prompts added to this session
94    prompt_extras: Arc<DashMap<String, Arc<dyn Prompt>>>,
95    /// Prompts hidden from this session
96    prompt_hidden: Arc<RwLock<HashSet<String>>>,
97}
98
99impl Session {
100    /// Create new session with random UUID
101    pub fn new() -> Self {
102        Self {
103            id: Uuid::new_v4().to_string(),
104            client_info: None,
105            capabilities: None,
106            protocol_version: None,
107            lifecycle: Arc::new(RwLock::new(SessionLifecycle::Created)),
108            error_count: Arc::new(RwLock::new(0)),
109            state: Arc::new(RwLock::new(HashMap::new())),
110            // Tool customization
111            tool_overrides: Arc::new(DashMap::new()),
112            tool_extras: Arc::new(DashMap::new()),
113            tool_hidden: Arc::new(RwLock::new(HashSet::new())),
114            tool_aliases: Arc::new(RwLock::new(HashMap::new())),
115            // Resource customization
116            resource_overrides: Arc::new(DashMap::new()),
117            resource_extras: Arc::new(DashMap::new()),
118            resource_hidden: Arc::new(RwLock::new(HashSet::new())),
119            // Prompt customization
120            prompt_overrides: Arc::new(DashMap::new()),
121            prompt_extras: Arc::new(DashMap::new()),
122            prompt_hidden: Arc::new(RwLock::new(HashSet::new())),
123        }
124    }
125
126    /// Create session with specific ID
127    pub fn with_id(id: impl Into<String>) -> Self {
128        Self {
129            id: id.into(),
130            client_info: None,
131            capabilities: None,
132            protocol_version: None,
133            lifecycle: Arc::new(RwLock::new(SessionLifecycle::Created)),
134            error_count: Arc::new(RwLock::new(0)),
135            state: Arc::new(RwLock::new(HashMap::new())),
136            // Tool customization
137            tool_overrides: Arc::new(DashMap::new()),
138            tool_extras: Arc::new(DashMap::new()),
139            tool_hidden: Arc::new(RwLock::new(HashSet::new())),
140            tool_aliases: Arc::new(RwLock::new(HashMap::new())),
141            // Resource customization
142            resource_overrides: Arc::new(DashMap::new()),
143            resource_extras: Arc::new(DashMap::new()),
144            resource_hidden: Arc::new(RwLock::new(HashSet::new())),
145            // Prompt customization
146            prompt_overrides: Arc::new(DashMap::new()),
147            prompt_extras: Arc::new(DashMap::new()),
148            prompt_hidden: Arc::new(RwLock::new(HashSet::new())),
149        }
150    }
151
152    /// Initialize session with client info and capabilities
153    /// Transitions: Created -> Ready
154    pub fn initialize(
155        &mut self,
156        client_info: Implementation,
157        capabilities: ClientCapabilities,
158        protocol_version: String,
159    ) {
160        self.client_info = Some(client_info);
161        self.capabilities = Some(capabilities);
162        self.protocol_version = Some(protocol_version);
163        *self.lifecycle.write().unwrap() = SessionLifecycle::Ready;
164    }
165
166    /// Check if session is initialized (Ready or Degraded)
167    pub fn is_initialized(&self) -> bool {
168        self.lifecycle.read().unwrap().can_accept_requests()
169    }
170
171    /// Get the negotiated protocol version
172    pub fn protocol_version(&self) -> Option<&str> {
173        self.protocol_version.as_deref()
174    }
175
176    /// Record an error, potentially transitioning to Degraded state
177    /// Transitions: Ready -> Degraded (after threshold errors)
178    pub fn record_error(&mut self) {
179        if let Ok(mut count) = self.error_count.write() {
180            *count += 1;
181            // Transition to Degraded after 3 consecutive errors
182            if *count >= 3 && *self.lifecycle.read().unwrap() == SessionLifecycle::Ready {
183                *self.lifecycle.write().unwrap() = SessionLifecycle::Degraded;
184            }
185        }
186    }
187
188    /// Record success, potentially recovering from Degraded state
189    /// Transitions: Degraded -> Ready (resets error count)
190    pub fn record_success(&mut self) {
191        if let Ok(mut count) = self.error_count.write() {
192            *count = 0;
193            if *self.lifecycle.read().unwrap() == SessionLifecycle::Degraded {
194                *self.lifecycle.write().unwrap() = SessionLifecycle::Ready;
195            }
196        }
197    }
198
199    /// Close the session
200    /// Transitions: Any -> Closed
201    pub fn close(&mut self) {
202        *self.lifecycle.write().unwrap() = SessionLifecycle::Closed;
203    }
204
205    /// Get current lifecycle state
206    pub fn lifecycle(&self) -> SessionLifecycle {
207        *self.lifecycle.read().unwrap()
208    }
209
210    /// Get current error count
211    pub fn error_count(&self) -> u32 {
212        self.error_count.read().map(|c| *c).unwrap_or(0)
213    }
214
215    /// Get state value
216    pub fn get_state(&self, key: &str) -> Option<Value> {
217        self.state.read().ok()?.get(key).cloned()
218    }
219
220    /// Set state value
221    pub fn set_state(&self, key: impl Into<String>, value: Value) {
222        if let Ok(mut state) = self.state.write() {
223            state.insert(key.into(), value);
224        }
225    }
226
227    /// Remove state value
228    pub fn remove_state(&self, key: &str) -> Option<Value> {
229        self.state.write().ok()?.remove(key)
230    }
231
232    /// Clear all state
233    pub fn clear_state(&self) {
234        if let Ok(mut state) = self.state.write() {
235            state.clear();
236        }
237    }
238
239    /// Get all state keys
240    pub fn state_keys(&self) -> Vec<String> {
241        self.state
242            .read()
243            .ok()
244            .map(|state| state.keys().cloned().collect())
245            .unwrap_or_default()
246    }
247
248    // ==================== Tool Management ====================
249
250    /// Add an extra tool to this session
251    pub fn add_tool(&self, tool: Arc<dyn Tool>) {
252        let name = tool.name().to_string();
253        self.tool_extras.insert(name, tool);
254    }
255
256    /// Override a tool (same name, different implementation)
257    pub fn override_tool(&self, name: impl Into<String>, tool: Arc<dyn Tool>) {
258        self.tool_overrides.insert(name.into(), tool);
259    }
260
261    /// Hide a tool from this session
262    pub fn hide_tool(&self, name: impl Into<String>) {
263        if let Ok(mut hidden) = self.tool_hidden.write() {
264            hidden.insert(name.into());
265        }
266    }
267
268    /// Unhide a tool
269    pub fn unhide_tool(&self, name: &str) {
270        if let Ok(mut hidden) = self.tool_hidden.write() {
271            hidden.remove(name);
272        }
273    }
274
275    /// Add a tool alias (alias -> target)
276    pub fn alias_tool(&self, alias: impl Into<String>, target: impl Into<String>) {
277        if let Ok(mut aliases) = self.tool_aliases.write() {
278            aliases.insert(alias.into(), target.into());
279        }
280    }
281
282    /// Remove a tool alias
283    pub fn remove_tool_alias(&self, alias: &str) {
284        if let Ok(mut aliases) = self.tool_aliases.write() {
285            aliases.remove(alias);
286        }
287    }
288
289    /// Check if a tool is hidden
290    pub fn is_tool_hidden(&self, name: &str) -> bool {
291        self.tool_hidden
292            .read()
293            .map(|hidden| hidden.contains(name))
294            .unwrap_or(false)
295    }
296
297    /// Resolve a tool alias to its target name
298    pub fn resolve_tool_alias<'a>(&self, name: &'a str) -> std::borrow::Cow<'a, str> {
299        self.tool_aliases
300            .read()
301            .ok()
302            .and_then(|aliases| aliases.get(name).cloned())
303            .map(std::borrow::Cow::Owned)
304            .unwrap_or(std::borrow::Cow::Borrowed(name))
305    }
306
307    /// Get a tool override for this session
308    pub fn get_tool_override(&self, name: &str) -> Option<Arc<dyn Tool>> {
309        self.tool_overrides.get(name).map(|r| Arc::clone(&r))
310    }
311
312    /// Get an extra tool added to this session
313    pub fn get_tool_extra(&self, name: &str) -> Option<Arc<dyn Tool>> {
314        self.tool_extras.get(name).map(|r| Arc::clone(&r))
315    }
316
317    /// Get all tool overrides
318    pub fn tool_overrides(&self) -> &Arc<DashMap<String, Arc<dyn Tool>>> {
319        &self.tool_overrides
320    }
321
322    /// Get all extra tools
323    pub fn tool_extras(&self) -> &Arc<DashMap<String, Arc<dyn Tool>>> {
324        &self.tool_extras
325    }
326
327    // ==================== Resource Management ====================
328
329    /// Add an extra resource to this session
330    pub fn add_resource(&self, resource: Arc<dyn Resource>) {
331        let uri = resource.uri().to_string();
332        self.resource_extras.insert(uri, resource);
333    }
334
335    /// Override a resource (same URI, different implementation)
336    pub fn override_resource(&self, uri: impl Into<String>, resource: Arc<dyn Resource>) {
337        self.resource_overrides.insert(uri.into(), resource);
338    }
339
340    /// Hide a resource from this session
341    pub fn hide_resource(&self, uri: impl Into<String>) {
342        if let Ok(mut hidden) = self.resource_hidden.write() {
343            hidden.insert(uri.into());
344        }
345    }
346
347    /// Unhide a resource
348    pub fn unhide_resource(&self, uri: &str) {
349        if let Ok(mut hidden) = self.resource_hidden.write() {
350            hidden.remove(uri);
351        }
352    }
353
354    /// Check if a resource is hidden
355    pub fn is_resource_hidden(&self, uri: &str) -> bool {
356        self.resource_hidden
357            .read()
358            .map(|hidden| hidden.contains(uri))
359            .unwrap_or(false)
360    }
361
362    /// Get a resource override for this session
363    pub fn get_resource_override(&self, uri: &str) -> Option<Arc<dyn Resource>> {
364        self.resource_overrides.get(uri).map(|r| Arc::clone(&r))
365    }
366
367    /// Get an extra resource added to this session
368    pub fn get_resource_extra(&self, uri: &str) -> Option<Arc<dyn Resource>> {
369        self.resource_extras.get(uri).map(|r| Arc::clone(&r))
370    }
371
372    /// Get all resource overrides
373    pub fn resource_overrides(&self) -> &Arc<DashMap<String, Arc<dyn Resource>>> {
374        &self.resource_overrides
375    }
376
377    /// Get all extra resources
378    pub fn resource_extras(&self) -> &Arc<DashMap<String, Arc<dyn Resource>>> {
379        &self.resource_extras
380    }
381
382    // ==================== Prompt Management ====================
383
384    /// Add an extra prompt to this session
385    pub fn add_prompt(&self, prompt: Arc<dyn Prompt>) {
386        let name = prompt.name().to_string();
387        self.prompt_extras.insert(name, prompt);
388    }
389
390    /// Override a prompt (same name, different implementation)
391    pub fn override_prompt(&self, name: impl Into<String>, prompt: Arc<dyn Prompt>) {
392        self.prompt_overrides.insert(name.into(), prompt);
393    }
394
395    /// Hide a prompt from this session
396    pub fn hide_prompt(&self, name: impl Into<String>) {
397        if let Ok(mut hidden) = self.prompt_hidden.write() {
398            hidden.insert(name.into());
399        }
400    }
401
402    /// Unhide a prompt
403    pub fn unhide_prompt(&self, name: &str) {
404        if let Ok(mut hidden) = self.prompt_hidden.write() {
405            hidden.remove(name);
406        }
407    }
408
409    /// Check if a prompt is hidden
410    pub fn is_prompt_hidden(&self, name: &str) -> bool {
411        self.prompt_hidden
412            .read()
413            .map(|hidden| hidden.contains(name))
414            .unwrap_or(false)
415    }
416
417    /// Get a prompt override for this session
418    pub fn get_prompt_override(&self, name: &str) -> Option<Arc<dyn Prompt>> {
419        self.prompt_overrides.get(name).map(|r| Arc::clone(&r))
420    }
421
422    /// Get an extra prompt added to this session
423    pub fn get_prompt_extra(&self, name: &str) -> Option<Arc<dyn Prompt>> {
424        self.prompt_extras.get(name).map(|r| Arc::clone(&r))
425    }
426
427    /// Get all prompt overrides
428    pub fn prompt_overrides(&self) -> &Arc<DashMap<String, Arc<dyn Prompt>>> {
429        &self.prompt_overrides
430    }
431
432    /// Get all extra prompts
433    pub fn prompt_extras(&self) -> &Arc<DashMap<String, Arc<dyn Prompt>>> {
434        &self.prompt_extras
435    }
436
437    // ==================== Profile Management ====================
438
439    /// Apply a session profile
440    ///
441    /// This adds/overrides/hides tools, resources, and prompts according to the profile
442    pub fn apply_profile(&self, profile: &SessionProfile) {
443        // Apply tool configuration
444        for tool in &profile.tool_extras {
445            self.add_tool(Arc::clone(tool));
446        }
447        for (name, tool) in &profile.tool_overrides {
448            self.override_tool(name.clone(), Arc::clone(tool));
449        }
450        for name in &profile.tool_hidden {
451            self.hide_tool(name.clone());
452        }
453        for (alias, target) in &profile.tool_aliases {
454            self.alias_tool(alias.clone(), target.clone());
455        }
456
457        // Apply resource configuration
458        for resource in &profile.resource_extras {
459            self.add_resource(Arc::clone(resource));
460        }
461        for (uri, resource) in &profile.resource_overrides {
462            self.override_resource(uri.clone(), Arc::clone(resource));
463        }
464        for uri in &profile.resource_hidden {
465            self.hide_resource(uri.clone());
466        }
467
468        // Apply prompt configuration
469        for prompt in &profile.prompt_extras {
470            self.add_prompt(Arc::clone(prompt));
471        }
472        for (name, prompt) in &profile.prompt_overrides {
473            self.override_prompt(name.clone(), Arc::clone(prompt));
474        }
475        for name in &profile.prompt_hidden {
476            self.hide_prompt(name.clone());
477        }
478    }
479
480    /// Clear all session customizations
481    pub fn clear_customizations(&self) {
482        // Clear tools
483        self.tool_overrides.clear();
484        self.tool_extras.clear();
485        if let Ok(mut hidden) = self.tool_hidden.write() {
486            hidden.clear();
487        }
488        if let Ok(mut aliases) = self.tool_aliases.write() {
489            aliases.clear();
490        }
491
492        // Clear resources
493        self.resource_overrides.clear();
494        self.resource_extras.clear();
495        if let Ok(mut hidden) = self.resource_hidden.write() {
496            hidden.clear();
497        }
498
499        // Clear prompts
500        self.prompt_overrides.clear();
501        self.prompt_extras.clear();
502        if let Ok(mut hidden) = self.prompt_hidden.write() {
503            hidden.clear();
504        }
505    }
506}
507
508// Implement Debug manually since trait objects don't implement Debug
509impl std::fmt::Debug for Session {
510    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511        f.debug_struct("Session")
512            .field("id", &self.id)
513            .field("client_info", &self.client_info)
514            .field("capabilities", &self.capabilities)
515            .field("lifecycle", &self.lifecycle)
516            .field("error_count", &self.error_count())
517            .field("tool_overrides_count", &self.tool_overrides.len())
518            .field("tool_extras_count", &self.tool_extras.len())
519            .field("resource_overrides_count", &self.resource_overrides.len())
520            .field("resource_extras_count", &self.resource_extras.len())
521            .field("prompt_overrides_count", &self.prompt_overrides.len())
522            .field("prompt_extras_count", &self.prompt_extras.len())
523            .finish()
524    }
525}
526
527impl Default for Session {
528    fn default() -> Self {
529        Self::new()
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn test_session_creation() {
539        let session = Session::new();
540        assert!(!session.id.is_empty());
541        assert_eq!(session.lifecycle(), SessionLifecycle::Created);
542        assert!(!session.is_initialized());
543        assert!(session.client_info.is_none());
544        assert!(session.capabilities.is_none());
545    }
546
547    #[test]
548    fn test_session_with_id() {
549        let session = Session::with_id("test-session");
550        assert_eq!(session.id, "test-session");
551        assert_eq!(session.lifecycle(), SessionLifecycle::Created);
552    }
553
554    #[test]
555    fn test_session_initialization() {
556        let mut session = Session::new();
557        let client_info = Implementation {
558            name: "test-client".to_string(),
559            version: "1.0.0".to_string(),
560        };
561        let capabilities = ClientCapabilities::default();
562
563        session.initialize(client_info.clone(), capabilities, "2025-06-18".to_string());
564
565        assert!(session.is_initialized());
566        assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
567        assert_eq!(session.client_info.unwrap().name, "test-client");
568    }
569
570    #[test]
571    fn test_session_lifecycle_transitions() {
572        let mut session = Session::new();
573        assert_eq!(session.lifecycle(), SessionLifecycle::Created);
574        assert!(!session.lifecycle().can_accept_requests());
575
576        // Initialize -> Ready
577        session.initialize(
578            Implementation {
579                name: "test".into(),
580                version: "1.0".into(),
581            },
582            ClientCapabilities::default(),
583            "2025-06-18".to_string(),
584        );
585        assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
586        assert!(session.lifecycle().can_accept_requests());
587        assert!(session.lifecycle().is_healthy());
588
589        // Record errors -> Degraded (after 3)
590        session.record_error();
591        assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
592        session.record_error();
593        assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
594        session.record_error();
595        assert_eq!(session.lifecycle(), SessionLifecycle::Degraded);
596        assert!(session.lifecycle().can_accept_requests());
597        assert!(!session.lifecycle().is_healthy());
598
599        // Record success -> Ready (recovery)
600        session.record_success();
601        assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
602        assert_eq!(session.error_count(), 0);
603
604        // Close -> Closed
605        session.close();
606        assert_eq!(session.lifecycle(), SessionLifecycle::Closed);
607        assert!(!session.lifecycle().can_accept_requests());
608    }
609
610    #[test]
611    fn test_session_state() {
612        let session = Session::new();
613
614        // Set state
615        session.set_state("key1", Value::String("value1".to_string()));
616        session.set_state("key2", Value::Number(42.into()));
617
618        // Get state
619        assert_eq!(
620            session.get_state("key1"),
621            Some(Value::String("value1".to_string()))
622        );
623        assert_eq!(session.get_state("key2"), Some(Value::Number(42.into())));
624        assert_eq!(session.get_state("nonexistent"), None);
625
626        // State keys
627        let keys = session.state_keys();
628        assert_eq!(keys.len(), 2);
629        assert!(keys.contains(&"key1".to_string()));
630        assert!(keys.contains(&"key2".to_string()));
631
632        // Remove state
633        let removed = session.remove_state("key1");
634        assert_eq!(removed, Some(Value::String("value1".to_string())));
635        assert_eq!(session.get_state("key1"), None);
636
637        // Clear state
638        session.clear_state();
639        assert_eq!(session.state_keys().len(), 0);
640    }
641
642    #[test]
643    fn test_session_clone() {
644        let session1 = Session::with_id("test");
645        session1.set_state("shared", Value::Bool(true));
646
647        let session2 = session1.clone();
648
649        // Both sessions share the same state storage
650        assert_eq!(session1.id, session2.id);
651        assert_eq!(session2.get_state("shared"), Some(Value::Bool(true)));
652
653        // Modifying state in one affects the other
654        session2.set_state("shared", Value::Bool(false));
655        assert_eq!(session1.get_state("shared"), Some(Value::Bool(false)));
656    }
657}