1use chrono::{DateTime, Local};
7use serde::{Deserialize, Serialize};
8
9const OPENAI_CONTEXT_WINDOW: usize = 128_000;
11const ANTHROPIC_CONTEXT_WINDOW: usize = 200_000;
12const CHARS_PER_TOKEN: usize = 4;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18 pub role: MessageRole,
20
21 pub content: String,
23
24 pub tokens: usize,
26
27 pub timestamp: DateTime<Local>,
29
30 pub metadata: Option<MessageMetadata>,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum MessageRole {
37 User,
39
40 AssistantThinking,
42
43 AssistantTools,
45
46 AssistantQueries,
48
49 AssistantExecuting,
51
52 AssistantAnswer,
54
55 System,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct MessageMetadata {
62 #[serde(default)]
64 pub queries: Vec<String>,
65
66 #[serde(default)]
68 pub tool_calls: Vec<String>,
69
70 #[serde(default)]
72 pub results_count: usize,
73
74 #[serde(default)]
76 pub execution_time_ms: Option<u64>,
77
78 #[serde(default)]
80 pub needs_context: bool,
81}
82
83pub struct ChatSession {
85 messages: Vec<Message>,
87
88 provider: String,
90
91 model: String,
93
94 context_limit: usize,
96
97 total_tokens: usize,
99}
100
101impl ChatSession {
102 pub fn new(provider: String, model: String) -> Self {
104 let context_limit = Self::get_context_limit(&provider);
105
106 Self {
107 messages: Vec::new(),
108 provider,
109 model,
110 context_limit,
111 total_tokens: 0,
112 }
113 }
114
115 pub fn add_user_message(&mut self, content: String) {
117 let tokens = Self::estimate_tokens(&content);
118 let message = Message {
119 role: MessageRole::User,
120 content,
121 tokens,
122 timestamp: Local::now(),
123 metadata: None,
124 };
125
126 self.total_tokens += tokens;
127 self.messages.push(message);
128 }
129
130 pub fn add_assistant_message(
132 &mut self,
133 content: String,
134 role: MessageRole,
135 metadata: Option<MessageMetadata>,
136 ) {
137 let tokens = Self::estimate_tokens(&content);
138 let message = Message {
139 role,
140 content,
141 tokens,
142 timestamp: Local::now(),
143 metadata,
144 };
145
146 self.total_tokens += tokens;
147 self.messages.push(message);
148 }
149
150 pub fn add_thinking_message(&mut self, reasoning: String, needs_context: bool) {
152 let metadata = MessageMetadata {
153 queries: Vec::new(),
154 tool_calls: Vec::new(),
155 results_count: 0,
156 execution_time_ms: None,
157 needs_context,
158 };
159 self.add_assistant_message(reasoning, MessageRole::AssistantThinking, Some(metadata));
160 }
161
162 pub fn add_tools_message(&mut self, content: String, tool_calls: Vec<String>) {
164 let metadata = MessageMetadata {
165 queries: Vec::new(),
166 tool_calls,
167 results_count: 0,
168 execution_time_ms: None,
169 needs_context: false,
170 };
171 self.add_assistant_message(content, MessageRole::AssistantTools, Some(metadata));
172 }
173
174 pub fn add_queries_message(&mut self, queries: Vec<String>) {
176 let content = format!("Generated {} queries", queries.len());
177 let metadata = MessageMetadata {
178 queries: queries.clone(),
179 tool_calls: Vec::new(),
180 results_count: 0,
181 execution_time_ms: None,
182 needs_context: false,
183 };
184 self.add_assistant_message(content, MessageRole::AssistantQueries, Some(metadata));
185 }
186
187 pub fn add_execution_message(&mut self, results_count: usize, execution_time_ms: u64) {
189 let content = format!(
190 "Found {} result{}",
191 results_count,
192 if results_count == 1 { "" } else { "s" }
193 );
194 let metadata = MessageMetadata {
195 queries: Vec::new(),
196 tool_calls: Vec::new(),
197 results_count,
198 execution_time_ms: Some(execution_time_ms),
199 needs_context: false,
200 };
201 self.add_assistant_message(content, MessageRole::AssistantExecuting, Some(metadata));
202 }
203
204 pub fn add_answer_message(&mut self, answer: String) {
206 self.add_assistant_message(answer, MessageRole::AssistantAnswer, None);
207 }
208
209 pub fn add_system_message(&mut self, content: String) {
211 let tokens = Self::estimate_tokens(&content);
212 let message = Message {
213 role: MessageRole::System,
214 content,
215 tokens,
216 timestamp: Local::now(),
217 metadata: None,
218 };
219
220 self.total_tokens += tokens;
221 self.messages.push(message);
222 }
223
224 pub fn clear(&mut self) {
226 self.messages.clear();
227 self.total_tokens = 0;
228 }
229
230 pub fn messages(&self) -> &[Message] {
232 &self.messages
233 }
234
235 pub fn total_tokens(&self) -> usize {
237 self.total_tokens
238 }
239
240 pub fn context_limit(&self) -> usize {
242 self.context_limit
243 }
244
245 pub fn context_usage(&self) -> f32 {
247 if self.context_limit == 0 {
248 return 0.0;
249 }
250 (self.total_tokens as f32) / (self.context_limit as f32)
251 }
252
253 pub fn is_near_limit(&self) -> bool {
255 self.context_usage() > 0.8
256 }
257
258 pub fn should_compact(&self) -> bool {
260 self.context_usage() > 0.9
261 }
262
263 pub fn provider(&self) -> &str {
265 &self.provider
266 }
267
268 pub fn model(&self) -> &str {
270 &self.model
271 }
272
273 pub fn update_provider(&mut self, provider: String, model: String) {
275 self.provider = provider.clone();
276 self.model = model;
277 self.context_limit = Self::get_context_limit(&provider);
278 }
279
280 pub fn build_context(&self) -> String {
285 let mut context = String::new();
286
287 context.push_str("Previous conversation:\n");
288 context.push_str("======================\n\n");
289
290 for msg in &self.messages {
291 match msg.role {
292 MessageRole::User => {
293 context.push_str(&format!("User: {}\n\n", msg.content));
294 }
295 MessageRole::AssistantThinking
296 | MessageRole::AssistantTools
297 | MessageRole::AssistantQueries
298 | MessageRole::AssistantExecuting
299 | MessageRole::AssistantAnswer => {
300 context.push_str(&format!("Assistant: {}\n\n", msg.content));
301 }
302 MessageRole::System => {
303 context.push_str(&format!("[System Note: {}]\n\n", msg.content));
304 }
305 }
306 }
307
308 context
309 }
310
311 pub fn prepare_compaction(&self, keep_recent: usize) -> (String, usize, usize) {
318 if self.messages.len() <= keep_recent {
319 return (String::new(), self.messages.len(), 0);
320 }
321
322 let split_point = self.messages.len() - keep_recent;
323 let old_messages = &self.messages[..split_point];
324
325 let mut summary_text = String::new();
326 let mut tokens_to_compact = 0;
327
328 for msg in old_messages {
329 tokens_to_compact += msg.tokens;
330
331 match msg.role {
332 MessageRole::User => {
333 summary_text.push_str(&format!("User: {}\n\n", msg.content));
334 }
335 MessageRole::AssistantThinking
336 | MessageRole::AssistantTools
337 | MessageRole::AssistantQueries
338 | MessageRole::AssistantExecuting
339 | MessageRole::AssistantAnswer => {
340 summary_text.push_str(&format!("Assistant: {}\n\n", msg.content));
341 }
342 MessageRole::System => {
343 summary_text.push_str(&format!("[System: {}]\n\n", msg.content));
344 }
345 }
346 }
347
348 (summary_text, old_messages.len(), tokens_to_compact)
349 }
350
351 pub fn apply_compaction(&mut self, remove_count: usize, summary: String) {
356 if remove_count >= self.messages.len() {
357 return;
359 }
360
361 let removed_tokens: usize = self.messages[..remove_count].iter().map(|m| m.tokens).sum();
363
364 self.messages.drain(..remove_count);
366
367 let summary_tokens = Self::estimate_tokens(&summary);
369 let summary_msg = Message {
370 role: MessageRole::System,
371 content: format!("Summary of previous conversation: {}", summary),
372 tokens: summary_tokens,
373 timestamp: Local::now(),
374 metadata: None,
375 };
376
377 self.messages.insert(0, summary_msg);
378
379 self.total_tokens = self.total_tokens - removed_tokens + summary_tokens;
381 }
382
383 fn estimate_tokens(text: &str) -> usize {
385 (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN
386 }
387
388 fn get_context_limit(provider: &str) -> usize {
390 match provider.to_lowercase().as_str() {
391 "openai" => OPENAI_CONTEXT_WINDOW,
392 "anthropic" => ANTHROPIC_CONTEXT_WINDOW,
393 _ => 32_000, }
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_new_session() {
404 let session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
405 assert_eq!(session.messages().len(), 0);
406 assert_eq!(session.total_tokens(), 0);
407 assert_eq!(session.context_limit(), OPENAI_CONTEXT_WINDOW);
408 }
409
410 #[test]
411 fn test_add_messages() {
412 let mut session = ChatSession::new("anthropic".to_string(), "claude-3-5-haiku".to_string());
413
414 session.add_user_message("Hello!".to_string());
415 assert_eq!(session.messages().len(), 1);
416 assert!(session.total_tokens() > 0);
417
418 session.add_answer_message("Hi there!".to_string());
419 assert_eq!(session.messages().len(), 2);
420 }
421
422 #[test]
423 fn test_clear() {
424 let mut session = ChatSession::new("openai".to_string(), "gpt-4o".to_string());
425 session.add_user_message("Test".to_string());
426 session.add_answer_message("Response".to_string());
427
428 assert_eq!(session.messages().len(), 2);
429
430 session.clear();
431 assert_eq!(session.messages().len(), 0);
432 assert_eq!(session.total_tokens(), 0);
433 }
434
435 #[test]
436 fn test_context_usage() {
437 let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
438 assert_eq!(session.context_usage(), 0.0);
439
440 let large_text = "a".repeat(OPENAI_CONTEXT_WINDOW * CHARS_PER_TOKEN / 4);
442 session.add_user_message(large_text);
443
444 let usage = session.context_usage();
445 assert!(usage > 0.2 && usage < 0.3); }
447
448 #[test]
449 fn test_prepare_compaction() {
450 let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
451
452 for i in 0..10 {
453 session.add_user_message(format!("Message {}", i));
454 session.add_answer_message(format!("Response {}", i));
455 }
456
457 let (summary_text, old_count, tokens) = session.prepare_compaction(4);
458
459 assert_eq!(old_count, 16); assert!(!summary_text.is_empty());
461 assert!(tokens > 0);
462 }
463
464 #[test]
465 fn test_apply_compaction() {
466 let mut session = ChatSession::new("anthropic".to_string(), "claude".to_string());
467
468 for i in 0..6 {
469 session.add_user_message(format!("Q{}", i));
470 session.add_answer_message(format!("A{}", i));
471 }
472
473 let initial_count = session.messages().len();
474 let initial_tokens = session.total_tokens();
475
476 session.apply_compaction(8, "This is a summary".to_string());
477
478 assert_eq!(session.messages().len(), 5);
480 assert_eq!(session.messages()[0].role, MessageRole::System);
481
482 assert!(session.total_tokens() < initial_tokens);
484 }
485
486 #[test]
487 fn test_estimate_tokens() {
488 let text = "Hello, world!"; let tokens = ChatSession::estimate_tokens(text);
490 assert_eq!(tokens, (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN);
492 assert_eq!(tokens, 4);
493 }
494}