Skip to main content

argyph_parse/languages/
rust.rs

1use crate::chunker::ast_chunks;
2use crate::error::{ParseError, Result};
3use crate::types::{ByteRange, ChunkKind, Import, ParsedFile, Symbol, SymbolId, SymbolKind};
4use argyph_fs::{FileEntry, Language};
5use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator};
6
7static QUERY_SRC: &str = include_str!("../../queries/rust.scm");
8
9pub fn parse_rust(file: &FileEntry, source: &str, max_chunk_size: usize) -> Result<ParsedFile> {
10    let lang: tree_sitter::Language = tree_sitter_rust::LANGUAGE.into();
11
12    let mut parser = Parser::new();
13    parser.set_language(&lang)?;
14
15    let tree = parser
16        .parse(source, None)
17        .ok_or_else(|| ParseError::Parse("tree-sitter returned None".into()))?;
18
19    let root = tree.root_node();
20    let source_bytes = source.as_bytes();
21
22    let symbols = extract_symbols(file, &lang, &root, source_bytes)?;
23    let imports = extract_imports(&root, source_bytes);
24    let chunks = ast_chunks(
25        &file.path,
26        &root,
27        source,
28        Language::Rust,
29        max_chunk_size,
30        chunk_kind_for_node,
31        is_chunk_boundary_rust,
32    )?;
33
34    Ok(ParsedFile {
35        symbols,
36        chunks,
37        imports,
38    })
39}
40
41fn extract_symbols(
42    file: &FileEntry,
43    lang: &tree_sitter::Language,
44    root: &tree_sitter::Node,
45    source: &[u8],
46) -> Result<Vec<Symbol>> {
47    let query = Query::new(lang, QUERY_SRC)?;
48    let mut cursor = QueryCursor::new();
49    let mut matches_iter = cursor.matches(&query, *root, source);
50    let mut symbols = Vec::new();
51
52    loop {
53        matches_iter.advance();
54        let Some(m) = matches_iter.get() else { break };
55
56        let mut def_node: Option<tree_sitter::Node> = None;
57        let mut name_node: Option<tree_sitter::Node> = None;
58
59        for cap in m.captures {
60            let cap_name = query.capture_names()[cap.index as usize];
61            match cap_name {
62                "def" => def_node = Some(cap.node),
63                "name" => name_node = Some(cap.node),
64                _ => {}
65            }
66        }
67
68        let Some(def) = def_node else { continue };
69        let name = name_node
70            .and_then(|n| n.utf8_text(source).ok())
71            .unwrap_or("");
72        if name.is_empty() {
73            continue;
74        }
75
76        let kind = match def.kind() {
77            "function_item" => SymbolKind::Function,
78            "struct_item" => SymbolKind::Struct,
79            "enum_item" => SymbolKind::Enum,
80            "trait_item" => SymbolKind::Trait,
81            "impl_item" => SymbolKind::Impl,
82            "mod_item" => SymbolKind::Module,
83            "macro_definition" => SymbolKind::Macro,
84            "const_item" => SymbolKind::Constant,
85            "static_item" => SymbolKind::Static,
86            "type_item" => SymbolKind::TypeAlias,
87            _ => continue,
88        };
89
90        let sig = signature_node(&def, source);
91        let id = SymbolId::new(&file.path, name, def.start_byte());
92
93        symbols.push(Symbol {
94            id,
95            name: name.to_string(),
96            kind,
97            file: file.path.clone(),
98            range: ByteRange::new(def.start_byte(), def.end_byte()),
99            signature: sig,
100            parent: None,
101        });
102    }
103
104    Ok(symbols)
105}
106
107fn extract_imports(root: &tree_sitter::Node, source: &[u8]) -> Vec<Import> {
108    let mut imports = Vec::new();
109    collect_imports(*root, source, &mut imports);
110    imports
111}
112
113fn collect_imports(node: tree_sitter::Node, source: &[u8], out: &mut Vec<Import>) {
114    match node.kind() {
115        "use_declaration" => {
116            if let Ok(raw) = node.utf8_text(source) {
117                let (mod_path, items) = parse_rust_use(raw);
118                out.push(Import {
119                    raw: raw.to_string(),
120                    module_path: mod_path,
121                    items,
122                    range: ByteRange::new(node.start_byte(), node.end_byte()),
123                });
124            }
125            return;
126        }
127        "extern_crate_declaration" => {
128            if let Ok(raw) = node.utf8_text(source) {
129                let mod_path = raw
130                    .strip_prefix("extern crate ")
131                    .unwrap_or("")
132                    .trim_end_matches(';')
133                    .trim()
134                    .to_string();
135                out.push(Import {
136                    raw: raw.to_string(),
137                    module_path: if mod_path.is_empty() {
138                        vec![]
139                    } else {
140                        vec![mod_path]
141                    },
142                    items: vec![],
143                    range: ByteRange::new(node.start_byte(), node.end_byte()),
144                });
145            }
146            return;
147        }
148        _ => {}
149    }
150    for i in 0..node.child_count() {
151        if let Some(child) = node.child(i as u32) {
152            collect_imports(child, source, out);
153        }
154    }
155}
156
157fn parse_rust_use(raw: &str) -> (Vec<String>, Vec<String>) {
158    let trimmed = raw.trim_start_matches("use ").trim_end_matches(';').trim();
159    let trimmed = trimmed.strip_prefix("pub ").unwrap_or(trimmed);
160    let trimmed = trimmed.strip_prefix("crate::").unwrap_or(trimmed);
161    let trimmed = trimmed.strip_prefix("self::").unwrap_or(trimmed);
162
163    let mut mod_parts: Vec<String> = Vec::new();
164    let mut items: Vec<String> = Vec::new();
165
166    if let Some(brace_pos) = trimmed.find('{') {
167        let path_part = trimmed[..brace_pos].trim();
168        let items_part = &trimmed[brace_pos..];
169
170        for segment in path_part.split("::") {
171            let seg = segment.trim();
172            if !seg.is_empty() {
173                mod_parts.push(seg.to_string());
174            }
175        }
176
177        let inner = items_part
178            .trim_start_matches('{')
179            .trim_end_matches('}')
180            .trim();
181        for item in inner.split(',') {
182            let item = item.trim();
183            if !item.is_empty() {
184                if let Some((alias, _)) = item.split_once(" as ") {
185                    items.push(alias.trim().to_string());
186                } else if let Some((first, _rest)) = item.split_once("::") {
187                    items.push(first.trim().to_string());
188                } else {
189                    items.push(item.to_string());
190                }
191            }
192        }
193    } else {
194        for segment in trimmed.split("::") {
195            let seg = segment.trim();
196            if !seg.is_empty() {
197                mod_parts.push(seg.to_string());
198            }
199        }
200    }
201
202    (mod_parts, items)
203}
204
205fn signature_node(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
206    let sig_end = node
207        .child_by_field_name("body")
208        .map(|b| b.start_byte())
209        .unwrap_or(node.end_byte());
210
211    let sig_bytes = &source[node.start_byte()..sig_end];
212    let sig = std::str::from_utf8(sig_bytes).unwrap_or("").to_string();
213    let sig = sig.trim().to_string();
214    if sig.is_empty() {
215        None
216    } else {
217        Some(sig)
218    }
219}
220
221fn chunk_kind_for_node(kind: &str) -> ChunkKind {
222    match kind {
223        "function_item" | "impl_item" => ChunkKind::FunctionBody,
224        "struct_item" | "enum_item" | "trait_item" | "mod_item" | "macro_definition"
225        | "const_item" | "static_item" | "type_item" => ChunkKind::TypeDef,
226        _ => ChunkKind::TopLevel,
227    }
228}
229
230fn is_chunk_boundary_rust(kind: &str) -> bool {
231    matches!(
232        kind,
233        "function_item"
234            | "struct_item"
235            | "enum_item"
236            | "trait_item"
237            | "impl_item"
238            | "mod_item"
239            | "macro_definition"
240            | "const_item"
241            | "static_item"
242            | "type_item"
243    )
244}
245
246#[cfg(test)]
247#[allow(clippy::unwrap_used, clippy::expect_used)]
248mod tests {
249    use super::*;
250    use camino::Utf8PathBuf;
251    use std::time::UNIX_EPOCH;
252
253    fn make_file(path: &str, lang: Language) -> FileEntry {
254        FileEntry {
255            path: Utf8PathBuf::from(path),
256            hash: argyph_fs::Blake3Hash::from([0u8; 32]),
257            language: Some(lang),
258            size: 0,
259            modified: UNIX_EPOCH,
260        }
261    }
262
263    fn count_expected(symbols: &[Symbol], expected: &[&str]) -> bool {
264        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
265        expected.iter().all(|e| names.contains(e))
266    }
267
268    #[test]
269    fn parse_rust_main_fn() {
270        let source = "fn main() {\n    println!(\"hello\");\n}\n";
271        let file = make_file("src/main.rs", Language::Rust);
272        let result = parse_rust(&file, source, 4096).unwrap();
273        assert_eq!(result.symbols.len(), 1);
274        assert_eq!(result.symbols[0].name, "main");
275        assert_eq!(result.symbols[0].kind, SymbolKind::Function);
276    }
277
278    #[test]
279    fn parse_rust_struct_and_fn() {
280        let source = r#"pub struct Foo {
281    x: i32,
282}
283
284impl Foo {
285    pub fn new(x: i32) -> Self {
286        Self { x }
287    }
288}
289
290pub fn add(a: i32, b: i32) -> i32 {
291    a + b
292}
293"#;
294        let file = make_file("src/lib.rs", Language::Rust);
295        let result = parse_rust(&file, source, 4096).unwrap();
296        assert!(
297            count_expected(&result.symbols, &["Foo", "new", "add"]),
298            "expected Foo, new, add; got: {:?}",
299            result.symbols.iter().map(|s| &s.name).collect::<Vec<_>>()
300        );
301    }
302
303    #[test]
304    fn parse_rust_use_import() {
305        let source = "use std::collections::HashMap;\n\nfn f() {}\n";
306        let file = make_file("src/lib.rs", Language::Rust);
307        let result = parse_rust(&file, source, 4096).unwrap();
308        assert_eq!(result.imports.len(), 1);
309    }
310
311    #[test]
312    fn parse_rust_extern_crate() {
313        let source = "extern crate serde;\n\nfn f() {}\n";
314        let file = make_file("src/lib.rs", Language::Rust);
315        let result = parse_rust(&file, source, 4096).unwrap();
316        assert_eq!(result.imports.len(), 1);
317    }
318
319    #[test]
320    fn parse_rust_trait_and_enum() {
321        let source = r#"pub trait Summary {
322    fn summarize(&self) -> String;
323}
324
325pub enum Color {
326    Red,
327    Green,
328    Blue,
329}
330"#;
331        let file = make_file("src/lib.rs", Language::Rust);
332        let result = parse_rust(&file, source, 4096).unwrap();
333        assert!(count_expected(&result.symbols, &["Summary", "Color"]));
334    }
335
336    #[test]
337    fn parse_rust_chunks_produced() {
338        let source = "fn one() {}\nfn two() {}\nfn three() {}\n";
339        let file = make_file("src/lib.rs", Language::Rust);
340        let result = parse_rust(&file, source, 4096).unwrap();
341        assert!(!result.chunks.is_empty(), "should produce chunks");
342    }
343
344    #[test]
345    fn all_symbols_have_valid_ranges_rust() {
346        let source = r#"pub fn add(a: i32, b: i32) -> i32 {
347    a + b
348}
349
350pub struct Point {
351    x: f64,
352    y: f64,
353}
354"#;
355        let file = make_file("src/lib.rs", Language::Rust);
356        let result = parse_rust(&file, source, 4096).unwrap();
357        assert!(result.symbols.len() >= 2);
358        for s in &result.symbols {
359            assert!(
360                s.range.end <= source.len(),
361                "range {:?} exceeds source length {}",
362                s.range,
363                source.len()
364            );
365        }
366    }
367}