ricecoder_research/
context_provider.rs

1//! Automatic context provision for AI providers
2
3use crate::context_builder::ContextBuilder;
4use crate::context_optimizer::ContextOptimizer;
5use crate::models::{CodeContext, FileContext};
6use crate::relevance_scorer::RelevanceScorer;
7use crate::ResearchError;
8
9/// Provides context automatically to AI providers
10#[derive(Debug, Clone)]
11pub struct ContextProvider {
12    /// Context builder for selecting files
13    context_builder: ContextBuilder,
14    /// Context optimizer for fitting within token budgets
15    context_optimizer: ContextOptimizer,
16    /// Relevance scorer for ranking files
17    relevance_scorer: RelevanceScorer,
18}
19
20impl ContextProvider {
21    /// Create a new context provider
22    pub fn new(max_tokens: usize, max_tokens_per_file: usize) -> Self {
23        ContextProvider {
24            context_builder: ContextBuilder::new(max_tokens),
25            context_optimizer: ContextOptimizer::new(max_tokens_per_file),
26            relevance_scorer: RelevanceScorer::new(),
27        }
28    }
29
30    /// Provide context for a code generation task
31    pub fn provide_context_for_generation(
32        &self,
33        query: &str,
34        available_files: Vec<FileContext>,
35    ) -> Result<CodeContext, ResearchError> {
36        // Select relevant files
37        let relevant_files = self
38            .context_builder
39            .select_relevant_files(query, available_files)?;
40
41        // Optimize files to fit within token budget
42        let optimized_files = self.context_optimizer.optimize_files(relevant_files)?;
43
44        // Build final context
45        self.context_builder.build_context(optimized_files)
46    }
47
48    /// Provide context for a code review task
49    pub fn provide_context_for_review(
50        &self,
51        query: &str,
52        available_files: Vec<FileContext>,
53    ) -> Result<CodeContext, ResearchError> {
54        // For review, we want to include more context
55        let mut builder = self.context_builder.clone();
56        builder.set_max_tokens(builder.max_tokens() * 2); // Double the budget for review
57
58        // Select relevant files
59        let relevant_files = builder.select_relevant_files(query, available_files)?;
60
61        // Optimize files
62        let optimized_files = self.context_optimizer.optimize_files(relevant_files)?;
63
64        // Build final context
65        builder.build_context(optimized_files)
66    }
67
68    /// Provide context for a refactoring task
69    pub fn provide_context_for_refactoring(
70        &self,
71        query: &str,
72        available_files: Vec<FileContext>,
73    ) -> Result<CodeContext, ResearchError> {
74        // For refactoring, include related files
75        let relevant_files = self
76            .context_builder
77            .select_relevant_files(query, available_files)?;
78
79        // Optimize files
80        let optimized_files = self.context_optimizer.optimize_files(relevant_files)?;
81
82        // Build final context
83        self.context_builder.build_context(optimized_files)
84    }
85
86    /// Provide context for a documentation task
87    pub fn provide_context_for_documentation(
88        &self,
89        query: &str,
90        available_files: Vec<FileContext>,
91    ) -> Result<CodeContext, ResearchError> {
92        // For documentation, prioritize files with summaries
93        let mut scored_files: Vec<(FileContext, f32)> = available_files
94            .into_iter()
95            .map(|file| {
96                let mut score = self.relevance_scorer.score_file(&file, query);
97                // Boost score if file has a summary
98                if file.summary.is_some() {
99                    score += 0.2;
100                }
101                (file, score)
102            })
103            .collect();
104
105        // Sort by score
106        scored_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107
108        // Extract files
109        let mut files: Vec<FileContext> = scored_files
110            .into_iter()
111            .map(|(mut file, score)| {
112                file.relevance = score;
113                file
114            })
115            .collect();
116
117        // Filter out files with zero relevance
118        files.retain(|f| f.relevance > 0.0);
119
120        // Optimize files
121        let optimized_files = self.context_optimizer.optimize_files(files)?;
122
123        // Build final context
124        self.context_builder.build_context(optimized_files)
125    }
126
127    /// Get the context builder
128    pub fn context_builder(&self) -> &ContextBuilder {
129        &self.context_builder
130    }
131
132    /// Get the context optimizer
133    pub fn context_optimizer(&self) -> &ContextOptimizer {
134        &self.context_optimizer
135    }
136
137    /// Get the relevance scorer
138    pub fn relevance_scorer(&self) -> &RelevanceScorer {
139        &self.relevance_scorer
140    }
141
142    /// Set maximum tokens
143    pub fn set_max_tokens(&mut self, max_tokens: usize) {
144        self.context_builder.set_max_tokens(max_tokens);
145    }
146
147    /// Set maximum tokens per file
148    pub fn set_max_tokens_per_file(&mut self, max_tokens: usize) {
149        self.context_optimizer.set_max_tokens_per_file(max_tokens);
150    }
151}
152
153impl Default for ContextProvider {
154    fn default() -> Self {
155        ContextProvider::new(4096, 2048) // Default: 4K total, 2K per file
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use std::path::PathBuf;
163
164    #[test]
165    fn test_context_provider_creation() {
166        let provider = ContextProvider::new(4096, 2048);
167        assert_eq!(provider.context_builder().max_tokens(), 4096);
168        assert_eq!(provider.context_optimizer().max_tokens_per_file(), 2048);
169    }
170
171    #[test]
172    fn test_context_provider_default() {
173        let provider = ContextProvider::default();
174        assert_eq!(provider.context_builder().max_tokens(), 4096);
175        assert_eq!(provider.context_optimizer().max_tokens_per_file(), 2048);
176    }
177
178    #[test]
179    fn test_provide_context_for_generation() {
180        let provider = ContextProvider::new(4096, 2048);
181
182        let files = vec![
183            FileContext {
184                path: PathBuf::from("src/main.rs"),
185                relevance: 0.0,
186                summary: Some("Main entry point".to_string()),
187                content: Some("fn main() {}".to_string()),
188            },
189            FileContext {
190                path: PathBuf::from("src/lib.rs"),
191                relevance: 0.0,
192                summary: Some("Library module".to_string()),
193                content: Some("pub fn helper() {}".to_string()),
194            },
195        ];
196
197        let result = provider.provide_context_for_generation("main", files);
198        assert!(result.is_ok());
199
200        let context = result.unwrap();
201        assert!(!context.files.is_empty());
202    }
203
204    #[test]
205    fn test_provide_context_for_review() {
206        let provider = ContextProvider::new(4096, 2048);
207
208        let files = vec![FileContext {
209            path: PathBuf::from("src/main.rs"),
210            relevance: 0.0,
211            summary: Some("Main entry point".to_string()),
212            content: Some("fn main() {}".to_string()),
213        }];
214
215        let result = provider.provide_context_for_review("main", files);
216        assert!(result.is_ok());
217
218        let context = result.unwrap();
219        assert!(!context.files.is_empty());
220    }
221
222    #[test]
223    fn test_provide_context_for_refactoring() {
224        let provider = ContextProvider::new(4096, 2048);
225
226        let files = vec![FileContext {
227            path: PathBuf::from("src/main.rs"),
228            relevance: 0.0,
229            summary: Some("Main entry point".to_string()),
230            content: Some("fn main() {}".to_string()),
231        }];
232
233        let result = provider.provide_context_for_refactoring("main", files);
234        assert!(result.is_ok());
235
236        let context = result.unwrap();
237        assert!(!context.files.is_empty());
238    }
239
240    #[test]
241    fn test_provide_context_for_documentation() {
242        let provider = ContextProvider::new(4096, 2048);
243
244        let files = vec![
245            FileContext {
246                path: PathBuf::from("src/main.rs"),
247                relevance: 0.0,
248                summary: Some("Main entry point".to_string()),
249                content: Some("fn main() {}".to_string()),
250            },
251            FileContext {
252                path: PathBuf::from("src/lib.rs"),
253                relevance: 0.0,
254                summary: None,
255                content: Some("pub fn helper() {}".to_string()),
256            },
257        ];
258
259        let result = provider.provide_context_for_documentation("main", files);
260        assert!(result.is_ok());
261
262        let context = result.unwrap();
263        assert!(!context.files.is_empty());
264    }
265
266    #[test]
267    fn test_set_max_tokens() {
268        let mut provider = ContextProvider::new(4096, 2048);
269        provider.set_max_tokens(8192);
270        assert_eq!(provider.context_builder().max_tokens(), 8192);
271    }
272
273    #[test]
274    fn test_set_max_tokens_per_file() {
275        let mut provider = ContextProvider::new(4096, 2048);
276        provider.set_max_tokens_per_file(4096);
277        assert_eq!(provider.context_optimizer().max_tokens_per_file(), 4096);
278    }
279
280    #[test]
281    fn test_provide_context_empty_files() {
282        let provider = ContextProvider::new(4096, 2048);
283        let result = provider.provide_context_for_generation("test", vec![]);
284        assert!(result.is_ok());
285
286        let context = result.unwrap();
287        assert!(context.files.is_empty());
288    }
289}