Skip to main content

argyph_parse/languages/
python.rs

1use crate::chunker::ast_chunks;
2use crate::error::{ParseError, Result};
3use crate::types::{ByteRange, ChunkKind, Import, ParsedFile, Symbol, SymbolId, SymbolKind};
4use argyph_fs::{FileEntry, Language};
5use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator};
6
7static QUERY_SRC: &str = include_str!("../../queries/python.scm");
8
9pub fn parse_python(file: &FileEntry, source: &str, max_chunk_size: usize) -> Result<ParsedFile> {
10    let lang: tree_sitter::Language = tree_sitter_python::LANGUAGE.into();
11
12    let mut parser = Parser::new();
13    parser.set_language(&lang)?;
14
15    let tree = parser
16        .parse(source, None)
17        .ok_or_else(|| ParseError::Parse("tree-sitter returned None".into()))?;
18
19    let root = tree.root_node();
20    let source_bytes = source.as_bytes();
21
22    let symbols = extract_symbols(file, &lang, &root, source_bytes)?;
23    let imports = extract_imports(&root, source_bytes);
24    let chunks = ast_chunks(
25        &file.path,
26        &root,
27        source,
28        Language::Python,
29        max_chunk_size,
30        chunk_kind_for_node,
31        is_chunk_boundary_py,
32    )?;
33
34    Ok(ParsedFile {
35        symbols,
36        chunks,
37        imports,
38    })
39}
40
41fn extract_symbols(
42    file: &FileEntry,
43    lang: &tree_sitter::Language,
44    root: &tree_sitter::Node,
45    source: &[u8],
46) -> Result<Vec<Symbol>> {
47    let query = Query::new(lang, QUERY_SRC)?;
48    let mut cursor = QueryCursor::new();
49    let mut matches_iter = cursor.matches(&query, *root, source);
50    let mut symbols = Vec::new();
51
52    loop {
53        matches_iter.advance();
54        let Some(m) = matches_iter.get() else { break };
55
56        let mut def_node: Option<tree_sitter::Node> = None;
57        let mut name_node: Option<tree_sitter::Node> = None;
58
59        for cap in m.captures {
60            let cap_name = query.capture_names()[cap.index as usize];
61            match cap_name {
62                "def" => def_node = Some(cap.node),
63                "name" => name_node = Some(cap.node),
64                _ => {}
65            }
66        }
67
68        let Some(def) = def_node else { continue };
69        let name = name_node
70            .and_then(|n| n.utf8_text(source).ok())
71            .unwrap_or("");
72        if name.is_empty() {
73            continue;
74        }
75
76        let kind = match def.kind() {
77            "function_definition" => {
78                if is_method_py(&def) {
79                    SymbolKind::Method
80                } else {
81                    SymbolKind::Function
82                }
83            }
84            "class_definition" => SymbolKind::Class,
85            "decorated_definition" => {
86                let inner = find_inner_def(&def);
87                match inner.map(|n| n.kind().to_string()).as_deref() {
88                    Some("class_definition") => SymbolKind::Class,
89                    Some("function_definition") => {
90                        if inner.is_some_and(|n| is_method_py(&n)) {
91                            SymbolKind::Method
92                        } else {
93                            SymbolKind::Function
94                        }
95                    }
96                    _ => SymbolKind::Function,
97                }
98            }
99            _ => continue,
100        };
101
102        let sig = signature_node(&def, source);
103        let id = SymbolId::new(&file.path, name, def.start_byte());
104
105        symbols.push(Symbol {
106            id,
107            name: name.to_string(),
108            kind,
109            file: file.path.clone(),
110            range: ByteRange::new(def.start_byte(), def.end_byte()),
111            signature: sig,
112            parent: None,
113        });
114    }
115
116    Ok(symbols)
117}
118
119fn is_method_py(node: &tree_sitter::Node) -> bool {
120    node.parent()
121        .is_some_and(|p| p.kind() == "block" || p.kind() == "class_body")
122}
123
124fn find_inner_def<'a>(node: &tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
125    for i in 0..node.child_count() {
126        if let Some(child) = node.child(i as u32) {
127            match child.kind() {
128                "function_definition" | "class_definition" => return Some(child),
129                _ => continue,
130            }
131        }
132    }
133    None
134}
135
136fn extract_imports(root: &tree_sitter::Node, source: &[u8]) -> Vec<Import> {
137    let mut imports = Vec::new();
138    collect_imports(*root, source, &mut imports);
139    imports
140}
141
142fn collect_imports(node: tree_sitter::Node, source: &[u8], out: &mut Vec<Import>) {
143    match node.kind() {
144        "import_statement" => {
145            if let Ok(raw) = node.utf8_text(source) {
146                let (mod_path, items) = parse_py_import(raw);
147                out.push(Import {
148                    raw: raw.to_string(),
149                    module_path: mod_path,
150                    items,
151                    range: ByteRange::new(node.start_byte(), node.end_byte()),
152                });
153            }
154            return;
155        }
156        "import_from_statement" => {
157            if let Ok(raw) = node.utf8_text(source) {
158                let (mod_path, items) = parse_py_from_import(raw);
159                out.push(Import {
160                    raw: raw.to_string(),
161                    module_path: mod_path,
162                    items,
163                    range: ByteRange::new(node.start_byte(), node.end_byte()),
164                });
165            }
166            return;
167        }
168        _ => {}
169    }
170    for i in 0..node.child_count() {
171        if let Some(child) = node.child(i as u32) {
172            collect_imports(child, source, out);
173        }
174    }
175}
176
177fn parse_py_import(raw: &str) -> (Vec<String>, Vec<String>) {
178    let trimmed = raw.trim().trim_start_matches("import ").trim();
179    let mut mod_path = Vec::new();
180    let mut items = Vec::new();
181    for part in trimmed.split(',') {
182        let part = part.trim();
183        if let Some((module, alias)) = part.split_once(" as ") {
184            mod_path.push(module.trim().to_string());
185            items.push(alias.trim().to_string());
186        } else {
187            let name = part.trim();
188            if !name.is_empty() {
189                mod_path.push(name.to_string());
190            }
191        }
192    }
193    (mod_path, items)
194}
195
196fn parse_py_from_import(raw: &str) -> (Vec<String>, Vec<String>) {
197    let trimmed = raw.trim().trim_start_matches("from ").trim();
198    if let Some((module_part, items_part)) = trimmed.split_once(" import ") {
199        let module_str = module_part.trim();
200        let mut mod_path = Vec::new();
201        for part in module_str.split('.') {
202            let p = part.trim();
203            if !p.is_empty() {
204                mod_path.push(p.to_string());
205            }
206        }
207
208        let mut items = Vec::new();
209        for item in items_part.split(',') {
210            let item = item.trim();
211            let item = if let Some((name, _)) = item.split_once(" as ") {
212                name.trim()
213            } else {
214                item
215            };
216            if !item.is_empty() {
217                items.push(item.to_string());
218            }
219        }
220        (mod_path, items)
221    } else {
222        (vec![], vec![])
223    }
224}
225
226fn signature_node(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
227    let sig_end = node
228        .child_by_field_name("body")
229        .map(|b| b.start_byte())
230        .unwrap_or(node.end_byte());
231
232    let sig_bytes = &source[node.start_byte()..sig_end];
233    let sig = std::str::from_utf8(sig_bytes).unwrap_or("").to_string();
234    let sig = sig.trim().to_string();
235    if sig.is_empty() {
236        None
237    } else {
238        Some(sig)
239    }
240}
241
242fn chunk_kind_for_node(kind: &str) -> ChunkKind {
243    match kind {
244        "function_definition" | "decorated_definition" => ChunkKind::FunctionBody,
245        "class_definition" => ChunkKind::TypeDef,
246        _ => ChunkKind::TopLevel,
247    }
248}
249
250fn is_chunk_boundary_py(kind: &str) -> bool {
251    matches!(
252        kind,
253        "function_definition" | "class_definition" | "decorated_definition"
254    )
255}
256
257#[cfg(test)]
258#[allow(clippy::unwrap_used, clippy::expect_used)]
259mod tests {
260    use super::*;
261    use camino::Utf8PathBuf;
262    use std::time::UNIX_EPOCH;
263
264    fn make_file(path: &str) -> FileEntry {
265        FileEntry {
266            path: Utf8PathBuf::from(path),
267            hash: argyph_fs::Blake3Hash::from([0u8; 32]),
268            language: Some(Language::Python),
269            size: 0,
270            modified: UNIX_EPOCH,
271        }
272    }
273
274    fn symbols_contain(symbols: &[Symbol], names: &[&str]) -> bool {
275        let got: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
276        names.iter().all(|n| got.contains(n))
277    }
278
279    #[test]
280    fn parse_py_function() {
281        let source = "def add(a: int, b: int) -> int:\n    return a + b\n";
282        let file = make_file("src/math.py");
283        let result = parse_python(&file, source, 4096).unwrap();
284        assert_eq!(result.symbols.len(), 1);
285        assert_eq!(result.symbols[0].name, "add");
286        assert_eq!(result.symbols[0].kind, SymbolKind::Function);
287    }
288
289    #[test]
290    fn parse_py_class_and_method() {
291        let source = r#"class Greeter:
292    """A friendly greeter."""
293
294    greeting: str
295
296    def __init__(self, message: str) -> None:
297        self.greeting = message
298
299    def greet(self, user: str) -> str:
300        return f"{self.greeting}, {user}!"
301"#;
302        let file = make_file("src/greeter.py");
303        let result = parse_python(&file, source, 4096).unwrap();
304        assert!(symbols_contain(
305            &result.symbols,
306            &["Greeter", "__init__", "greet"]
307        ));
308    }
309
310    #[test]
311    fn parse_py_dataclass() {
312        let source = r#"from dataclasses import dataclass
313
314@dataclass
315class User:
316    name: str
317    age: int
318"#;
319        let file = make_file("src/types.py");
320        let result = parse_python(&file, source, 4096).unwrap();
321        assert!(symbols_contain(&result.symbols, &["User"]));
322    }
323
324    #[test]
325    fn parse_py_imports() {
326        let source = r#"import os
327from .math import add, multiply
328from typing import List, Optional
329
330def f(): pass
331"#;
332        let file = make_file("src/main.py");
333        let result = parse_python(&file, source, 4096).unwrap();
334        assert_eq!(result.imports.len(), 3);
335    }
336
337    #[test]
338    fn parse_py_chunks_produced() {
339        let source = "def a(): pass\ndef b(): pass\nclass C: pass\n";
340        let file = make_file("src/app.py");
341        let result = parse_python(&file, source, 4096).unwrap();
342        assert!(!result.chunks.is_empty());
343    }
344
345    #[test]
346    fn parse_py_enum() {
347        let source = r#"from enum import Enum
348
349class Status(Enum):
350    ACTIVE = "ACTIVE"
351    INACTIVE = "INACTIVE"
352"#;
353        let file = make_file("src/status.py");
354        let result = parse_python(&file, source, 4096).unwrap();
355        assert!(symbols_contain(&result.symbols, &["Status"]));
356    }
357}