1use std::collections::HashMap;
7
8pub struct ModelLimits;
10
11impl ModelLimits {
12 pub fn get_limit(model: &str) -> Option<u32> {
14 let limits: HashMap<&str, u32> = [
15 ("gpt-4", 8192),
17 ("gpt-4-32k", 32768),
18 ("gpt-4-turbo", 128000),
19 ("gpt-4-turbo-preview", 128000),
20 ("gpt-4o", 128000),
21 ("gpt-4o-mini", 128000),
22 ("o1-preview", 128000),
23 ("o1-mini", 128000),
24 ("gpt-3.5-turbo", 4096),
25 ("gpt-3.5-turbo-16k", 16384),
26 ("claude-3-opus-20240229", 200000),
28 ("claude-3-sonnet-20240229", 200000),
29 ("claude-3-haiku-20240307", 200000),
30 ("claude-3-5-sonnet-20241022", 200000),
31 ("claude-3-5-haiku-20241022", 200000),
32 ("claude-2.1", 200000),
33 ("claude-2.0", 100000),
34 ("claude-instant-1.2", 100000),
35 ("gemini-pro", 32760),
37 ("gemini-1.5-pro", 1048576),
38 ("gemini-1.5-flash", 1048576),
39 ("gemini-2.0-flash", 1048576),
40 ("command", 4096),
42 ("command-r", 128000),
43 ("command-r-plus", 128000),
44 ("mistral-large-latest", 32000),
46 ("mistral-medium-latest", 32000),
47 ("mistral-small-latest", 32000),
48 ("open-mixtral-8x7b", 32000),
49 ]
50 .iter()
51 .copied()
52 .collect();
53
54 limits.get(model).copied()
55 }
56
57 pub fn has_limit(model: &str) -> bool {
59 Self::get_limit(model).is_some()
60 }
61}
62
63#[derive(Debug, Clone, Default)]
65pub struct CompressionStats {
66 pub original_length: usize,
68 pub compressed_length: usize,
70 pub estimated_original_tokens: u32,
72 pub estimated_compressed_tokens: u32,
74 pub compression_ratio: f64,
76 pub token_savings: u32,
78}
79
80impl CompressionStats {
81 pub fn new(
83 original: &str,
84 compressed: &str,
85 original_tokens: u32,
86 compressed_tokens: u32,
87 ) -> Self {
88 let original_length = original.len();
89 let compressed_length = compressed.len();
90
91 Self {
92 original_length,
93 compressed_length,
94 estimated_original_tokens: original_tokens,
95 estimated_compressed_tokens: compressed_tokens,
96 compression_ratio: if original_length > 0 {
97 compressed_length as f64 / original_length as f64
98 } else {
99 1.0
100 },
101 token_savings: original_tokens.saturating_sub(compressed_tokens),
102 }
103 }
104}
105
106pub struct PromptCompressor {
108 remove_whitespace: bool,
110 remove_empty_lines: bool,
112 trim_lines: bool,
114}
115
116impl Default for PromptCompressor {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl PromptCompressor {
123 pub fn new() -> Self {
125 Self {
126 remove_whitespace: true,
127 remove_empty_lines: true,
128 trim_lines: true,
129 }
130 }
131
132 pub fn with_whitespace_removal(mut self, remove: bool) -> Self {
134 self.remove_whitespace = remove;
135 self
136 }
137
138 pub fn with_empty_line_removal(mut self, remove: bool) -> Self {
140 self.remove_empty_lines = remove;
141 self
142 }
143
144 pub fn with_line_trimming(mut self, trim: bool) -> Self {
146 self.trim_lines = trim;
147 self
148 }
149
150 pub fn estimate_tokens(text: &str) -> u32 {
155 ((text.len() as f64) / 4.0).ceil() as u32
158 }
159
160 pub fn compress(&self, text: &str) -> (String, CompressionStats) {
168 let original_tokens = Self::estimate_tokens(text);
169 let mut result = text.to_string();
170
171 if self.remove_whitespace {
173 result = self.normalize_whitespace(&result);
174 }
175
176 if self.remove_empty_lines || self.trim_lines {
178 let lines: Vec<String> = result
179 .lines()
180 .filter_map(|line| {
181 let processed = if self.trim_lines { line.trim() } else { line };
182
183 if self.remove_empty_lines && processed.is_empty() {
184 None
185 } else {
186 Some(processed.to_string())
187 }
188 })
189 .collect();
190
191 result = lines.join("\n");
192 }
193
194 result = result.trim().to_string();
196
197 let compressed_tokens = Self::estimate_tokens(&result);
198 let stats = CompressionStats::new(text, &result, original_tokens, compressed_tokens);
199
200 (result, stats)
201 }
202
203 fn normalize_whitespace(&self, text: &str) -> String {
205 let mut result = String::with_capacity(text.len());
206 let mut prev_was_space = false;
207
208 for ch in text.chars() {
209 if ch.is_whitespace() && ch != '\n' {
210 if !prev_was_space {
211 result.push(' ');
212 prev_was_space = true;
213 }
214 } else {
215 result.push(ch);
216 prev_was_space = false;
217 }
218 }
219
220 result
221 }
222
223 pub fn check_limit(text: &str, model: &str) -> Option<u32> {
232 let estimated_tokens = Self::estimate_tokens(text);
233
234 if let Some(limit) = ModelLimits::get_limit(model) {
235 if estimated_tokens > limit {
236 return Some(estimated_tokens);
237 }
238 }
239
240 None
241 }
242
243 pub fn get_limit_warning(text: &str, model: &str) -> Option<String> {
252 if let Some(tokens) = Self::check_limit(text, model) {
253 if let Some(limit) = ModelLimits::get_limit(model) {
254 return Some(format!(
255 "Prompt exceeds model limit: {} tokens (limit: {} tokens for {})",
256 tokens, limit, model
257 ));
258 }
259 }
260
261 None
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_token_estimation() {
271 let text = "Hello, world!";
272 let tokens = PromptCompressor::estimate_tokens(text);
273 assert_eq!(tokens, 4);
275
276 let long_text = "This is a longer text with multiple words and punctuation.";
277 let long_tokens = PromptCompressor::estimate_tokens(long_text);
278 assert_eq!(long_tokens, 15);
280 }
281
282 #[test]
283 fn test_whitespace_compression() {
284 let compressor = PromptCompressor::new();
285 let text = "Hello world with extra spaces";
286 let (compressed, stats) = compressor.compress(text);
287
288 assert_eq!(compressed, "Hello world with extra spaces");
289 assert!(stats.compressed_length < stats.original_length);
290 assert!(stats.compression_ratio < 1.0);
291 }
292
293 #[test]
294 fn test_empty_line_removal() {
295 let compressor = PromptCompressor::new();
296 let text = "Line 1\n\n\nLine 2\n\nLine 3";
297 let (compressed, _) = compressor.compress(text);
298
299 assert_eq!(compressed, "Line 1\nLine 2\nLine 3");
300 }
301
302 #[test]
303 fn test_line_trimming() {
304 let compressor = PromptCompressor::new();
305 let text = " Line 1 \n Line 2 \n Line 3 ";
306 let (compressed, _) = compressor.compress(text);
307
308 assert_eq!(compressed, "Line 1\nLine 2\nLine 3");
309 }
310
311 #[test]
312 fn test_compression_disabled() {
313 let compressor = PromptCompressor::new()
314 .with_whitespace_removal(false)
315 .with_empty_line_removal(false)
316 .with_line_trimming(false);
317
318 let text = "Hello world\n\n\ntest";
319 let (compressed, _) = compressor.compress(text);
320
321 assert_eq!(compressed, "Hello world\n\n\ntest");
323 }
324
325 #[test]
326 fn test_model_limits() {
327 assert_eq!(ModelLimits::get_limit("gpt-4"), Some(8192));
328 assert_eq!(ModelLimits::get_limit("gpt-4-32k"), Some(32768));
329 assert_eq!(
330 ModelLimits::get_limit("claude-3-opus-20240229"),
331 Some(200000)
332 );
333 assert_eq!(ModelLimits::get_limit("unknown-model"), None);
334 }
335
336 #[test]
337 fn test_limit_checking() {
338 let text = "x".repeat(20000);
340 let result = PromptCompressor::check_limit(&text, "gpt-3.5-turbo");
341
342 assert!(result.is_some());
343 assert!(result.unwrap() > 4096);
344 }
345
346 #[test]
347 fn test_limit_warning() {
348 let text = "x".repeat(20000);
349 let warning = PromptCompressor::get_limit_warning(&text, "gpt-3.5-turbo");
350
351 assert!(warning.is_some());
352 assert!(warning.unwrap().contains("exceeds model limit"));
353 }
354
355 #[test]
356 fn test_no_limit_warning() {
357 let text = "Short text";
358 let warning = PromptCompressor::get_limit_warning(text, "gpt-4");
359
360 assert!(warning.is_none());
361 }
362
363 #[test]
364 fn test_compression_stats() {
365 let compressor = PromptCompressor::new();
366 let text = "Hello world with many spaces\n\n\nand empty lines";
367 let (_, stats) = compressor.compress(text);
368
369 assert!(stats.original_length > stats.compressed_length);
370 assert!(stats.estimated_original_tokens > stats.estimated_compressed_tokens);
371 assert!(stats.token_savings > 0);
372 assert!(stats.compression_ratio < 1.0);
373 assert!(stats.compression_ratio > 0.0);
374 }
375
376 #[test]
377 fn test_compression_preserves_content() {
378 let compressor = PromptCompressor::new();
379 let text = "Important data: value1, value2, value3";
380 let (compressed, _) = compressor.compress(text);
381
382 assert!(compressed.contains("Important data: value1, value2, value3"));
384 }
385
386 #[test]
387 fn test_model_limit_has_limit() {
388 assert!(ModelLimits::has_limit("gpt-4"));
389 assert!(ModelLimits::has_limit("claude-3-opus-20240229"));
390 assert!(!ModelLimits::has_limit("unknown-model"));
391 }
392}