Skip to main content

reflex/semantic/
chat_session.rs

1//! Chat session management for interactive `rfx ask` mode
2//!
3//! This module manages conversation state, message history, token tracking,
4//! and context window management for the TUI chat interface.
5
6use chrono::{DateTime, Local};
7use serde::{Deserialize, Serialize};
8
9/// Maximum context window sizes by provider (in tokens)
10const OPENAI_CONTEXT_WINDOW: usize = 128_000;
11const ANTHROPIC_CONTEXT_WINDOW: usize = 200_000;
12/// Rough estimate: 4 characters per token (common heuristic)
13const CHARS_PER_TOKEN: usize = 4;
14
15/// A single message in the conversation
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18    /// Message role
19    pub role: MessageRole,
20
21    /// Message content
22    pub content: String,
23
24    /// Estimated token count for this message
25    pub tokens: usize,
26
27    /// Timestamp when message was created
28    pub timestamp: DateTime<Local>,
29
30    /// Optional metadata (queries executed, results found, etc.)
31    pub metadata: Option<MessageMetadata>,
32}
33
34/// Message role in conversation
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum MessageRole {
37    /// User input
38    User,
39
40    /// Assistant - Phase 1: Thinking/Assessment
41    AssistantThinking,
42
43    /// Assistant - Phase 2: Tool gathering results
44    AssistantTools,
45
46    /// Assistant - Phase 3: Generated queries
47    AssistantQueries,
48
49    /// Assistant - Phase 4: Execution status
50    AssistantExecuting,
51
52    /// Assistant - Phase 5: Final answer
53    AssistantAnswer,
54
55    /// System message (for compaction summaries, etc.)
56    System,
57}
58
59/// Metadata attached to assistant messages
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct MessageMetadata {
62    /// Generated queries (for AssistantQueries phase)
63    #[serde(default)]
64    pub queries: Vec<String>,
65
66    /// Tool calls made (for AssistantTools phase)
67    #[serde(default)]
68    pub tool_calls: Vec<String>,
69
70    /// Number of results found (for AssistantExecuting phase)
71    #[serde(default)]
72    pub results_count: usize,
73
74    /// Execution time in milliseconds
75    #[serde(default)]
76    pub execution_time_ms: Option<u64>,
77
78    /// Whether this needs more context (for AssistantThinking phase)
79    #[serde(default)]
80    pub needs_context: bool,
81}
82
83/// Chat session state
84pub struct ChatSession {
85    /// Conversation history
86    messages: Vec<Message>,
87
88    /// LLM provider name
89    provider: String,
90
91    /// Model name
92    model: String,
93
94    /// Context window limit for current model
95    context_limit: usize,
96
97    /// Total tokens used in conversation
98    total_tokens: usize,
99}
100
101impl ChatSession {
102    /// Create a new chat session
103    pub fn new(provider: String, model: String) -> Self {
104        let context_limit = Self::get_context_limit(&provider);
105
106        Self {
107            messages: Vec::new(),
108            provider,
109            model,
110            context_limit,
111            total_tokens: 0,
112        }
113    }
114
115    /// Add a user message to the conversation
116    pub fn add_user_message(&mut self, content: String) {
117        let tokens = Self::estimate_tokens(&content);
118        let message = Message {
119            role: MessageRole::User,
120            content,
121            tokens,
122            timestamp: Local::now(),
123            metadata: None,
124        };
125
126        self.total_tokens += tokens;
127        self.messages.push(message);
128    }
129
130    /// Add an assistant message to the conversation (generic)
131    pub fn add_assistant_message(
132        &mut self,
133        content: String,
134        role: MessageRole,
135        metadata: Option<MessageMetadata>,
136    ) {
137        let tokens = Self::estimate_tokens(&content);
138        let message = Message {
139            role,
140            content,
141            tokens,
142            timestamp: Local::now(),
143            metadata,
144        };
145
146        self.total_tokens += tokens;
147        self.messages.push(message);
148    }
149
150    /// Add a thinking/assessment message
151    pub fn add_thinking_message(&mut self, reasoning: String, needs_context: bool) {
152        let metadata = MessageMetadata {
153            queries: Vec::new(),
154            tool_calls: Vec::new(),
155            results_count: 0,
156            execution_time_ms: None,
157            needs_context,
158        };
159        self.add_assistant_message(reasoning, MessageRole::AssistantThinking, Some(metadata));
160    }
161
162    /// Add a tool gathering message
163    pub fn add_tools_message(&mut self, content: String, tool_calls: Vec<String>) {
164        let metadata = MessageMetadata {
165            queries: Vec::new(),
166            tool_calls,
167            results_count: 0,
168            execution_time_ms: None,
169            needs_context: false,
170        };
171        self.add_assistant_message(content, MessageRole::AssistantTools, Some(metadata));
172    }
173
174    /// Add a queries generated message
175    pub fn add_queries_message(&mut self, queries: Vec<String>) {
176        let content = format!("Generated {} queries", queries.len());
177        let metadata = MessageMetadata {
178            queries: queries.clone(),
179            tool_calls: Vec::new(),
180            results_count: 0,
181            execution_time_ms: None,
182            needs_context: false,
183        };
184        self.add_assistant_message(content, MessageRole::AssistantQueries, Some(metadata));
185    }
186
187    /// Add an execution status message
188    pub fn add_execution_message(&mut self, results_count: usize, execution_time_ms: u64) {
189        let content = format!(
190            "Found {} result{}",
191            results_count,
192            if results_count == 1 { "" } else { "s" }
193        );
194        let metadata = MessageMetadata {
195            queries: Vec::new(),
196            tool_calls: Vec::new(),
197            results_count,
198            execution_time_ms: Some(execution_time_ms),
199            needs_context: false,
200        };
201        self.add_assistant_message(content, MessageRole::AssistantExecuting, Some(metadata));
202    }
203
204    /// Add a final answer message
205    pub fn add_answer_message(&mut self, answer: String) {
206        self.add_assistant_message(answer, MessageRole::AssistantAnswer, None);
207    }
208
209    /// Add a system message (e.g., compaction summary)
210    pub fn add_system_message(&mut self, content: String) {
211        let tokens = Self::estimate_tokens(&content);
212        let message = Message {
213            role: MessageRole::System,
214            content,
215            tokens,
216            timestamp: Local::now(),
217            metadata: None,
218        };
219
220        self.total_tokens += tokens;
221        self.messages.push(message);
222    }
223
224    /// Clear all messages and reset token count
225    pub fn clear(&mut self) {
226        self.messages.clear();
227        self.total_tokens = 0;
228    }
229
230    /// Get all messages in the conversation
231    pub fn messages(&self) -> &[Message] {
232        &self.messages
233    }
234
235    /// Get total token count
236    pub fn total_tokens(&self) -> usize {
237        self.total_tokens
238    }
239
240    /// Get context window limit
241    pub fn context_limit(&self) -> usize {
242        self.context_limit
243    }
244
245    /// Get context usage as percentage (0.0 to 1.0)
246    pub fn context_usage(&self) -> f32 {
247        if self.context_limit == 0 {
248            return 0.0;
249        }
250        (self.total_tokens as f32) / (self.context_limit as f32)
251    }
252
253    /// Check if we're approaching context limit (>80%)
254    pub fn is_near_limit(&self) -> bool {
255        self.context_usage() > 0.8
256    }
257
258    /// Check if we should suggest compaction (>90%)
259    pub fn should_compact(&self) -> bool {
260        self.context_usage() > 0.9
261    }
262
263    /// Get provider name
264    pub fn provider(&self) -> &str {
265        &self.provider
266    }
267
268    /// Get model name
269    pub fn model(&self) -> &str {
270        &self.model
271    }
272
273    /// Update provider and model (for /model command)
274    pub fn update_provider(&mut self, provider: String, model: String) {
275        self.provider = provider.clone();
276        self.model = model;
277        self.context_limit = Self::get_context_limit(&provider);
278    }
279
280    /// Build conversation history for LLM prompt
281    ///
282    /// Returns a formatted string suitable for including in LLM prompts,
283    /// containing all messages in chronological order.
284    pub fn build_context(&self) -> String {
285        let mut context = String::new();
286
287        context.push_str("Previous conversation:\n");
288        context.push_str("======================\n\n");
289
290        for msg in &self.messages {
291            match msg.role {
292                MessageRole::User => {
293                    context.push_str(&format!("User: {}\n\n", msg.content));
294                }
295                MessageRole::AssistantThinking
296                | MessageRole::AssistantTools
297                | MessageRole::AssistantQueries
298                | MessageRole::AssistantExecuting
299                | MessageRole::AssistantAnswer => {
300                    context.push_str(&format!("Assistant: {}\n\n", msg.content));
301                }
302                MessageRole::System => {
303                    context.push_str(&format!("[System Note: {}]\n\n", msg.content));
304                }
305            }
306        }
307
308        context
309    }
310
311    /// Compact old messages by summarizing them
312    ///
313    /// Keeps the last `keep_recent` messages verbatim and returns the older
314    /// messages as a formatted string that can be sent to an LLM for summarization.
315    ///
316    /// Returns (old_messages_for_summary, kept_messages_count, tokens_to_compact)
317    pub fn prepare_compaction(&self, keep_recent: usize) -> (String, usize, usize) {
318        if self.messages.len() <= keep_recent {
319            return (String::new(), self.messages.len(), 0);
320        }
321
322        let split_point = self.messages.len() - keep_recent;
323        let old_messages = &self.messages[..split_point];
324
325        let mut summary_text = String::new();
326        let mut tokens_to_compact = 0;
327
328        for msg in old_messages {
329            tokens_to_compact += msg.tokens;
330
331            match msg.role {
332                MessageRole::User => {
333                    summary_text.push_str(&format!("User: {}\n\n", msg.content));
334                }
335                MessageRole::AssistantThinking
336                | MessageRole::AssistantTools
337                | MessageRole::AssistantQueries
338                | MessageRole::AssistantExecuting
339                | MessageRole::AssistantAnswer => {
340                    summary_text.push_str(&format!("Assistant: {}\n\n", msg.content));
341                }
342                MessageRole::System => {
343                    summary_text.push_str(&format!("[System: {}]\n\n", msg.content));
344                }
345            }
346        }
347
348        (summary_text, old_messages.len(), tokens_to_compact)
349    }
350
351    /// Apply compaction by replacing old messages with a summary
352    ///
353    /// Removes the first `remove_count` messages and replaces them with
354    /// a single system message containing the summary.
355    pub fn apply_compaction(&mut self, remove_count: usize, summary: String) {
356        if remove_count >= self.messages.len() {
357            // Safety check: don't remove all messages
358            return;
359        }
360
361        // Calculate tokens being removed
362        let removed_tokens: usize = self.messages[..remove_count].iter().map(|m| m.tokens).sum();
363
364        // Remove old messages
365        self.messages.drain(..remove_count);
366
367        // Add summary as system message at the beginning
368        let summary_tokens = Self::estimate_tokens(&summary);
369        let summary_msg = Message {
370            role: MessageRole::System,
371            content: format!("Summary of previous conversation: {}", summary),
372            tokens: summary_tokens,
373            timestamp: Local::now(),
374            metadata: None,
375        };
376
377        self.messages.insert(0, summary_msg);
378
379        // Update total token count
380        self.total_tokens = self.total_tokens - removed_tokens + summary_tokens;
381    }
382
383    /// Estimate token count from text (rough heuristic: ~4 chars per token)
384    fn estimate_tokens(text: &str) -> usize {
385        (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN
386    }
387
388    /// Get context window limit for a provider
389    fn get_context_limit(provider: &str) -> usize {
390        match provider.to_lowercase().as_str() {
391            "openai" => OPENAI_CONTEXT_WINDOW,
392            "anthropic" => ANTHROPIC_CONTEXT_WINDOW,
393            _ => 32_000, // Conservative default
394        }
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_new_session() {
404        let session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
405        assert_eq!(session.messages().len(), 0);
406        assert_eq!(session.total_tokens(), 0);
407        assert_eq!(session.context_limit(), OPENAI_CONTEXT_WINDOW);
408    }
409
410    #[test]
411    fn test_add_messages() {
412        let mut session = ChatSession::new("anthropic".to_string(), "claude-3-5-haiku".to_string());
413
414        session.add_user_message("Hello!".to_string());
415        assert_eq!(session.messages().len(), 1);
416        assert!(session.total_tokens() > 0);
417
418        session.add_answer_message("Hi there!".to_string());
419        assert_eq!(session.messages().len(), 2);
420    }
421
422    #[test]
423    fn test_clear() {
424        let mut session = ChatSession::new("openai".to_string(), "gpt-4o".to_string());
425        session.add_user_message("Test".to_string());
426        session.add_answer_message("Response".to_string());
427
428        assert_eq!(session.messages().len(), 2);
429
430        session.clear();
431        assert_eq!(session.messages().len(), 0);
432        assert_eq!(session.total_tokens(), 0);
433    }
434
435    #[test]
436    fn test_context_usage() {
437        let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
438        assert_eq!(session.context_usage(), 0.0);
439
440        // Add a message that's roughly 1/4 of the context window
441        let large_text = "a".repeat(OPENAI_CONTEXT_WINDOW * CHARS_PER_TOKEN / 4);
442        session.add_user_message(large_text);
443
444        let usage = session.context_usage();
445        assert!(usage > 0.2 && usage < 0.3); // Should be around 25%
446    }
447
448    #[test]
449    fn test_prepare_compaction() {
450        let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
451
452        for i in 0..10 {
453            session.add_user_message(format!("Message {}", i));
454            session.add_answer_message(format!("Response {}", i));
455        }
456
457        let (summary_text, old_count, tokens) = session.prepare_compaction(4);
458
459        assert_eq!(old_count, 16); // 20 messages - 4 kept = 16 old
460        assert!(!summary_text.is_empty());
461        assert!(tokens > 0);
462    }
463
464    #[test]
465    fn test_apply_compaction() {
466        let mut session = ChatSession::new("anthropic".to_string(), "claude".to_string());
467
468        for i in 0..6 {
469            session.add_user_message(format!("Q{}", i));
470            session.add_answer_message(format!("A{}", i));
471        }
472
473        let initial_count = session.messages().len();
474        let initial_tokens = session.total_tokens();
475
476        session.apply_compaction(8, "This is a summary".to_string());
477
478        // Should have: 1 summary + 4 kept messages = 5 total
479        assert_eq!(session.messages().len(), 5);
480        assert_eq!(session.messages()[0].role, MessageRole::System);
481
482        // Token count should be updated
483        assert!(session.total_tokens() < initial_tokens);
484    }
485
486    #[test]
487    fn test_estimate_tokens() {
488        let text = "Hello, world!"; // 13 chars
489        let tokens = ChatSession::estimate_tokens(text);
490        // Uses ceiling division: (13 + 4 - 1) / 4 = 16 / 4 = 4
491        assert_eq!(tokens, (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN);
492        assert_eq!(tokens, 4);
493    }
494}