1use chrono::{DateTime, Local};
7use serde::{Deserialize, Serialize};
8
9const OPENAI_CONTEXT_WINDOW: usize = 128_000;
11const ANTHROPIC_CONTEXT_WINDOW: usize = 200_000;
12const CHARS_PER_TOKEN: usize = 4;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18 pub role: MessageRole,
20
21 pub content: String,
23
24 pub tokens: usize,
26
27 pub timestamp: DateTime<Local>,
29
30 pub metadata: Option<MessageMetadata>,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum MessageRole {
37 User,
39
40 AssistantThinking,
42
43 AssistantTools,
45
46 AssistantQueries,
48
49 AssistantExecuting,
51
52 AssistantAnswer,
54
55 System,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct MessageMetadata {
62 #[serde(default)]
64 pub queries: Vec<String>,
65
66 #[serde(default)]
68 pub tool_calls: Vec<String>,
69
70 #[serde(default)]
72 pub results_count: usize,
73
74 #[serde(default)]
76 pub execution_time_ms: Option<u64>,
77
78 #[serde(default)]
80 pub needs_context: bool,
81}
82
83pub struct ChatSession {
85 messages: Vec<Message>,
87
88 provider: String,
90
91 model: String,
93
94 context_limit: usize,
96
97 total_tokens: usize,
99}
100
101impl ChatSession {
102 pub fn new(provider: String, model: String) -> Self {
104 let context_limit = Self::get_context_limit(&provider);
105
106 Self {
107 messages: Vec::new(),
108 provider,
109 model,
110 context_limit,
111 total_tokens: 0,
112 }
113 }
114
115 pub fn add_user_message(&mut self, content: String) {
117 let tokens = Self::estimate_tokens(&content);
118 let message = Message {
119 role: MessageRole::User,
120 content,
121 tokens,
122 timestamp: Local::now(),
123 metadata: None,
124 };
125
126 self.total_tokens += tokens;
127 self.messages.push(message);
128 }
129
130 pub fn add_assistant_message(&mut self, content: String, role: MessageRole, metadata: Option<MessageMetadata>) {
132 let tokens = Self::estimate_tokens(&content);
133 let message = Message {
134 role,
135 content,
136 tokens,
137 timestamp: Local::now(),
138 metadata,
139 };
140
141 self.total_tokens += tokens;
142 self.messages.push(message);
143 }
144
145 pub fn add_thinking_message(&mut self, reasoning: String, needs_context: bool) {
147 let metadata = MessageMetadata {
148 queries: Vec::new(),
149 tool_calls: Vec::new(),
150 results_count: 0,
151 execution_time_ms: None,
152 needs_context,
153 };
154 self.add_assistant_message(reasoning, MessageRole::AssistantThinking, Some(metadata));
155 }
156
157 pub fn add_tools_message(&mut self, content: String, tool_calls: Vec<String>) {
159 let metadata = MessageMetadata {
160 queries: Vec::new(),
161 tool_calls,
162 results_count: 0,
163 execution_time_ms: None,
164 needs_context: false,
165 };
166 self.add_assistant_message(content, MessageRole::AssistantTools, Some(metadata));
167 }
168
169 pub fn add_queries_message(&mut self, queries: Vec<String>) {
171 let content = format!("Generated {} queries", queries.len());
172 let metadata = MessageMetadata {
173 queries: queries.clone(),
174 tool_calls: Vec::new(),
175 results_count: 0,
176 execution_time_ms: None,
177 needs_context: false,
178 };
179 self.add_assistant_message(content, MessageRole::AssistantQueries, Some(metadata));
180 }
181
182 pub fn add_execution_message(&mut self, results_count: usize, execution_time_ms: u64) {
184 let content = format!("Found {} results", results_count);
185 let metadata = MessageMetadata {
186 queries: Vec::new(),
187 tool_calls: Vec::new(),
188 results_count,
189 execution_time_ms: Some(execution_time_ms),
190 needs_context: false,
191 };
192 self.add_assistant_message(content, MessageRole::AssistantExecuting, Some(metadata));
193 }
194
195 pub fn add_answer_message(&mut self, answer: String) {
197 self.add_assistant_message(answer, MessageRole::AssistantAnswer, None);
198 }
199
200 pub fn add_system_message(&mut self, content: String) {
202 let tokens = Self::estimate_tokens(&content);
203 let message = Message {
204 role: MessageRole::System,
205 content,
206 tokens,
207 timestamp: Local::now(),
208 metadata: None,
209 };
210
211 self.total_tokens += tokens;
212 self.messages.push(message);
213 }
214
215 pub fn clear(&mut self) {
217 self.messages.clear();
218 self.total_tokens = 0;
219 }
220
221 pub fn messages(&self) -> &[Message] {
223 &self.messages
224 }
225
226 pub fn total_tokens(&self) -> usize {
228 self.total_tokens
229 }
230
231 pub fn context_limit(&self) -> usize {
233 self.context_limit
234 }
235
236 pub fn context_usage(&self) -> f32 {
238 if self.context_limit == 0 {
239 return 0.0;
240 }
241 (self.total_tokens as f32) / (self.context_limit as f32)
242 }
243
244 pub fn is_near_limit(&self) -> bool {
246 self.context_usage() > 0.8
247 }
248
249 pub fn should_compact(&self) -> bool {
251 self.context_usage() > 0.9
252 }
253
254 pub fn provider(&self) -> &str {
256 &self.provider
257 }
258
259 pub fn model(&self) -> &str {
261 &self.model
262 }
263
264 pub fn update_provider(&mut self, provider: String, model: String) {
266 self.provider = provider.clone();
267 self.model = model;
268 self.context_limit = Self::get_context_limit(&provider);
269 }
270
271 pub fn build_context(&self) -> String {
276 let mut context = String::new();
277
278 context.push_str("Previous conversation:\n");
279 context.push_str("======================\n\n");
280
281 for msg in &self.messages {
282 match msg.role {
283 MessageRole::User => {
284 context.push_str(&format!("User: {}\n\n", msg.content));
285 }
286 MessageRole::AssistantThinking
287 | MessageRole::AssistantTools
288 | MessageRole::AssistantQueries
289 | MessageRole::AssistantExecuting
290 | MessageRole::AssistantAnswer => {
291 context.push_str(&format!("Assistant: {}\n\n", msg.content));
292 }
293 MessageRole::System => {
294 context.push_str(&format!("[System Note: {}]\n\n", msg.content));
295 }
296 }
297 }
298
299 context
300 }
301
302 pub fn prepare_compaction(&self, keep_recent: usize) -> (String, usize, usize) {
309 if self.messages.len() <= keep_recent {
310 return (String::new(), self.messages.len(), 0);
311 }
312
313 let split_point = self.messages.len() - keep_recent;
314 let old_messages = &self.messages[..split_point];
315
316 let mut summary_text = String::new();
317 let mut tokens_to_compact = 0;
318
319 for msg in old_messages {
320 tokens_to_compact += msg.tokens;
321
322 match msg.role {
323 MessageRole::User => {
324 summary_text.push_str(&format!("User: {}\n\n", msg.content));
325 }
326 MessageRole::AssistantThinking
327 | MessageRole::AssistantTools
328 | MessageRole::AssistantQueries
329 | MessageRole::AssistantExecuting
330 | MessageRole::AssistantAnswer => {
331 summary_text.push_str(&format!("Assistant: {}\n\n", msg.content));
332 }
333 MessageRole::System => {
334 summary_text.push_str(&format!("[System: {}]\n\n", msg.content));
335 }
336 }
337 }
338
339 (summary_text, old_messages.len(), tokens_to_compact)
340 }
341
342 pub fn apply_compaction(&mut self, remove_count: usize, summary: String) {
347 if remove_count >= self.messages.len() {
348 return;
350 }
351
352 let removed_tokens: usize = self.messages[..remove_count]
354 .iter()
355 .map(|m| m.tokens)
356 .sum();
357
358 self.messages.drain(..remove_count);
360
361 let summary_tokens = Self::estimate_tokens(&summary);
363 let summary_msg = Message {
364 role: MessageRole::System,
365 content: format!("Summary of previous conversation: {}", summary),
366 tokens: summary_tokens,
367 timestamp: Local::now(),
368 metadata: None,
369 };
370
371 self.messages.insert(0, summary_msg);
372
373 self.total_tokens = self.total_tokens - removed_tokens + summary_tokens;
375 }
376
377 fn estimate_tokens(text: &str) -> usize {
379 (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN
380 }
381
382 fn get_context_limit(provider: &str) -> usize {
384 match provider.to_lowercase().as_str() {
385 "openai" => OPENAI_CONTEXT_WINDOW,
386 "anthropic" => ANTHROPIC_CONTEXT_WINDOW,
387 _ => 32_000, }
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_new_session() {
398 let session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
399 assert_eq!(session.messages().len(), 0);
400 assert_eq!(session.total_tokens(), 0);
401 assert_eq!(session.context_limit(), OPENAI_CONTEXT_WINDOW);
402 }
403
404 #[test]
405 fn test_add_messages() {
406 let mut session = ChatSession::new("anthropic".to_string(), "claude-3-5-haiku".to_string());
407
408 session.add_user_message("Hello!".to_string());
409 assert_eq!(session.messages().len(), 1);
410 assert!(session.total_tokens() > 0);
411
412 session.add_answer_message("Hi there!".to_string());
413 assert_eq!(session.messages().len(), 2);
414 }
415
416 #[test]
417 fn test_clear() {
418 let mut session = ChatSession::new("openai".to_string(), "gpt-4o".to_string());
419 session.add_user_message("Test".to_string());
420 session.add_answer_message("Response".to_string());
421
422 assert_eq!(session.messages().len(), 2);
423
424 session.clear();
425 assert_eq!(session.messages().len(), 0);
426 assert_eq!(session.total_tokens(), 0);
427 }
428
429 #[test]
430 fn test_context_usage() {
431 let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
432 assert_eq!(session.context_usage(), 0.0);
433
434 let large_text = "a".repeat(OPENAI_CONTEXT_WINDOW * CHARS_PER_TOKEN / 4);
436 session.add_user_message(large_text);
437
438 let usage = session.context_usage();
439 assert!(usage > 0.2 && usage < 0.3); }
441
442 #[test]
443 fn test_prepare_compaction() {
444 let mut session = ChatSession::new("openai".to_string(), "gpt-4o-mini".to_string());
445
446 for i in 0..10 {
447 session.add_user_message(format!("Message {}", i));
448 session.add_answer_message(format!("Response {}", i));
449 }
450
451 let (summary_text, old_count, tokens) = session.prepare_compaction(4);
452
453 assert_eq!(old_count, 16); assert!(!summary_text.is_empty());
455 assert!(tokens > 0);
456 }
457
458 #[test]
459 fn test_apply_compaction() {
460 let mut session = ChatSession::new("anthropic".to_string(), "claude".to_string());
461
462 for i in 0..6 {
463 session.add_user_message(format!("Q{}", i));
464 session.add_answer_message(format!("A{}", i));
465 }
466
467 let initial_count = session.messages().len();
468 let initial_tokens = session.total_tokens();
469
470 session.apply_compaction(8, "This is a summary".to_string());
471
472 assert_eq!(session.messages().len(), 5);
474 assert_eq!(session.messages()[0].role, MessageRole::System);
475
476 assert!(session.total_tokens() < initial_tokens);
478 }
479
480 #[test]
481 fn test_estimate_tokens() {
482 let text = "Hello, world!"; let tokens = ChatSession::estimate_tokens(text);
484 assert_eq!(tokens, (text.len() + CHARS_PER_TOKEN - 1) / CHARS_PER_TOKEN);
486 assert_eq!(tokens, 4);
487 }
488}