Skip to main content

libgrite_core/context/extractor/
mod.rs

1mod ts_engine;
2mod regex_fallback;
3
4use std::path::Path;
5
6use crate::types::event::SymbolInfo;
7
8/// Detect programming language from file extension
9pub fn detect_language(path: &str) -> &'static str {
10    match Path::new(path).extension().and_then(|e| e.to_str()) {
11        Some("rs") => "rust",
12        Some("py") => "python",
13        Some("ts") => "typescript",
14        Some("tsx") => "typescriptreact",
15        Some("js") => "javascript",
16        Some("jsx") => "javascript",
17        Some("go") => "go",
18        Some("java") => "java",
19        Some("c") | Some("h") => "c",
20        Some("cpp") | Some("hpp") | Some("cc") | Some("cxx") => "cpp",
21        Some("rb") => "ruby",
22        Some("ex") | Some("exs") => "elixir",
23        _ => "unknown",
24    }
25}
26
27/// Extract symbols from source code using tree-sitter (with regex fallback)
28pub fn extract_symbols(content: &str, language: &str) -> Vec<SymbolInfo> {
29    match ts_engine::extract(content, language) {
30        Some(symbols) => symbols,
31        None => regex_fallback::extract(content, language),
32    }
33}
34
35/// Generate a short summary of a file based on its symbols
36pub fn generate_summary(path: &str, symbols: &[SymbolInfo], language: &str) -> String {
37    let display_language = match language {
38        "typescriptreact" => "typescript",
39        other => other,
40    };
41
42    if symbols.is_empty() {
43        return format!("{} file", display_language);
44    }
45
46    let structs: Vec<&str> = symbols.iter()
47        .filter(|s| s.kind == "struct" || s.kind == "class" || s.kind == "interface")
48        .map(|s| s.name.as_str())
49        .collect();
50
51    let functions: Vec<&str> = symbols.iter()
52        .filter(|s| s.kind == "function" || s.kind == "method")
53        .map(|s| s.name.as_str())
54        .collect();
55
56    let mut parts = Vec::new();
57    if !structs.is_empty() {
58        let names: String = structs.iter().take(3).copied().collect::<Vec<_>>().join(", ");
59        if structs.len() > 3 {
60            parts.push(format!("defines {} (+{} more)", names, structs.len() - 3));
61        } else {
62            parts.push(format!("defines {}", names));
63        }
64    }
65    if !functions.is_empty() {
66        parts.push(format!("{} functions", functions.len()));
67    }
68
69    if parts.is_empty() {
70        format!("{} ({})", Path::new(path).file_name().unwrap_or_default().to_string_lossy(), display_language)
71    } else {
72        parts.join("; ")
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_detect_language() {
82        assert_eq!(detect_language("src/main.rs"), "rust");
83        assert_eq!(detect_language("app.py"), "python");
84        assert_eq!(detect_language("index.ts"), "typescript");
85        assert_eq!(detect_language("component.tsx"), "typescriptreact");
86        assert_eq!(detect_language("main.go"), "go");
87        assert_eq!(detect_language("Main.java"), "java");
88        assert_eq!(detect_language("main.c"), "c");
89        assert_eq!(detect_language("main.cpp"), "cpp");
90        assert_eq!(detect_language("app.rb"), "ruby");
91        assert_eq!(detect_language("lib.ex"), "elixir");
92        assert_eq!(detect_language("README.md"), "unknown");
93    }
94
95    #[test]
96    fn test_extract_rust_symbols() {
97        let content = r#"
98pub struct Config {
99    pub name: String,
100}
101
102pub enum State {
103    Open,
104    Closed,
105}
106
107pub trait Handler {
108    fn handle(&self);
109}
110
111impl Config {
112    pub fn new(name: String) -> Self {
113        Self { name }
114    }
115
116    pub async fn load() -> Self {
117        todo!()
118    }
119}
120
121pub const MAX_SIZE: usize = 100;
122"#;
123
124        let symbols = extract_symbols(content, "rust");
125        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
126
127        assert!(names.contains(&"Config"));
128        assert!(names.contains(&"State"));
129        assert!(names.contains(&"Handler"));
130        assert!(names.contains(&"new"));
131        assert!(names.contains(&"load"));
132        assert!(names.contains(&"MAX_SIZE"));
133    }
134
135    #[test]
136    fn test_extract_python_symbols() {
137        let content = r#"
138class MyClass:
139    pass
140
141def my_function():
142    pass
143
144async def async_func():
145    pass
146"#;
147
148        let symbols = extract_symbols(content, "python");
149        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
150
151        assert!(names.contains(&"MyClass"));
152        assert!(names.contains(&"my_function"));
153        assert!(names.contains(&"async_func"));
154    }
155
156    #[test]
157    fn test_extract_go_symbols() {
158        let content = r#"
159func main() {
160}
161
162func (s *Server) Start() error {
163    return nil
164}
165
166type Config struct {
167    Name string
168}
169
170type Handler interface {
171    Handle()
172}
173"#;
174
175        let symbols = extract_symbols(content, "go");
176        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
177
178        assert!(names.contains(&"main"));
179        assert!(names.contains(&"Start"));
180        assert!(names.contains(&"Config"));
181        assert!(names.contains(&"Handler"));
182    }
183
184    #[test]
185    fn test_extract_typescript_symbols() {
186        let content = r#"
187export function greet(name: string): string {
188    return `Hello, ${name}!`;
189}
190
191export class UserService {
192    constructor() {}
193}
194
195export interface Config {
196    name: string;
197}
198
199type UserId = string;
200
201const fetchData = async (url: string) => {
202    return fetch(url);
203};
204"#;
205
206        let symbols = extract_symbols(content, "typescript");
207        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
208
209        assert!(names.contains(&"greet"));
210        assert!(names.contains(&"UserService"));
211        assert!(names.contains(&"Config"));
212        assert!(names.contains(&"UserId"));
213        assert!(names.contains(&"fetchData"));
214    }
215
216    #[test]
217    fn test_extract_java_symbols() {
218        let content = r#"
219public class UserService {
220    private String name;
221
222    public UserService(String name) {
223        this.name = name;
224    }
225
226    public String getName() {
227        return name;
228    }
229}
230
231public interface Repository {
232    void save(Object entity);
233}
234
235public enum Status {
236    ACTIVE,
237    INACTIVE
238}
239"#;
240
241        let symbols = extract_symbols(content, "java");
242        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
243
244        assert!(names.contains(&"UserService"));
245        assert!(names.contains(&"getName"));
246        assert!(names.contains(&"Repository"));
247        assert!(names.contains(&"Status"));
248    }
249
250    #[test]
251    fn test_extract_c_symbols() {
252        let content = r#"
253struct Point {
254    int x;
255    int y;
256};
257
258enum Color {
259    RED,
260    GREEN,
261    BLUE
262};
263
264typedef unsigned long ulong;
265
266int main(int argc, char** argv) {
267    return 0;
268}
269"#;
270
271        let symbols = extract_symbols(content, "c");
272        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
273
274        assert!(names.contains(&"Point"));
275        assert!(names.contains(&"Color"));
276        assert!(names.contains(&"main"));
277    }
278
279    #[test]
280    fn test_extract_ruby_symbols() {
281        let content = r#"
282module Authentication
283  class User
284    def initialize(name)
285      @name = name
286    end
287
288    def self.find(id)
289      new("user_#{id}")
290    end
291
292    def greet
293      "Hello, #{@name}"
294    end
295  end
296end
297"#;
298
299        let symbols = extract_symbols(content, "ruby");
300        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
301
302        assert!(names.contains(&"Authentication"));
303        assert!(names.contains(&"User"));
304        assert!(names.contains(&"initialize"));
305        assert!(names.contains(&"greet"));
306    }
307
308    #[test]
309    fn test_generate_summary() {
310        let symbols = vec![
311            SymbolInfo { name: "Config".to_string(), kind: "struct".to_string(), line_start: 1, line_end: 10 },
312            SymbolInfo { name: "new".to_string(), kind: "function".to_string(), line_start: 12, line_end: 20 },
313            SymbolInfo { name: "load".to_string(), kind: "function".to_string(), line_start: 22, line_end: 30 },
314        ];
315
316        let summary = generate_summary("src/config.rs", &symbols, "rust");
317        assert!(summary.contains("Config"));
318        assert!(summary.contains("2 functions"));
319    }
320
321    #[test]
322    fn test_generate_summary_tsx() {
323        let symbols = vec![
324            SymbolInfo { name: "App".to_string(), kind: "function".to_string(), line_start: 1, line_end: 10 },
325        ];
326
327        let summary = generate_summary("src/App.tsx", &symbols, "typescriptreact");
328        assert!(summary.contains("1 functions"));
329    }
330
331    #[test]
332    fn test_fallback_for_unknown_language() {
333        let symbols = extract_symbols("fn main() {}", "brainfuck");
334        assert!(symbols.is_empty());
335    }
336
337    #[test]
338    fn test_rust_accurate_line_ranges() {
339        let content = r#"pub struct Config {
340    pub name: String,
341    pub value: u32,
342}
343
344pub fn process(config: &Config) -> String {
345    format!("{}: {}", config.name, config.value)
346}
347"#;
348
349        let symbols = extract_symbols(content, "rust");
350
351        let config = symbols.iter().find(|s| s.name == "Config" && s.kind == "struct").unwrap();
352        assert_eq!(config.line_start, 1);
353        assert_eq!(config.line_end, 4);
354
355        let process = symbols.iter().find(|s| s.name == "process").unwrap();
356        assert_eq!(process.line_start, 6);
357        assert_eq!(process.line_end, 8);
358    }
359
360    #[test]
361    fn test_python_accurate_line_ranges() {
362        let content = r#"class MyClass:
363    def __init__(self):
364        self.x = 0
365
366    def method(self):
367        return self.x
368
369def standalone():
370    pass
371"#;
372
373        let symbols = extract_symbols(content, "python");
374
375        let class = symbols.iter().find(|s| s.name == "MyClass").unwrap();
376        assert_eq!(class.line_start, 1);
377        assert_eq!(class.line_end, 6);
378
379        let standalone = symbols.iter().find(|s| s.name == "standalone").unwrap();
380        assert_eq!(standalone.line_start, 8);
381        assert_eq!(standalone.line_end, 9);
382    }
383}