ricecoder_research/
context_provider.rs1use crate::context_builder::ContextBuilder;
4use crate::context_optimizer::ContextOptimizer;
5use crate::models::{CodeContext, FileContext};
6use crate::relevance_scorer::RelevanceScorer;
7use crate::ResearchError;
8
9#[derive(Debug, Clone)]
11pub struct ContextProvider {
12 context_builder: ContextBuilder,
14 context_optimizer: ContextOptimizer,
16 relevance_scorer: RelevanceScorer,
18}
19
20impl ContextProvider {
21 pub fn new(max_tokens: usize, max_tokens_per_file: usize) -> Self {
23 ContextProvider {
24 context_builder: ContextBuilder::new(max_tokens),
25 context_optimizer: ContextOptimizer::new(max_tokens_per_file),
26 relevance_scorer: RelevanceScorer::new(),
27 }
28 }
29
30 pub fn provide_context_for_generation(
32 &self,
33 query: &str,
34 available_files: Vec<FileContext>,
35 ) -> Result<CodeContext, ResearchError> {
36 let relevant_files = self
38 .context_builder
39 .select_relevant_files(query, available_files)?;
40
41 let optimized_files = self.context_optimizer.optimize_files(relevant_files)?;
43
44 self.context_builder.build_context(optimized_files)
46 }
47
48 pub fn provide_context_for_review(
50 &self,
51 query: &str,
52 available_files: Vec<FileContext>,
53 ) -> Result<CodeContext, ResearchError> {
54 let mut builder = self.context_builder.clone();
56 builder.set_max_tokens(builder.max_tokens() * 2); let relevant_files = builder.select_relevant_files(query, available_files)?;
60
61 let optimized_files = self.context_optimizer.optimize_files(relevant_files)?;
63
64 builder.build_context(optimized_files)
66 }
67
68 pub fn provide_context_for_refactoring(
70 &self,
71 query: &str,
72 available_files: Vec<FileContext>,
73 ) -> Result<CodeContext, ResearchError> {
74 let relevant_files = self
76 .context_builder
77 .select_relevant_files(query, available_files)?;
78
79 let optimized_files = self.context_optimizer.optimize_files(relevant_files)?;
81
82 self.context_builder.build_context(optimized_files)
84 }
85
86 pub fn provide_context_for_documentation(
88 &self,
89 query: &str,
90 available_files: Vec<FileContext>,
91 ) -> Result<CodeContext, ResearchError> {
92 let mut scored_files: Vec<(FileContext, f32)> = available_files
94 .into_iter()
95 .map(|file| {
96 let mut score = self.relevance_scorer.score_file(&file, query);
97 if file.summary.is_some() {
99 score += 0.2;
100 }
101 (file, score)
102 })
103 .collect();
104
105 scored_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107
108 let mut files: Vec<FileContext> = scored_files
110 .into_iter()
111 .map(|(mut file, score)| {
112 file.relevance = score;
113 file
114 })
115 .collect();
116
117 files.retain(|f| f.relevance > 0.0);
119
120 let optimized_files = self.context_optimizer.optimize_files(files)?;
122
123 self.context_builder.build_context(optimized_files)
125 }
126
127 pub fn context_builder(&self) -> &ContextBuilder {
129 &self.context_builder
130 }
131
132 pub fn context_optimizer(&self) -> &ContextOptimizer {
134 &self.context_optimizer
135 }
136
137 pub fn relevance_scorer(&self) -> &RelevanceScorer {
139 &self.relevance_scorer
140 }
141
142 pub fn set_max_tokens(&mut self, max_tokens: usize) {
144 self.context_builder.set_max_tokens(max_tokens);
145 }
146
147 pub fn set_max_tokens_per_file(&mut self, max_tokens: usize) {
149 self.context_optimizer.set_max_tokens_per_file(max_tokens);
150 }
151}
152
153impl Default for ContextProvider {
154 fn default() -> Self {
155 ContextProvider::new(4096, 2048) }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use std::path::PathBuf;
163
164 #[test]
165 fn test_context_provider_creation() {
166 let provider = ContextProvider::new(4096, 2048);
167 assert_eq!(provider.context_builder().max_tokens(), 4096);
168 assert_eq!(provider.context_optimizer().max_tokens_per_file(), 2048);
169 }
170
171 #[test]
172 fn test_context_provider_default() {
173 let provider = ContextProvider::default();
174 assert_eq!(provider.context_builder().max_tokens(), 4096);
175 assert_eq!(provider.context_optimizer().max_tokens_per_file(), 2048);
176 }
177
178 #[test]
179 fn test_provide_context_for_generation() {
180 let provider = ContextProvider::new(4096, 2048);
181
182 let files = vec![
183 FileContext {
184 path: PathBuf::from("src/main.rs"),
185 relevance: 0.0,
186 summary: Some("Main entry point".to_string()),
187 content: Some("fn main() {}".to_string()),
188 },
189 FileContext {
190 path: PathBuf::from("src/lib.rs"),
191 relevance: 0.0,
192 summary: Some("Library module".to_string()),
193 content: Some("pub fn helper() {}".to_string()),
194 },
195 ];
196
197 let result = provider.provide_context_for_generation("main", files);
198 assert!(result.is_ok());
199
200 let context = result.unwrap();
201 assert!(!context.files.is_empty());
202 }
203
204 #[test]
205 fn test_provide_context_for_review() {
206 let provider = ContextProvider::new(4096, 2048);
207
208 let files = vec![FileContext {
209 path: PathBuf::from("src/main.rs"),
210 relevance: 0.0,
211 summary: Some("Main entry point".to_string()),
212 content: Some("fn main() {}".to_string()),
213 }];
214
215 let result = provider.provide_context_for_review("main", files);
216 assert!(result.is_ok());
217
218 let context = result.unwrap();
219 assert!(!context.files.is_empty());
220 }
221
222 #[test]
223 fn test_provide_context_for_refactoring() {
224 let provider = ContextProvider::new(4096, 2048);
225
226 let files = vec![FileContext {
227 path: PathBuf::from("src/main.rs"),
228 relevance: 0.0,
229 summary: Some("Main entry point".to_string()),
230 content: Some("fn main() {}".to_string()),
231 }];
232
233 let result = provider.provide_context_for_refactoring("main", files);
234 assert!(result.is_ok());
235
236 let context = result.unwrap();
237 assert!(!context.files.is_empty());
238 }
239
240 #[test]
241 fn test_provide_context_for_documentation() {
242 let provider = ContextProvider::new(4096, 2048);
243
244 let files = vec![
245 FileContext {
246 path: PathBuf::from("src/main.rs"),
247 relevance: 0.0,
248 summary: Some("Main entry point".to_string()),
249 content: Some("fn main() {}".to_string()),
250 },
251 FileContext {
252 path: PathBuf::from("src/lib.rs"),
253 relevance: 0.0,
254 summary: None,
255 content: Some("pub fn helper() {}".to_string()),
256 },
257 ];
258
259 let result = provider.provide_context_for_documentation("main", files);
260 assert!(result.is_ok());
261
262 let context = result.unwrap();
263 assert!(!context.files.is_empty());
264 }
265
266 #[test]
267 fn test_set_max_tokens() {
268 let mut provider = ContextProvider::new(4096, 2048);
269 provider.set_max_tokens(8192);
270 assert_eq!(provider.context_builder().max_tokens(), 8192);
271 }
272
273 #[test]
274 fn test_set_max_tokens_per_file() {
275 let mut provider = ContextProvider::new(4096, 2048);
276 provider.set_max_tokens_per_file(4096);
277 assert_eq!(provider.context_optimizer().max_tokens_per_file(), 4096);
278 }
279
280 #[test]
281 fn test_provide_context_empty_files() {
282 let provider = ContextProvider::new(4096, 2048);
283 let result = provider.provide_context_for_generation("test", vec![]);
284 assert!(result.is_ok());
285
286 let context = result.unwrap();
287 assert!(context.files.is_empty());
288 }
289}