Skip to main content

reflex/semantic/
context.rs

1//! Codebase context extraction for semantic query generation
2//!
3//! This module extracts rich context about the indexed codebase to help LLMs
4//! generate better search queries. Context includes language distribution,
5//! directory structure, monorepo detection, and more.
6
7use crate::cache::CacheManager;
8use anyhow::{Context as AnyhowContext, Result};
9use rusqlite::Connection;
10use std::collections::HashMap;
11use std::path::Path;
12
13/// Comprehensive codebase context for LLM prompt injection
14#[derive(Debug, Clone)]
15pub struct CodebaseContext {
16    /// Total number of indexed files
17    pub total_files: usize,
18
19    /// Language distribution with counts and percentages
20    pub languages: Vec<LanguageInfo>,
21
22    /// Top-level directories (first path segment)
23    pub top_level_dirs: Vec<String>,
24
25    /// Common path patterns (depth 2-3) for framework-aware suggestions
26    pub common_paths: Vec<String>,
27
28    /// Whether this appears to be a monorepo
29    pub is_monorepo: bool,
30
31    /// Number of detected projects in monorepo (if applicable)
32    pub project_count: Option<usize>,
33
34    /// Dominant language (if any language is >60% of files)
35    pub dominant_language: Option<LanguageInfo>,
36}
37
38/// Language information with count and percentage
39#[derive(Debug, Clone)]
40pub struct LanguageInfo {
41    pub name: String,
42    pub file_count: usize,
43    pub percentage: f64,
44}
45
46impl CodebaseContext {
47    /// Extract comprehensive context from cache
48    pub fn extract(cache: &CacheManager) -> Result<Self> {
49        let db_path = cache.path().join("meta.db");
50        let conn =
51            Connection::open(&db_path).context("Failed to open database for context extraction")?;
52
53        // Get total file count
54        let total_files: usize = conn
55            .query_row("SELECT COUNT(*) FROM files", [], |row| row.get(0))
56            .unwrap_or(0);
57
58        // Extract language distribution
59        let languages = extract_language_distribution(&conn, total_files)?;
60
61        // Find dominant language (>60% of files)
62        let dominant_language = languages
63            .iter()
64            .find(|lang| lang.percentage > 60.0)
65            .cloned();
66
67        // Extract file paths for directory analysis
68        let file_paths = extract_file_paths(&conn)?;
69
70        // Analyze directory structure
71        let top_level_dirs = extract_top_level_dirs(&file_paths);
72        let common_paths = extract_common_paths(&file_paths, 2, 10); // depth 2-3, top 10
73
74        // Detect monorepo
75        let (is_monorepo, project_count) = detect_monorepo(&file_paths);
76
77        Ok(Self {
78            total_files,
79            languages,
80            top_level_dirs,
81            common_paths,
82            is_monorepo,
83            project_count,
84            dominant_language,
85        })
86    }
87
88    /// Format context as a human-readable string for LLM prompt injection
89    pub fn to_prompt_string(&self) -> String {
90        let mut parts = Vec::new();
91
92        // Language distribution (Tier 1)
93        if !self.languages.is_empty() {
94            let lang_summary: Vec<String> = self
95                .languages
96                .iter()
97                .map(|lang| {
98                    format!(
99                        "{} ({} files, {:.0}%)",
100                        lang.name, lang.file_count, lang.percentage
101                    )
102                })
103                .collect();
104            parts.push(format!("**Languages:** {}", lang_summary.join(", ")));
105        }
106
107        // File scale indicator (Tier 1)
108        let scale_hint = if self.total_files < 100 {
109            "small codebase - broad queries work well"
110        } else if self.total_files < 1000 {
111            "medium codebase - moderate specificity recommended"
112        } else {
113            "large codebase - use specific filters for best results"
114        };
115        parts.push(format!(
116            "**Total files:** {} ({})",
117            self.total_files, scale_hint
118        ));
119
120        // Top-level directories (Tier 1)
121        if !self.top_level_dirs.is_empty() {
122            parts.push(format!(
123                "**Top-level directories:** {}",
124                self.top_level_dirs.join(", ")
125            ));
126        }
127
128        // Dominant language (Tier 2)
129        if let Some(ref dominant) = self.dominant_language {
130            parts.push(format!(
131                "**Primary language:** {} ({:.0}% of codebase)",
132                dominant.name, dominant.percentage
133            ));
134        }
135
136        // Common paths (Tier 2)
137        if !self.common_paths.is_empty() {
138            let paths_str = self
139                .common_paths
140                .iter()
141                .take(8) // Limit to 8 most common
142                .map(|p| p.as_str())
143                .collect::<Vec<_>>()
144                .join(", ");
145            parts.push(format!("**Common paths:** {}", paths_str));
146        }
147
148        // Monorepo info (Tier 2)
149        if self.is_monorepo {
150            if let Some(count) = self.project_count {
151                parts.push(format!("**Monorepo:** Yes ({} projects detected - use --file to target specific projects)", count));
152            } else {
153                parts
154                    .push("**Monorepo:** Yes (use --file to target specific projects)".to_string());
155            }
156        }
157
158        parts.join("\n")
159    }
160}
161
162/// Extract language distribution with counts and percentages
163fn extract_language_distribution(
164    conn: &Connection,
165    total_files: usize,
166) -> Result<Vec<LanguageInfo>> {
167    let mut stmt = conn.prepare(
168        "SELECT language, COUNT(*) as count
169         FROM files
170         WHERE language IS NOT NULL
171         GROUP BY language
172         ORDER BY count DESC",
173    )?;
174
175    let languages = stmt
176        .query_map([], |row| {
177            let name: String = row.get(0)?;
178            let file_count: usize = row.get(1)?;
179            let percentage = if total_files > 0 {
180                (file_count as f64 / total_files as f64) * 100.0
181            } else {
182                0.0
183            };
184
185            Ok(LanguageInfo {
186                name,
187                file_count,
188                percentage,
189            })
190        })?
191        .collect::<Result<Vec<_>, _>>()?;
192
193    Ok(languages)
194}
195
196/// Extract all file paths from database
197fn extract_file_paths(conn: &Connection) -> Result<Vec<String>> {
198    let mut stmt = conn.prepare("SELECT path FROM files")?;
199    let paths = stmt
200        .query_map([], |row| row.get(0))?
201        .collect::<Result<Vec<_>, _>>()?;
202    Ok(paths)
203}
204
205/// Extract top-level directories (first path segment)
206fn extract_top_level_dirs(paths: &[String]) -> Vec<String> {
207    let mut dir_counts: HashMap<String, usize> = HashMap::new();
208
209    for path in paths {
210        if let Some(first_segment) = path.split('/').next() {
211            if !first_segment.is_empty() && !first_segment.starts_with('.') {
212                *dir_counts.entry(first_segment.to_string()).or_insert(0) += 1;
213            }
214        }
215    }
216
217    // Return top directories sorted by count (descending)
218    let mut dirs: Vec<(String, usize)> = dir_counts.into_iter().collect();
219    dirs.sort_by(|a, b| b.1.cmp(&a.1));
220
221    // Return top 10 directories with trailing slash
222    dirs.into_iter()
223        .take(10)
224        .map(|(dir, _)| format!("{}/", dir))
225        .collect()
226}
227
228/// Extract common path patterns at specified depth
229fn extract_common_paths(paths: &[String], min_depth: usize, max_results: usize) -> Vec<String> {
230    let mut path_counts: HashMap<String, usize> = HashMap::new();
231
232    for path in paths {
233        let segments: Vec<&str> = path.split('/').collect();
234
235        // Extract paths at depth 2 and 3
236        for depth in min_depth..=3 {
237            if segments.len() > depth {
238                let partial_path = segments[..=depth].join("/");
239                // Skip if it's just a filename (no directory structure)
240                if !partial_path.contains('/') {
241                    continue;
242                }
243                // Skip hidden directories and common noise
244                if partial_path.contains("/.")
245                    || partial_path.contains("/node_modules")
246                    || partial_path.contains("/vendor")
247                    || partial_path.contains("/target")
248                {
249                    continue;
250                }
251                *path_counts.entry(partial_path).or_insert(0) += 1;
252            }
253        }
254    }
255
256    // Filter to paths that appear at least 3 times (signal vs noise)
257    let min_count = 3;
258    let mut common_paths: Vec<(String, usize)> = path_counts
259        .into_iter()
260        .filter(|(_, count)| *count >= min_count)
261        .collect();
262
263    // Sort by count descending
264    common_paths.sort_by(|a, b| b.1.cmp(&a.1));
265
266    // Return top paths with trailing slash
267    common_paths
268        .into_iter()
269        .take(max_results)
270        .map(|(path, _)| format!("{}/", path))
271        .collect()
272}
273
274/// Detect if this is a monorepo by counting package manager files
275fn detect_monorepo(paths: &[String]) -> (bool, Option<usize>) {
276    let package_files = [
277        "package.json",
278        "Cargo.toml",
279        "go.mod",
280        "composer.json",
281        "pom.xml",
282        "build.gradle",
283        "Gemfile",
284    ];
285
286    let mut project_count = 0;
287
288    for path in paths {
289        let path_lower = path.to_lowercase();
290        for pkg_file in &package_files {
291            if path_lower.ends_with(pkg_file) {
292                // Skip root-level package files (not indicative of monorepo)
293                // Only count if in subdirectory (e.g., packages/foo/package.json)
294                if Path::new(path).components().count() > 2 {
295                    project_count += 1;
296                    break; // Don't double-count same project
297                }
298            }
299        }
300    }
301
302    let is_monorepo = project_count >= 2;
303    let project_count_opt = if is_monorepo {
304        Some(project_count)
305    } else {
306        None
307    };
308
309    (is_monorepo, project_count_opt)
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_extract_top_level_dirs() {
318        let paths = vec![
319            "src/main.rs".to_string(),
320            "src/lib.rs".to_string(),
321            "app/models/user.rb".to_string(),
322            "app/controllers/home.rb".to_string(),
323            "tests/test.rs".to_string(),
324        ];
325
326        let dirs = extract_top_level_dirs(&paths);
327        assert_eq!(dirs.len(), 3);
328        assert!(dirs.contains(&"src/".to_string()));
329        assert!(dirs.contains(&"app/".to_string()));
330        assert!(dirs.contains(&"tests/".to_string()));
331    }
332
333    #[test]
334    fn test_extract_common_paths() {
335        let paths = vec![
336            "app/models/user.rb".to_string(),
337            "app/models/post.rb".to_string(),
338            "app/models/comment.rb".to_string(),
339            "app/models/article.rb".to_string(),
340            "app/controllers/home.rb".to_string(),
341            "app/controllers/posts.rb".to_string(),
342            "app/controllers/articles.rb".to_string(),
343            "app/controllers/users.rb".to_string(),
344            "src/main.rs".to_string(),
345        ];
346
347        let common = extract_common_paths(&paths, 1, 10);
348        assert!(common.contains(&"app/models/".to_string()));
349        assert!(common.contains(&"app/controllers/".to_string()));
350    }
351
352    #[test]
353    fn test_detect_monorepo() {
354        let monorepo_paths = vec![
355            "packages/web/package.json".to_string(),
356            "packages/api/package.json".to_string(),
357            "packages/shared/package.json".to_string(),
358        ];
359
360        let (is_monorepo, count) = detect_monorepo(&monorepo_paths);
361        assert!(is_monorepo);
362        assert_eq!(count, Some(3));
363
364        let single_project = vec!["package.json".to_string(), "src/main.ts".to_string()];
365
366        let (is_mono, _) = detect_monorepo(&single_project);
367        assert!(!is_mono);
368    }
369
370    #[test]
371    fn test_language_percentage() {
372        let lang = LanguageInfo {
373            name: "Rust".to_string(),
374            file_count: 64,
375            percentage: 64.0,
376        };
377
378        assert_eq!(lang.percentage, 64.0);
379    }
380}