ricecoder_research/
context_optimizer.rs1use crate::models::FileContext;
4use crate::ResearchError;
5
6#[derive(Debug, Clone)]
8pub struct ContextOptimizer {
9 max_tokens_per_file: usize,
11 min_important_tokens: usize,
13}
14
15impl ContextOptimizer {
16 pub fn new(max_tokens_per_file: usize) -> Self {
18 ContextOptimizer {
19 max_tokens_per_file,
20 min_important_tokens: 100,
21 }
22 }
23
24 pub fn optimize_file(&self, file: &FileContext) -> Result<FileContext, ResearchError> {
26 let mut optimized = file.clone();
27
28 if let Some(content) = &file.content {
29 let tokens = self.estimate_tokens(content);
30
31 if tokens > self.max_tokens_per_file {
32 optimized.content = Some(self.summarize_content(content)?);
34 }
35 }
36
37 Ok(optimized)
38 }
39
40 pub fn optimize_files(
42 &self,
43 files: Vec<FileContext>,
44 ) -> Result<Vec<FileContext>, ResearchError> {
45 files
46 .into_iter()
47 .map(|file| self.optimize_file(&file))
48 .collect()
49 }
50
51 pub fn estimate_tokens(&self, content: &str) -> usize {
53 (content.len() / 4).max(1)
54 }
55
56 fn summarize_content(&self, content: &str) -> Result<String, ResearchError> {
58 let lines: Vec<&str> = content.lines().collect();
59
60 if lines.is_empty() {
61 return Ok(String::new());
62 }
63
64 let mut summary = String::new();
65
66 let mut included_lines = Vec::new();
68
69 for (idx, line) in lines.iter().enumerate() {
70 let trimmed = line.trim();
71
72 if trimmed.starts_with("use ") || trimmed.starts_with("import ") {
74 included_lines.push((idx, line));
75 continue;
76 }
77
78 if trimmed.starts_with("pub struct ")
80 || trimmed.starts_with("pub enum ")
81 || trimmed.starts_with("pub trait ")
82 || trimmed.starts_with("pub type ")
83 || trimmed.starts_with("struct ")
84 || trimmed.starts_with("enum ")
85 || trimmed.starts_with("trait ")
86 || trimmed.starts_with("type ")
87 {
88 included_lines.push((idx, line));
89 continue;
90 }
91
92 if trimmed.starts_with("pub fn ")
94 || trimmed.starts_with("pub async fn ")
95 || trimmed.starts_with("fn ")
96 || trimmed.starts_with("async fn ")
97 {
98 included_lines.push((idx, line));
99 continue;
100 }
101
102 if trimmed.starts_with("//") || trimmed.starts_with("/*") {
104 included_lines.push((idx, line));
105 continue;
106 }
107 }
108
109 if !included_lines.is_empty() {
111 for (_, line) in included_lines {
112 summary.push_str(line);
113 summary.push('\n');
114 }
115
116 summary.push_str("\n// ... (content truncated for context window) ...\n");
118
119 if self.estimate_tokens(&summary) <= self.max_tokens_per_file {
121 return Ok(summary);
122 }
123 }
124
125 let max_lines = (self.max_tokens_per_file * 4) / 50; let mut result = String::new();
128
129 for line in lines.iter().take(max_lines) {
130 result.push_str(line);
131 result.push('\n');
132 }
133
134 if lines.len() > max_lines {
135 result.push_str("\n// ... (content truncated for context window) ...\n");
136 }
137
138 Ok(result)
139 }
140
141 pub fn extract_key_sections(&self, content: &str) -> Vec<String> {
143 let mut sections = Vec::new();
144 let lines: Vec<&str> = content.lines().collect();
145
146 let mut current_section = String::new();
147 let mut in_function = false;
148
149 for line in lines {
150 let trimmed = line.trim();
151
152 if trimmed.starts_with("pub fn ")
154 || trimmed.starts_with("pub async fn ")
155 || trimmed.starts_with("fn ")
156 || trimmed.starts_with("async fn ")
157 {
158 if !current_section.is_empty() {
159 sections.push(current_section.clone());
160 current_section.clear();
161 }
162 in_function = true;
163 current_section.push_str(line);
164 current_section.push('\n');
165 } else if in_function {
166 current_section.push_str(line);
167 current_section.push('\n');
168
169 if trimmed == "}" {
171 in_function = false;
172 sections.push(current_section.clone());
173 current_section.clear();
174 }
175 }
176 }
177
178 if !current_section.is_empty() {
179 sections.push(current_section);
180 }
181
182 sections
183 }
184
185 pub fn max_tokens_per_file(&self) -> usize {
187 self.max_tokens_per_file
188 }
189
190 pub fn set_max_tokens_per_file(&mut self, max_tokens: usize) {
192 self.max_tokens_per_file = max_tokens;
193 }
194
195 pub fn min_important_tokens(&self) -> usize {
197 self.min_important_tokens
198 }
199
200 pub fn set_min_important_tokens(&mut self, min_tokens: usize) {
202 self.min_important_tokens = min_tokens;
203 }
204}
205
206impl Default for ContextOptimizer {
207 fn default() -> Self {
208 ContextOptimizer::new(2048) }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use std::path::PathBuf;
216
217 #[test]
218 fn test_context_optimizer_creation() {
219 let optimizer = ContextOptimizer::new(2048);
220 assert_eq!(optimizer.max_tokens_per_file(), 2048);
221 }
222
223 #[test]
224 fn test_context_optimizer_default() {
225 let optimizer = ContextOptimizer::default();
226 assert_eq!(optimizer.max_tokens_per_file(), 2048);
227 }
228
229 #[test]
230 fn test_estimate_tokens() {
231 let optimizer = ContextOptimizer::new(2048);
232 let content = "x".repeat(400); let tokens = optimizer.estimate_tokens(&content);
234 assert_eq!(tokens, 100);
235 }
236
237 #[test]
238 fn test_optimize_file_small_content() {
239 let optimizer = ContextOptimizer::new(2048);
240 let file = FileContext {
241 path: PathBuf::from("src/main.rs"),
242 relevance: 0.9,
243 summary: None,
244 content: Some("fn main() {}".to_string()),
245 };
246
247 let result = optimizer.optimize_file(&file);
248 assert!(result.is_ok());
249
250 let optimized = result.unwrap();
251 assert_eq!(optimized.content, file.content);
252 }
253
254 #[test]
255 fn test_optimize_file_large_content() {
256 let optimizer = ContextOptimizer::new(100); let large_content = "fn main() {}\n".repeat(100); let file = FileContext {
260 path: PathBuf::from("src/main.rs"),
261 relevance: 0.9,
262 summary: None,
263 content: Some(large_content),
264 };
265
266 let result = optimizer.optimize_file(&file);
267 assert!(result.is_ok());
268
269 let optimized = result.unwrap();
270 assert!(optimized.content.is_some());
271 let optimized_tokens = optimizer.estimate_tokens(optimized.content.as_ref().unwrap());
273 assert!(optimized_tokens <= 100);
274 }
275
276 #[test]
277 fn test_summarize_content_with_imports() {
278 let optimizer = ContextOptimizer::new(2048);
279 let content = "use std::path::PathBuf;\nuse std::collections::HashMap;\n\nfn main() {}\n";
280
281 let result = optimizer.summarize_content(content);
282 assert!(result.is_ok());
283
284 let summary = result.unwrap();
285 assert!(summary.contains("use std::path::PathBuf"));
286 assert!(summary.contains("use std::collections::HashMap"));
287 }
288
289 #[test]
290 fn test_summarize_content_with_types() {
291 let optimizer = ContextOptimizer::new(2048);
292 let content = "pub struct MyStruct {\n field: String,\n}\n\nfn main() {}\n";
293
294 let result = optimizer.summarize_content(content);
295 assert!(result.is_ok());
296
297 let summary = result.unwrap();
298 assert!(summary.contains("pub struct MyStruct"));
299 }
300
301 #[test]
302 fn test_extract_key_sections() {
303 let optimizer = ContextOptimizer::new(2048);
304 let content =
305 "fn helper() {\n println!(\"hello\");\n}\n\nfn main() {\n helper();\n}\n";
306
307 let sections = optimizer.extract_key_sections(content);
308 assert!(!sections.is_empty());
309 }
310
311 #[test]
312 fn test_optimize_files() {
313 let optimizer = ContextOptimizer::new(2048);
314 let files = vec![
315 FileContext {
316 path: PathBuf::from("src/main.rs"),
317 relevance: 0.9,
318 summary: None,
319 content: Some("fn main() {}".to_string()),
320 },
321 FileContext {
322 path: PathBuf::from("src/lib.rs"),
323 relevance: 0.8,
324 summary: None,
325 content: Some("pub fn helper() {}".to_string()),
326 },
327 ];
328
329 let result = optimizer.optimize_files(files);
330 assert!(result.is_ok());
331
332 let optimized = result.unwrap();
333 assert_eq!(optimized.len(), 2);
334 }
335
336 #[test]
337 fn test_set_max_tokens_per_file() {
338 let mut optimizer = ContextOptimizer::new(2048);
339 optimizer.set_max_tokens_per_file(4096);
340 assert_eq!(optimizer.max_tokens_per_file(), 4096);
341 }
342
343 #[test]
344 fn test_set_min_important_tokens() {
345 let mut optimizer = ContextOptimizer::new(2048);
346 optimizer.set_min_important_tokens(200);
347 assert_eq!(optimizer.min_important_tokens(), 200);
348 }
349}