Skip to main content

reflex/semantic/
chat_session.rs

1//! Chat session management for interactive `rfx ask` mode
2//!
3//! This module manages conversation state, message history, token tracking,
4//! and context window management for the TUI chat interface.
5
6use chrono::{DateTime, Local};
7use serde::{Deserialize, Serialize};
8
9/// Maximum context window sizes by provider (in tokens)
10const OPENAI_CONTEXT_WINDOW: usize = 128_000;
11const ANTHROPIC_CONTEXT_WINDOW: usize = 200_000;
12const GROQ_CONTEXT_WINDOW: usize = 32_000; // Conservative default for Groq
13
14/// Rough estimate: 4 characters per token (common heuristic)
15const CHARS_PER_TOKEN: usize = 4;
16
17/// A single message in the conversation
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Message {
20    /// Message role
21    pub role: MessageRole,
22
23    /// Message content
24    pub content: String,
25
26    /// Estimated token count for this message
27    pub tokens: usize,
28
29    /// Timestamp when message was created
30    pub timestamp: DateTime<Local>,
31
32    /// Optional metadata (queries executed, results found, etc.)
33    pub metadata: Option<MessageMetadata>,
34}
35
36/// Message role in conversation
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum MessageRole {
39    /// User input
40    User,
41
42    /// Assistant - Phase 1: Thinking/Assessment
43    AssistantThinking,
44
45    /// Assistant - Phase 2: Tool gathering results
46    AssistantTools,
47
48    /// Assistant - Phase 3: Generated queries
49    AssistantQueries,
50
51    /// Assistant - Phase 4: Execution status
52    AssistantExecuting,
53
54    /// Assistant - Phase 5: Final answer
55    AssistantAnswer,
56
57    /// System message (for compaction summaries, etc.)
58    System,
59}
60
61/// Metadata attached to assistant messages
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct MessageMetadata {
64    /// Generated queries (for AssistantQueries phase)
65    #[serde(default)]
66    pub queries: Vec<String>,
67
68    /// Tool calls made (for AssistantTools phase)
69    #[serde(default)]
70    pub tool_calls: Vec<String>,
71
72    /// Number of results found (for AssistantExecuting phase)
73    #[serde(default)]
74    pub results_count: usize,
75
76    /// Execution time in milliseconds
77    #[serde(default)]
78    pub execution_time_ms: Option<u64>,
79
80    /// Whether this needs more context (for AssistantThinking phase)
81    #[serde(default)]
82    pub needs_context: bool,
83}
84
85/// Chat session state
86pub struct ChatSession {
87    /// Conversation history
88    messages: Vec<Message>,
89
90    /// LLM provider name
91    provider: String,
92
93    /// Model name
94    model: String,
95
96    /// Context window limit for current model
97    context_limit: usize,
98
99    /// Total tokens used in conversation
100    total_tokens: usize,
101}
102
103impl ChatSession {
104    /// Create a new chat session
105    pub fn new(provider: String, model: String) -> Self {
106        let context_limit = Self::get_context_limit(&provider);
107
108        Self {
109            messages: Vec::new(),
110            provider,
111            model,
112            context_limit,
113            total_tokens: 0,
114        }
115    }
116
117    /// Add a user message to the conversation
118    pub fn add_user_message(&mut self, content: String) {
119        let tokens = Self::estimate_tokens(&content);
120        let message = Message {
121            role: MessageRole::User,
122            content,
123            tokens,
124            timestamp: Local::now(),
125            metadata: None,
126        };
127
128        self.total_tokens += tokens;
129        self.messages.push(message);
130    }
131
132    /// Add an assistant message to the conversation (generic)
133    pub fn add_assistant_message(&mut self, content: String, role: MessageRole, metadata: Option<MessageMetadata>) {
134        let tokens = Self::estimate_tokens(&content);
135        let message = Message {
136            role,
137            content,
138            tokens,
139            timestamp: Local::now(),
140            metadata,
141        };
142
143        self.total_tokens += tokens;
144        self.messages.push(message);
145    }
146
147    /// Add a thinking/assessment message
148    pub fn add_thinking_message(&mut self, reasoning: String, needs_context: bool) {
149        let metadata = MessageMetadata {
150            queries: Vec::new(),
151            tool_calls: Vec::new(),
152            results_count: 0,
153            execution_time_ms: None,
154            needs_context,
155        };
156        self.add_assistant_message(reasoning, MessageRole::AssistantThinking, Some(metadata));
157    }
158
159    /// Add a tool gathering message
160    pub fn add_tools_message(&mut self, content: String, tool_calls: Vec<String>) {
161        let metadata = MessageMetadata {
162            queries: Vec::new(),
163            tool_calls,
164            results_count: 0,
165            execution_time_ms: None,
166            needs_context: false,
167        };
168        self.add_assistant_message(content, MessageRole::AssistantTools, Some(metadata));
169    }
170
171    /// Add a queries generated message
172    pub fn add_queries_message(&mut self, queries: Vec<String>) {
173        let content = format!("Generated {} queries", queries.len());
174        let metadata = MessageMetadata {
175            queries: queries.clone(),
176            tool_calls: Vec::new(),
177            results_count: 0,
178            execution_time_ms: None,
179            needs_context: false,
180        };
181        self.add_assistant_message(content, MessageRole::AssistantQueries, Some(metadata));
182    }
183
184    /// Add an execution status message
185    pub fn add_execution_message(&mut self, results_count: usize, execution_time_ms: u64) {
186        let content = format!("Found {} results", results_count);
187        let metadata = MessageMetadata {
188            queries: Vec::new(),
189            tool_calls: Vec::new(),
190            results_count,
191            execution_time_ms: Some(execution_time_ms),
192            needs_context: false,
193        };
194        self.add_assistant_message(content, MessageRole::AssistantExecuting, Some(metadata));
195    }
196
197    /// Add a final answer message
198    pub fn add_answer_message(&mut self, answer: String) {
199        self.add_assistant_message(answer, MessageRole::AssistantAnswer, None);
200    }
201
202    /// Add a system message (e.g., compaction summary)
203    pub fn add_system_message(&mut self, content: String) {
204        let tokens = Self::estimate_tokens(&content);
205        let message = Message {
206            role: MessageRole::System,
207            content,
208            tokens,
209            timestamp: Local::now(),
210            metadata: None,
211        };
212
213        self.total_tokens += tokens;
214        self.messages.push(message);
215    }
216
217    /// Clear all messages and reset token count
218    pub fn clear(&mut self) {
219        self.messages.clear();
220        self.total_tokens = 0;
221    }
222
223    /// Get all messages in the conversation
224    pub fn messages(&self) -> &[Message] {
225        &self.messages
226    }
227
228    /// Get total token count
229    pub fn total_tokens(&self) -> usize {
230        self.total_tokens
231    }
232
233    /// Get context window limit
234    pub fn context_limit(&self) -> usize {
235        self.context_limit
236    }
237
238    /// Get context usage as percentage (0.0 to 1.0)
239    pub fn context_usage(&self) -> f32 {
240        if self.context_limit == 0 {
241            return 0.0;
242        }
243        (self.total_tokens as f32) / (self.context_limit as f32)
244    }
245
246    /// Check if we're approaching context limit (>80%)
247    pub fn is_near_limit(&self) -> bool {
248        self.context_usage() > 0.8
249    }
250
251    /// Check if we should suggest compaction (>90%)
252    pub fn should_compact(&self) -> bool {
253        self.context_usage() > 0.9
254    }
255
256    /// Get provider name
257    pub fn provider(&self) -> &str {
258        &self.provider
259    }
260
261    /// Get model name
262    pub fn model(&self) -> &str {
263        &self.model
264    }
265
266    /// Update provider and model (for /model command)
267    pub fn update_provider(&mut self, provider: String, model: String) {
268        self.provider = provider.clone();
269        self.model = model;
270        self.context_limit = Self::get_context_limit(&provider);
271    }
272
273    /// Build conversation history for LLM prompt
274    ///
275    /// Returns a formatted string suitable for including in LLM prompts,
276    /// containing all messages in chronological order.
277    pub fn build_context(&self) -> String {
278        let mut context = String::new();
279
280        context.push_str("Previous conversation:\n");
281        context.push_str("======================\n\n");
282
283        for msg in &self.messages {
284            match msg.role {
285                MessageRole::User => {
286                    context.push_str(&format!("User: {}\n\n", msg.content));
287                }
288                MessageRole::AssistantThinking
289                | MessageRole::AssistantTools
290                | MessageRole::AssistantQueries
291                | MessageRole::AssistantExecuting
292                | MessageRole::AssistantAnswer => {
293                    context.push_str(&format!("Assistant: {}\n\n", msg.content));
294                }
295                MessageRole::System => {
296                    context.push_str(&format!("[System Note: {}]\n\n", msg.content));
297                }
298            }
299        }
300
301        context
302    }
303
304    /// Compact old messages by summarizing them
305    ///
306    /// Keeps the last `keep_recent` messages verbatim and returns the older
307    /// messages as a formatted string that can be sent to an LLM for summarization.
308    ///
309    /// Returns (old_messages_for_summary, kept_messages_count, tokens_to_compact)
310    pub fn prepare_compaction(&self, keep_recent: usize) -> (String, usize, usize) {
311        if self.messages.len() <= keep_recent {
312            return (String::new(), self.messages.len(), 0);
313        }
314
315        let split_point = self.messages.len() - keep_recent;
316        let old_messages = &self.messages[..split_point];
317
318        let mut summary_text = String::new();
319        let mut tokens_to_compact = 0;
320
321        for msg in old_messages {
322            tokens_to_compact += msg.tokens;
323
324            match msg.role {
325                MessageRole::User => {
326                    summary_text.push_str(&format!("User: {}\n\n", msg.content));
327                }
328                MessageRole::AssistantThinking
329                | MessageRole::AssistantTools
330                | MessageRole::AssistantQueries
331                | MessageRole::AssistantExecuting
332                | MessageRole::AssistantAnswer => {
333                    summary_text.push_str(&format!("Assistant: {}\n\n", msg.content));
334                }
335                MessageRole::System => {
336                    summary_text.push_str(&format!("[System: {}]\n\n", msg.content));
337                }
338            }
339        }
340
341        (summary_text, old_messages.len(), tokens_to_compact)
342    }
343
344    /// Apply compaction by replacing old messages with a summary
345    ///
346    /// Removes the first `remove_count` messages and replaces them with
347    /// a single system message containing the summary.
348    pub fn apply_compaction(&mut self, remove_count: usize, summary: String) {
349        if remove_count >= self.messages.len() {
350            // Safety check: don't remove all messages
351            return;
352        }
353
354        // Calculate tokens being removed
355        let removed_tokens: usize = self.messages[..remove_count]
356            .iter()
357            .map(|m| m.tokens)
358            .sum();
359
360        // Remove old messages
361        self.messages.drain(..remove_count);
362
363        // Add summary as system message at the beginning
364        let summary_tokens = Self::estimate_tokens(&summary);
365        let summary_msg = Message {
366            role: MessageRole::System,
367            content: format!("Summary of previous conversation: {}", summary),
368            tokens: summary_tokens,
369            timestamp: Local::now(),
370            metadata: None,
371        };
372
373        self.messages.insert(0, summary_msg);
374
375        // Update total token count
376        self.total_tokens = self.total_tokens - removed_tokens + summary_tokens;
377    }
378
379    /// Estimate token count from text (rough heuristic: ~4 chars per token)
380    fn estimate_tokens(text: &str) -> usize {
381        (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN
382    }
383
384    /// Get context window limit for a provider
385    fn get_context_limit(provider: &str) -> usize {
386        match provider.to_lowercase().as_str() {
387            "openai" => OPENAI_CONTEXT_WINDOW,
388            "anthropic" => ANTHROPIC_CONTEXT_WINDOW,
389            "groq" => GROQ_CONTEXT_WINDOW,
390            _ => 32_000, // Conservative default
391        }
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_new_session() {
401        let session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
402        assert_eq!(session.messages().len(), 0);
403        assert_eq!(session.total_tokens(), 0);
404        assert_eq!(session.context_limit(), OPENAI_CONTEXT_WINDOW);
405    }
406
407    #[test]
408    fn test_add_messages() {
409        let mut session = ChatSession::new("anthropic".to_string(), "claude-3-5-haiku".to_string());
410
411        session.add_user_message("Hello!".to_string());
412        assert_eq!(session.messages().len(), 1);
413        assert!(session.total_tokens() > 0);
414
415        session.add_answer_message("Hi there!".to_string());
416        assert_eq!(session.messages().len(), 2);
417    }
418
419    #[test]
420    fn test_clear() {
421        let mut session = ChatSession::new("openai".to_string(), "gpt-4o".to_string());
422        session.add_user_message("Test".to_string());
423        session.add_answer_message("Response".to_string());
424
425        assert_eq!(session.messages().len(), 2);
426
427        session.clear();
428        assert_eq!(session.messages().len(), 0);
429        assert_eq!(session.total_tokens(), 0);
430    }
431
432    #[test]
433    fn test_context_usage() {
434        let mut session = ChatSession::new("groq".to_string(), "llama-3.3-70b".to_string());
435        assert_eq!(session.context_usage(), 0.0);
436
437        // Add a message that's roughly 1/4 of the context window
438        let large_text = "a".repeat(GROQ_CONTEXT_WINDOW * CHARS_PER_TOKEN / 4);
439        session.add_user_message(large_text);
440
441        let usage = session.context_usage();
442        assert!(usage > 0.2 && usage < 0.3); // Should be around 25%
443    }
444
445    #[test]
446    fn test_prepare_compaction() {
447        let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
448
449        for i in 0..10 {
450            session.add_user_message(format!("Message {}", i));
451            session.add_answer_message(format!("Response {}", i));
452        }
453
454        let (summary_text, old_count, tokens) = session.prepare_compaction(4);
455
456        assert_eq!(old_count, 16); // 20 messages - 4 kept = 16 old
457        assert!(!summary_text.is_empty());
458        assert!(tokens > 0);
459    }
460
461    #[test]
462    fn test_apply_compaction() {
463        let mut session = ChatSession::new("anthropic".to_string(), "claude".to_string());
464
465        for i in 0..6 {
466            session.add_user_message(format!("Q{}", i));
467            session.add_answer_message(format!("A{}", i));
468        }
469
470        let initial_count = session.messages().len();
471        let initial_tokens = session.total_tokens();
472
473        session.apply_compaction(8, "This is a summary".to_string());
474
475        // Should have: 1 summary + 4 kept messages = 5 total
476        assert_eq!(session.messages().len(), 5);
477        assert_eq!(session.messages()[0].role, MessageRole::System);
478
479        // Token count should be updated
480        assert!(session.total_tokens() < initial_tokens);
481    }
482
483    #[test]
484    fn test_estimate_tokens() {
485        let text = "Hello, world!"; // 13 chars
486        let tokens = ChatSession::estimate_tokens(text);
487        // Uses ceiling division: (13 + 4 - 1) / 4 = 16 / 4 = 4
488        assert_eq!(tokens, (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN);
489        assert_eq!(tokens, 4);
490    }
491}