ricecoder_research/
context_optimizer.rs

1//! Context optimization for fitting large files within token budgets
2
3use crate::models::FileContext;
4use crate::ResearchError;
5
6/// Optimizes context by summarizing large files to fit token budgets
7#[derive(Debug, Clone)]
8pub struct ContextOptimizer {
9    /// Maximum tokens per file
10    max_tokens_per_file: usize,
11    /// Minimum tokens to preserve for important sections
12    min_important_tokens: usize,
13}
14
15impl ContextOptimizer {
16    /// Create a new context optimizer
17    pub fn new(max_tokens_per_file: usize) -> Self {
18        ContextOptimizer {
19            max_tokens_per_file,
20            min_important_tokens: 100,
21        }
22    }
23
24    /// Optimize a file to fit within token budget
25    pub fn optimize_file(&self, file: &FileContext) -> Result<FileContext, ResearchError> {
26        let mut optimized = file.clone();
27
28        if let Some(content) = &file.content {
29            let tokens = self.estimate_tokens(content);
30
31            if tokens > self.max_tokens_per_file {
32                // Summarize the file
33                optimized.content = Some(self.summarize_content(content)?);
34            }
35        }
36
37        Ok(optimized)
38    }
39
40    /// Optimize multiple files
41    pub fn optimize_files(
42        &self,
43        files: Vec<FileContext>,
44    ) -> Result<Vec<FileContext>, ResearchError> {
45        files
46            .into_iter()
47            .map(|file| self.optimize_file(&file))
48            .collect()
49    }
50
51    /// Estimate token count for content (rough approximation: 1 token per 4 characters)
52    pub fn estimate_tokens(&self, content: &str) -> usize {
53        (content.len() / 4).max(1)
54    }
55
56    /// Summarize content to fit within token budget
57    fn summarize_content(&self, content: &str) -> Result<String, ResearchError> {
58        let lines: Vec<&str> = content.lines().collect();
59
60        if lines.is_empty() {
61            return Ok(String::new());
62        }
63
64        let mut summary = String::new();
65
66        // Include important sections: imports, type definitions, function signatures
67        let mut included_lines = Vec::new();
68
69        for (idx, line) in lines.iter().enumerate() {
70            let trimmed = line.trim();
71
72            // Always include imports
73            if trimmed.starts_with("use ") || trimmed.starts_with("import ") {
74                included_lines.push((idx, line));
75                continue;
76            }
77
78            // Include type definitions
79            if trimmed.starts_with("pub struct ")
80                || trimmed.starts_with("pub enum ")
81                || trimmed.starts_with("pub trait ")
82                || trimmed.starts_with("pub type ")
83                || trimmed.starts_with("struct ")
84                || trimmed.starts_with("enum ")
85                || trimmed.starts_with("trait ")
86                || trimmed.starts_with("type ")
87            {
88                included_lines.push((idx, line));
89                continue;
90            }
91
92            // Include function signatures
93            if trimmed.starts_with("pub fn ")
94                || trimmed.starts_with("pub async fn ")
95                || trimmed.starts_with("fn ")
96                || trimmed.starts_with("async fn ")
97            {
98                included_lines.push((idx, line));
99                continue;
100            }
101
102            // Include comments
103            if trimmed.starts_with("//") || trimmed.starts_with("/*") {
104                included_lines.push((idx, line));
105                continue;
106            }
107        }
108
109        // If we have important lines, use them
110        if !included_lines.is_empty() {
111            for (_, line) in included_lines {
112                summary.push_str(line);
113                summary.push('\n');
114            }
115
116            // Add ellipsis to indicate truncation
117            summary.push_str("\n// ... (content truncated for context window) ...\n");
118
119            // Check if summary fits within budget
120            if self.estimate_tokens(&summary) <= self.max_tokens_per_file {
121                return Ok(summary);
122            }
123        }
124
125        // Fallback: just take first N lines
126        let max_lines = (self.max_tokens_per_file * 4) / 50; // Rough estimate
127        let mut result = String::new();
128
129        for line in lines.iter().take(max_lines) {
130            result.push_str(line);
131            result.push('\n');
132        }
133
134        if lines.len() > max_lines {
135            result.push_str("\n// ... (content truncated for context window) ...\n");
136        }
137
138        Ok(result)
139    }
140
141    /// Extract key sections from content
142    pub fn extract_key_sections(&self, content: &str) -> Vec<String> {
143        let mut sections = Vec::new();
144        let lines: Vec<&str> = content.lines().collect();
145
146        let mut current_section = String::new();
147        let mut in_function = false;
148
149        for line in lines {
150            let trimmed = line.trim();
151
152            // Start of function
153            if trimmed.starts_with("pub fn ")
154                || trimmed.starts_with("pub async fn ")
155                || trimmed.starts_with("fn ")
156                || trimmed.starts_with("async fn ")
157            {
158                if !current_section.is_empty() {
159                    sections.push(current_section.clone());
160                    current_section.clear();
161                }
162                in_function = true;
163                current_section.push_str(line);
164                current_section.push('\n');
165            } else if in_function {
166                current_section.push_str(line);
167                current_section.push('\n');
168
169                // Simple heuristic: end of function (closing brace at start of line)
170                if trimmed == "}" {
171                    in_function = false;
172                    sections.push(current_section.clone());
173                    current_section.clear();
174                }
175            }
176        }
177
178        if !current_section.is_empty() {
179            sections.push(current_section);
180        }
181
182        sections
183    }
184
185    /// Get maximum tokens per file
186    pub fn max_tokens_per_file(&self) -> usize {
187        self.max_tokens_per_file
188    }
189
190    /// Set maximum tokens per file
191    pub fn set_max_tokens_per_file(&mut self, max_tokens: usize) {
192        self.max_tokens_per_file = max_tokens;
193    }
194
195    /// Get minimum important tokens
196    pub fn min_important_tokens(&self) -> usize {
197        self.min_important_tokens
198    }
199
200    /// Set minimum important tokens
201    pub fn set_min_important_tokens(&mut self, min_tokens: usize) {
202        self.min_important_tokens = min_tokens;
203    }
204}
205
206impl Default for ContextOptimizer {
207    fn default() -> Self {
208        ContextOptimizer::new(2048) // Default to 2K tokens per file
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use std::path::PathBuf;
216
217    #[test]
218    fn test_context_optimizer_creation() {
219        let optimizer = ContextOptimizer::new(2048);
220        assert_eq!(optimizer.max_tokens_per_file(), 2048);
221    }
222
223    #[test]
224    fn test_context_optimizer_default() {
225        let optimizer = ContextOptimizer::default();
226        assert_eq!(optimizer.max_tokens_per_file(), 2048);
227    }
228
229    #[test]
230    fn test_estimate_tokens() {
231        let optimizer = ContextOptimizer::new(2048);
232        let content = "x".repeat(400); // 400 chars = ~100 tokens
233        let tokens = optimizer.estimate_tokens(&content);
234        assert_eq!(tokens, 100);
235    }
236
237    #[test]
238    fn test_optimize_file_small_content() {
239        let optimizer = ContextOptimizer::new(2048);
240        let file = FileContext {
241            path: PathBuf::from("src/main.rs"),
242            relevance: 0.9,
243            summary: None,
244            content: Some("fn main() {}".to_string()),
245        };
246
247        let result = optimizer.optimize_file(&file);
248        assert!(result.is_ok());
249
250        let optimized = result.unwrap();
251        assert_eq!(optimized.content, file.content);
252    }
253
254    #[test]
255    fn test_optimize_file_large_content() {
256        let optimizer = ContextOptimizer::new(100); // Very small budget
257
258        let large_content = "fn main() {}\n".repeat(100); // Large content
259        let file = FileContext {
260            path: PathBuf::from("src/main.rs"),
261            relevance: 0.9,
262            summary: None,
263            content: Some(large_content),
264        };
265
266        let result = optimizer.optimize_file(&file);
267        assert!(result.is_ok());
268
269        let optimized = result.unwrap();
270        assert!(optimized.content.is_some());
271        // Optimized content should be smaller
272        let optimized_tokens = optimizer.estimate_tokens(optimized.content.as_ref().unwrap());
273        assert!(optimized_tokens <= 100);
274    }
275
276    #[test]
277    fn test_summarize_content_with_imports() {
278        let optimizer = ContextOptimizer::new(2048);
279        let content = "use std::path::PathBuf;\nuse std::collections::HashMap;\n\nfn main() {}\n";
280
281        let result = optimizer.summarize_content(content);
282        assert!(result.is_ok());
283
284        let summary = result.unwrap();
285        assert!(summary.contains("use std::path::PathBuf"));
286        assert!(summary.contains("use std::collections::HashMap"));
287    }
288
289    #[test]
290    fn test_summarize_content_with_types() {
291        let optimizer = ContextOptimizer::new(2048);
292        let content = "pub struct MyStruct {\n    field: String,\n}\n\nfn main() {}\n";
293
294        let result = optimizer.summarize_content(content);
295        assert!(result.is_ok());
296
297        let summary = result.unwrap();
298        assert!(summary.contains("pub struct MyStruct"));
299    }
300
301    #[test]
302    fn test_extract_key_sections() {
303        let optimizer = ContextOptimizer::new(2048);
304        let content =
305            "fn helper() {\n    println!(\"hello\");\n}\n\nfn main() {\n    helper();\n}\n";
306
307        let sections = optimizer.extract_key_sections(content);
308        assert!(!sections.is_empty());
309    }
310
311    #[test]
312    fn test_optimize_files() {
313        let optimizer = ContextOptimizer::new(2048);
314        let files = vec![
315            FileContext {
316                path: PathBuf::from("src/main.rs"),
317                relevance: 0.9,
318                summary: None,
319                content: Some("fn main() {}".to_string()),
320            },
321            FileContext {
322                path: PathBuf::from("src/lib.rs"),
323                relevance: 0.8,
324                summary: None,
325                content: Some("pub fn helper() {}".to_string()),
326            },
327        ];
328
329        let result = optimizer.optimize_files(files);
330        assert!(result.is_ok());
331
332        let optimized = result.unwrap();
333        assert_eq!(optimized.len(), 2);
334    }
335
336    #[test]
337    fn test_set_max_tokens_per_file() {
338        let mut optimizer = ContextOptimizer::new(2048);
339        optimizer.set_max_tokens_per_file(4096);
340        assert_eq!(optimizer.max_tokens_per_file(), 4096);
341    }
342
343    #[test]
344    fn test_set_min_important_tokens() {
345        let mut optimizer = ContextOptimizer::new(2048);
346        optimizer.set_min_important_tokens(200);
347        assert_eq!(optimizer.min_important_tokens(), 200);
348    }
349}