ricecoder_research/
context_builder.rs1use crate::models::{CodeContext, FileContext};
4use crate::semantic_index::SemanticIndex;
5use crate::ResearchError;
6
7#[derive(Debug, Clone)]
9pub struct ContextBuilder {
10 max_tokens: usize,
12 semantic_index: Option<SemanticIndex>,
14}
15
16impl ContextBuilder {
17 pub fn new(max_tokens: usize) -> Self {
19 ContextBuilder {
20 max_tokens,
21 semantic_index: None,
22 }
23 }
24
25 pub fn with_semantic_index(mut self, index: SemanticIndex) -> Self {
27 self.semantic_index = Some(index);
28 self
29 }
30
31 pub fn select_relevant_files(
33 &self,
34 query: &str,
35 all_files: Vec<FileContext>,
36 ) -> Result<Vec<FileContext>, ResearchError> {
37 if all_files.is_empty() {
38 return Ok(Vec::new());
39 }
40
41 let mut scored_files: Vec<(FileContext, f32)> = all_files
43 .into_iter()
44 .map(|file| {
45 let relevance = self.calculate_file_relevance(&file, query);
46 (file, relevance)
47 })
48 .collect();
49
50 scored_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
52
53 let mut result: Vec<FileContext> = scored_files
55 .into_iter()
56 .map(|(mut file, relevance)| {
57 file.relevance = relevance;
58 file
59 })
60 .collect();
61
62 result.retain(|f| f.relevance > 0.0);
64
65 Ok(result)
66 }
67
68 pub fn build_context(&self, files: Vec<FileContext>) -> Result<CodeContext, ResearchError> {
70 let mut total_tokens = 0;
71 let mut included_files = Vec::new();
72 let mut all_symbols = Vec::new();
73 let mut all_references = Vec::new();
74
75 for file in files {
76 let file_tokens = file
78 .content
79 .as_ref()
80 .map(|c| (c.len() / 4).max(1))
81 .unwrap_or(0);
82
83 if total_tokens + file_tokens > self.max_tokens && !included_files.is_empty() {
85 break;
86 }
87
88 total_tokens += file_tokens;
89 included_files.push(file);
90 }
91
92 if let Some(index) = &self.semantic_index {
94 for file in &included_files {
95 let symbols = index.get_symbols_in_file(&file.path);
96 for symbol in symbols {
97 all_symbols.push(symbol.clone());
98 let refs = index.get_references_to_symbol(&symbol.id);
99 for reference in refs {
100 all_references.push(reference.clone());
101 }
102 }
103 }
104 }
105
106 Ok(CodeContext {
107 files: included_files,
108 symbols: all_symbols,
109 references: all_references,
110 total_tokens,
111 })
112 }
113
114 fn calculate_file_relevance(&self, file: &FileContext, query: &str) -> f32 {
116 let mut score: f32 = 0.0;
117
118 let path_str = file.path.to_string_lossy().to_lowercase();
120 let query_lower = query.to_lowercase();
121
122 if path_str.contains(&query_lower) {
123 score += 0.5;
124 }
125
126 if let Some(content) = &file.content {
128 let content_lower = content.to_lowercase();
129 let query_words: Vec<&str> = query_lower.split_whitespace().collect();
130
131 for word in query_words {
132 if content_lower.contains(word) {
133 score += 0.1;
134 }
135 }
136 }
137
138 if let Some(summary) = &file.summary {
140 let summary_lower = summary.to_lowercase();
141 if summary_lower.contains(&query_lower) {
142 score += 0.3;
143 }
144 }
145
146 score.min(1.0)
148 }
149
150 pub fn max_tokens(&self) -> usize {
152 self.max_tokens
153 }
154
155 pub fn set_max_tokens(&mut self, max_tokens: usize) {
157 self.max_tokens = max_tokens;
158 }
159}
160
161impl Default for ContextBuilder {
162 fn default() -> Self {
163 ContextBuilder::new(4096) }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use std::path::PathBuf;
171
172 #[test]
173 fn test_context_builder_creation() {
174 let builder = ContextBuilder::new(8192);
175 assert_eq!(builder.max_tokens(), 8192);
176 }
177
178 #[test]
179 fn test_context_builder_default() {
180 let builder = ContextBuilder::default();
181 assert_eq!(builder.max_tokens(), 4096);
182 }
183
184 #[test]
185 fn test_select_relevant_files_empty() {
186 let builder = ContextBuilder::new(4096);
187 let result = builder.select_relevant_files("test", vec![]);
188 assert!(result.is_ok());
189 assert!(result.unwrap().is_empty());
190 }
191
192 #[test]
193 fn test_select_relevant_files_with_query() {
194 let builder = ContextBuilder::new(4096);
195
196 let file1 = FileContext {
197 path: PathBuf::from("src/main.rs"),
198 relevance: 0.0,
199 summary: Some("Main entry point".to_string()),
200 content: Some("fn main() {}".to_string()),
201 };
202
203 let file2 = FileContext {
204 path: PathBuf::from("src/lib.rs"),
205 relevance: 0.0,
206 summary: Some("Library module".to_string()),
207 content: Some("pub fn helper() {}".to_string()),
208 };
209
210 let result = builder.select_relevant_files("main", vec![file1, file2]);
211 assert!(result.is_ok());
212
213 let files = result.unwrap();
214 assert!(!files.is_empty());
215 assert!(files[0].relevance > 0.0);
217 }
218
219 #[test]
220 fn test_build_context_respects_token_budget() {
221 let builder = ContextBuilder::new(100); let file1 = FileContext {
224 path: PathBuf::from("src/file1.rs"),
225 relevance: 0.9,
226 summary: None,
227 content: Some("x".repeat(200)), };
229
230 let file2 = FileContext {
231 path: PathBuf::from("src/file2.rs"),
232 relevance: 0.8,
233 summary: None,
234 content: Some("y".repeat(200)), };
236
237 let result = builder.build_context(vec![file1, file2]);
238 assert!(result.is_ok());
239
240 let context = result.unwrap();
241 assert!(context.total_tokens <= 100);
243 }
244
245 #[test]
246 fn test_calculate_file_relevance_path_match() {
247 let builder = ContextBuilder::new(4096);
248 let file = FileContext {
249 path: PathBuf::from("src/utils.rs"),
250 relevance: 0.0,
251 summary: None,
252 content: None,
253 };
254
255 let relevance = builder.calculate_file_relevance(&file, "utils");
256 assert!(relevance > 0.0);
257 }
258
259 #[test]
260 fn test_calculate_file_relevance_content_match() {
261 let builder = ContextBuilder::new(4096);
262 let file = FileContext {
263 path: PathBuf::from("src/main.rs"),
264 relevance: 0.0,
265 summary: None,
266 content: Some("fn helper_function() {}".to_string()),
267 };
268
269 let relevance = builder.calculate_file_relevance(&file, "helper");
270 assert!(relevance > 0.0);
271 }
272
273 #[test]
274 fn test_calculate_file_relevance_summary_match() {
275 let builder = ContextBuilder::new(4096);
276 let file = FileContext {
277 path: PathBuf::from("src/main.rs"),
278 relevance: 0.0,
279 summary: Some("Utility functions for parsing".to_string()),
280 content: None,
281 };
282
283 let relevance = builder.calculate_file_relevance(&file, "parsing");
284 assert!(relevance > 0.0);
285 }
286
287 #[test]
288 fn test_calculate_file_relevance_no_match() {
289 let builder = ContextBuilder::new(4096);
290 let file = FileContext {
291 path: PathBuf::from("src/main.rs"),
292 relevance: 0.0,
293 summary: None,
294 content: Some("fn main() {}".to_string()),
295 };
296
297 let relevance = builder.calculate_file_relevance(&file, "nonexistent");
298 assert_eq!(relevance, 0.0);
299 }
300}