reflex/semantic/
context.rs

1//! Codebase context extraction for semantic query generation
2//!
3//! This module extracts rich context about the indexed codebase to help LLMs
4//! generate better search queries. Context includes language distribution,
5//! directory structure, monorepo detection, and more.
6
7use crate::cache::CacheManager;
8use anyhow::{Context as AnyhowContext, Result};
9use rusqlite::Connection;
10use std::collections::HashMap;
11use std::path::Path;
12
13/// Comprehensive codebase context for LLM prompt injection
14#[derive(Debug, Clone)]
15pub struct CodebaseContext {
16    /// Total number of indexed files
17    pub total_files: usize,
18
19    /// Language distribution with counts and percentages
20    pub languages: Vec<LanguageInfo>,
21
22    /// Top-level directories (first path segment)
23    pub top_level_dirs: Vec<String>,
24
25    /// Common path patterns (depth 2-3) for framework-aware suggestions
26    pub common_paths: Vec<String>,
27
28    /// Whether this appears to be a monorepo
29    pub is_monorepo: bool,
30
31    /// Number of detected projects in monorepo (if applicable)
32    pub project_count: Option<usize>,
33
34    /// Dominant language (if any language is >60% of files)
35    pub dominant_language: Option<LanguageInfo>,
36}
37
38/// Language information with count and percentage
39#[derive(Debug, Clone)]
40pub struct LanguageInfo {
41    pub name: String,
42    pub file_count: usize,
43    pub percentage: f64,
44}
45
46impl CodebaseContext {
47    /// Extract comprehensive context from cache
48    pub fn extract(cache: &CacheManager) -> Result<Self> {
49        let db_path = cache.path().join("meta.db");
50        let conn = Connection::open(&db_path)
51            .context("Failed to open database for context extraction")?;
52
53        // Get total file count
54        let total_files: usize = conn.query_row(
55            "SELECT COUNT(*) FROM files",
56            [],
57            |row| row.get(0),
58        ).unwrap_or(0);
59
60        // Extract language distribution
61        let languages = extract_language_distribution(&conn, total_files)?;
62
63        // Find dominant language (>60% of files)
64        let dominant_language = languages.iter()
65            .find(|lang| lang.percentage > 60.0)
66            .cloned();
67
68        // Extract file paths for directory analysis
69        let file_paths = extract_file_paths(&conn)?;
70
71        // Analyze directory structure
72        let top_level_dirs = extract_top_level_dirs(&file_paths);
73        let common_paths = extract_common_paths(&file_paths, 2, 10); // depth 2-3, top 10
74
75        // Detect monorepo
76        let (is_monorepo, project_count) = detect_monorepo(&file_paths);
77
78        Ok(Self {
79            total_files,
80            languages,
81            top_level_dirs,
82            common_paths,
83            is_monorepo,
84            project_count,
85            dominant_language,
86        })
87    }
88
89    /// Format context as a human-readable string for LLM prompt injection
90    pub fn to_prompt_string(&self) -> String {
91        let mut parts = Vec::new();
92
93        // Language distribution (Tier 1)
94        if !self.languages.is_empty() {
95            let lang_summary: Vec<String> = self.languages.iter()
96                .map(|lang| {
97                    format!("{} ({} files, {:.0}%)",
98                            lang.name, lang.file_count, lang.percentage)
99                })
100                .collect();
101            parts.push(format!("**Languages:** {}", lang_summary.join(", ")));
102        }
103
104        // File scale indicator (Tier 1)
105        let scale_hint = if self.total_files < 100 {
106            "small codebase - broad queries work well"
107        } else if self.total_files < 1000 {
108            "medium codebase - moderate specificity recommended"
109        } else {
110            "large codebase - use specific filters for best results"
111        };
112        parts.push(format!("**Total files:** {} ({})", self.total_files, scale_hint));
113
114        // Top-level directories (Tier 1)
115        if !self.top_level_dirs.is_empty() {
116            parts.push(format!("**Top-level directories:** {}",
117                             self.top_level_dirs.join(", ")));
118        }
119
120        // Dominant language (Tier 2)
121        if let Some(ref dominant) = self.dominant_language {
122            parts.push(format!("**Primary language:** {} ({:.0}% of codebase)",
123                             dominant.name, dominant.percentage));
124        }
125
126        // Common paths (Tier 2)
127        if !self.common_paths.is_empty() {
128            let paths_str = self.common_paths.iter()
129                .take(8) // Limit to 8 most common
130                .map(|p| p.as_str())
131                .collect::<Vec<_>>()
132                .join(", ");
133            parts.push(format!("**Common paths:** {}", paths_str));
134        }
135
136        // Monorepo info (Tier 2)
137        if self.is_monorepo {
138            if let Some(count) = self.project_count {
139                parts.push(format!("**Monorepo:** Yes ({} projects detected - use --file to target specific projects)", count));
140            } else {
141                parts.push("**Monorepo:** Yes (use --file to target specific projects)".to_string());
142            }
143        }
144
145        parts.join("\n")
146    }
147}
148
149/// Extract language distribution with counts and percentages
150fn extract_language_distribution(conn: &Connection, total_files: usize) -> Result<Vec<LanguageInfo>> {
151    let mut stmt = conn.prepare(
152        "SELECT language, COUNT(*) as count
153         FROM files
154         WHERE language IS NOT NULL
155         GROUP BY language
156         ORDER BY count DESC"
157    )?;
158
159    let languages = stmt.query_map([], |row| {
160        let name: String = row.get(0)?;
161        let file_count: usize = row.get(1)?;
162        let percentage = if total_files > 0 {
163            (file_count as f64 / total_files as f64) * 100.0
164        } else {
165            0.0
166        };
167
168        Ok(LanguageInfo {
169            name,
170            file_count,
171            percentage,
172        })
173    })?
174    .collect::<Result<Vec<_>, _>>()?;
175
176    Ok(languages)
177}
178
179/// Extract all file paths from database
180fn extract_file_paths(conn: &Connection) -> Result<Vec<String>> {
181    let mut stmt = conn.prepare("SELECT path FROM files")?;
182    let paths = stmt.query_map([], |row| row.get(0))?
183        .collect::<Result<Vec<_>, _>>()?;
184    Ok(paths)
185}
186
187/// Extract top-level directories (first path segment)
188fn extract_top_level_dirs(paths: &[String]) -> Vec<String> {
189    let mut dir_counts: HashMap<String, usize> = HashMap::new();
190
191    for path in paths {
192        if let Some(first_segment) = path.split('/').next() {
193            if !first_segment.is_empty() && !first_segment.starts_with('.') {
194                *dir_counts.entry(first_segment.to_string()).or_insert(0) += 1;
195            }
196        }
197    }
198
199    // Return top directories sorted by count (descending)
200    let mut dirs: Vec<(String, usize)> = dir_counts.into_iter().collect();
201    dirs.sort_by(|a, b| b.1.cmp(&a.1));
202
203    // Return top 10 directories with trailing slash
204    dirs.into_iter()
205        .take(10)
206        .map(|(dir, _)| format!("{}/", dir))
207        .collect()
208}
209
210/// Extract common path patterns at specified depth
211fn extract_common_paths(paths: &[String], min_depth: usize, max_results: usize) -> Vec<String> {
212    let mut path_counts: HashMap<String, usize> = HashMap::new();
213
214    for path in paths {
215        let segments: Vec<&str> = path.split('/').collect();
216
217        // Extract paths at depth 2 and 3
218        for depth in min_depth..=3 {
219            if segments.len() > depth {
220                let partial_path = segments[..=depth].join("/");
221                // Skip if it's just a filename (no directory structure)
222                if !partial_path.contains('/') {
223                    continue;
224                }
225                // Skip hidden directories and common noise
226                if partial_path.contains("/.") ||
227                   partial_path.contains("/node_modules") ||
228                   partial_path.contains("/vendor") ||
229                   partial_path.contains("/target") {
230                    continue;
231                }
232                *path_counts.entry(partial_path).or_insert(0) += 1;
233            }
234        }
235    }
236
237    // Filter to paths that appear at least 3 times (signal vs noise)
238    let min_count = 3;
239    let mut common_paths: Vec<(String, usize)> = path_counts
240        .into_iter()
241        .filter(|(_, count)| *count >= min_count)
242        .collect();
243
244    // Sort by count descending
245    common_paths.sort_by(|a, b| b.1.cmp(&a.1));
246
247    // Return top paths with trailing slash
248    common_paths.into_iter()
249        .take(max_results)
250        .map(|(path, _)| format!("{}/", path))
251        .collect()
252}
253
254/// Detect if this is a monorepo by counting package manager files
255fn detect_monorepo(paths: &[String]) -> (bool, Option<usize>) {
256    let package_files = [
257        "package.json",
258        "Cargo.toml",
259        "go.mod",
260        "composer.json",
261        "pom.xml",
262        "build.gradle",
263        "Gemfile",
264    ];
265
266    let mut project_count = 0;
267
268    for path in paths {
269        let path_lower = path.to_lowercase();
270        for pkg_file in &package_files {
271            if path_lower.ends_with(pkg_file) {
272                // Skip root-level package files (not indicative of monorepo)
273                // Only count if in subdirectory (e.g., packages/foo/package.json)
274                if Path::new(path).components().count() > 2 {
275                    project_count += 1;
276                    break; // Don't double-count same project
277                }
278            }
279        }
280    }
281
282    let is_monorepo = project_count >= 2;
283    let project_count_opt = if is_monorepo { Some(project_count) } else { None };
284
285    (is_monorepo, project_count_opt)
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_extract_top_level_dirs() {
294        let paths = vec![
295            "src/main.rs".to_string(),
296            "src/lib.rs".to_string(),
297            "app/models/user.rb".to_string(),
298            "app/controllers/home.rb".to_string(),
299            "tests/test.rs".to_string(),
300        ];
301
302        let dirs = extract_top_level_dirs(&paths);
303        assert_eq!(dirs.len(), 3);
304        assert!(dirs.contains(&"src/".to_string()));
305        assert!(dirs.contains(&"app/".to_string()));
306        assert!(dirs.contains(&"tests/".to_string()));
307    }
308
309    #[test]
310    fn test_extract_common_paths() {
311        let paths = vec![
312            "app/models/user.rb".to_string(),
313            "app/models/post.rb".to_string(),
314            "app/models/comment.rb".to_string(),
315            "app/models/article.rb".to_string(),
316            "app/controllers/home.rb".to_string(),
317            "app/controllers/posts.rb".to_string(),
318            "app/controllers/articles.rb".to_string(),
319            "app/controllers/users.rb".to_string(),
320            "src/main.rs".to_string(),
321        ];
322
323        let common = extract_common_paths(&paths, 1, 10);
324        assert!(common.contains(&"app/models/".to_string()));
325        assert!(common.contains(&"app/controllers/".to_string()));
326    }
327
328    #[test]
329    fn test_detect_monorepo() {
330        let monorepo_paths = vec![
331            "packages/web/package.json".to_string(),
332            "packages/api/package.json".to_string(),
333            "packages/shared/package.json".to_string(),
334        ];
335
336        let (is_monorepo, count) = detect_monorepo(&monorepo_paths);
337        assert!(is_monorepo);
338        assert_eq!(count, Some(3));
339
340        let single_project = vec![
341            "package.json".to_string(),
342            "src/main.ts".to_string(),
343        ];
344
345        let (is_mono, _) = detect_monorepo(&single_project);
346        assert!(!is_mono);
347    }
348
349    #[test]
350    fn test_language_percentage() {
351        let lang = LanguageInfo {
352            name: "Rust".to_string(),
353            file_count: 64,
354            percentage: 64.0,
355        };
356
357        assert_eq!(lang.percentage, 64.0);
358    }
359}