Skip to main content

libgrite_core/context/extractor/
mod.rs

1mod regex_fallback;
2mod ts_engine;
3
4use std::path::Path;
5
6use crate::types::event::SymbolInfo;
7
8/// Detect programming language from file extension
9pub fn detect_language(path: &str) -> &'static str {
10    match Path::new(path).extension().and_then(|e| e.to_str()) {
11        Some("rs") => "rust",
12        Some("py") => "python",
13        Some("ts") => "typescript",
14        Some("tsx") => "typescriptreact",
15        Some("js") => "javascript",
16        Some("jsx") => "javascript",
17        Some("go") => "go",
18        Some("java") => "java",
19        Some("c") | Some("h") => "c",
20        Some("cpp") | Some("hpp") | Some("cc") | Some("cxx") => "cpp",
21        Some("rb") => "ruby",
22        Some("ex") | Some("exs") => "elixir",
23        _ => "unknown",
24    }
25}
26
27/// Extract symbols from source code using tree-sitter (with regex fallback)
28pub fn extract_symbols(content: &str, language: &str) -> Vec<SymbolInfo> {
29    match ts_engine::extract(content, language) {
30        Some(symbols) => symbols,
31        None => regex_fallback::extract(content, language),
32    }
33}
34
35/// Generate a short summary of a file based on its symbols
36pub fn generate_summary(path: &str, symbols: &[SymbolInfo], language: &str) -> String {
37    let display_language = match language {
38        "typescriptreact" => "typescript",
39        other => other,
40    };
41
42    if symbols.is_empty() {
43        return format!("{} file", display_language);
44    }
45
46    let structs: Vec<&str> = symbols
47        .iter()
48        .filter(|s| s.kind == "struct" || s.kind == "class" || s.kind == "interface")
49        .map(|s| s.name.as_str())
50        .collect();
51
52    let functions: Vec<&str> = symbols
53        .iter()
54        .filter(|s| s.kind == "function" || s.kind == "method")
55        .map(|s| s.name.as_str())
56        .collect();
57
58    let mut parts = Vec::new();
59    if !structs.is_empty() {
60        let names: String = structs
61            .iter()
62            .take(3)
63            .copied()
64            .collect::<Vec<_>>()
65            .join(", ");
66        if structs.len() > 3 {
67            parts.push(format!("defines {} (+{} more)", names, structs.len() - 3));
68        } else {
69            parts.push(format!("defines {}", names));
70        }
71    }
72    if !functions.is_empty() {
73        parts.push(format!("{} functions", functions.len()));
74    }
75
76    if parts.is_empty() {
77        format!(
78            "{} ({})",
79            Path::new(path)
80                .file_name()
81                .unwrap_or_default()
82                .to_string_lossy(),
83            display_language
84        )
85    } else {
86        parts.join("; ")
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_detect_language() {
96        assert_eq!(detect_language("src/main.rs"), "rust");
97        assert_eq!(detect_language("app.py"), "python");
98        assert_eq!(detect_language("index.ts"), "typescript");
99        assert_eq!(detect_language("component.tsx"), "typescriptreact");
100        assert_eq!(detect_language("main.go"), "go");
101        assert_eq!(detect_language("Main.java"), "java");
102        assert_eq!(detect_language("main.c"), "c");
103        assert_eq!(detect_language("main.cpp"), "cpp");
104        assert_eq!(detect_language("app.rb"), "ruby");
105        assert_eq!(detect_language("lib.ex"), "elixir");
106        assert_eq!(detect_language("README.md"), "unknown");
107    }
108
109    #[test]
110    fn test_extract_rust_symbols() {
111        let content = r#"
112pub struct Config {
113    pub name: String,
114}
115
116pub enum State {
117    Open,
118    Closed,
119}
120
121pub trait Handler {
122    fn handle(&self);
123}
124
125impl Config {
126    pub fn new(name: String) -> Self {
127        Self { name }
128    }
129
130    pub async fn load() -> Self {
131        todo!()
132    }
133}
134
135pub const MAX_SIZE: usize = 100;
136"#;
137
138        let symbols = extract_symbols(content, "rust");
139        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
140
141        assert!(names.contains(&"Config"));
142        assert!(names.contains(&"State"));
143        assert!(names.contains(&"Handler"));
144        assert!(names.contains(&"new"));
145        assert!(names.contains(&"load"));
146        assert!(names.contains(&"MAX_SIZE"));
147    }
148
149    #[test]
150    fn test_extract_python_symbols() {
151        let content = r#"
152class MyClass:
153    pass
154
155def my_function():
156    pass
157
158async def async_func():
159    pass
160"#;
161
162        let symbols = extract_symbols(content, "python");
163        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
164
165        assert!(names.contains(&"MyClass"));
166        assert!(names.contains(&"my_function"));
167        assert!(names.contains(&"async_func"));
168    }
169
170    #[test]
171    fn test_extract_go_symbols() {
172        let content = r#"
173func main() {
174}
175
176func (s *Server) Start() error {
177    return nil
178}
179
180type Config struct {
181    Name string
182}
183
184type Handler interface {
185    Handle()
186}
187"#;
188
189        let symbols = extract_symbols(content, "go");
190        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
191
192        assert!(names.contains(&"main"));
193        assert!(names.contains(&"Start"));
194        assert!(names.contains(&"Config"));
195        assert!(names.contains(&"Handler"));
196    }
197
198    #[test]
199    fn test_extract_typescript_symbols() {
200        let content = r#"
201export function greet(name: string): string {
202    return `Hello, ${name}!`;
203}
204
205export class UserService {
206    constructor() {}
207}
208
209export interface Config {
210    name: string;
211}
212
213type UserId = string;
214
215const fetchData = async (url: string) => {
216    return fetch(url);
217};
218"#;
219
220        let symbols = extract_symbols(content, "typescript");
221        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
222
223        assert!(names.contains(&"greet"));
224        assert!(names.contains(&"UserService"));
225        assert!(names.contains(&"Config"));
226        assert!(names.contains(&"UserId"));
227        assert!(names.contains(&"fetchData"));
228    }
229
230    #[test]
231    fn test_extract_java_symbols() {
232        let content = r#"
233public class UserService {
234    private String name;
235
236    public UserService(String name) {
237        this.name = name;
238    }
239
240    public String getName() {
241        return name;
242    }
243}
244
245public interface Repository {
246    void save(Object entity);
247}
248
249public enum Status {
250    ACTIVE,
251    INACTIVE
252}
253"#;
254
255        let symbols = extract_symbols(content, "java");
256        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
257
258        assert!(names.contains(&"UserService"));
259        assert!(names.contains(&"getName"));
260        assert!(names.contains(&"Repository"));
261        assert!(names.contains(&"Status"));
262    }
263
264    #[test]
265    fn test_extract_c_symbols() {
266        let content = r#"
267struct Point {
268    int x;
269    int y;
270};
271
272enum Color {
273    RED,
274    GREEN,
275    BLUE
276};
277
278typedef unsigned long ulong;
279
280int main(int argc, char** argv) {
281    return 0;
282}
283"#;
284
285        let symbols = extract_symbols(content, "c");
286        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
287
288        assert!(names.contains(&"Point"));
289        assert!(names.contains(&"Color"));
290        assert!(names.contains(&"main"));
291    }
292
293    #[test]
294    fn test_extract_ruby_symbols() {
295        let content = r#"
296module Authentication
297  class User
298    def initialize(name)
299      @name = name
300    end
301
302    def self.find(id)
303      new("user_#{id}")
304    end
305
306    def greet
307      "Hello, #{@name}"
308    end
309  end
310end
311"#;
312
313        let symbols = extract_symbols(content, "ruby");
314        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
315
316        assert!(names.contains(&"Authentication"));
317        assert!(names.contains(&"User"));
318        assert!(names.contains(&"initialize"));
319        assert!(names.contains(&"greet"));
320    }
321
322    #[test]
323    fn test_generate_summary() {
324        let symbols = vec![
325            SymbolInfo {
326                name: "Config".to_string(),
327                kind: "struct".to_string(),
328                line_start: 1,
329                line_end: 10,
330            },
331            SymbolInfo {
332                name: "new".to_string(),
333                kind: "function".to_string(),
334                line_start: 12,
335                line_end: 20,
336            },
337            SymbolInfo {
338                name: "load".to_string(),
339                kind: "function".to_string(),
340                line_start: 22,
341                line_end: 30,
342            },
343        ];
344
345        let summary = generate_summary("src/config.rs", &symbols, "rust");
346        assert!(summary.contains("Config"));
347        assert!(summary.contains("2 functions"));
348    }
349
350    #[test]
351    fn test_generate_summary_tsx() {
352        let symbols = vec![SymbolInfo {
353            name: "App".to_string(),
354            kind: "function".to_string(),
355            line_start: 1,
356            line_end: 10,
357        }];
358
359        let summary = generate_summary("src/App.tsx", &symbols, "typescriptreact");
360        assert!(summary.contains("1 functions"));
361    }
362
363    #[test]
364    fn test_fallback_for_unknown_language() {
365        let symbols = extract_symbols("fn main() {}", "brainfuck");
366        assert!(symbols.is_empty());
367    }
368
369    #[test]
370    fn test_rust_accurate_line_ranges() {
371        let content = r#"pub struct Config {
372    pub name: String,
373    pub value: u32,
374}
375
376pub fn process(config: &Config) -> String {
377    format!("{}: {}", config.name, config.value)
378}
379"#;
380
381        let symbols = extract_symbols(content, "rust");
382
383        let config = symbols
384            .iter()
385            .find(|s| s.name == "Config" && s.kind == "struct")
386            .unwrap();
387        assert_eq!(config.line_start, 1);
388        assert_eq!(config.line_end, 4);
389
390        let process = symbols.iter().find(|s| s.name == "process").unwrap();
391        assert_eq!(process.line_start, 6);
392        assert_eq!(process.line_end, 8);
393    }
394
395    #[test]
396    fn test_python_accurate_line_ranges() {
397        let content = r#"class MyClass:
398    def __init__(self):
399        self.x = 0
400
401    def method(self):
402        return self.x
403
404def standalone():
405    pass
406"#;
407
408        let symbols = extract_symbols(content, "python");
409
410        let class = symbols.iter().find(|s| s.name == "MyClass").unwrap();
411        assert_eq!(class.line_start, 1);
412        assert_eq!(class.line_end, 6);
413
414        let standalone = symbols.iter().find(|s| s.name == "standalone").unwrap();
415        assert_eq!(standalone.line_start, 8);
416        assert_eq!(standalone.line_end, 9);
417    }
418}