Skip to main content

ai_session/context/
mod.rs

1//! Session context management for AI agents
2//!
3//! This module provides intelligent context management for AI agents, including conversation
4//! history, task context, and workspace state. The context system is designed to maximize
5//! AI performance while minimizing token usage through intelligent compression and summarization.
6//!
7//! # Key Features
8//!
9//! - **Efficient History Management**: Real zstd compression for optimized context handling
10//! - **Task Context**: Structured task and goal tracking for AI agents
11//! - **Agent State**: Persistent agent memory and decision tracking
12//! - **Workspace State**: File system and project state awareness
13//! - **Smart Compression**: Zstd-based compression with message summarization
14//!
15//! # Examples
16//!
17//! ## Basic Context Management
18//!
19//! ```rust
20//! use ai_session::context::{SessionContext, Message, MessageRole};
21//! use ai_session::SessionId;
22//! use chrono::Utc;
23//!
24//! let session_id = SessionId::new();
25//! let mut context = SessionContext::new(session_id);
26//!
27//! // Add conversation messages
28//! let user_message = Message {
29//!     role: MessageRole::User,
30//!     content: "Help me implement a REST API".to_string(),
31//!     timestamp: Utc::now(),
32//!     token_count: 7,
33//! };
34//! context.add_message(user_message);
35//!
36//! let assistant_message = Message {
37//!     role: MessageRole::Assistant,
38//!     content: "I'll help you create a REST API. Let's start with the basic structure...".to_string(),
39//!     timestamp: Utc::now(),
40//!     token_count: 18,
41//! };
42//! context.add_message(assistant_message);
43//!
44//! // Check context stats
45//! println!("Messages: {}", context.get_message_count());
46//! println!("Total tokens: {}", context.get_total_tokens());
47//! ```
48//!
49//! ## Context Compression
50//!
51//! ```rust
52//! use ai_session::context::{SessionContext, Message, MessageRole};
53//! use ai_session::SessionId;
54//!
55//! # tokio_test::block_on(async {
56//! let session_id = SessionId::new();
57//! let mut context = SessionContext::new(session_id);
58//!
59//! // Fill context with many messages...
60//! for i in 0..100 {
61//!     let message = Message {
62//!         role: MessageRole::User,
63//!         content: format!("Message {}", i),
64//!         timestamp: chrono::Utc::now(),
65//!         token_count: 5,
66//!     };
67//!     context.add_message(message);
68//! }
69//!
70//! println!("Before compression: {} tokens", context.get_total_tokens());
71//!
72//! // Compress when approaching token limit
73//! if context.get_total_tokens() > 400 {
74//!     let compressed = context.compress_context().await;
75//!     if compressed {
76//!         println!("After compression: {} tokens", context.get_total_tokens());
77//!     }
78//! }
79//! # });
80//! ```
81
82use crate::core::SessionId;
83use chrono::{DateTime, Utc};
84use serde::{Deserialize, Serialize};
85use std::collections::HashMap;
86use std::io::Read as IoRead;
87
88/// Session configuration
89#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(default)]
91pub struct SessionConfig {
92    /// Maximum token limit for the session
93    pub max_tokens: usize,
94    /// Number of recent messages to keep uncompressed
95    pub keep_recent_messages: usize,
96    /// Compression level for zstd (1-22, default 3)
97    pub compression_level: i32,
98    /// Threshold ratio to trigger compression (e.g., 0.8 means compress when 80% full)
99    pub compression_threshold: f32,
100}
101
102impl Default for SessionConfig {
103    fn default() -> Self {
104        Self {
105            max_tokens: 100_000,        // Default context window
106            keep_recent_messages: 20,   // Keep last 20 messages uncompressed
107            compression_level: 3,       // Balanced compression
108            compression_threshold: 0.8, // Compress at 80% capacity
109        }
110    }
111}
112
113/// Session context containing AI-relevant state
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SessionContext {
116    /// Session ID
117    pub session_id: SessionId,
118    /// Conversation history (token-efficient)
119    pub conversation_history: TokenEfficientHistory,
120    /// Current task context
121    pub task_context: TaskContext,
122    /// Agent state
123    pub agent_state: AgentState,
124    /// Workspace state
125    pub workspace_state: WorkspaceState,
126    /// Context metadata
127    pub metadata: HashMap<String, serde_json::Value>,
128    /// Session configuration
129    pub config: SessionConfig,
130}
131
132impl SessionContext {
133    /// Create a new session context
134    pub fn new(session_id: SessionId) -> Self {
135        let config = SessionConfig::default();
136        let mut conversation_history = TokenEfficientHistory::new();
137        conversation_history.max_tokens = config.max_tokens;
138        conversation_history.keep_recent = config.keep_recent_messages;
139        conversation_history.compression_level = config.compression_level;
140
141        Self {
142            session_id,
143            conversation_history,
144            task_context: TaskContext::default(),
145            agent_state: AgentState::default(),
146            workspace_state: WorkspaceState::default(),
147            metadata: HashMap::new(),
148            config,
149        }
150    }
151
152    /// Add a message to the conversation history (takes Message struct)
153    pub fn add_message(&mut self, message: Message) {
154        self.conversation_history.add_message_struct(message);
155    }
156
157    /// Add a message to the conversation history (legacy method)
158    pub fn add_message_raw(&mut self, role: MessageRole, content: String) {
159        self.conversation_history.add_message(role, content);
160    }
161
162    /// Get the total number of messages in the conversation history
163    pub fn get_message_count(&self) -> usize {
164        self.conversation_history.messages.len()
165    }
166
167    /// Get the total estimated token count
168    pub fn get_total_tokens(&self) -> usize {
169        self.conversation_history.current_tokens
170    }
171
172    /// Get the most recent n messages
173    pub fn get_recent_messages(&self, n: usize) -> Vec<&Message> {
174        let message_count = self.conversation_history.messages.len();
175        if n >= message_count {
176            self.conversation_history.messages.iter().collect()
177        } else {
178            self.conversation_history
179                .messages
180                .iter()
181                .skip(message_count - n)
182                .collect()
183        }
184    }
185
186    /// Compress the context if needed, returns true if compression occurred
187    pub async fn compress_context(&mut self) -> bool {
188        // Check if compression is needed
189        let threshold = (self.conversation_history.max_tokens as f32
190            * self.config.compression_threshold) as usize;
191
192        if self.conversation_history.current_tokens > threshold {
193            self.conversation_history.compress_old_messages();
194            true
195        } else {
196            false
197        }
198    }
199
200    /// Update the task context
201    pub fn update_task(&mut self, task: TaskContext) {
202        self.task_context = task;
203    }
204
205    /// Get a summary of the context
206    pub fn summarize(&self) -> ContextSummary {
207        ContextSummary {
208            session_id: self.session_id.clone(),
209            message_count: self.conversation_history.messages.len(),
210            current_task: self.task_context.name.clone(),
211            agent_state: self.agent_state.state.clone(),
212            workspace_files: self.workspace_state.tracked_files.len(),
213        }
214    }
215
216    /// Get compression statistics
217    pub fn get_compression_stats(&self) -> CompressionStats {
218        self.conversation_history.get_compression_stats()
219    }
220}
221
222/// Token-efficient conversation history with real zstd compression
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct TokenEfficientHistory {
225    /// Active messages in the conversation (uncompressed, recent)
226    #[serde(default)]
227    pub messages: Vec<Message>,
228    /// Compressed older messages (zstd compressed)
229    #[serde(default)]
230    pub compressed_history: Option<CompressedHistory>,
231    /// Maximum token limit
232    #[serde(default = "default_max_tokens")]
233    pub max_tokens: usize,
234    /// Current token count (approximate)
235    #[serde(default)]
236    pub current_tokens: usize,
237    /// Number of recent messages to keep uncompressed
238    #[serde(default = "default_keep_recent")]
239    pub keep_recent: usize,
240    /// Compression level for zstd
241    #[serde(default = "default_compression_level")]
242    pub compression_level: i32,
243    /// Total messages ever added (including compressed)
244    #[serde(default)]
245    pub total_messages_added: usize,
246    /// Total tokens saved through compression
247    #[serde(default)]
248    pub tokens_saved_by_compression: usize,
249}
250
251fn default_max_tokens() -> usize {
252    100_000
253}
254
255fn default_keep_recent() -> usize {
256    20
257}
258
259fn default_compression_level() -> i32 {
260    3
261}
262
263impl Default for TokenEfficientHistory {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269impl TokenEfficientHistory {
270    /// Create a new history
271    pub fn new() -> Self {
272        Self {
273            messages: Vec::new(),
274            compressed_history: None,
275            max_tokens: 100_000,
276            current_tokens: 0,
277            keep_recent: 20,
278            compression_level: 3,
279            total_messages_added: 0,
280            tokens_saved_by_compression: 0,
281        }
282    }
283
284    /// Add a message to the history
285    pub fn add_message(&mut self, role: MessageRole, content: String) {
286        let token_estimate = estimate_tokens(&content);
287        let message = Message {
288            role,
289            content,
290            timestamp: Utc::now(),
291            token_count: token_estimate,
292        };
293
294        self.messages.push(message);
295        self.current_tokens += token_estimate;
296        self.total_messages_added += 1;
297
298        // Compress if needed
299        if self.current_tokens > self.max_tokens {
300            self.compress_old_messages();
301        }
302    }
303
304    /// Add a message struct directly to the history
305    pub fn add_message_struct(&mut self, message: Message) {
306        self.current_tokens += message.token_count;
307        self.messages.push(message);
308        self.total_messages_added += 1;
309
310        // Compress if needed
311        if self.current_tokens > self.max_tokens {
312            self.compress_old_messages();
313        }
314    }
315
316    /// Compress old messages using zstd
317    pub fn compress_old_messages(&mut self) {
318        // Keep only the most recent messages uncompressed
319        if self.messages.len() <= self.keep_recent {
320            return;
321        }
322
323        // Split messages: older ones to compress, recent ones to keep
324        let split_point = self.messages.len() - self.keep_recent;
325        let messages_to_compress: Vec<Message> = self.messages.drain(..split_point).collect();
326
327        if messages_to_compress.is_empty() {
328            return;
329        }
330
331        // Calculate tokens being compressed
332        let tokens_to_compress: usize = messages_to_compress.iter().map(|m| m.token_count).sum();
333
334        // Serialize messages to JSON
335        let json_data = match serde_json::to_vec(&messages_to_compress) {
336            Ok(data) => data,
337            Err(e) => {
338                tracing::warn!("Failed to serialize messages for compression: {}", e);
339                // Put messages back if serialization fails
340                let mut restored = messages_to_compress;
341                restored.append(&mut self.messages);
342                self.messages = restored;
343                return;
344            }
345        };
346
347        // Compress using zstd
348        let compressed_data = match zstd::encode_all(json_data.as_slice(), self.compression_level) {
349            Ok(data) => data,
350            Err(e) => {
351                tracing::warn!("Failed to compress messages: {}", e);
352                // Put messages back if compression fails
353                let mut restored = messages_to_compress;
354                restored.append(&mut self.messages);
355                self.messages = restored;
356                return;
357            }
358        };
359
360        // Calculate compression ratio
361        let original_size = json_data.len();
362        let compressed_size = compressed_data.len();
363        let compression_ratio = if original_size > 0 {
364            1.0 - (compressed_size as f64 / original_size as f64)
365        } else {
366            0.0
367        };
368
369        // Create summary of compressed content
370        let summary = create_compression_summary(&messages_to_compress);
371
372        // Merge with existing compressed history if any
373        let new_compressed = if let Some(existing) = self.compressed_history.take() {
374            CompressedHistory {
375                compressed_data: merge_compressed_data(
376                    &existing.compressed_data,
377                    &compressed_data,
378                    self.compression_level,
379                ),
380                summary: format!("{}\n---\n{}", existing.summary, summary),
381                message_count: existing.message_count + messages_to_compress.len(),
382                original_tokens: existing.original_tokens + tokens_to_compress,
383                compressed_bytes: existing.compressed_bytes + compressed_size,
384                compression_ratio: (existing.compression_ratio + compression_ratio) / 2.0,
385            }
386        } else {
387            CompressedHistory {
388                compressed_data,
389                summary,
390                message_count: messages_to_compress.len(),
391                original_tokens: tokens_to_compress,
392                compressed_bytes: compressed_size,
393                compression_ratio,
394            }
395        };
396
397        // Update state
398        self.compressed_history = Some(new_compressed);
399        self.current_tokens -= tokens_to_compress;
400        self.tokens_saved_by_compression += tokens_to_compress;
401
402        // Add a small token cost for the summary (accessible without decompression)
403        let summary_tokens = estimate_tokens(
404            self.compressed_history
405                .as_ref()
406                .map(|h| h.summary.as_str())
407                .unwrap_or(""),
408        );
409        self.current_tokens += summary_tokens.min(100); // Cap summary token cost
410
411        tracing::info!(
412            "Compressed {} messages ({} tokens) with {:.1}% ratio",
413            messages_to_compress.len(),
414            tokens_to_compress,
415            compression_ratio * 100.0
416        );
417    }
418
419    /// Decompress and retrieve all historical messages
420    pub fn decompress_history(&self) -> Option<Vec<Message>> {
421        let compressed = self.compressed_history.as_ref()?;
422
423        // Decompress using zstd
424        let mut decompressed = Vec::new();
425        let mut decoder = match zstd::Decoder::new(compressed.compressed_data.as_slice()) {
426            Ok(d) => d,
427            Err(e) => {
428                tracing::error!("Failed to create zstd decoder: {}", e);
429                return None;
430            }
431        };
432
433        if let Err(e) = decoder.read_to_end(&mut decompressed) {
434            tracing::error!("Failed to decompress history: {}", e);
435            return None;
436        }
437
438        // Deserialize messages
439        match serde_json::from_slice(&decompressed) {
440            Ok(messages) => Some(messages),
441            Err(e) => {
442                tracing::error!("Failed to deserialize decompressed messages: {}", e);
443                None
444            }
445        }
446    }
447
448    /// Get all messages including decompressed history
449    pub fn get_all_messages(&self) -> Vec<Message> {
450        let mut all_messages = self.decompress_history().unwrap_or_default();
451        all_messages.extend(self.messages.clone());
452        all_messages
453    }
454
455    /// Get messages within token limit
456    pub fn get_messages_within_limit(&self, token_limit: usize) -> Vec<&Message> {
457        let mut messages = Vec::new();
458        let mut tokens = 0;
459
460        // Start from most recent messages
461        for message in self.messages.iter().rev() {
462            if tokens + message.token_count <= token_limit {
463                messages.push(message);
464                tokens += message.token_count;
465            } else {
466                break;
467            }
468        }
469
470        messages.reverse();
471        messages
472    }
473
474    /// Get compression statistics
475    pub fn get_compression_stats(&self) -> CompressionStats {
476        let compressed_stats = self.compressed_history.as_ref().map(|h| {
477            (
478                h.message_count,
479                h.original_tokens,
480                h.compressed_bytes,
481                h.compression_ratio,
482            )
483        });
484
485        CompressionStats {
486            total_messages_added: self.total_messages_added,
487            active_messages: self.messages.len(),
488            compressed_messages: compressed_stats.map(|(c, _, _, _)| c).unwrap_or(0),
489            active_tokens: self.current_tokens,
490            tokens_saved: self.tokens_saved_by_compression,
491            compressed_bytes: compressed_stats.map(|(_, _, b, _)| b).unwrap_or(0),
492            compression_ratio: compressed_stats.map(|(_, _, _, r)| r).unwrap_or(0.0),
493        }
494    }
495}
496
497/// Estimate tokens for a string (improved approximation)
498fn estimate_tokens(content: &str) -> usize {
499    // Handle empty string explicitly
500    if content.is_empty() {
501        return 1;
502    }
503
504    // Better approximation than simple len/4:
505    // - Count words (roughly 1.3 tokens per word for English)
506    // - Account for punctuation and special characters
507    // - Add overhead for JSON structure
508
509    let word_count = content.split_whitespace().count();
510    let char_count = content.chars().count();
511
512    // Special characters often become their own tokens
513    let special_chars = content
514        .chars()
515        .filter(|c| !c.is_alphanumeric() && !c.is_whitespace())
516        .count();
517
518    // Estimate: words * 1.3 + special_chars + small overhead
519    let estimate = (word_count as f64 * 1.3) as usize + special_chars + 2;
520
521    // Also consider raw character-based estimate for very long strings
522    let char_estimate = char_count / 4;
523
524    // Use the larger of the two estimates
525    estimate.max(char_estimate).max(1)
526}
527
528/// Create a summary of compressed messages
529fn create_compression_summary(messages: &[Message]) -> String {
530    if messages.is_empty() {
531        return String::new();
532    }
533
534    let first = messages.first().unwrap();
535    let last = messages.last().unwrap();
536
537    let user_count = messages
538        .iter()
539        .filter(|m| m.role == MessageRole::User)
540        .count();
541    let assistant_count = messages
542        .iter()
543        .filter(|m| m.role == MessageRole::Assistant)
544        .count();
545
546    format!(
547        "[Compressed: {} messages ({} user, {} assistant) from {} to {}]",
548        messages.len(),
549        user_count,
550        assistant_count,
551        first.timestamp.format("%H:%M:%S"),
552        last.timestamp.format("%H:%M:%S")
553    )
554}
555
556/// Merge two compressed data blocks
557fn merge_compressed_data(existing: &[u8], new: &[u8], level: i32) -> Vec<u8> {
558    // Decompress both, merge, recompress
559    // This is less efficient but maintains a single compressed block
560
561    let mut existing_decompressed = Vec::new();
562    if let Ok(mut decoder) = zstd::Decoder::new(existing) {
563        let _ = decoder.read_to_end(&mut existing_decompressed);
564    }
565
566    let mut new_decompressed = Vec::new();
567    if let Ok(mut decoder) = zstd::Decoder::new(new) {
568        let _ = decoder.read_to_end(&mut new_decompressed);
569    }
570
571    // Parse both as message arrays and merge
572    let existing_messages: Vec<Message> =
573        serde_json::from_slice(&existing_decompressed).unwrap_or_default();
574    let new_messages: Vec<Message> = serde_json::from_slice(&new_decompressed).unwrap_or_default();
575
576    let mut merged = existing_messages;
577    merged.extend(new_messages);
578
579    // Recompress
580    let json_data = serde_json::to_vec(&merged).unwrap_or_default();
581    zstd::encode_all(json_data.as_slice(), level).unwrap_or_else(|_| new.to_vec())
582}
583
584/// A message in the conversation
585#[derive(Debug, Clone, Serialize, Deserialize)]
586pub struct Message {
587    /// Role of the message sender
588    pub role: MessageRole,
589    /// Message content
590    pub content: String,
591    /// Timestamp
592    pub timestamp: DateTime<Utc>,
593    /// Estimated token count
594    pub token_count: usize,
595}
596
597/// Message role
598#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
599pub enum MessageRole {
600    System,
601    User,
602    Assistant,
603    Tool,
604}
605
606/// Compressed history with zstd compression
607#[derive(Debug, Clone, Serialize, Deserialize)]
608pub struct CompressedHistory {
609    /// Zstd compressed message data
610    #[serde(with = "base64_serde")]
611    pub compressed_data: Vec<u8>,
612    /// Human-readable summary of compressed content
613    pub summary: String,
614    /// Number of messages compressed
615    pub message_count: usize,
616    /// Original token count before compression
617    pub original_tokens: usize,
618    /// Size of compressed data in bytes
619    pub compressed_bytes: usize,
620    /// Compression ratio achieved (0.0 to 1.0)
621    pub compression_ratio: f64,
622}
623
624/// Base64 serialization for binary data
625mod base64_serde {
626    use base64::{Engine, engine::general_purpose::STANDARD};
627    use serde::{Deserialize, Deserializer, Serialize, Serializer};
628
629    pub fn serialize<S>(data: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
630    where
631        S: Serializer,
632    {
633        let encoded = STANDARD.encode(data);
634        encoded.serialize(serializer)
635    }
636
637    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
638    where
639        D: Deserializer<'de>,
640    {
641        let encoded = String::deserialize(deserializer)?;
642        STANDARD.decode(&encoded).map_err(serde::de::Error::custom)
643    }
644}
645
646/// Compression statistics
647#[derive(Debug, Clone, Serialize, Deserialize)]
648pub struct CompressionStats {
649    /// Total messages ever added
650    pub total_messages_added: usize,
651    /// Currently active (uncompressed) messages
652    pub active_messages: usize,
653    /// Messages in compressed storage
654    pub compressed_messages: usize,
655    /// Current active token count
656    pub active_tokens: usize,
657    /// Tokens saved by compression
658    pub tokens_saved: usize,
659    /// Size of compressed data in bytes
660    pub compressed_bytes: usize,
661    /// Average compression ratio
662    pub compression_ratio: f64,
663}
664
665/// Current task context
666#[derive(Debug, Clone, Default, Serialize, Deserialize)]
667pub struct TaskContext {
668    /// Task ID
669    pub id: Option<String>,
670    /// Task name
671    pub name: Option<String>,
672    /// Task description
673    pub description: Option<String>,
674    /// Task type
675    pub task_type: Option<String>,
676    /// Priority
677    pub priority: Option<TaskPriority>,
678    /// Started at
679    pub started_at: Option<DateTime<Utc>>,
680    /// Additional context
681    pub metadata: HashMap<String, serde_json::Value>,
682}
683
684/// Task priority
685#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
686pub enum TaskPriority {
687    Low,
688    Medium,
689    High,
690    Critical,
691}
692
693/// Agent state
694#[derive(Debug, Clone, Default, Serialize, Deserialize)]
695pub struct AgentState {
696    /// Current state
697    pub state: String,
698    /// Agent capabilities
699    pub capabilities: Vec<String>,
700    /// Performance metrics
701    pub metrics: HashMap<String, f64>,
702    /// Last error (if any)
703    pub last_error: Option<String>,
704}
705
706/// Workspace state
707#[derive(Debug, Clone, Default, Serialize, Deserialize)]
708pub struct WorkspaceState {
709    /// Current working directory
710    pub working_directory: String,
711    /// Tracked files
712    pub tracked_files: HashMap<String, FileState>,
713    /// Recent changes
714    pub recent_changes: Vec<FileChange>,
715}
716
717/// File state
718#[derive(Debug, Clone, Serialize, Deserialize)]
719pub struct FileState {
720    /// File path
721    pub path: String,
722    /// Last modified
723    pub last_modified: DateTime<Utc>,
724    /// File hash
725    pub hash: String,
726    /// Is modified
727    pub is_modified: bool,
728}
729
730/// File change event
731#[derive(Debug, Clone, Serialize, Deserialize)]
732pub struct FileChange {
733    /// File path
734    pub path: String,
735    /// Change type
736    pub change_type: FileChangeType,
737    /// Timestamp
738    pub timestamp: DateTime<Utc>,
739}
740
741/// Type of file change
742#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
743pub enum FileChangeType {
744    Created,
745    Modified,
746    Deleted,
747    Renamed,
748}
749
750/// Context summary
751#[derive(Debug, Clone, Serialize, Deserialize)]
752pub struct ContextSummary {
753    /// Session ID
754    pub session_id: SessionId,
755    /// Number of messages
756    pub message_count: usize,
757    /// Current task name
758    pub current_task: Option<String>,
759    /// Agent state
760    pub agent_state: String,
761    /// Number of tracked files
762    pub workspace_files: usize,
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768
769    #[test]
770    fn test_token_efficient_history() {
771        let mut history = TokenEfficientHistory::new();
772        history.max_tokens = 200; // Moderate limit for testing
773        history.keep_recent = 3; // Keep 3 recent messages
774
775        // Add messages with moderate size (~5-10 tokens each)
776        for i in 0..10 {
777            history.add_message(
778                MessageRole::User,
779                format!("Test message number {}", i), // ~5 tokens each
780            );
781        }
782
783        // Should have compressed some messages
784        assert!(history.messages.len() <= 10);
785
786        // Verify compression happened
787        assert!(
788            history.compressed_history.is_some() || history.current_tokens <= history.max_tokens
789        );
790
791        // After compression: tokens should be manageable
792        // 3 recent messages (~15 tokens) + summary (capped at 100) should be < 200 + 100
793        assert!(
794            history.current_tokens <= history.max_tokens + 150,
795            "current_tokens {} exceeded max_tokens {} + 150",
796            history.current_tokens,
797            history.max_tokens
798        );
799    }
800
801    #[test]
802    fn test_zstd_compression() {
803        let mut history = TokenEfficientHistory::new();
804        history.max_tokens = 50;
805        history.keep_recent = 2;
806
807        // Add many messages to trigger compression
808        for i in 0..20 {
809            history.add_message(MessageRole::User, format!("Test message number {}", i));
810        }
811
812        // Should have compressed history
813        assert!(history.compressed_history.is_some());
814
815        // Verify we can decompress
816        let decompressed = history.decompress_history();
817        assert!(decompressed.is_some());
818
819        let messages = decompressed.unwrap();
820        assert!(!messages.is_empty());
821    }
822
823    #[test]
824    fn test_compression_stats() {
825        let mut history = TokenEfficientHistory::new();
826        history.max_tokens = 30;
827        history.keep_recent = 2;
828
829        // Add messages
830        for i in 0..10 {
831            history.add_message(MessageRole::User, format!("Message {}", i));
832        }
833
834        let stats = history.get_compression_stats();
835        assert_eq!(stats.total_messages_added, 10);
836        assert!(stats.compressed_messages > 0 || stats.active_messages == 10);
837    }
838
839    #[test]
840    fn test_context_summary() {
841        let session_id = SessionId::new();
842        let mut context = SessionContext::new(session_id.clone());
843
844        context.add_message_raw(MessageRole::User, "Hello".to_string());
845        context.add_message_raw(MessageRole::Assistant, "Hi there!".to_string());
846
847        let summary = context.summarize();
848        assert_eq!(summary.session_id, session_id);
849        assert_eq!(summary.message_count, 2);
850    }
851
852    #[test]
853    fn test_new_api_methods() {
854        let session_id = SessionId::new();
855        let mut context = SessionContext::new(session_id.clone());
856
857        // Test message count
858        assert_eq!(context.get_message_count(), 0);
859
860        // Add a message using new API
861        let message = Message {
862            role: MessageRole::User,
863            content: "Test message".to_string(),
864            timestamp: Utc::now(),
865            token_count: 3,
866        };
867        context.add_message(message);
868
869        // Check message count and tokens
870        assert_eq!(context.get_message_count(), 1);
871        assert_eq!(context.get_total_tokens(), 3);
872
873        // Test get_recent_messages
874        let recent = context.get_recent_messages(1);
875        assert_eq!(recent.len(), 1);
876        assert_eq!(recent[0].content, "Test message");
877
878        // Test config field
879        assert_eq!(context.config.max_tokens, 100_000);
880    }
881
882    #[tokio::test]
883    async fn test_compress_context() {
884        let session_id = SessionId::new();
885        let mut context = SessionContext::new(session_id);
886
887        // Set a small token limit for testing
888        context.config.max_tokens = 50;
889        context.config.compression_threshold = 0.5;
890        context.conversation_history.max_tokens = 50;
891        context.conversation_history.keep_recent = 3;
892
893        // Add messages that exceed the limit
894        for i in 0..10 {
895            let message = Message {
896                role: MessageRole::User,
897                content: format!("Message {}", i),
898                timestamp: Utc::now(),
899                token_count: 10,
900            };
901            context.add_message(message);
902        }
903
904        // Should have auto-compressed during add
905        // Check compression stats
906        let stats = context.get_compression_stats();
907        assert!(stats.compressed_messages > 0 || stats.total_messages_added == 10);
908    }
909
910    #[test]
911    fn test_estimate_tokens() {
912        // Simple string
913        let tokens = estimate_tokens("Hello world");
914        assert!(tokens >= 2);
915
916        // Empty string
917        let tokens = estimate_tokens("");
918        assert_eq!(tokens, 1); // Minimum 1
919
920        // String with special characters
921        let tokens = estimate_tokens("Hello, world! How are you?");
922        assert!(tokens >= 5);
923
924        // Long string
925        let tokens = estimate_tokens(&"word ".repeat(100));
926        assert!(tokens >= 100);
927    }
928
929    #[test]
930    fn test_get_all_messages() {
931        let mut history = TokenEfficientHistory::new();
932        history.max_tokens = 20;
933        history.keep_recent = 2;
934
935        // Add messages
936        for i in 0..5 {
937            history.add_message(MessageRole::User, format!("Message {}", i));
938        }
939
940        // Get all messages (compressed + active)
941        let all = history.get_all_messages();
942        assert_eq!(all.len(), 5);
943    }
944}