Skip to main content

oxirs_chat/
session_manager.rs

1//! Session management and chat functionality
2
3use crate::messages::Message;
4use serde::{Deserialize, Serialize};
5use std::{
6    collections::{HashMap, HashSet, VecDeque},
7    time::SystemTime,
8};
9
10/// Session configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ChatConfig {
13    pub max_context_tokens: usize,
14    pub sliding_window_size: usize,
15    pub enable_context_compression: bool,
16    pub temperature: f32,
17    pub max_tokens: usize,
18    pub timeout_seconds: u64,
19    pub enable_topic_tracking: bool,
20    pub enable_sentiment_analysis: bool,
21    pub enable_intent_detection: bool,
22}
23
24impl Default for ChatConfig {
25    fn default() -> Self {
26        Self {
27            max_context_tokens: 8000,
28            sliding_window_size: 20,
29            enable_context_compression: true,
30            temperature: 0.7,
31            max_tokens: 2000,
32            timeout_seconds: 30,
33            enable_topic_tracking: true,
34            enable_sentiment_analysis: true,
35            enable_intent_detection: true,
36        }
37    }
38}
39
40/// Session state enumeration
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub enum SessionState {
43    Active,
44    Idle,
45    Suspended,
46    Archived,
47    Expired,
48}
49
50/// Context window for managing conversation context
51#[derive(Debug, Clone)]
52pub struct ContextWindow {
53    pub window_size: usize,
54    pub active_messages: VecDeque<String>, // Message IDs
55    pub pinned_messages: HashSet<String>,
56    pub context_summary: Option<String>,
57    pub importance_scores: HashMap<String, f32>,
58    pub token_count: usize,
59    pub last_compression: Option<SystemTime>,
60}
61
62impl ContextWindow {
63    pub fn new(window_size: usize) -> Self {
64        Self {
65            window_size,
66            active_messages: VecDeque::new(),
67            pinned_messages: HashSet::new(),
68            context_summary: None,
69            importance_scores: HashMap::new(),
70            token_count: 0,
71            last_compression: None,
72        }
73    }
74
75    pub fn add_message(&mut self, message_id: String, importance: f32, tokens: usize) {
76        // Remove oldest if window is full and message isn't pinned
77        while self.active_messages.len() >= self.window_size {
78            // Check if front message is pinned before removing
79            let should_remove = self
80                .active_messages
81                .front()
82                .map(|id| !self.pinned_messages.contains(id))
83                .unwrap_or(false);
84
85            if should_remove {
86                if let Some(removed_id) = self.active_messages.pop_front() {
87                    if let Some(removed_tokens) = self.importance_scores.remove(&removed_id) {
88                        // Estimate tokens from importance score (simplified)
89                        self.token_count = self
90                            .token_count
91                            .saturating_sub((removed_tokens * 100.0) as usize);
92                    }
93                }
94            } else {
95                break;
96            }
97        }
98
99        self.active_messages.push_back(message_id.clone());
100        self.importance_scores.insert(message_id, importance);
101        self.token_count += tokens;
102    }
103
104    pub fn pin_message(&mut self, message_id: String) {
105        self.pinned_messages.insert(message_id);
106    }
107
108    pub fn unpin_message(&mut self, message_id: &str) {
109        self.pinned_messages.remove(message_id);
110    }
111
112    pub fn get_context_messages(&self) -> Vec<String> {
113        self.active_messages.iter().cloned().collect()
114    }
115
116    pub fn compress_context(&mut self, summary: String) {
117        self.context_summary = Some(summary);
118        self.last_compression = Some(SystemTime::now());
119
120        // Remove oldest unpinned messages after compression
121        let mut to_remove = Vec::new();
122        for message_id in self.active_messages.iter() {
123            if !self.pinned_messages.contains(message_id) {
124                to_remove.push(message_id.clone());
125            }
126        }
127
128        // Keep only half after compression
129        let keep_count = to_remove.len() / 2;
130        for (i, message_id) in to_remove.iter().enumerate() {
131            if i < to_remove.len() - keep_count {
132                self.active_messages.retain(|id| id != message_id);
133                self.importance_scores.remove(message_id);
134            }
135        }
136
137        // Recalculate token count
138        self.token_count = self.importance_scores.len() * 50; // Simplified estimation
139    }
140
141    pub fn needs_compression(&self, max_tokens: usize) -> bool {
142        self.token_count > max_tokens && self.context_summary.is_none()
143    }
144}
145
146/// Topic tracking for conversation analysis
147#[derive(Debug, Clone)]
148pub struct TopicTracker {
149    pub current_topics: Vec<Topic>,
150    pub topic_history: Vec<TopicTransition>,
151    pub confidence_threshold: f32,
152    pub max_topics: usize,
153}
154
155impl Default for TopicTracker {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl TopicTracker {
162    pub fn new() -> Self {
163        Self {
164            current_topics: Vec::new(),
165            topic_history: Vec::new(),
166            confidence_threshold: 0.6,
167            max_topics: 5,
168        }
169    }
170
171    pub fn analyze_message(&mut self, message: &Message) -> Option<TopicTransition> {
172        // Simplified topic analysis - in real implementation, this would use NLP
173        let content = message.content.to_text().to_lowercase();
174
175        // Basic keyword-based topic detection
176        let detected_topics = self.extract_topics(&content);
177
178        if !detected_topics.is_empty() {
179            let transition = self.update_topics(detected_topics, &message.id);
180            if let Some(ref t) = transition {
181                self.topic_history.push(t.clone());
182            }
183            transition
184        } else {
185            None
186        }
187    }
188
189    fn extract_topics(&self, content: &str) -> Vec<String> {
190        let mut topics = Vec::new();
191
192        // Simple keyword matching - in practice, use proper NLP
193        if content.contains("sparql") || content.contains("query") {
194            topics.push("SPARQL Queries".to_string());
195        }
196        if content.contains("graph") || content.contains("rdf") {
197            topics.push("Knowledge Graphs".to_string());
198        }
199        if content.contains("data") || content.contains("dataset") {
200            topics.push("Data Management".to_string());
201        }
202
203        topics
204    }
205
206    fn update_topics(
207        &mut self,
208        new_topics: Vec<String>,
209        trigger_message_id: &str,
210    ) -> Option<TopicTransition> {
211        // Determine transition type
212        let transition_type = if self.current_topics.is_empty() {
213            TransitionType::NewTopic
214        } else {
215            // Check for topic overlap
216            let current_topic_names: HashSet<String> =
217                self.current_topics.iter().map(|t| t.name.clone()).collect();
218            let new_topic_names: HashSet<String> = new_topics.iter().cloned().collect();
219
220            if current_topic_names.intersection(&new_topic_names).count() > 0 {
221                TransitionType::TopicReturn
222            } else {
223                TransitionType::TopicShift
224            }
225        };
226
227        // Update current topics
228        self.current_topics.clear();
229        for topic_name in new_topics {
230            self.current_topics.push(Topic {
231                name: topic_name.clone(),
232                confidence: 0.8, // Simplified confidence
233                first_mentioned: chrono::Utc::now(),
234                last_mentioned: chrono::Utc::now(),
235                message_count: 1,
236                keywords: vec![topic_name.to_lowercase()],
237            });
238        }
239
240        Some(TopicTransition {
241            from_topics: Vec::new(), // Simplified
242            to_topics: self.current_topics.iter().map(|t| t.name.clone()).collect(),
243            timestamp: chrono::Utc::now(),
244            trigger_message_id: trigger_message_id.to_string(),
245            confidence: 0.8,
246            transition_type,
247        })
248    }
249
250    pub fn get_current_topic_summary(&self) -> String {
251        if self.current_topics.is_empty() {
252            "No specific topic detected".to_string()
253        } else {
254            let topic_names: Vec<String> =
255                self.current_topics.iter().map(|t| t.name.clone()).collect();
256            format!("Current topics: {}", topic_names.join(", "))
257        }
258    }
259}
260
261/// Topic information
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct Topic {
264    pub name: String,
265    pub confidence: f32,
266    pub first_mentioned: chrono::DateTime<chrono::Utc>,
267    pub last_mentioned: chrono::DateTime<chrono::Utc>,
268    pub message_count: usize,
269    pub keywords: Vec<String>,
270}
271
272impl Topic {
273    pub fn update_mention(&mut self) {
274        self.last_mentioned = chrono::Utc::now();
275        self.message_count += 1;
276    }
277
278    pub fn add_keyword(&mut self, keyword: String) {
279        if !self.keywords.contains(&keyword) {
280            self.keywords.push(keyword);
281        }
282    }
283
284    pub fn get_relevance_score(&self) -> f32 {
285        let time_decay = {
286            let now = chrono::Utc::now();
287            let hours_since = now.signed_duration_since(self.last_mentioned).num_hours() as f32;
288            (-hours_since / 24.0).exp() // Exponential decay over days
289        };
290
291        let frequency_boost = (self.message_count as f32).ln().max(1.0);
292
293        self.confidence * time_decay * frequency_boost
294    }
295}
296
297/// Topic transition information
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct TopicTransition {
300    pub from_topics: Vec<String>,
301    pub to_topics: Vec<String>,
302    pub timestamp: chrono::DateTime<chrono::Utc>,
303    pub trigger_message_id: String,
304    pub confidence: f32,
305    pub transition_type: TransitionType,
306}
307
308/// Type of topic transition
309#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
310pub enum TransitionType {
311    NewTopic,
312    TopicShift,
313    TopicReturn,
314    TopicMerge,
315    TopicSplit,
316}
317
318impl PartialEq<&str> for TransitionType {
319    fn eq(&self, other: &&str) -> bool {
320        matches!(
321            (self, *other),
322            (TransitionType::NewTopic, "new")
323                | (TransitionType::TopicShift, "shift")
324                | (TransitionType::TopicReturn, "return")
325                | (TransitionType::TopicMerge, "merge")
326                | (TransitionType::TopicSplit, "split")
327        )
328    }
329}
330
331/// Session performance metrics
332#[derive(Debug, Default, Clone, Serialize, Deserialize)]
333pub struct SessionMetrics {
334    pub total_messages: usize,
335    pub user_messages: usize,
336    pub assistant_messages: usize,
337    pub average_response_time: f64,
338    pub total_tokens_used: usize,
339    pub successful_queries: usize,
340    pub failed_queries: usize,
341    pub context_compressions: usize,
342    pub topic_transitions: usize,
343    pub user_satisfaction_scores: Vec<f32>,
344    pub error_count: usize,
345    pub warning_count: usize,
346    pub cache_hits: usize,
347    pub cache_misses: usize,
348    pub last_updated: chrono::DateTime<chrono::Utc>,
349}
350
351impl SessionMetrics {
352    pub fn update_response_time(&mut self, response_time_ms: u64) {
353        let response_time_s = response_time_ms as f64 / 1000.0;
354        if self.assistant_messages == 0 {
355            self.average_response_time = response_time_s;
356        } else {
357            self.average_response_time =
358                (self.average_response_time * self.assistant_messages as f64 + response_time_s)
359                    / (self.assistant_messages as f64 + 1.0);
360        }
361        self.assistant_messages += 1;
362        self.total_messages += 1;
363        self.last_updated = chrono::Utc::now();
364    }
365
366    pub fn add_user_message(&mut self) {
367        self.user_messages += 1;
368        self.total_messages += 1;
369        self.last_updated = chrono::Utc::now();
370    }
371
372    pub fn add_successful_query(&mut self, tokens_used: usize) {
373        self.successful_queries += 1;
374        self.total_tokens_used += tokens_used;
375        self.last_updated = chrono::Utc::now();
376    }
377
378    pub fn add_failed_query(&mut self) {
379        self.failed_queries += 1;
380        self.error_count += 1;
381        self.last_updated = chrono::Utc::now();
382    }
383
384    pub fn add_context_compression(&mut self) {
385        self.context_compressions += 1;
386        self.last_updated = chrono::Utc::now();
387    }
388
389    pub fn add_topic_transition(&mut self) {
390        self.topic_transitions += 1;
391        self.last_updated = chrono::Utc::now();
392    }
393
394    pub fn add_satisfaction_score(&mut self, score: f32) {
395        self.user_satisfaction_scores.push(score.clamp(0.0, 5.0));
396        self.last_updated = chrono::Utc::now();
397    }
398
399    pub fn get_average_satisfaction(&self) -> f32 {
400        if self.user_satisfaction_scores.is_empty() {
401            0.0
402        } else {
403            self.user_satisfaction_scores.iter().sum::<f32>()
404                / self.user_satisfaction_scores.len() as f32
405        }
406    }
407
408    pub fn get_query_success_rate(&self) -> f32 {
409        let total_queries = self.successful_queries + self.failed_queries;
410        if total_queries == 0 {
411            0.0
412        } else {
413            self.successful_queries as f32 / total_queries as f32
414        }
415    }
416
417    pub fn get_cache_hit_rate(&self) -> f32 {
418        let total_cache_requests = self.cache_hits + self.cache_misses;
419        if total_cache_requests == 0 {
420            0.0
421        } else {
422            self.cache_hits as f32 / total_cache_requests as f32
423        }
424    }
425}
426
427/// Chat session data for serialization
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct SessionData {
430    pub id: String,
431    pub config: ChatConfig,
432    pub messages: Vec<Message>,
433    pub created_at: chrono::DateTime<chrono::Utc>,
434    pub last_activity: chrono::DateTime<chrono::Utc>,
435    pub user_preferences: HashMap<String, String>,
436    pub session_state: SessionState,
437    pub context_summary: Option<String>,
438    pub pinned_messages: HashSet<String>,
439    pub current_topics: Vec<Topic>,
440    pub topic_history: Vec<TopicTransition>,
441    pub performance_metrics: SessionMetrics,
442}
443
444// Compatibility aliases for lib.rs
445impl SessionData {
446    pub fn user_id(&self) -> Option<&str> {
447        self.user_preferences.get("user_id").map(|s| s.as_str())
448    }
449
450    pub fn set_user_id(&mut self, user_id: String) {
451        self.user_preferences.insert("user_id".to_string(), user_id);
452    }
453}
454
455// Re-export with compatibility
456pub use SessionState as state;
457
458// ─────────────────────────────────────────────────────────────────────────────
459// Tests
460// ─────────────────────────────────────────────────────────────────────────────
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    // ── ChatConfig ────────────────────────────────────────────────────────────
467
468    #[test]
469    fn test_chat_config_default_has_reasonable_values() {
470        let cfg = ChatConfig::default();
471        assert!(cfg.max_context_tokens > 0);
472        assert!(cfg.sliding_window_size > 0);
473        assert!(cfg.max_tokens > 0);
474        assert!(cfg.timeout_seconds > 0);
475    }
476
477    #[test]
478    fn test_chat_config_temperature_in_range() {
479        let cfg = ChatConfig::default();
480        assert!(cfg.temperature >= 0.0 && cfg.temperature <= 2.0);
481    }
482
483    #[test]
484    fn test_chat_config_flags_set() {
485        let cfg = ChatConfig::default();
486        assert!(cfg.enable_topic_tracking);
487        assert!(cfg.enable_context_compression);
488    }
489
490    // ── SessionState ──────────────────────────────────────────────────────────
491
492    #[test]
493    fn test_session_state_active_eq() {
494        assert_eq!(SessionState::Active, SessionState::Active);
495        assert_ne!(SessionState::Active, SessionState::Expired);
496    }
497
498    #[test]
499    fn test_session_state_all_variants_accessible() {
500        let states = [
501            SessionState::Active,
502            SessionState::Idle,
503            SessionState::Suspended,
504            SessionState::Archived,
505            SessionState::Expired,
506        ];
507        assert_eq!(states.len(), 5);
508    }
509
510    // ── ContextWindow ─────────────────────────────────────────────────────────
511
512    #[test]
513    fn test_context_window_new_empty() {
514        let cw = ContextWindow::new(10);
515        assert_eq!(cw.window_size, 10);
516        assert!(cw.active_messages.is_empty());
517        assert_eq!(cw.token_count, 0);
518    }
519
520    #[test]
521    fn test_context_window_add_message() {
522        let mut cw = ContextWindow::new(5);
523        cw.add_message("msg1".to_string(), 1.0, 10);
524        assert_eq!(cw.active_messages.len(), 1);
525        assert_eq!(cw.token_count, 10);
526    }
527
528    #[test]
529    fn test_context_window_evicts_oldest_when_full() {
530        let mut cw = ContextWindow::new(3);
531        cw.add_message("a".to_string(), 0.5, 5);
532        cw.add_message("b".to_string(), 0.5, 5);
533        cw.add_message("c".to_string(), 0.5, 5);
534        cw.add_message("d".to_string(), 0.5, 5); // evicts 'a'
535        let msgs = cw.get_context_messages();
536        assert!(!msgs.contains(&"a".to_string()));
537        assert!(msgs.contains(&"b".to_string()) || msgs.contains(&"c".to_string()));
538    }
539
540    #[test]
541    fn test_context_window_pin_message() {
542        let mut cw = ContextWindow::new(2);
543        cw.add_message("pinned".to_string(), 1.0, 5);
544        cw.pin_message("pinned".to_string());
545        assert!(cw.pinned_messages.contains("pinned"));
546    }
547
548    #[test]
549    fn test_context_window_unpin_message() {
550        let mut cw = ContextWindow::new(5);
551        cw.pin_message("msg".to_string());
552        cw.unpin_message("msg");
553        assert!(!cw.pinned_messages.contains("msg"));
554    }
555
556    #[test]
557    fn test_context_window_get_context_messages() {
558        let mut cw = ContextWindow::new(5);
559        cw.add_message("m1".to_string(), 0.5, 3);
560        cw.add_message("m2".to_string(), 0.8, 4);
561        let msgs = cw.get_context_messages();
562        assert_eq!(msgs.len(), 2);
563    }
564
565    #[test]
566    fn test_context_window_needs_compression() {
567        let mut cw = ContextWindow::new(100);
568        cw.token_count = 5000;
569        cw.context_summary = None;
570        assert!(cw.needs_compression(4000));
571        assert!(!cw.needs_compression(6000));
572    }
573
574    #[test]
575    fn test_context_window_needs_compression_false_when_summarized() {
576        let mut cw = ContextWindow::new(100);
577        cw.token_count = 9000;
578        cw.context_summary = Some("summary".to_string());
579        // Already compressed — should not trigger again.
580        assert!(!cw.needs_compression(100));
581    }
582
583    #[test]
584    fn test_context_window_compress_sets_summary() {
585        let mut cw = ContextWindow::new(10);
586        cw.add_message("m1".to_string(), 0.5, 10);
587        cw.compress_context("compressed summary".to_string());
588        assert_eq!(cw.context_summary.as_deref(), Some("compressed summary"));
589        assert!(cw.last_compression.is_some());
590    }
591
592    // ── TopicTracker ──────────────────────────────────────────────────────────
593
594    #[test]
595    fn test_topic_tracker_new_empty() {
596        let t = TopicTracker::new();
597        assert!(t.current_topics.is_empty());
598        assert!(t.topic_history.is_empty());
599    }
600
601    #[test]
602    fn test_topic_tracker_default() {
603        let t = TopicTracker::default();
604        assert!(t.current_topics.is_empty());
605        assert_eq!(t.max_topics, 5);
606    }
607
608    #[test]
609    fn test_topic_tracker_summary_no_topics() {
610        let t = TopicTracker::new();
611        let s = t.get_current_topic_summary();
612        assert!(s.contains("No") || !s.is_empty());
613    }
614
615    // ── Topic ─────────────────────────────────────────────────────────────────
616
617    #[test]
618    fn test_topic_add_keyword() {
619        let mut topic = Topic {
620            name: "SPARQL".to_string(),
621            confidence: 0.9,
622            first_mentioned: chrono::Utc::now(),
623            last_mentioned: chrono::Utc::now(),
624            message_count: 1,
625            keywords: vec!["sparql".to_string()],
626        };
627        topic.add_keyword("query".to_string());
628        assert!(topic.keywords.contains(&"query".to_string()));
629        assert_eq!(topic.keywords.len(), 2);
630    }
631
632    #[test]
633    fn test_topic_add_keyword_no_duplicate() {
634        let mut topic = Topic {
635            name: "RDF".to_string(),
636            confidence: 0.8,
637            first_mentioned: chrono::Utc::now(),
638            last_mentioned: chrono::Utc::now(),
639            message_count: 1,
640            keywords: vec!["rdf".to_string()],
641        };
642        topic.add_keyword("rdf".to_string()); // duplicate
643        assert_eq!(topic.keywords.len(), 1);
644    }
645
646    #[test]
647    fn test_topic_update_mention_increments_count() {
648        let mut topic = Topic {
649            name: "Graphs".to_string(),
650            confidence: 0.7,
651            first_mentioned: chrono::Utc::now(),
652            last_mentioned: chrono::Utc::now(),
653            message_count: 1,
654            keywords: vec![],
655        };
656        topic.update_mention();
657        assert_eq!(topic.message_count, 2);
658    }
659
660    #[test]
661    fn test_topic_relevance_score_positive() {
662        let topic = Topic {
663            name: "Test".to_string(),
664            confidence: 0.9,
665            first_mentioned: chrono::Utc::now(),
666            last_mentioned: chrono::Utc::now(),
667            message_count: 5,
668            keywords: vec![],
669        };
670        let score = topic.get_relevance_score();
671        assert!(score > 0.0);
672    }
673
674    // ── SessionMetrics ────────────────────────────────────────────────────────
675
676    #[test]
677    fn test_session_metrics_default_zero() {
678        let m = SessionMetrics::default();
679        assert_eq!(m.total_messages, 0);
680        assert_eq!(m.successful_queries, 0);
681    }
682
683    #[test]
684    fn test_session_metrics_add_user_message() {
685        let mut m = SessionMetrics::default();
686        m.add_user_message();
687        assert_eq!(m.user_messages, 1);
688        assert_eq!(m.total_messages, 1);
689    }
690
691    #[test]
692    fn test_session_metrics_add_successful_query() {
693        let mut m = SessionMetrics::default();
694        m.add_successful_query(100);
695        assert_eq!(m.successful_queries, 1);
696        assert_eq!(m.total_tokens_used, 100);
697    }
698
699    #[test]
700    fn test_session_metrics_add_failed_query() {
701        let mut m = SessionMetrics::default();
702        m.add_failed_query();
703        assert_eq!(m.failed_queries, 1);
704        assert_eq!(m.error_count, 1);
705    }
706
707    #[test]
708    fn test_session_metrics_query_success_rate() {
709        let mut m = SessionMetrics::default();
710        m.add_successful_query(10);
711        m.add_successful_query(10);
712        m.add_failed_query();
713        let rate = m.get_query_success_rate();
714        assert!((rate - 2.0 / 3.0).abs() < 1e-6);
715    }
716
717    #[test]
718    fn test_session_metrics_query_success_rate_zero_total() {
719        let m = SessionMetrics::default();
720        assert_eq!(m.get_query_success_rate(), 0.0);
721    }
722
723    #[test]
724    fn test_session_metrics_cache_hit_rate() {
725        let m = SessionMetrics {
726            cache_hits: 3,
727            cache_misses: 1,
728            ..SessionMetrics::default()
729        };
730        let rate = m.get_cache_hit_rate();
731        assert!((rate - 0.75).abs() < 1e-6);
732    }
733
734    #[test]
735    fn test_session_metrics_cache_hit_rate_zero_total() {
736        let m = SessionMetrics::default();
737        assert_eq!(m.get_cache_hit_rate(), 0.0);
738    }
739
740    #[test]
741    fn test_session_metrics_satisfaction_score() {
742        let mut m = SessionMetrics::default();
743        m.add_satisfaction_score(4.0);
744        m.add_satisfaction_score(2.0);
745        let avg = m.get_average_satisfaction();
746        assert!((avg - 3.0).abs() < 1e-6);
747    }
748
749    #[test]
750    fn test_session_metrics_satisfaction_clamped_to_five() {
751        let mut m = SessionMetrics::default();
752        m.add_satisfaction_score(10.0); // clamped to 5.0
753        let avg = m.get_average_satisfaction();
754        assert!((avg - 5.0).abs() < 1e-6);
755    }
756
757    #[test]
758    fn test_session_metrics_average_satisfaction_empty() {
759        let m = SessionMetrics::default();
760        assert_eq!(m.get_average_satisfaction(), 0.0);
761    }
762
763    #[test]
764    fn test_session_metrics_context_compression_tracked() {
765        let mut m = SessionMetrics::default();
766        m.add_context_compression();
767        assert_eq!(m.context_compressions, 1);
768    }
769
770    #[test]
771    fn test_session_metrics_topic_transition_tracked() {
772        let mut m = SessionMetrics::default();
773        m.add_topic_transition();
774        assert_eq!(m.topic_transitions, 1);
775    }
776
777    // ── SessionData ───────────────────────────────────────────────────────────
778
779    #[test]
780    fn test_session_data_user_id_roundtrip() {
781        let mut sd = SessionData {
782            id: "test-session".to_string(),
783            config: ChatConfig::default(),
784            messages: Vec::new(),
785            created_at: chrono::Utc::now(),
786            last_activity: chrono::Utc::now(),
787            user_preferences: HashMap::new(),
788            session_state: SessionState::Active,
789            context_summary: None,
790            pinned_messages: HashSet::new(),
791            current_topics: Vec::new(),
792            topic_history: Vec::new(),
793            performance_metrics: SessionMetrics::default(),
794        };
795        sd.set_user_id("alice".to_string());
796        assert_eq!(sd.user_id(), Some("alice"));
797    }
798
799    #[test]
800    fn test_session_data_user_id_none_initially() {
801        let sd = SessionData {
802            id: "s1".to_string(),
803            config: ChatConfig::default(),
804            messages: Vec::new(),
805            created_at: chrono::Utc::now(),
806            last_activity: chrono::Utc::now(),
807            user_preferences: HashMap::new(),
808            session_state: SessionState::Active,
809            context_summary: None,
810            pinned_messages: HashSet::new(),
811            current_topics: Vec::new(),
812            topic_history: Vec::new(),
813            performance_metrics: SessionMetrics::default(),
814        };
815        assert!(sd.user_id().is_none());
816    }
817
818    // ── TransitionType ────────────────────────────────────────────────────────
819
820    #[test]
821    fn test_transition_type_partial_eq_str() {
822        assert_eq!(TransitionType::NewTopic, "new");
823        assert_eq!(TransitionType::TopicShift, "shift");
824        assert_eq!(TransitionType::TopicReturn, "return");
825        assert_eq!(TransitionType::TopicMerge, "merge");
826        assert_eq!(TransitionType::TopicSplit, "split");
827    }
828
829    #[test]
830    fn test_transition_type_ne_wrong_str() {
831        assert_ne!(TransitionType::NewTopic, "shift");
832    }
833
834    // ── SessionMetrics additional ─────────────────────────────────────────────
835
836    #[test]
837    fn test_session_metrics_update_response_time_first_call() {
838        let mut m = SessionMetrics::default();
839        m.update_response_time(500);
840        assert!((m.average_response_time - 0.5).abs() < 1e-6);
841        assert_eq!(m.assistant_messages, 1);
842    }
843
844    #[test]
845    fn test_session_metrics_update_response_time_average() {
846        let mut m = SessionMetrics::default();
847        m.update_response_time(1000); // 1s
848        m.update_response_time(3000); // 3s
849                                      // Average: (1 + 3) / 2 = 2s
850        assert!((m.average_response_time - 2.0).abs() < 0.1);
851    }
852
853    #[test]
854    fn test_session_metrics_tokens_accumulate() {
855        let mut m = SessionMetrics::default();
856        m.add_successful_query(100);
857        m.add_successful_query(250);
858        assert_eq!(m.total_tokens_used, 350);
859    }
860
861    #[test]
862    fn test_session_metrics_satisfaction_clamped_to_zero() {
863        let mut m = SessionMetrics::default();
864        m.add_satisfaction_score(-5.0); // clamped to 0.0
865        let avg = m.get_average_satisfaction();
866        assert!((avg - 0.0).abs() < 1e-6);
867    }
868
869    // ── ContextWindow additional ──────────────────────────────────────────────
870
871    #[test]
872    fn test_context_window_importance_scores_stored() {
873        let mut cw = ContextWindow::new(10);
874        cw.add_message("m1".to_string(), 0.9, 5);
875        assert!(cw.importance_scores.contains_key("m1"));
876    }
877
878    #[test]
879    fn test_context_window_multiple_pins() {
880        let mut cw = ContextWindow::new(10);
881        cw.pin_message("m1".to_string());
882        cw.pin_message("m2".to_string());
883        assert_eq!(cw.pinned_messages.len(), 2);
884    }
885
886    // ── ChatConfig custom ─────────────────────────────────────────────────────
887
888    #[test]
889    fn test_chat_config_custom_values() {
890        let cfg = ChatConfig {
891            max_context_tokens: 4096,
892            sliding_window_size: 10,
893            enable_context_compression: false,
894            temperature: 0.5,
895            max_tokens: 512,
896            timeout_seconds: 15,
897            enable_topic_tracking: false,
898            enable_sentiment_analysis: false,
899            enable_intent_detection: false,
900        };
901        assert_eq!(cfg.max_context_tokens, 4096);
902        assert_eq!(cfg.sliding_window_size, 10);
903        assert!(!cfg.enable_context_compression);
904    }
905
906    // ── TopicTracker analyze_message ──────────────────────────────────────────
907
908    #[test]
909    fn test_topic_tracker_analyze_sparql_message() {
910        use crate::messages::{MessageContent, MessageRole};
911        let mut tracker = TopicTracker::new();
912        let msg = crate::messages::Message {
913            id: "m1".to_string(),
914            role: MessageRole::User,
915            content: MessageContent::Text("How do I write a sparql query?".to_string()),
916            timestamp: chrono::Utc::now(),
917            metadata: None,
918            thread_id: None,
919            parent_message_id: None,
920            token_count: None,
921            reactions: Vec::new(),
922            attachments: Vec::new(),
923            rich_elements: Vec::new(),
924        };
925        let transition = tracker.analyze_message(&msg);
926        // Should detect "SPARQL Queries" topic.
927        assert!(transition.is_some());
928    }
929
930    #[test]
931    fn test_topic_tracker_summary_with_topics() {
932        let mut tracker = TopicTracker::new();
933        tracker.current_topics.push(Topic {
934            name: "SPARQL".to_string(),
935            confidence: 0.9,
936            first_mentioned: chrono::Utc::now(),
937            last_mentioned: chrono::Utc::now(),
938            message_count: 1,
939            keywords: vec![],
940        });
941        let summary = tracker.get_current_topic_summary();
942        assert!(summary.contains("SPARQL"));
943    }
944}