Skip to main content

reflex/semantic/
chat_session.rs

1//! Chat session management for interactive `rfx ask` mode
2//!
3//! This module manages conversation state, message history, token tracking,
4//! and context window management for the TUI chat interface.
5
6use chrono::{DateTime, Local};
7use serde::{Deserialize, Serialize};
8
9/// Maximum context window sizes by provider (in tokens)
10const OPENAI_CONTEXT_WINDOW: usize = 128_000;
11const ANTHROPIC_CONTEXT_WINDOW: usize = 200_000;
12/// Rough estimate: 4 characters per token (common heuristic)
13const CHARS_PER_TOKEN: usize = 4;
14
15/// A single message in the conversation
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18    /// Message role
19    pub role: MessageRole,
20
21    /// Message content
22    pub content: String,
23
24    /// Estimated token count for this message
25    pub tokens: usize,
26
27    /// Timestamp when message was created
28    pub timestamp: DateTime<Local>,
29
30    /// Optional metadata (queries executed, results found, etc.)
31    pub metadata: Option<MessageMetadata>,
32}
33
34/// Message role in conversation
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum MessageRole {
37    /// User input
38    User,
39
40    /// Assistant - Phase 1: Thinking/Assessment
41    AssistantThinking,
42
43    /// Assistant - Phase 2: Tool gathering results
44    AssistantTools,
45
46    /// Assistant - Phase 3: Generated queries
47    AssistantQueries,
48
49    /// Assistant - Phase 4: Execution status
50    AssistantExecuting,
51
52    /// Assistant - Phase 5: Final answer
53    AssistantAnswer,
54
55    /// System message (for compaction summaries, etc.)
56    System,
57}
58
59/// Metadata attached to assistant messages
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct MessageMetadata {
62    /// Generated queries (for AssistantQueries phase)
63    #[serde(default)]
64    pub queries: Vec<String>,
65
66    /// Tool calls made (for AssistantTools phase)
67    #[serde(default)]
68    pub tool_calls: Vec<String>,
69
70    /// Number of results found (for AssistantExecuting phase)
71    #[serde(default)]
72    pub results_count: usize,
73
74    /// Execution time in milliseconds
75    #[serde(default)]
76    pub execution_time_ms: Option<u64>,
77
78    /// Whether this needs more context (for AssistantThinking phase)
79    #[serde(default)]
80    pub needs_context: bool,
81}
82
83/// Chat session state
84pub struct ChatSession {
85    /// Conversation history
86    messages: Vec<Message>,
87
88    /// LLM provider name
89    provider: String,
90
91    /// Model name
92    model: String,
93
94    /// Context window limit for current model
95    context_limit: usize,
96
97    /// Total tokens used in conversation
98    total_tokens: usize,
99}
100
101impl ChatSession {
102    /// Create a new chat session
103    pub fn new(provider: String, model: String) -> Self {
104        let context_limit = Self::get_context_limit(&provider);
105
106        Self {
107            messages: Vec::new(),
108            provider,
109            model,
110            context_limit,
111            total_tokens: 0,
112        }
113    }
114
115    /// Add a user message to the conversation
116    pub fn add_user_message(&mut self, content: String) {
117        let tokens = Self::estimate_tokens(&content);
118        let message = Message {
119            role: MessageRole::User,
120            content,
121            tokens,
122            timestamp: Local::now(),
123            metadata: None,
124        };
125
126        self.total_tokens += tokens;
127        self.messages.push(message);
128    }
129
130    /// Add an assistant message to the conversation (generic)
131    pub fn add_assistant_message(&mut self, content: String, role: MessageRole, metadata: Option<MessageMetadata>) {
132        let tokens = Self::estimate_tokens(&content);
133        let message = Message {
134            role,
135            content,
136            tokens,
137            timestamp: Local::now(),
138            metadata,
139        };
140
141        self.total_tokens += tokens;
142        self.messages.push(message);
143    }
144
145    /// Add a thinking/assessment message
146    pub fn add_thinking_message(&mut self, reasoning: String, needs_context: bool) {
147        let metadata = MessageMetadata {
148            queries: Vec::new(),
149            tool_calls: Vec::new(),
150            results_count: 0,
151            execution_time_ms: None,
152            needs_context,
153        };
154        self.add_assistant_message(reasoning, MessageRole::AssistantThinking, Some(metadata));
155    }
156
157    /// Add a tool gathering message
158    pub fn add_tools_message(&mut self, content: String, tool_calls: Vec<String>) {
159        let metadata = MessageMetadata {
160            queries: Vec::new(),
161            tool_calls,
162            results_count: 0,
163            execution_time_ms: None,
164            needs_context: false,
165        };
166        self.add_assistant_message(content, MessageRole::AssistantTools, Some(metadata));
167    }
168
169    /// Add a queries generated message
170    pub fn add_queries_message(&mut self, queries: Vec<String>) {
171        let content = format!("Generated {} queries", queries.len());
172        let metadata = MessageMetadata {
173            queries: queries.clone(),
174            tool_calls: Vec::new(),
175            results_count: 0,
176            execution_time_ms: None,
177            needs_context: false,
178        };
179        self.add_assistant_message(content, MessageRole::AssistantQueries, Some(metadata));
180    }
181
182    /// Add an execution status message
183    pub fn add_execution_message(&mut self, results_count: usize, execution_time_ms: u64) {
184        let content = format!("Found {} results", results_count);
185        let metadata = MessageMetadata {
186            queries: Vec::new(),
187            tool_calls: Vec::new(),
188            results_count,
189            execution_time_ms: Some(execution_time_ms),
190            needs_context: false,
191        };
192        self.add_assistant_message(content, MessageRole::AssistantExecuting, Some(metadata));
193    }
194
195    /// Add a final answer message
196    pub fn add_answer_message(&mut self, answer: String) {
197        self.add_assistant_message(answer, MessageRole::AssistantAnswer, None);
198    }
199
200    /// Add a system message (e.g., compaction summary)
201    pub fn add_system_message(&mut self, content: String) {
202        let tokens = Self::estimate_tokens(&content);
203        let message = Message {
204            role: MessageRole::System,
205            content,
206            tokens,
207            timestamp: Local::now(),
208            metadata: None,
209        };
210
211        self.total_tokens += tokens;
212        self.messages.push(message);
213    }
214
215    /// Clear all messages and reset token count
216    pub fn clear(&mut self) {
217        self.messages.clear();
218        self.total_tokens = 0;
219    }
220
221    /// Get all messages in the conversation
222    pub fn messages(&self) -> &[Message] {
223        &self.messages
224    }
225
226    /// Get total token count
227    pub fn total_tokens(&self) -> usize {
228        self.total_tokens
229    }
230
231    /// Get context window limit
232    pub fn context_limit(&self) -> usize {
233        self.context_limit
234    }
235
236    /// Get context usage as percentage (0.0 to 1.0)
237    pub fn context_usage(&self) -> f32 {
238        if self.context_limit == 0 {
239            return 0.0;
240        }
241        (self.total_tokens as f32) / (self.context_limit as f32)
242    }
243
244    /// Check if we're approaching context limit (>80%)
245    pub fn is_near_limit(&self) -> bool {
246        self.context_usage() > 0.8
247    }
248
249    /// Check if we should suggest compaction (>90%)
250    pub fn should_compact(&self) -> bool {
251        self.context_usage() > 0.9
252    }
253
254    /// Get provider name
255    pub fn provider(&self) -> &str {
256        &self.provider
257    }
258
259    /// Get model name
260    pub fn model(&self) -> &str {
261        &self.model
262    }
263
264    /// Update provider and model (for /model command)
265    pub fn update_provider(&mut self, provider: String, model: String) {
266        self.provider = provider.clone();
267        self.model = model;
268        self.context_limit = Self::get_context_limit(&provider);
269    }
270
271    /// Build conversation history for LLM prompt
272    ///
273    /// Returns a formatted string suitable for including in LLM prompts,
274    /// containing all messages in chronological order.
275    pub fn build_context(&self) -> String {
276        let mut context = String::new();
277
278        context.push_str("Previous conversation:\n");
279        context.push_str("======================\n\n");
280
281        for msg in &self.messages {
282            match msg.role {
283                MessageRole::User => {
284                    context.push_str(&format!("User: {}\n\n", msg.content));
285                }
286                MessageRole::AssistantThinking
287                | MessageRole::AssistantTools
288                | MessageRole::AssistantQueries
289                | MessageRole::AssistantExecuting
290                | MessageRole::AssistantAnswer => {
291                    context.push_str(&format!("Assistant: {}\n\n", msg.content));
292                }
293                MessageRole::System => {
294                    context.push_str(&format!("[System Note: {}]\n\n", msg.content));
295                }
296            }
297        }
298
299        context
300    }
301
302    /// Compact old messages by summarizing them
303    ///
304    /// Keeps the last `keep_recent` messages verbatim and returns the older
305    /// messages as a formatted string that can be sent to an LLM for summarization.
306    ///
307    /// Returns (old_messages_for_summary, kept_messages_count, tokens_to_compact)
308    pub fn prepare_compaction(&self, keep_recent: usize) -> (String, usize, usize) {
309        if self.messages.len() <= keep_recent {
310            return (String::new(), self.messages.len(), 0);
311        }
312
313        let split_point = self.messages.len() - keep_recent;
314        let old_messages = &self.messages[..split_point];
315
316        let mut summary_text = String::new();
317        let mut tokens_to_compact = 0;
318
319        for msg in old_messages {
320            tokens_to_compact += msg.tokens;
321
322            match msg.role {
323                MessageRole::User => {
324                    summary_text.push_str(&format!("User: {}\n\n", msg.content));
325                }
326                MessageRole::AssistantThinking
327                | MessageRole::AssistantTools
328                | MessageRole::AssistantQueries
329                | MessageRole::AssistantExecuting
330                | MessageRole::AssistantAnswer => {
331                    summary_text.push_str(&format!("Assistant: {}\n\n", msg.content));
332                }
333                MessageRole::System => {
334                    summary_text.push_str(&format!("[System: {}]\n\n", msg.content));
335                }
336            }
337        }
338
339        (summary_text, old_messages.len(), tokens_to_compact)
340    }
341
342    /// Apply compaction by replacing old messages with a summary
343    ///
344    /// Removes the first `remove_count` messages and replaces them with
345    /// a single system message containing the summary.
346    pub fn apply_compaction(&mut self, remove_count: usize, summary: String) {
347        if remove_count >= self.messages.len() {
348            // Safety check: don't remove all messages
349            return;
350        }
351
352        // Calculate tokens being removed
353        let removed_tokens: usize = self.messages[..remove_count]
354            .iter()
355            .map(|m| m.tokens)
356            .sum();
357
358        // Remove old messages
359        self.messages.drain(..remove_count);
360
361        // Add summary as system message at the beginning
362        let summary_tokens = Self::estimate_tokens(&summary);
363        let summary_msg = Message {
364            role: MessageRole::System,
365            content: format!("Summary of previous conversation: {}", summary),
366            tokens: summary_tokens,
367            timestamp: Local::now(),
368            metadata: None,
369        };
370
371        self.messages.insert(0, summary_msg);
372
373        // Update total token count
374        self.total_tokens = self.total_tokens - removed_tokens + summary_tokens;
375    }
376
377    /// Estimate token count from text (rough heuristic: ~4 chars per token)
378    fn estimate_tokens(text: &str) -> usize {
379        (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN
380    }
381
382    /// Get context window limit for a provider
383    fn get_context_limit(provider: &str) -> usize {
384        match provider.to_lowercase().as_str() {
385            "openai" => OPENAI_CONTEXT_WINDOW,
386            "anthropic" => ANTHROPIC_CONTEXT_WINDOW,
387            _ => 32_000, // Conservative default
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_new_session() {
398        let session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
399        assert_eq!(session.messages().len(), 0);
400        assert_eq!(session.total_tokens(), 0);
401        assert_eq!(session.context_limit(), OPENAI_CONTEXT_WINDOW);
402    }
403
404    #[test]
405    fn test_add_messages() {
406        let mut session = ChatSession::new("anthropic".to_string(), "claude-3-5-haiku".to_string());
407
408        session.add_user_message("Hello!".to_string());
409        assert_eq!(session.messages().len(), 1);
410        assert!(session.total_tokens() > 0);
411
412        session.add_answer_message("Hi there!".to_string());
413        assert_eq!(session.messages().len(), 2);
414    }
415
416    #[test]
417    fn test_clear() {
418        let mut session = ChatSession::new("openai".to_string(), "gpt-4o".to_string());
419        session.add_user_message("Test".to_string());
420        session.add_answer_message("Response".to_string());
421
422        assert_eq!(session.messages().len(), 2);
423
424        session.clear();
425        assert_eq!(session.messages().len(), 0);
426        assert_eq!(session.total_tokens(), 0);
427    }
428
429    #[test]
430    fn test_context_usage() {
431        let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
432        assert_eq!(session.context_usage(), 0.0);
433
434        // Add a message that's roughly 1/4 of the context window
435        let large_text = "a".repeat(OPENAI_CONTEXT_WINDOW * CHARS_PER_TOKEN / 4);
436        session.add_user_message(large_text);
437
438        let usage = session.context_usage();
439        assert!(usage > 0.2 && usage < 0.3); // Should be around 25%
440    }
441
442    #[test]
443    fn test_prepare_compaction() {
444        let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
445
446        for i in 0..10 {
447            session.add_user_message(format!("Message {}", i));
448            session.add_answer_message(format!("Response {}", i));
449        }
450
451        let (summary_text, old_count, tokens) = session.prepare_compaction(4);
452
453        assert_eq!(old_count, 16); // 20 messages - 4 kept = 16 old
454        assert!(!summary_text.is_empty());
455        assert!(tokens > 0);
456    }
457
458    #[test]
459    fn test_apply_compaction() {
460        let mut session = ChatSession::new("anthropic".to_string(), "claude".to_string());
461
462        for i in 0..6 {
463            session.add_user_message(format!("Q{}", i));
464            session.add_answer_message(format!("A{}", i));
465        }
466
467        let initial_count = session.messages().len();
468        let initial_tokens = session.total_tokens();
469
470        session.apply_compaction(8, "This is a summary".to_string());
471
472        // Should have: 1 summary + 4 kept messages = 5 total
473        assert_eq!(session.messages().len(), 5);
474        assert_eq!(session.messages()[0].role, MessageRole::System);
475
476        // Token count should be updated
477        assert!(session.total_tokens() < initial_tokens);
478    }
479
480    #[test]
481    fn test_estimate_tokens() {
482        let text = "Hello, world!"; // 13 chars
483        let tokens = ChatSession::estimate_tokens(text);
484        // Uses ceiling division: (13 + 4 - 1) / 4 = 16 / 4 = 4
485        assert_eq!(tokens, (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN);
486        assert_eq!(tokens, 4);
487    }
488}