oxify_connect_llm/
compression.rs

1//! Prompt compression utilities for optimizing token usage
2//!
3//! This module provides tools to compress prompts, estimate token counts,
4//! and warn about model limits to reduce costs and improve performance.
5
6use std::collections::HashMap;
7
8/// Token limits for common models
9pub struct ModelLimits;
10
11impl ModelLimits {
12    /// Get the token limit for a given model
13    pub fn get_limit(model: &str) -> Option<u32> {
14        let limits: HashMap<&str, u32> = [
15            // OpenAI models
16            ("gpt-4", 8192),
17            ("gpt-4-32k", 32768),
18            ("gpt-4-turbo", 128000),
19            ("gpt-4-turbo-preview", 128000),
20            ("gpt-4o", 128000),
21            ("gpt-4o-mini", 128000),
22            ("o1-preview", 128000),
23            ("o1-mini", 128000),
24            ("gpt-3.5-turbo", 4096),
25            ("gpt-3.5-turbo-16k", 16384),
26            // Anthropic models
27            ("claude-3-opus-20240229", 200000),
28            ("claude-3-sonnet-20240229", 200000),
29            ("claude-3-haiku-20240307", 200000),
30            ("claude-3-5-sonnet-20241022", 200000),
31            ("claude-3-5-haiku-20241022", 200000),
32            ("claude-2.1", 200000),
33            ("claude-2.0", 100000),
34            ("claude-instant-1.2", 100000),
35            // Google models
36            ("gemini-pro", 32760),
37            ("gemini-1.5-pro", 1048576),
38            ("gemini-1.5-flash", 1048576),
39            ("gemini-2.0-flash", 1048576),
40            // Cohere models
41            ("command", 4096),
42            ("command-r", 128000),
43            ("command-r-plus", 128000),
44            // Mistral models
45            ("mistral-large-latest", 32000),
46            ("mistral-medium-latest", 32000),
47            ("mistral-small-latest", 32000),
48            ("open-mixtral-8x7b", 32000),
49        ]
50        .iter()
51        .copied()
52        .collect();
53
54        limits.get(model).copied()
55    }
56
57    /// Check if a model has a known limit
58    pub fn has_limit(model: &str) -> bool {
59        Self::get_limit(model).is_some()
60    }
61}
62
63/// Compression statistics
64#[derive(Debug, Clone, Default)]
65pub struct CompressionStats {
66    /// Original text length in characters
67    pub original_length: usize,
68    /// Compressed text length in characters
69    pub compressed_length: usize,
70    /// Estimated original token count
71    pub estimated_original_tokens: u32,
72    /// Estimated compressed token count
73    pub estimated_compressed_tokens: u32,
74    /// Compression ratio (compressed_length / original_length)
75    pub compression_ratio: f64,
76    /// Token savings (estimated_original_tokens - estimated_compressed_tokens)
77    pub token_savings: u32,
78}
79
80impl CompressionStats {
81    /// Create new compression statistics
82    pub fn new(
83        original: &str,
84        compressed: &str,
85        original_tokens: u32,
86        compressed_tokens: u32,
87    ) -> Self {
88        let original_length = original.len();
89        let compressed_length = compressed.len();
90
91        Self {
92            original_length,
93            compressed_length,
94            estimated_original_tokens: original_tokens,
95            estimated_compressed_tokens: compressed_tokens,
96            compression_ratio: if original_length > 0 {
97                compressed_length as f64 / original_length as f64
98            } else {
99                1.0
100            },
101            token_savings: original_tokens.saturating_sub(compressed_tokens),
102        }
103    }
104}
105
106/// Prompt compressor for optimizing token usage
107pub struct PromptCompressor {
108    /// Whether to remove extra whitespace
109    remove_whitespace: bool,
110    /// Whether to remove empty lines
111    remove_empty_lines: bool,
112    /// Whether to trim line endings
113    trim_lines: bool,
114}
115
116impl Default for PromptCompressor {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl PromptCompressor {
123    /// Create a new prompt compressor with default settings
124    pub fn new() -> Self {
125        Self {
126            remove_whitespace: true,
127            remove_empty_lines: true,
128            trim_lines: true,
129        }
130    }
131
132    /// Configure whitespace removal
133    pub fn with_whitespace_removal(mut self, remove: bool) -> Self {
134        self.remove_whitespace = remove;
135        self
136    }
137
138    /// Configure empty line removal
139    pub fn with_empty_line_removal(mut self, remove: bool) -> Self {
140        self.remove_empty_lines = remove;
141        self
142    }
143
144    /// Configure line trimming
145    pub fn with_line_trimming(mut self, trim: bool) -> Self {
146        self.trim_lines = trim;
147        self
148    }
149
150    /// Estimate token count for a text
151    ///
152    /// Uses a simple heuristic: ~4 characters per token for English text.
153    /// This is approximate and may vary by language and tokenizer.
154    pub fn estimate_tokens(text: &str) -> u32 {
155        // Simple heuristic: average of 4 characters per token
156        // This is conservative and works reasonably well for English
157        ((text.len() as f64) / 4.0).ceil() as u32
158    }
159
160    /// Compress a prompt to reduce token usage
161    ///
162    /// # Arguments
163    /// * `text` - The text to compress
164    ///
165    /// # Returns
166    /// A tuple of (compressed_text, compression_stats)
167    pub fn compress(&self, text: &str) -> (String, CompressionStats) {
168        let original_tokens = Self::estimate_tokens(text);
169        let mut result = text.to_string();
170
171        // Remove extra whitespace between words
172        if self.remove_whitespace {
173            result = self.normalize_whitespace(&result);
174        }
175
176        // Process lines
177        if self.remove_empty_lines || self.trim_lines {
178            let lines: Vec<String> = result
179                .lines()
180                .filter_map(|line| {
181                    let processed = if self.trim_lines { line.trim() } else { line };
182
183                    if self.remove_empty_lines && processed.is_empty() {
184                        None
185                    } else {
186                        Some(processed.to_string())
187                    }
188                })
189                .collect();
190
191            result = lines.join("\n");
192        }
193
194        // Final cleanup: ensure no trailing whitespace
195        result = result.trim().to_string();
196
197        let compressed_tokens = Self::estimate_tokens(&result);
198        let stats = CompressionStats::new(text, &result, original_tokens, compressed_tokens);
199
200        (result, stats)
201    }
202
203    /// Normalize whitespace by replacing multiple spaces with single space
204    fn normalize_whitespace(&self, text: &str) -> String {
205        let mut result = String::with_capacity(text.len());
206        let mut prev_was_space = false;
207
208        for ch in text.chars() {
209            if ch.is_whitespace() && ch != '\n' {
210                if !prev_was_space {
211                    result.push(' ');
212                    prev_was_space = true;
213                }
214            } else {
215                result.push(ch);
216                prev_was_space = false;
217            }
218        }
219
220        result
221    }
222
223    /// Check if a prompt exceeds the model's token limit
224    ///
225    /// # Arguments
226    /// * `text` - The text to check
227    /// * `model` - The model name
228    ///
229    /// # Returns
230    /// `Some(tokens)` if the limit is exceeded, `None` otherwise
231    pub fn check_limit(text: &str, model: &str) -> Option<u32> {
232        let estimated_tokens = Self::estimate_tokens(text);
233
234        if let Some(limit) = ModelLimits::get_limit(model) {
235            if estimated_tokens > limit {
236                return Some(estimated_tokens);
237            }
238        }
239
240        None
241    }
242
243    /// Get a warning message if the prompt exceeds the model limit
244    ///
245    /// # Arguments
246    /// * `text` - The text to check
247    /// * `model` - The model name
248    ///
249    /// # Returns
250    /// A warning message if the limit is exceeded, `None` otherwise
251    pub fn get_limit_warning(text: &str, model: &str) -> Option<String> {
252        if let Some(tokens) = Self::check_limit(text, model) {
253            if let Some(limit) = ModelLimits::get_limit(model) {
254                return Some(format!(
255                    "Prompt exceeds model limit: {} tokens (limit: {} tokens for {})",
256                    tokens, limit, model
257                ));
258            }
259        }
260
261        None
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_token_estimation() {
271        let text = "Hello, world!";
272        let tokens = PromptCompressor::estimate_tokens(text);
273        // "Hello, world!" is 13 chars, so ~3.25 tokens, rounds to 4
274        assert_eq!(tokens, 4);
275
276        let long_text = "This is a longer text with multiple words and punctuation.";
277        let long_tokens = PromptCompressor::estimate_tokens(long_text);
278        // 59 chars / 4 = 14.75, rounds to 15
279        assert_eq!(long_tokens, 15);
280    }
281
282    #[test]
283    fn test_whitespace_compression() {
284        let compressor = PromptCompressor::new();
285        let text = "Hello    world   with    extra    spaces";
286        let (compressed, stats) = compressor.compress(text);
287
288        assert_eq!(compressed, "Hello world with extra spaces");
289        assert!(stats.compressed_length < stats.original_length);
290        assert!(stats.compression_ratio < 1.0);
291    }
292
293    #[test]
294    fn test_empty_line_removal() {
295        let compressor = PromptCompressor::new();
296        let text = "Line 1\n\n\nLine 2\n\nLine 3";
297        let (compressed, _) = compressor.compress(text);
298
299        assert_eq!(compressed, "Line 1\nLine 2\nLine 3");
300    }
301
302    #[test]
303    fn test_line_trimming() {
304        let compressor = PromptCompressor::new();
305        let text = "  Line 1  \n  Line 2  \n  Line 3  ";
306        let (compressed, _) = compressor.compress(text);
307
308        assert_eq!(compressed, "Line 1\nLine 2\nLine 3");
309    }
310
311    #[test]
312    fn test_compression_disabled() {
313        let compressor = PromptCompressor::new()
314            .with_whitespace_removal(false)
315            .with_empty_line_removal(false)
316            .with_line_trimming(false);
317
318        let text = "Hello    world\n\n\ntest";
319        let (compressed, _) = compressor.compress(text);
320
321        // Should only trim the final string
322        assert_eq!(compressed, "Hello    world\n\n\ntest");
323    }
324
325    #[test]
326    fn test_model_limits() {
327        assert_eq!(ModelLimits::get_limit("gpt-4"), Some(8192));
328        assert_eq!(ModelLimits::get_limit("gpt-4-32k"), Some(32768));
329        assert_eq!(
330            ModelLimits::get_limit("claude-3-opus-20240229"),
331            Some(200000)
332        );
333        assert_eq!(ModelLimits::get_limit("unknown-model"), None);
334    }
335
336    #[test]
337    fn test_limit_checking() {
338        // Create a text that's definitely over 4096 tokens (~16384 chars)
339        let text = "x".repeat(20000);
340        let result = PromptCompressor::check_limit(&text, "gpt-3.5-turbo");
341
342        assert!(result.is_some());
343        assert!(result.unwrap() > 4096);
344    }
345
346    #[test]
347    fn test_limit_warning() {
348        let text = "x".repeat(20000);
349        let warning = PromptCompressor::get_limit_warning(&text, "gpt-3.5-turbo");
350
351        assert!(warning.is_some());
352        assert!(warning.unwrap().contains("exceeds model limit"));
353    }
354
355    #[test]
356    fn test_no_limit_warning() {
357        let text = "Short text";
358        let warning = PromptCompressor::get_limit_warning(text, "gpt-4");
359
360        assert!(warning.is_none());
361    }
362
363    #[test]
364    fn test_compression_stats() {
365        let compressor = PromptCompressor::new();
366        let text = "Hello    world   with    many    spaces\n\n\nand empty lines";
367        let (_, stats) = compressor.compress(text);
368
369        assert!(stats.original_length > stats.compressed_length);
370        assert!(stats.estimated_original_tokens > stats.estimated_compressed_tokens);
371        assert!(stats.token_savings > 0);
372        assert!(stats.compression_ratio < 1.0);
373        assert!(stats.compression_ratio > 0.0);
374    }
375
376    #[test]
377    fn test_compression_preserves_content() {
378        let compressor = PromptCompressor::new();
379        let text = "Important   data:  value1,   value2,   value3";
380        let (compressed, _) = compressor.compress(text);
381
382        // Content should be preserved, just whitespace reduced
383        assert!(compressed.contains("Important data: value1, value2, value3"));
384    }
385
386    #[test]
387    fn test_model_limit_has_limit() {
388        assert!(ModelLimits::has_limit("gpt-4"));
389        assert!(ModelLimits::has_limit("claude-3-opus-20240229"));
390        assert!(!ModelLimits::has_limit("unknown-model"));
391    }
392}