1use std::collections::{HashMap, VecDeque};
68use std::sync::Arc;
69use std::time::Duration as StdDuration;
70
71use chrono::{DateTime, Duration, Utc};
72use dashmap::DashMap;
73use parking_lot::RwLock;
74use serde::{Deserialize, Serialize};
75use tokio::time::{Interval, interval};
76
77use crate::context::{
78 ClientIdExtractor, ClientSession, CompletionContext, ElicitationContext, RequestInfo,
79};
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SessionConfig {
84 pub max_sessions: usize,
86 pub session_timeout: Duration,
88 pub max_request_history: usize,
90 pub max_requests_per_session: Option<usize>,
92 pub cleanup_interval: StdDuration,
94 pub enable_analytics: bool,
96}
97
98impl Default for SessionConfig {
99 fn default() -> Self {
100 Self {
101 max_sessions: 1000,
102 session_timeout: Duration::hours(24),
103 max_request_history: 1000,
104 max_requests_per_session: None,
105 cleanup_interval: StdDuration::from_secs(300), enable_analytics: true,
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct SessionAnalytics {
114 pub total_sessions: usize,
116 pub active_sessions: usize,
118 pub total_requests: usize,
120 pub successful_requests: usize,
122 pub failed_requests: usize,
124 pub avg_session_duration: Duration,
126 pub top_clients: Vec<(String, usize)>,
128 pub top_methods: Vec<(String, usize)>,
130 pub requests_per_minute: f64,
132}
133
134#[derive(Debug)]
136pub struct SessionManager {
137 config: SessionConfig,
139 sessions: Arc<DashMap<String, ClientSession>>,
141 client_extractor: Arc<ClientIdExtractor>,
143 request_history: Arc<RwLock<VecDeque<RequestInfo>>>,
145 session_history: Arc<RwLock<VecDeque<SessionEvent>>>,
147 cleanup_timer: Arc<RwLock<Option<Interval>>>,
149 stats: Arc<RwLock<SessionStats>>,
151 pending_elicitations: Arc<DashMap<String, Vec<ElicitationContext>>>,
153 active_completions: Arc<DashMap<String, Vec<CompletionContext>>>,
155}
156
157#[derive(Debug, Default)]
159struct SessionStats {
160 total_sessions: usize,
161 total_requests: usize,
162 successful_requests: usize,
163 failed_requests: usize,
164 total_session_duration: Duration,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct SessionEvent {
170 pub timestamp: DateTime<Utc>,
172 pub client_id: String,
174 pub event_type: SessionEventType,
176 pub metadata: HashMap<String, serde_json::Value>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub enum SessionEventType {
183 Created,
185 Authenticated,
187 Updated,
189 Expired,
191 Terminated,
193}
194
195impl SessionManager {
196 #[must_use]
198 pub fn new(config: SessionConfig) -> Self {
199 Self {
200 config,
201 sessions: Arc::new(DashMap::new()),
202 client_extractor: Arc::new(ClientIdExtractor::new()),
203 request_history: Arc::new(RwLock::new(VecDeque::new())),
204 session_history: Arc::new(RwLock::new(VecDeque::new())),
205 cleanup_timer: Arc::new(RwLock::new(None)),
206 stats: Arc::new(RwLock::new(SessionStats::default())),
207 pending_elicitations: Arc::new(DashMap::new()),
208 active_completions: Arc::new(DashMap::new()),
209 }
210 }
211
212 pub fn start(&self) {
214 let mut timer_guard = self.cleanup_timer.write();
215 if timer_guard.is_none() {
216 *timer_guard = Some(interval(self.config.cleanup_interval));
217 }
218 drop(timer_guard);
219
220 let sessions = self.sessions.clone();
222 let config = self.config.clone();
223 let session_history = self.session_history.clone();
224 let stats = self.stats.clone();
225 let pending_elicitations = self.pending_elicitations.clone();
226 let active_completions = self.active_completions.clone();
227
228 tokio::spawn(async move {
229 let mut timer = interval(config.cleanup_interval);
230 loop {
231 timer.tick().await;
232 Self::cleanup_expired_sessions(
233 &sessions,
234 &config,
235 &session_history,
236 &stats,
237 &pending_elicitations,
238 &active_completions,
239 );
240 }
241 });
242 }
243
244 #[must_use]
246 pub fn get_or_create_session(
247 &self,
248 client_id: String,
249 transport_type: String,
250 ) -> ClientSession {
251 self.sessions.get(&client_id).map_or_else(
252 || {
253 self.enforce_capacity();
255
256 let session = ClientSession::new(client_id.clone(), transport_type);
257 self.sessions.insert(client_id.clone(), session.clone());
258
259 let mut stats = self.stats.write();
261 stats.total_sessions += 1;
262 drop(stats);
263
264 self.record_session_event(client_id, SessionEventType::Created, HashMap::new());
265
266 session
267 },
268 |session| session.clone(),
269 )
270 }
271
272 pub fn update_client_activity(&self, client_id: &str) {
274 if let Some(mut session) = self.sessions.get_mut(client_id) {
275 session.update_activity();
276
277 if let Some(cap) = self.config.max_requests_per_session
279 && session.request_count > cap
280 {
281 drop(session);
284 let _ = self.terminate_session(client_id);
285 }
286 }
287 }
288
289 #[must_use]
291 pub fn authenticate_client(
292 &self,
293 client_id: &str,
294 client_name: Option<String>,
295 token: Option<String>,
296 ) -> bool {
297 if let Some(mut session) = self.sessions.get_mut(client_id) {
298 session.authenticate(client_name.clone());
299
300 if let Some(token) = token {
301 self.client_extractor
302 .register_token(token, client_id.to_string());
303 }
304
305 let mut metadata = HashMap::new();
306 if let Some(name) = client_name {
307 metadata.insert("client_name".to_string(), serde_json::json!(name));
308 }
309
310 self.record_session_event(
311 client_id.to_string(),
312 SessionEventType::Authenticated,
313 metadata,
314 );
315
316 return true;
317 }
318 false
319 }
320
321 pub fn record_request(&self, mut request_info: RequestInfo) {
323 if !self.config.enable_analytics {
324 return;
325 }
326
327 self.update_client_activity(&request_info.client_id);
329
330 let mut stats = self.stats.write();
332 stats.total_requests += 1;
333 if request_info.success {
334 stats.successful_requests += 1;
335 } else {
336 stats.failed_requests += 1;
337 }
338 drop(stats);
339
340 let mut history = self.request_history.write();
342 if history.len() >= self.config.max_request_history {
343 history.pop_front();
344 }
345
346 request_info.parameters = self.sanitize_parameters(request_info.parameters);
348 history.push_back(request_info);
349 }
350
351 #[must_use]
353 pub fn get_analytics(&self) -> SessionAnalytics {
354 let sessions = self.sessions.clone();
355
356 let active_sessions = sessions.len();
358
359 let total_duration = sessions
361 .iter()
362 .map(|entry| entry.session_duration())
363 .reduce(|acc, dur| acc + dur)
364 .unwrap_or_else(Duration::zero);
365
366 let avg_session_duration = if active_sessions > 0 {
367 total_duration / active_sessions as i32
368 } else {
369 Duration::zero()
370 };
371
372 let mut client_requests: HashMap<String, usize> = HashMap::new();
374 let mut method_requests: HashMap<String, usize> = HashMap::new();
375
376 let (recent_requests, top_clients, top_methods) = {
377 let history = self.request_history.read();
378 for request in history.iter() {
379 *client_requests
380 .entry(request.client_id.clone())
381 .or_insert(0) += 1;
382 *method_requests
383 .entry(request.method_name.clone())
384 .or_insert(0) += 1;
385 }
386
387 let mut top_clients: Vec<(String, usize)> = client_requests.into_iter().collect();
388 top_clients.sort_by(|a, b| b.1.cmp(&a.1));
389 top_clients.truncate(10);
390
391 let mut top_methods: Vec<(String, usize)> = method_requests.into_iter().collect();
392 top_methods.sort_by(|a, b| b.1.cmp(&a.1));
393 top_methods.truncate(10);
394
395 let one_hour_ago = Utc::now() - Duration::hours(1);
397 let recent_requests = history
398 .iter()
399 .filter(|req| req.timestamp > one_hour_ago)
400 .count();
401 drop(history);
402
403 (recent_requests, top_clients, top_methods)
404 };
405 let requests_per_minute = recent_requests as f64 / 60.0;
406
407 let stats = self.stats.read();
408 SessionAnalytics {
409 total_sessions: stats.total_sessions,
410 active_sessions,
411 total_requests: stats.total_requests,
412 successful_requests: stats.successful_requests,
413 failed_requests: stats.failed_requests,
414 avg_session_duration,
415 top_clients,
416 top_methods,
417 requests_per_minute,
418 }
419 }
420
421 #[must_use]
423 pub fn get_active_sessions(&self) -> Vec<ClientSession> {
424 self.sessions
425 .iter()
426 .map(|entry| entry.value().clone())
427 .collect()
428 }
429
430 #[must_use]
432 pub fn get_session(&self, client_id: &str) -> Option<ClientSession> {
433 self.sessions.get(client_id).map(|session| session.clone())
434 }
435
436 #[must_use]
438 pub fn client_extractor(&self) -> Arc<ClientIdExtractor> {
439 self.client_extractor.clone()
440 }
441
442 #[must_use]
444 pub fn terminate_session(&self, client_id: &str) -> bool {
445 if let Some((_, session)) = self.sessions.remove(client_id) {
446 let mut stats = self.stats.write();
447 stats.total_session_duration += session.session_duration();
448 drop(stats);
449
450 self.pending_elicitations.remove(client_id);
452 self.active_completions.remove(client_id);
453
454 self.record_session_event(
455 client_id.to_string(),
456 SessionEventType::Terminated,
457 HashMap::new(),
458 );
459
460 true
461 } else {
462 false
463 }
464 }
465
466 #[must_use]
468 pub fn get_request_history(&self, limit: Option<usize>) -> Vec<RequestInfo> {
469 let history = self.request_history.read();
470 let limit = limit.unwrap_or(100);
471
472 history.iter().rev().take(limit).cloned().collect()
473 }
474
475 #[must_use]
477 pub fn get_session_events(&self, limit: Option<usize>) -> Vec<SessionEvent> {
478 let events = self.session_history.read();
479 let limit = limit.unwrap_or(100);
480
481 events.iter().rev().take(limit).cloned().collect()
482 }
483
484 pub fn add_pending_elicitation(&self, client_id: String, elicitation: ElicitationContext) {
488 self.pending_elicitations
489 .entry(client_id)
490 .or_default()
491 .push(elicitation);
492 }
493
494 #[must_use]
496 pub fn get_pending_elicitations(&self, client_id: &str) -> Vec<ElicitationContext> {
497 self.pending_elicitations
498 .get(client_id)
499 .map(|entry| entry.clone())
500 .unwrap_or_default()
501 }
502
503 pub fn update_elicitation_state(
505 &self,
506 client_id: &str,
507 elicitation_id: &str,
508 state: crate::context::ElicitationState,
509 ) -> bool {
510 if let Some(mut elicitations) = self.pending_elicitations.get_mut(client_id) {
511 for elicitation in elicitations.iter_mut() {
512 if elicitation.elicitation_id == elicitation_id {
513 elicitation.set_state(state);
514 return true;
515 }
516 }
517 }
518 false
519 }
520
521 pub fn remove_completed_elicitations(&self, client_id: &str) {
523 if let Some(mut elicitations) = self.pending_elicitations.get_mut(client_id) {
524 elicitations.retain(|e| !e.is_complete());
525 }
526 }
527
528 pub fn clear_elicitations(&self, client_id: &str) {
530 self.pending_elicitations.remove(client_id);
531 }
532
533 pub fn add_active_completion(&self, client_id: String, completion: CompletionContext) {
537 self.active_completions
538 .entry(client_id)
539 .or_default()
540 .push(completion);
541 }
542
543 #[must_use]
545 pub fn get_active_completions(&self, client_id: &str) -> Vec<CompletionContext> {
546 self.active_completions
547 .get(client_id)
548 .map(|entry| entry.clone())
549 .unwrap_or_default()
550 }
551
552 pub fn remove_completion(&self, client_id: &str, completion_id: &str) -> bool {
554 if let Some(mut completions) = self.active_completions.get_mut(client_id) {
555 let original_len = completions.len();
556 completions.retain(|c| c.completion_id != completion_id);
557 return completions.len() < original_len;
558 }
559 false
560 }
561
562 pub fn clear_completions(&self, client_id: &str) {
564 self.active_completions.remove(client_id);
565 }
566
567 #[must_use]
569 pub fn get_enhanced_analytics(&self) -> SessionAnalytics {
570 let analytics = self.get_analytics();
571
572 let mut _total_elicitations = 0;
574 let mut _pending_elicitations = 0;
575 let mut _total_completions = 0;
576
577 for entry in self.pending_elicitations.iter() {
578 let elicitations = entry.value();
579 _total_elicitations += elicitations.len();
580 _pending_elicitations += elicitations.iter().filter(|e| !e.is_complete()).count();
581 }
582
583 for entry in self.active_completions.iter() {
584 _total_completions += entry.value().len();
585 }
586
587 analytics
591 }
592
593 fn cleanup_expired_sessions(
596 sessions: &Arc<DashMap<String, ClientSession>>,
597 config: &SessionConfig,
598 session_history: &Arc<RwLock<VecDeque<SessionEvent>>>,
599 stats: &Arc<RwLock<SessionStats>>,
600 pending_elicitations: &Arc<DashMap<String, Vec<ElicitationContext>>>,
601 active_completions: &Arc<DashMap<String, Vec<CompletionContext>>>,
602 ) {
603 let cutoff_time = Utc::now() - config.session_timeout;
604 let mut expired_sessions = Vec::new();
605
606 for entry in sessions.iter() {
607 if entry.last_activity < cutoff_time {
608 expired_sessions.push(entry.client_id.clone());
609 }
610 }
611
612 for client_id in expired_sessions {
613 if let Some((_, session)) = sessions.remove(&client_id) {
614 let mut stats_guard = stats.write();
616 stats_guard.total_session_duration += session.session_duration();
617 drop(stats_guard);
618
619 pending_elicitations.remove(&client_id);
621 active_completions.remove(&client_id);
622
623 let event = SessionEvent {
625 timestamp: Utc::now(),
626 client_id,
627 event_type: SessionEventType::Expired,
628 metadata: HashMap::new(),
629 };
630
631 let mut history = session_history.write();
632 if history.len() >= 1000 {
633 history.pop_front();
634 }
635 history.push_back(event);
636 }
637 }
638 }
639
640 fn record_session_event(
641 &self,
642 client_id: String,
643 event_type: SessionEventType,
644 metadata: HashMap<String, serde_json::Value>,
645 ) {
646 let event = SessionEvent {
647 timestamp: Utc::now(),
648 client_id,
649 event_type,
650 metadata,
651 };
652
653 let mut history = self.session_history.write();
654 if history.len() >= 1000 {
655 history.pop_front();
656 }
657 history.push_back(event);
658 }
659
660 fn enforce_capacity(&self) {
663 let target = self.config.max_sessions;
664 if self.sessions.len() < target {
666 return;
667 }
668
669 let mut entries: Vec<_> = self
671 .sessions
672 .iter()
673 .map(|entry| (entry.key().clone(), entry.last_activity))
674 .collect();
675 entries.sort_by_key(|(_, ts)| *ts);
676
677 let mut to_evict = self.sessions.len().saturating_sub(target) + 1; for (client_id, _) in entries {
680 if to_evict == 0 {
681 break;
682 }
683 if let Some((_, session)) = self.sessions.remove(&client_id) {
684 let mut stats = self.stats.write();
685 stats.total_session_duration += session.session_duration();
686 drop(stats);
687
688 let event = SessionEvent {
690 timestamp: Utc::now(),
691 client_id: client_id.clone(),
692 event_type: SessionEventType::Terminated,
693 metadata: {
694 let mut m = HashMap::new();
695 m.insert("reason".to_string(), serde_json::json!("capacity_eviction"));
696 m
697 },
698 };
699 {
700 let mut history = self.session_history.write();
701 if history.len() >= 1000 {
702 history.pop_front();
703 }
704 history.push_back(event);
705 } to_evict = to_evict.saturating_sub(1);
707 }
708 }
709 }
710
711 fn sanitize_parameters(&self, mut params: serde_json::Value) -> serde_json::Value {
712 let _ = self; if let Some(obj) = params.as_object_mut() {
715 let sensitive_keys = &["password", "token", "api_key", "secret", "auth"];
716 for key in sensitive_keys {
717 if obj.contains_key(*key) {
718 obj.insert(
719 (*key).to_string(),
720 serde_json::Value::String("[REDACTED]".to_string()),
721 );
722 }
723 }
724 }
725 params
726 }
727}
728
729impl Default for SessionManager {
730 fn default() -> Self {
731 Self::new(SessionConfig::default())
732 }
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 #[tokio::test]
740 async fn test_session_creation() {
741 let manager = SessionManager::new(SessionConfig::default());
742
743 let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
744
745 assert_eq!(session.client_id, "client-1");
746 assert_eq!(session.transport_type, "http");
747 assert!(!session.authenticated);
748
749 let analytics = manager.get_analytics();
750 assert_eq!(analytics.total_sessions, 1);
751 assert_eq!(analytics.active_sessions, 1);
752 }
753
754 #[tokio::test]
755 async fn test_session_authentication() {
756 let manager = SessionManager::new(SessionConfig::default());
757
758 let session = manager.get_or_create_session("client-1".to_string(), "http".to_string());
759 assert!(!session.authenticated);
760
761 let success = manager.authenticate_client(
762 "client-1",
763 Some("Test Client".to_string()),
764 Some("token123".to_string()),
765 );
766
767 assert!(success);
768
769 let updated_session = manager.get_session("client-1").unwrap();
770 assert!(updated_session.authenticated);
771 assert_eq!(updated_session.client_name, Some("Test Client".to_string()));
772 }
773
774 #[tokio::test]
775 async fn test_request_recording() {
776 let mut manager = SessionManager::new(SessionConfig::default());
777 manager.config.enable_analytics = true;
778
779 let request = RequestInfo::new(
780 "client-1".to_string(),
781 "test_method".to_string(),
782 serde_json::json!({"param": "value"}),
783 )
784 .complete_success(100);
785
786 manager.record_request(request);
787
788 let analytics = manager.get_analytics();
789 assert_eq!(analytics.total_requests, 1);
790 assert_eq!(analytics.successful_requests, 1);
791 assert_eq!(analytics.failed_requests, 0);
792
793 let history = manager.get_request_history(Some(10));
794 assert_eq!(history.len(), 1);
795 assert_eq!(history[0].method_name, "test_method");
796 }
797
798 #[tokio::test]
799 async fn test_session_termination() {
800 let manager = SessionManager::new(SessionConfig::default());
801
802 let _ = manager.get_or_create_session("client-1".to_string(), "http".to_string());
803 assert!(manager.get_session("client-1").is_some());
804
805 let terminated = manager.terminate_session("client-1");
806 assert!(terminated);
807 assert!(manager.get_session("client-1").is_none());
808
809 let analytics = manager.get_analytics();
810 assert_eq!(analytics.active_sessions, 0);
811 }
812
813 #[tokio::test]
814 async fn test_parameter_sanitization() {
815 let manager = SessionManager::new(SessionConfig::default());
816
817 let sensitive_params = serde_json::json!({
818 "username": "testuser",
819 "password": "secret123",
820 "api_key": "key456",
821 "data": "normal_data"
822 });
823
824 let sanitized = manager.sanitize_parameters(sensitive_params);
825 let obj = sanitized.as_object().unwrap();
826
827 assert_eq!(
828 obj["username"],
829 serde_json::Value::String("testuser".to_string())
830 );
831 assert_eq!(
832 obj["password"],
833 serde_json::Value::String("[REDACTED]".to_string())
834 );
835 assert_eq!(
836 obj["api_key"],
837 serde_json::Value::String("[REDACTED]".to_string())
838 );
839 assert_eq!(
840 obj["data"],
841 serde_json::Value::String("normal_data".to_string())
842 );
843 }
844}