oxirs_chat/
session_manager.rs

1//! Session management and chat functionality
2
3use crate::messages::Message;
4use serde::{Deserialize, Serialize};
5use std::{
6    collections::{HashMap, HashSet, VecDeque},
7    time::SystemTime,
8};
9
10/// Session configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ChatConfig {
13    pub max_context_tokens: usize,
14    pub sliding_window_size: usize,
15    pub enable_context_compression: bool,
16    pub temperature: f32,
17    pub max_tokens: usize,
18    pub timeout_seconds: u64,
19    pub enable_topic_tracking: bool,
20    pub enable_sentiment_analysis: bool,
21    pub enable_intent_detection: bool,
22}
23
24impl Default for ChatConfig {
25    fn default() -> Self {
26        Self {
27            max_context_tokens: 8000,
28            sliding_window_size: 20,
29            enable_context_compression: true,
30            temperature: 0.7,
31            max_tokens: 2000,
32            timeout_seconds: 30,
33            enable_topic_tracking: true,
34            enable_sentiment_analysis: true,
35            enable_intent_detection: true,
36        }
37    }
38}
39
40/// Session state enumeration
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub enum SessionState {
43    Active,
44    Idle,
45    Suspended,
46    Archived,
47    Expired,
48}
49
50/// Context window for managing conversation context
51#[derive(Debug, Clone)]
52pub struct ContextWindow {
53    pub window_size: usize,
54    pub active_messages: VecDeque<String>, // Message IDs
55    pub pinned_messages: HashSet<String>,
56    pub context_summary: Option<String>,
57    pub importance_scores: HashMap<String, f32>,
58    pub token_count: usize,
59    pub last_compression: Option<SystemTime>,
60}
61
62impl ContextWindow {
63    pub fn new(window_size: usize) -> Self {
64        Self {
65            window_size,
66            active_messages: VecDeque::new(),
67            pinned_messages: HashSet::new(),
68            context_summary: None,
69            importance_scores: HashMap::new(),
70            token_count: 0,
71            last_compression: None,
72        }
73    }
74
75    pub fn add_message(&mut self, message_id: String, importance: f32, tokens: usize) {
76        // Remove oldest if window is full and message isn't pinned
77        while self.active_messages.len() >= self.window_size {
78            // Check if front message is pinned before removing
79            let should_remove = self
80                .active_messages
81                .front()
82                .map(|id| !self.pinned_messages.contains(id))
83                .unwrap_or(false);
84
85            if should_remove {
86                if let Some(removed_id) = self.active_messages.pop_front() {
87                    if let Some(removed_tokens) = self.importance_scores.remove(&removed_id) {
88                        // Estimate tokens from importance score (simplified)
89                        self.token_count = self
90                            .token_count
91                            .saturating_sub((removed_tokens * 100.0) as usize);
92                    }
93                }
94            } else {
95                break;
96            }
97        }
98
99        self.active_messages.push_back(message_id.clone());
100        self.importance_scores.insert(message_id, importance);
101        self.token_count += tokens;
102    }
103
104    pub fn pin_message(&mut self, message_id: String) {
105        self.pinned_messages.insert(message_id);
106    }
107
108    pub fn unpin_message(&mut self, message_id: &str) {
109        self.pinned_messages.remove(message_id);
110    }
111
112    pub fn get_context_messages(&self) -> Vec<String> {
113        self.active_messages.iter().cloned().collect()
114    }
115
116    pub fn compress_context(&mut self, summary: String) {
117        self.context_summary = Some(summary);
118        self.last_compression = Some(SystemTime::now());
119
120        // Remove oldest unpinned messages after compression
121        let mut to_remove = Vec::new();
122        for message_id in self.active_messages.iter() {
123            if !self.pinned_messages.contains(message_id) {
124                to_remove.push(message_id.clone());
125            }
126        }
127
128        // Keep only half after compression
129        let keep_count = to_remove.len() / 2;
130        for (i, message_id) in to_remove.iter().enumerate() {
131            if i < to_remove.len() - keep_count {
132                self.active_messages.retain(|id| id != message_id);
133                self.importance_scores.remove(message_id);
134            }
135        }
136
137        // Recalculate token count
138        self.token_count = self.importance_scores.len() * 50; // Simplified estimation
139    }
140
141    pub fn needs_compression(&self, max_tokens: usize) -> bool {
142        self.token_count > max_tokens && self.context_summary.is_none()
143    }
144}
145
146/// Topic tracking for conversation analysis
147#[derive(Debug, Clone)]
148pub struct TopicTracker {
149    pub current_topics: Vec<Topic>,
150    pub topic_history: Vec<TopicTransition>,
151    pub confidence_threshold: f32,
152    pub max_topics: usize,
153}
154
155impl Default for TopicTracker {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl TopicTracker {
162    pub fn new() -> Self {
163        Self {
164            current_topics: Vec::new(),
165            topic_history: Vec::new(),
166            confidence_threshold: 0.6,
167            max_topics: 5,
168        }
169    }
170
171    pub fn analyze_message(&mut self, message: &Message) -> Option<TopicTransition> {
172        // Simplified topic analysis - in real implementation, this would use NLP
173        let content = message.content.to_text().to_lowercase();
174
175        // Basic keyword-based topic detection
176        let detected_topics = self.extract_topics(&content);
177
178        if !detected_topics.is_empty() {
179            let transition = self.update_topics(detected_topics, &message.id);
180            if let Some(ref t) = transition {
181                self.topic_history.push(t.clone());
182            }
183            transition
184        } else {
185            None
186        }
187    }
188
189    fn extract_topics(&self, content: &str) -> Vec<String> {
190        let mut topics = Vec::new();
191
192        // Simple keyword matching - in practice, use proper NLP
193        if content.contains("sparql") || content.contains("query") {
194            topics.push("SPARQL Queries".to_string());
195        }
196        if content.contains("graph") || content.contains("rdf") {
197            topics.push("Knowledge Graphs".to_string());
198        }
199        if content.contains("data") || content.contains("dataset") {
200            topics.push("Data Management".to_string());
201        }
202
203        topics
204    }
205
206    fn update_topics(
207        &mut self,
208        new_topics: Vec<String>,
209        trigger_message_id: &str,
210    ) -> Option<TopicTransition> {
211        // Determine transition type
212        let transition_type = if self.current_topics.is_empty() {
213            TransitionType::NewTopic
214        } else {
215            // Check for topic overlap
216            let current_topic_names: HashSet<String> =
217                self.current_topics.iter().map(|t| t.name.clone()).collect();
218            let new_topic_names: HashSet<String> = new_topics.iter().cloned().collect();
219
220            if current_topic_names.intersection(&new_topic_names).count() > 0 {
221                TransitionType::TopicReturn
222            } else {
223                TransitionType::TopicShift
224            }
225        };
226
227        // Update current topics
228        self.current_topics.clear();
229        for topic_name in new_topics {
230            self.current_topics.push(Topic {
231                name: topic_name.clone(),
232                confidence: 0.8, // Simplified confidence
233                first_mentioned: chrono::Utc::now(),
234                last_mentioned: chrono::Utc::now(),
235                message_count: 1,
236                keywords: vec![topic_name.to_lowercase()],
237            });
238        }
239
240        Some(TopicTransition {
241            from_topics: Vec::new(), // Simplified
242            to_topics: self.current_topics.iter().map(|t| t.name.clone()).collect(),
243            timestamp: chrono::Utc::now(),
244            trigger_message_id: trigger_message_id.to_string(),
245            confidence: 0.8,
246            transition_type,
247        })
248    }
249
250    pub fn get_current_topic_summary(&self) -> String {
251        if self.current_topics.is_empty() {
252            "No specific topic detected".to_string()
253        } else {
254            let topic_names: Vec<String> =
255                self.current_topics.iter().map(|t| t.name.clone()).collect();
256            format!("Current topics: {}", topic_names.join(", "))
257        }
258    }
259}
260
261/// Topic information
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct Topic {
264    pub name: String,
265    pub confidence: f32,
266    pub first_mentioned: chrono::DateTime<chrono::Utc>,
267    pub last_mentioned: chrono::DateTime<chrono::Utc>,
268    pub message_count: usize,
269    pub keywords: Vec<String>,
270}
271
272impl Topic {
273    pub fn update_mention(&mut self) {
274        self.last_mentioned = chrono::Utc::now();
275        self.message_count += 1;
276    }
277
278    pub fn add_keyword(&mut self, keyword: String) {
279        if !self.keywords.contains(&keyword) {
280            self.keywords.push(keyword);
281        }
282    }
283
284    pub fn get_relevance_score(&self) -> f32 {
285        let time_decay = {
286            let now = chrono::Utc::now();
287            let hours_since = now.signed_duration_since(self.last_mentioned).num_hours() as f32;
288            (-hours_since / 24.0).exp() // Exponential decay over days
289        };
290
291        let frequency_boost = (self.message_count as f32).ln().max(1.0);
292
293        self.confidence * time_decay * frequency_boost
294    }
295}
296
297/// Topic transition information
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct TopicTransition {
300    pub from_topics: Vec<String>,
301    pub to_topics: Vec<String>,
302    pub timestamp: chrono::DateTime<chrono::Utc>,
303    pub trigger_message_id: String,
304    pub confidence: f32,
305    pub transition_type: TransitionType,
306}
307
308/// Type of topic transition
309#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
310pub enum TransitionType {
311    NewTopic,
312    TopicShift,
313    TopicReturn,
314    TopicMerge,
315    TopicSplit,
316}
317
318impl PartialEq<&str> for TransitionType {
319    fn eq(&self, other: &&str) -> bool {
320        matches!(
321            (self, *other),
322            (TransitionType::NewTopic, "new")
323                | (TransitionType::TopicShift, "shift")
324                | (TransitionType::TopicReturn, "return")
325                | (TransitionType::TopicMerge, "merge")
326                | (TransitionType::TopicSplit, "split")
327        )
328    }
329}
330
331/// Session performance metrics
332#[derive(Debug, Default, Clone, Serialize, Deserialize)]
333pub struct SessionMetrics {
334    pub total_messages: usize,
335    pub user_messages: usize,
336    pub assistant_messages: usize,
337    pub average_response_time: f64,
338    pub total_tokens_used: usize,
339    pub successful_queries: usize,
340    pub failed_queries: usize,
341    pub context_compressions: usize,
342    pub topic_transitions: usize,
343    pub user_satisfaction_scores: Vec<f32>,
344    pub error_count: usize,
345    pub warning_count: usize,
346    pub cache_hits: usize,
347    pub cache_misses: usize,
348    pub last_updated: chrono::DateTime<chrono::Utc>,
349}
350
351impl SessionMetrics {
352    pub fn update_response_time(&mut self, response_time_ms: u64) {
353        let response_time_s = response_time_ms as f64 / 1000.0;
354        if self.assistant_messages == 0 {
355            self.average_response_time = response_time_s;
356        } else {
357            self.average_response_time =
358                (self.average_response_time * self.assistant_messages as f64 + response_time_s)
359                    / (self.assistant_messages as f64 + 1.0);
360        }
361        self.assistant_messages += 1;
362        self.total_messages += 1;
363        self.last_updated = chrono::Utc::now();
364    }
365
366    pub fn add_user_message(&mut self) {
367        self.user_messages += 1;
368        self.total_messages += 1;
369        self.last_updated = chrono::Utc::now();
370    }
371
372    pub fn add_successful_query(&mut self, tokens_used: usize) {
373        self.successful_queries += 1;
374        self.total_tokens_used += tokens_used;
375        self.last_updated = chrono::Utc::now();
376    }
377
378    pub fn add_failed_query(&mut self) {
379        self.failed_queries += 1;
380        self.error_count += 1;
381        self.last_updated = chrono::Utc::now();
382    }
383
384    pub fn add_context_compression(&mut self) {
385        self.context_compressions += 1;
386        self.last_updated = chrono::Utc::now();
387    }
388
389    pub fn add_topic_transition(&mut self) {
390        self.topic_transitions += 1;
391        self.last_updated = chrono::Utc::now();
392    }
393
394    pub fn add_satisfaction_score(&mut self, score: f32) {
395        self.user_satisfaction_scores.push(score.clamp(0.0, 5.0));
396        self.last_updated = chrono::Utc::now();
397    }
398
399    pub fn get_average_satisfaction(&self) -> f32 {
400        if self.user_satisfaction_scores.is_empty() {
401            0.0
402        } else {
403            self.user_satisfaction_scores.iter().sum::<f32>()
404                / self.user_satisfaction_scores.len() as f32
405        }
406    }
407
408    pub fn get_query_success_rate(&self) -> f32 {
409        let total_queries = self.successful_queries + self.failed_queries;
410        if total_queries == 0 {
411            0.0
412        } else {
413            self.successful_queries as f32 / total_queries as f32
414        }
415    }
416
417    pub fn get_cache_hit_rate(&self) -> f32 {
418        let total_cache_requests = self.cache_hits + self.cache_misses;
419        if total_cache_requests == 0 {
420            0.0
421        } else {
422            self.cache_hits as f32 / total_cache_requests as f32
423        }
424    }
425}
426
427/// Chat session data for serialization
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct SessionData {
430    pub id: String,
431    pub config: ChatConfig,
432    pub messages: Vec<Message>,
433    pub created_at: chrono::DateTime<chrono::Utc>,
434    pub last_activity: chrono::DateTime<chrono::Utc>,
435    pub user_preferences: HashMap<String, String>,
436    pub session_state: SessionState,
437    pub context_summary: Option<String>,
438    pub pinned_messages: HashSet<String>,
439    pub current_topics: Vec<Topic>,
440    pub topic_history: Vec<TopicTransition>,
441    pub performance_metrics: SessionMetrics,
442}
443
444// Compatibility aliases for lib.rs
445impl SessionData {
446    pub fn user_id(&self) -> Option<&str> {
447        self.user_preferences.get("user_id").map(|s| s.as_str())
448    }
449
450    pub fn set_user_id(&mut self, user_id: String) {
451        self.user_preferences.insert("user_id".to_string(), user_id);
452    }
453}
454
455// Re-export with compatibility
456pub use SessionState as state;