Skip to main content

git_lore/parser/
mod.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6use tree_sitter::{Node, Parser};
7
8#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
9#[serde(rename_all = "snake_case")]
10pub enum ScopeKind {
11    File,
12    Function,
13    Method,
14    Class,
15    Module,
16    Struct,
17    Enum,
18    Trait,
19    Impl,
20    Interface,
21    TypeAlias,
22    ArrowFunction,
23    Unknown,
24}
25
26#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
27pub struct ScopeContext {
28    pub path: PathBuf,
29    pub language: String,
30    pub kind: ScopeKind,
31    pub name: String,
32    pub start_line: usize,
33    pub end_line: usize,
34    pub cursor_line: Option<usize>,
35}
36
37impl ScopeContext {
38    pub fn key(&self) -> String {
39        format!("{}::{}", self.kind_label(), self.name)
40    }
41
42    pub fn kind_label(&self) -> &'static str {
43        match self.kind {
44            ScopeKind::File => "file",
45            ScopeKind::Function => "function",
46            ScopeKind::Method => "method",
47            ScopeKind::Class => "class",
48            ScopeKind::Module => "module",
49            ScopeKind::Struct => "struct",
50            ScopeKind::Enum => "enum",
51            ScopeKind::Trait => "trait",
52            ScopeKind::Impl => "impl",
53            ScopeKind::Interface => "interface",
54            ScopeKind::TypeAlias => "type_alias",
55            ScopeKind::ArrowFunction => "arrow_function",
56            ScopeKind::Unknown => "unknown",
57        }
58    }
59}
60
61pub fn detect_scope(path: impl AsRef<Path>, cursor_line: Option<usize>) -> Result<ScopeContext> {
62    let path = path.as_ref();
63    let source = fs::read_to_string(path)
64        .with_context(|| format!("failed to read source file {}", path.display()))?;
65
66    match path.extension().and_then(|value| value.to_str()).map(|value| value.to_ascii_lowercase()).as_deref() {
67        Some("rs") => detect_scoped_source(
68            path,
69            &source,
70            cursor_line,
71            "rust",
72            tree_sitter_rust::language(),
73            rust_is_scope_node,
74            rust_scope_kind,
75            rust_scope_name,
76        ),
77        Some("js") | Some("jsx") | Some("mjs") | Some("cjs") => detect_scoped_source(
78            path,
79            &source,
80            cursor_line,
81            "javascript",
82            tree_sitter_javascript::language(),
83            javascript_is_scope_node,
84            javascript_scope_kind,
85            javascript_scope_name,
86        ),
87        Some("ts") | Some("mts") | Some("cts") => detect_scoped_source(
88            path,
89            &source,
90            cursor_line,
91            "typescript",
92            tree_sitter_typescript::language_typescript(),
93            typescript_is_scope_node,
94            typescript_scope_kind,
95            typescript_scope_name,
96        ),
97        Some("tsx") => detect_scoped_source(
98            path,
99            &source,
100            cursor_line,
101            "typescript",
102            tree_sitter_typescript::language_tsx(),
103            typescript_is_scope_node,
104            typescript_scope_kind,
105            typescript_scope_name,
106        ),
107        _ => Ok(file_scope(path, cursor_line, inferred_language_label(path))),
108    }
109}
110
111fn detect_scoped_source(
112    path: &Path,
113    source: &str,
114    cursor_line: Option<usize>,
115    language_label: &str,
116    language: tree_sitter::Language,
117    scope_matcher: fn(&str) -> bool,
118    kind_mapper: fn(&str) -> ScopeKind,
119    name_resolver: fn(Node<'_>, &str) -> String,
120) -> Result<ScopeContext> {
121    let mut parser = Parser::new();
122    parser.set_language(language).context("failed to load tree-sitter grammar")?;
123
124    let tree = parser
125        .parse(source, None)
126        .ok_or_else(|| anyhow::anyhow!("failed to parse source with tree-sitter"))?;
127
128    let root = tree.root_node();
129    let line = cursor_line.unwrap_or(1).max(1);
130
131    if let Some(scope_node) = best_scope_node(root, line, scope_matcher) {
132        return Ok(scope_context_from_node(
133            path,
134            source,
135            scope_node,
136            cursor_line,
137            language_label,
138            kind_mapper,
139            name_resolver,
140        ));
141    }
142
143    Ok(file_scope(path, cursor_line, language_label))
144}
145
146fn best_scope_node<'tree>(node: Node<'tree>, line: usize, scope_matcher: fn(&str) -> bool) -> Option<Node<'tree>> {
147    let mut best: Option<Node<'tree>> = None;
148    best_scope_node_recursive(node, line, scope_matcher, &mut best);
149    best
150}
151
152fn best_scope_node_recursive<'tree>(
153    node: Node<'tree>,
154    line: usize,
155    scope_matcher: fn(&str) -> bool,
156    best: &mut Option<Node<'tree>>,
157) {
158    if scope_matcher(node.kind()) && node_contains_line(node, line) {
159        let replace = match *best {
160            Some(current) => node_span(node) < node_span(current),
161            None => true,
162        };
163
164        if replace {
165            *best = Some(node);
166        }
167    }
168
169    for index in 0..node.child_count() {
170        if let Some(child) = node.child(index) {
171            best_scope_node_recursive(child, line, scope_matcher, best);
172        }
173    }
174}
175
176fn node_contains_line(node: Node<'_>, line: usize) -> bool {
177    let start = node.start_position().row + 1;
178    let end = node.end_position().row + 1;
179    start <= line && line <= end
180}
181
182fn node_span(node: Node<'_>) -> usize {
183    let start = node.start_position().row;
184    let end = node.end_position().row;
185    end.saturating_sub(start)
186}
187
188fn rust_is_scope_node(kind: &str) -> bool {
189    matches!(
190        kind,
191        "function_item" | "mod_item" | "struct_item" | "enum_item" | "trait_item" | "impl_item"
192    )
193}
194
195fn javascript_is_scope_node(kind: &str) -> bool {
196    matches!(
197        kind,
198        "function_declaration"
199            | "generator_function_declaration"
200            | "arrow_function"
201            | "method_definition"
202            | "class_declaration"
203    )
204}
205
206fn typescript_is_scope_node(kind: &str) -> bool {
207    matches!(
208        kind,
209        "function_declaration"
210            | "arrow_function"
211            | "method_definition"
212            | "class_declaration"
213            | "interface_declaration"
214            | "type_alias_declaration"
215    )
216}
217
218fn scope_context_from_node(
219    path: &Path,
220    source: &str,
221    node: Node<'_>,
222    cursor_line: Option<usize>,
223    language: &str,
224    kind_mapper: fn(&str) -> ScopeKind,
225    name_resolver: fn(Node<'_>, &str) -> String,
226) -> ScopeContext {
227    ScopeContext {
228        path: path.to_path_buf(),
229        language: language.to_string(),
230        kind: kind_mapper(node.kind()),
231        name: name_resolver(node, source),
232        start_line: node.start_position().row + 1,
233        end_line: node.end_position().row + 1,
234        cursor_line,
235    }
236}
237
238fn rust_scope_kind(kind: &str) -> ScopeKind {
239    match kind {
240        "function_item" => ScopeKind::Function,
241        "mod_item" => ScopeKind::Module,
242        "struct_item" => ScopeKind::Struct,
243        "enum_item" => ScopeKind::Enum,
244        "trait_item" => ScopeKind::Trait,
245        "impl_item" => ScopeKind::Impl,
246        _ => ScopeKind::Unknown,
247    }
248}
249
250fn javascript_scope_kind(kind: &str) -> ScopeKind {
251    match kind {
252        "function_declaration" | "generator_function_declaration" => ScopeKind::Function,
253        "arrow_function" => ScopeKind::ArrowFunction,
254        "method_definition" => ScopeKind::Method,
255        "class_declaration" => ScopeKind::Class,
256        _ => ScopeKind::Unknown,
257    }
258}
259
260fn typescript_scope_kind(kind: &str) -> ScopeKind {
261    match kind {
262        "function_declaration" => ScopeKind::Function,
263        "arrow_function" => ScopeKind::ArrowFunction,
264        "method_definition" => ScopeKind::Method,
265        "class_declaration" => ScopeKind::Class,
266        "interface_declaration" => ScopeKind::Interface,
267        "type_alias_declaration" => ScopeKind::TypeAlias,
268        _ => ScopeKind::Unknown,
269    }
270}
271
272fn rust_scope_name(node: Node<'_>, source: &str) -> String {
273    match node.kind() {
274        "function_item" | "struct_item" | "enum_item" | "trait_item" | "mod_item" => {
275            node.child_by_field_name("name")
276                .and_then(|child| node_text(child, source))
277                .unwrap_or_else(|| node.kind().to_string())
278        }
279        "impl_item" => {
280            let target = node
281                .child_by_field_name("type")
282                .and_then(|child| node_text(child, source))
283                .unwrap_or_else(|| "unknown".to_string());
284            format!("impl {target}")
285        }
286        _ => node.kind().to_string(),
287    }
288}
289
290fn javascript_scope_name(node: Node<'_>, source: &str) -> String {
291    match node.kind() {
292        "function_declaration" | "generator_function_declaration" | "class_declaration" => {
293            node.child_by_field_name("name")
294                .and_then(|child| node_text(child, source))
295                .unwrap_or_else(|| node.kind().to_string())
296        }
297        "method_definition" => node
298            .child_by_field_name("name")
299            .and_then(|child| node_text(child, source))
300            .unwrap_or_else(|| node.kind().to_string()),
301        "arrow_function" => ancestor_name(node, source).unwrap_or_else(|| "arrow_function".to_string()),
302        _ => node.kind().to_string(),
303    }
304}
305
306fn typescript_scope_name(node: Node<'_>, source: &str) -> String {
307    match node.kind() {
308        "function_declaration" | "class_declaration" | "interface_declaration" | "type_alias_declaration" => {
309            node.child_by_field_name("name")
310                .and_then(|child| node_text(child, source))
311                .unwrap_or_else(|| node.kind().to_string())
312        }
313        "method_definition" => node
314            .child_by_field_name("name")
315            .and_then(|child| node_text(child, source))
316            .unwrap_or_else(|| node.kind().to_string()),
317        "arrow_function" => ancestor_name(node, source).unwrap_or_else(|| "arrow_function".to_string()),
318        _ => node.kind().to_string(),
319    }
320}
321
322fn ancestor_name(node: Node<'_>, source: &str) -> Option<String> {
323    let mut current = node.parent();
324
325    while let Some(parent) = current {
326        if parent.kind() == "variable_declarator" {
327            if let Some(name_node) = parent.child_by_field_name("name") {
328                if let Some(name) = node_text(name_node, source) {
329                    return Some(name);
330                }
331            }
332        }
333
334        current = parent.parent();
335    }
336
337    None
338}
339
340fn node_text(node: Node<'_>, source: &str) -> Option<String> {
341    source.get(node.byte_range()).map(|text| text.trim().to_string())
342}
343
344fn file_scope(path: &Path, cursor_line: Option<usize>, language: impl Into<String>) -> ScopeContext {
345    let name = path
346        .file_stem()
347        .and_then(|value| value.to_str())
348        .unwrap_or("file")
349        .to_string();
350
351    ScopeContext {
352        path: path.to_path_buf(),
353        language: language.into(),
354        kind: ScopeKind::File,
355        name,
356        start_line: 1,
357        end_line: cursor_line.unwrap_or(1),
358        cursor_line,
359    }
360}
361
362fn inferred_language_label(path: &Path) -> String {
363    match path
364        .extension()
365        .and_then(|value| value.to_str())
366        .map(|value| value.to_ascii_lowercase())
367        .as_deref()
368    {
369        Some("rs") => "rust".to_string(),
370        Some("js") | Some("jsx") | Some("mjs") | Some("cjs") => "javascript".to_string(),
371        Some("ts") | Some("tsx") | Some("mts") | Some("cts") => "typescript".to_string(),
372        Some(other) => other.to_string(),
373        None => "unknown".to_string(),
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use std::fs;
381    use uuid::Uuid;
382
383    #[test]
384    fn detects_rust_function_scope() {
385        let root = std::env::temp_dir().join(format!("git-lore-parser-test-{}", Uuid::new_v4()));
386        fs::create_dir_all(&root).unwrap();
387        let file = root.join("lib.rs");
388        fs::write(
389            &file,
390            r#"
391pub fn outer() {
392    fn inner() {
393        println!("hi");
394    }
395}
396"#,
397        )
398        .unwrap();
399
400        let scope = detect_scope(&file, Some(3)).unwrap();
401        assert_eq!(scope.kind, ScopeKind::Function);
402        assert_eq!(scope.name, "inner");
403        assert_eq!(scope.language, "rust");
404    }
405
406    #[test]
407    fn falls_back_to_file_scope_for_non_rust() {
408        let root = std::env::temp_dir().join(format!("git-lore-parser-test-{}", Uuid::new_v4()));
409        fs::create_dir_all(&root).unwrap();
410        let file = root.join("notes.txt");
411        fs::write(&file, "hello\nworld\n").unwrap();
412
413        let scope = detect_scope(&file, Some(2)).unwrap();
414        assert_eq!(scope.kind, ScopeKind::File);
415        assert_eq!(scope.name, "notes");
416    }
417
418        #[test]
419        fn detects_javascript_function_scope() {
420                let root = std::env::temp_dir().join(format!("git-lore-parser-js-test-{}", Uuid::new_v4()));
421                fs::create_dir_all(&root).unwrap();
422                let file = root.join("index.js");
423                fs::write(
424                        &file,
425                        r#"
426function outer() {
427    function inner() {
428        return 1;
429    }
430}
431"#,
432                )
433                .unwrap();
434
435                let scope = detect_scope(&file, Some(3)).unwrap();
436                assert_eq!(scope.language, "javascript");
437                assert_eq!(scope.kind, ScopeKind::Function);
438                assert_eq!(scope.name, "inner");
439        }
440
441        #[test]
442        fn detects_typescript_class_scope() {
443                let root = std::env::temp_dir().join(format!("git-lore-parser-ts-test-{}", Uuid::new_v4()));
444                fs::create_dir_all(&root).unwrap();
445                let file = root.join("service.ts");
446                fs::write(
447                        &file,
448                        r#"
449class Service {
450    run(): void {
451        return;
452    }
453}
454"#,
455                )
456                .unwrap();
457
458                let scope = detect_scope(&file, Some(3)).unwrap();
459                assert_eq!(scope.language, "typescript");
460                assert_eq!(scope.kind, ScopeKind::Method);
461                assert_eq!(scope.name, "run");
462        }
463}