1use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, RwLock};
8
9use dashmap::DashMap;
10use serde_json::Value;
11use uuid::Uuid;
12
13use crate::protocol::capabilities::ClientCapabilities;
14use crate::protocol::types::Implementation;
15use crate::registry::prompts::Prompt;
16use crate::registry::resources::Resource;
17use crate::registry::tools::Tool;
18use crate::server::profile::SessionProfile;
19
20pub type SessionState = Arc<RwLock<HashMap<String, Value>>>;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SessionLifecycle {
26 Created,
28 Ready,
30 Degraded,
32 Closed,
34}
35
36impl SessionLifecycle {
37 pub fn can_accept_requests(&self) -> bool {
39 matches!(self, Self::Ready | Self::Degraded)
40 }
41
42 pub fn is_healthy(&self) -> bool {
44 matches!(self, Self::Ready)
45 }
46}
47
48#[derive(Clone)]
50pub struct Session {
51 pub id: String,
53
54 pub client_info: Option<Implementation>,
56
57 pub capabilities: Option<ClientCapabilities>,
59
60 pub protocol_version: Option<String>,
62
63 lifecycle: Arc<RwLock<SessionLifecycle>>,
65
66 error_count: Arc<RwLock<u32>>,
68
69 state: SessionState,
71
72 tool_overrides: Arc<DashMap<String, Arc<dyn Tool>>>,
75 tool_extras: Arc<DashMap<String, Arc<dyn Tool>>>,
77 tool_hidden: Arc<RwLock<HashSet<String>>>,
79 tool_aliases: Arc<RwLock<HashMap<String, String>>>,
81
82 resource_overrides: Arc<DashMap<String, Arc<dyn Resource>>>,
85 resource_extras: Arc<DashMap<String, Arc<dyn Resource>>>,
87 resource_hidden: Arc<RwLock<HashSet<String>>>,
89
90 prompt_overrides: Arc<DashMap<String, Arc<dyn Prompt>>>,
93 prompt_extras: Arc<DashMap<String, Arc<dyn Prompt>>>,
95 prompt_hidden: Arc<RwLock<HashSet<String>>>,
97}
98
99impl Session {
100 pub fn new() -> Self {
102 Self {
103 id: Uuid::new_v4().to_string(),
104 client_info: None,
105 capabilities: None,
106 protocol_version: None,
107 lifecycle: Arc::new(RwLock::new(SessionLifecycle::Created)),
108 error_count: Arc::new(RwLock::new(0)),
109 state: Arc::new(RwLock::new(HashMap::new())),
110 tool_overrides: Arc::new(DashMap::new()),
112 tool_extras: Arc::new(DashMap::new()),
113 tool_hidden: Arc::new(RwLock::new(HashSet::new())),
114 tool_aliases: Arc::new(RwLock::new(HashMap::new())),
115 resource_overrides: Arc::new(DashMap::new()),
117 resource_extras: Arc::new(DashMap::new()),
118 resource_hidden: Arc::new(RwLock::new(HashSet::new())),
119 prompt_overrides: Arc::new(DashMap::new()),
121 prompt_extras: Arc::new(DashMap::new()),
122 prompt_hidden: Arc::new(RwLock::new(HashSet::new())),
123 }
124 }
125
126 pub fn with_id(id: impl Into<String>) -> Self {
128 Self {
129 id: id.into(),
130 client_info: None,
131 capabilities: None,
132 protocol_version: None,
133 lifecycle: Arc::new(RwLock::new(SessionLifecycle::Created)),
134 error_count: Arc::new(RwLock::new(0)),
135 state: Arc::new(RwLock::new(HashMap::new())),
136 tool_overrides: Arc::new(DashMap::new()),
138 tool_extras: Arc::new(DashMap::new()),
139 tool_hidden: Arc::new(RwLock::new(HashSet::new())),
140 tool_aliases: Arc::new(RwLock::new(HashMap::new())),
141 resource_overrides: Arc::new(DashMap::new()),
143 resource_extras: Arc::new(DashMap::new()),
144 resource_hidden: Arc::new(RwLock::new(HashSet::new())),
145 prompt_overrides: Arc::new(DashMap::new()),
147 prompt_extras: Arc::new(DashMap::new()),
148 prompt_hidden: Arc::new(RwLock::new(HashSet::new())),
149 }
150 }
151
152 pub fn initialize(
155 &mut self,
156 client_info: Implementation,
157 capabilities: ClientCapabilities,
158 protocol_version: String,
159 ) {
160 self.client_info = Some(client_info);
161 self.capabilities = Some(capabilities);
162 self.protocol_version = Some(protocol_version);
163 *self.lifecycle.write().unwrap() = SessionLifecycle::Ready;
164 }
165
166 pub fn is_initialized(&self) -> bool {
168 self.lifecycle.read().unwrap().can_accept_requests()
169 }
170
171 pub fn protocol_version(&self) -> Option<&str> {
173 self.protocol_version.as_deref()
174 }
175
176 pub fn record_error(&mut self) {
179 if let Ok(mut count) = self.error_count.write() {
180 *count += 1;
181 if *count >= 3 && *self.lifecycle.read().unwrap() == SessionLifecycle::Ready {
183 *self.lifecycle.write().unwrap() = SessionLifecycle::Degraded;
184 }
185 }
186 }
187
188 pub fn record_success(&mut self) {
191 if let Ok(mut count) = self.error_count.write() {
192 *count = 0;
193 if *self.lifecycle.read().unwrap() == SessionLifecycle::Degraded {
194 *self.lifecycle.write().unwrap() = SessionLifecycle::Ready;
195 }
196 }
197 }
198
199 pub fn close(&mut self) {
202 *self.lifecycle.write().unwrap() = SessionLifecycle::Closed;
203 }
204
205 pub fn lifecycle(&self) -> SessionLifecycle {
207 *self.lifecycle.read().unwrap()
208 }
209
210 pub fn error_count(&self) -> u32 {
212 self.error_count.read().map(|c| *c).unwrap_or(0)
213 }
214
215 pub fn get_state(&self, key: &str) -> Option<Value> {
217 self.state.read().ok()?.get(key).cloned()
218 }
219
220 pub fn set_state(&self, key: impl Into<String>, value: Value) {
222 if let Ok(mut state) = self.state.write() {
223 state.insert(key.into(), value);
224 }
225 }
226
227 pub fn remove_state(&self, key: &str) -> Option<Value> {
229 self.state.write().ok()?.remove(key)
230 }
231
232 pub fn clear_state(&self) {
234 if let Ok(mut state) = self.state.write() {
235 state.clear();
236 }
237 }
238
239 pub fn state_keys(&self) -> Vec<String> {
241 self.state
242 .read()
243 .ok()
244 .map(|state| state.keys().cloned().collect())
245 .unwrap_or_default()
246 }
247
248 pub fn add_tool(&self, tool: Arc<dyn Tool>) {
252 let name = tool.name().to_string();
253 self.tool_extras.insert(name, tool);
254 }
255
256 pub fn override_tool(&self, name: impl Into<String>, tool: Arc<dyn Tool>) {
258 self.tool_overrides.insert(name.into(), tool);
259 }
260
261 pub fn hide_tool(&self, name: impl Into<String>) {
263 if let Ok(mut hidden) = self.tool_hidden.write() {
264 hidden.insert(name.into());
265 }
266 }
267
268 pub fn unhide_tool(&self, name: &str) {
270 if let Ok(mut hidden) = self.tool_hidden.write() {
271 hidden.remove(name);
272 }
273 }
274
275 pub fn alias_tool(&self, alias: impl Into<String>, target: impl Into<String>) {
277 if let Ok(mut aliases) = self.tool_aliases.write() {
278 aliases.insert(alias.into(), target.into());
279 }
280 }
281
282 pub fn remove_tool_alias(&self, alias: &str) {
284 if let Ok(mut aliases) = self.tool_aliases.write() {
285 aliases.remove(alias);
286 }
287 }
288
289 pub fn is_tool_hidden(&self, name: &str) -> bool {
291 self.tool_hidden
292 .read()
293 .map(|hidden| hidden.contains(name))
294 .unwrap_or(false)
295 }
296
297 pub fn resolve_tool_alias<'a>(&self, name: &'a str) -> std::borrow::Cow<'a, str> {
299 self.tool_aliases
300 .read()
301 .ok()
302 .and_then(|aliases| aliases.get(name).cloned())
303 .map(std::borrow::Cow::Owned)
304 .unwrap_or(std::borrow::Cow::Borrowed(name))
305 }
306
307 pub fn get_tool_override(&self, name: &str) -> Option<Arc<dyn Tool>> {
309 self.tool_overrides.get(name).map(|r| Arc::clone(&r))
310 }
311
312 pub fn get_tool_extra(&self, name: &str) -> Option<Arc<dyn Tool>> {
314 self.tool_extras.get(name).map(|r| Arc::clone(&r))
315 }
316
317 pub fn tool_overrides(&self) -> &Arc<DashMap<String, Arc<dyn Tool>>> {
319 &self.tool_overrides
320 }
321
322 pub fn tool_extras(&self) -> &Arc<DashMap<String, Arc<dyn Tool>>> {
324 &self.tool_extras
325 }
326
327 pub fn add_resource(&self, resource: Arc<dyn Resource>) {
331 let uri = resource.uri().to_string();
332 self.resource_extras.insert(uri, resource);
333 }
334
335 pub fn override_resource(&self, uri: impl Into<String>, resource: Arc<dyn Resource>) {
337 self.resource_overrides.insert(uri.into(), resource);
338 }
339
340 pub fn hide_resource(&self, uri: impl Into<String>) {
342 if let Ok(mut hidden) = self.resource_hidden.write() {
343 hidden.insert(uri.into());
344 }
345 }
346
347 pub fn unhide_resource(&self, uri: &str) {
349 if let Ok(mut hidden) = self.resource_hidden.write() {
350 hidden.remove(uri);
351 }
352 }
353
354 pub fn is_resource_hidden(&self, uri: &str) -> bool {
356 self.resource_hidden
357 .read()
358 .map(|hidden| hidden.contains(uri))
359 .unwrap_or(false)
360 }
361
362 pub fn get_resource_override(&self, uri: &str) -> Option<Arc<dyn Resource>> {
364 self.resource_overrides.get(uri).map(|r| Arc::clone(&r))
365 }
366
367 pub fn get_resource_extra(&self, uri: &str) -> Option<Arc<dyn Resource>> {
369 self.resource_extras.get(uri).map(|r| Arc::clone(&r))
370 }
371
372 pub fn resource_overrides(&self) -> &Arc<DashMap<String, Arc<dyn Resource>>> {
374 &self.resource_overrides
375 }
376
377 pub fn resource_extras(&self) -> &Arc<DashMap<String, Arc<dyn Resource>>> {
379 &self.resource_extras
380 }
381
382 pub fn add_prompt(&self, prompt: Arc<dyn Prompt>) {
386 let name = prompt.name().to_string();
387 self.prompt_extras.insert(name, prompt);
388 }
389
390 pub fn override_prompt(&self, name: impl Into<String>, prompt: Arc<dyn Prompt>) {
392 self.prompt_overrides.insert(name.into(), prompt);
393 }
394
395 pub fn hide_prompt(&self, name: impl Into<String>) {
397 if let Ok(mut hidden) = self.prompt_hidden.write() {
398 hidden.insert(name.into());
399 }
400 }
401
402 pub fn unhide_prompt(&self, name: &str) {
404 if let Ok(mut hidden) = self.prompt_hidden.write() {
405 hidden.remove(name);
406 }
407 }
408
409 pub fn is_prompt_hidden(&self, name: &str) -> bool {
411 self.prompt_hidden
412 .read()
413 .map(|hidden| hidden.contains(name))
414 .unwrap_or(false)
415 }
416
417 pub fn get_prompt_override(&self, name: &str) -> Option<Arc<dyn Prompt>> {
419 self.prompt_overrides.get(name).map(|r| Arc::clone(&r))
420 }
421
422 pub fn get_prompt_extra(&self, name: &str) -> Option<Arc<dyn Prompt>> {
424 self.prompt_extras.get(name).map(|r| Arc::clone(&r))
425 }
426
427 pub fn prompt_overrides(&self) -> &Arc<DashMap<String, Arc<dyn Prompt>>> {
429 &self.prompt_overrides
430 }
431
432 pub fn prompt_extras(&self) -> &Arc<DashMap<String, Arc<dyn Prompt>>> {
434 &self.prompt_extras
435 }
436
437 pub fn apply_profile(&self, profile: &SessionProfile) {
443 for tool in &profile.tool_extras {
445 self.add_tool(Arc::clone(tool));
446 }
447 for (name, tool) in &profile.tool_overrides {
448 self.override_tool(name.clone(), Arc::clone(tool));
449 }
450 for name in &profile.tool_hidden {
451 self.hide_tool(name.clone());
452 }
453 for (alias, target) in &profile.tool_aliases {
454 self.alias_tool(alias.clone(), target.clone());
455 }
456
457 for resource in &profile.resource_extras {
459 self.add_resource(Arc::clone(resource));
460 }
461 for (uri, resource) in &profile.resource_overrides {
462 self.override_resource(uri.clone(), Arc::clone(resource));
463 }
464 for uri in &profile.resource_hidden {
465 self.hide_resource(uri.clone());
466 }
467
468 for prompt in &profile.prompt_extras {
470 self.add_prompt(Arc::clone(prompt));
471 }
472 for (name, prompt) in &profile.prompt_overrides {
473 self.override_prompt(name.clone(), Arc::clone(prompt));
474 }
475 for name in &profile.prompt_hidden {
476 self.hide_prompt(name.clone());
477 }
478 }
479
480 pub fn clear_customizations(&self) {
482 self.tool_overrides.clear();
484 self.tool_extras.clear();
485 if let Ok(mut hidden) = self.tool_hidden.write() {
486 hidden.clear();
487 }
488 if let Ok(mut aliases) = self.tool_aliases.write() {
489 aliases.clear();
490 }
491
492 self.resource_overrides.clear();
494 self.resource_extras.clear();
495 if let Ok(mut hidden) = self.resource_hidden.write() {
496 hidden.clear();
497 }
498
499 self.prompt_overrides.clear();
501 self.prompt_extras.clear();
502 if let Ok(mut hidden) = self.prompt_hidden.write() {
503 hidden.clear();
504 }
505 }
506}
507
508impl std::fmt::Debug for Session {
510 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511 f.debug_struct("Session")
512 .field("id", &self.id)
513 .field("client_info", &self.client_info)
514 .field("capabilities", &self.capabilities)
515 .field("lifecycle", &self.lifecycle)
516 .field("error_count", &self.error_count())
517 .field("tool_overrides_count", &self.tool_overrides.len())
518 .field("tool_extras_count", &self.tool_extras.len())
519 .field("resource_overrides_count", &self.resource_overrides.len())
520 .field("resource_extras_count", &self.resource_extras.len())
521 .field("prompt_overrides_count", &self.prompt_overrides.len())
522 .field("prompt_extras_count", &self.prompt_extras.len())
523 .finish()
524 }
525}
526
527impl Default for Session {
528 fn default() -> Self {
529 Self::new()
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn test_session_creation() {
539 let session = Session::new();
540 assert!(!session.id.is_empty());
541 assert_eq!(session.lifecycle(), SessionLifecycle::Created);
542 assert!(!session.is_initialized());
543 assert!(session.client_info.is_none());
544 assert!(session.capabilities.is_none());
545 }
546
547 #[test]
548 fn test_session_with_id() {
549 let session = Session::with_id("test-session");
550 assert_eq!(session.id, "test-session");
551 assert_eq!(session.lifecycle(), SessionLifecycle::Created);
552 }
553
554 #[test]
555 fn test_session_initialization() {
556 let mut session = Session::new();
557 let client_info = Implementation {
558 name: "test-client".to_string(),
559 version: "1.0.0".to_string(),
560 };
561 let capabilities = ClientCapabilities::default();
562
563 session.initialize(client_info.clone(), capabilities, "2025-06-18".to_string());
564
565 assert!(session.is_initialized());
566 assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
567 assert_eq!(session.client_info.unwrap().name, "test-client");
568 }
569
570 #[test]
571 fn test_session_lifecycle_transitions() {
572 let mut session = Session::new();
573 assert_eq!(session.lifecycle(), SessionLifecycle::Created);
574 assert!(!session.lifecycle().can_accept_requests());
575
576 session.initialize(
578 Implementation {
579 name: "test".into(),
580 version: "1.0".into(),
581 },
582 ClientCapabilities::default(),
583 "2025-06-18".to_string(),
584 );
585 assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
586 assert!(session.lifecycle().can_accept_requests());
587 assert!(session.lifecycle().is_healthy());
588
589 session.record_error();
591 assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
592 session.record_error();
593 assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
594 session.record_error();
595 assert_eq!(session.lifecycle(), SessionLifecycle::Degraded);
596 assert!(session.lifecycle().can_accept_requests());
597 assert!(!session.lifecycle().is_healthy());
598
599 session.record_success();
601 assert_eq!(session.lifecycle(), SessionLifecycle::Ready);
602 assert_eq!(session.error_count(), 0);
603
604 session.close();
606 assert_eq!(session.lifecycle(), SessionLifecycle::Closed);
607 assert!(!session.lifecycle().can_accept_requests());
608 }
609
610 #[test]
611 fn test_session_state() {
612 let session = Session::new();
613
614 session.set_state("key1", Value::String("value1".to_string()));
616 session.set_state("key2", Value::Number(42.into()));
617
618 assert_eq!(
620 session.get_state("key1"),
621 Some(Value::String("value1".to_string()))
622 );
623 assert_eq!(session.get_state("key2"), Some(Value::Number(42.into())));
624 assert_eq!(session.get_state("nonexistent"), None);
625
626 let keys = session.state_keys();
628 assert_eq!(keys.len(), 2);
629 assert!(keys.contains(&"key1".to_string()));
630 assert!(keys.contains(&"key2".to_string()));
631
632 let removed = session.remove_state("key1");
634 assert_eq!(removed, Some(Value::String("value1".to_string())));
635 assert_eq!(session.get_state("key1"), None);
636
637 session.clear_state();
639 assert_eq!(session.state_keys().len(), 0);
640 }
641
642 #[test]
643 fn test_session_clone() {
644 let session1 = Session::with_id("test");
645 session1.set_state("shared", Value::Bool(true));
646
647 let session2 = session1.clone();
648
649 assert_eq!(session1.id, session2.id);
651 assert_eq!(session2.get_state("shared"), Some(Value::Bool(true)));
652
653 session2.set_state("shared", Value::Bool(false));
655 assert_eq!(session1.get_state("shared"), Some(Value::Bool(false)));
656 }
657}