1use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5
6use crate::ParseResult;
7
8pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
10 parse_with_edges(path, content).map(|r| r.nodes)
11}
12
13pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
15 let mut parser = tree_sitter::Parser::new();
16 let language = tree_sitter_python::LANGUAGE;
17 parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
18 file: path.display().to_string(),
19 message: format!("Failed to set language: {}", e),
20 })?;
21
22 let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
23 file: path.display().to_string(),
24 message: "Failed to parse file".into(),
25 })?;
26
27 let mut nodes = Vec::new();
28 let file_path = path.to_string_lossy().to_string();
29
30 nodes.push(Node {
32 id: 0,
33 name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
34 kind: NodeKind::File,
35 language: Language::Python,
36 file_path: file_path.clone(),
37 line_start: 1,
38 line_end: content.lines().count() as u32,
39 content: None,
40 });
41
42 extract_definitions(tree.root_node(), content, &file_path, &mut nodes, false);
43
44 let mut edges = Vec::new();
46 for i in 1..nodes.len() {
47 edges.push((0, i, EdgeKind::Contains));
48 }
49
50 Ok(ParseResult { nodes, edges })
51}
52
53fn extract_definitions(
54 node: tree_sitter::Node,
55 source: &str,
56 file_path: &str,
57 results: &mut Vec<Node>,
58 inside_class: bool,
59) {
60 let kind = match node.kind() {
61 "function_definition" => {
62 if inside_class {
63 Some(NodeKind::Method)
64 } else {
65 Some(NodeKind::Function)
66 }
67 }
68 "class_definition" => Some(NodeKind::Class),
69 "import_statement" | "import_from_statement" => Some(NodeKind::Import),
70 "global_statement" => None, "expression_statement" => {
72 if !inside_class {
74 if let Some(child) = node.child(0) {
75 if child.kind() == "assignment" {
76 if let Some(name) = extract_assignment_name(child, source) {
78 if name.chars().all(|c| c.is_uppercase() || c == '_' || c.is_ascii_digit()) && !name.is_empty() {
79 let start = node.start_position();
80 let end = node.end_position();
81 let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
82 if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
83 });
84 results.push(Node {
85 id: 0,
86 name,
87 kind: NodeKind::Constant,
88 language: Language::Python,
89 file_path: file_path.to_string(),
90 line_start: (start.row + 1) as u32,
91 line_end: (end.row + 1) as u32,
92 content,
93 });
94 }
95 }
96 }
97 }
98 }
99 None
100 }
101 _ => None,
102 };
103
104 if let Some(kind) = kind {
105 if let Some(name) = extract_name(node, source, kind) {
106 let start = node.start_position();
107 let end = node.end_position();
108 let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
109 if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
110 });
111
112 results.push(Node {
113 id: 0,
114 name,
115 kind,
116 language: Language::Python,
117 file_path: file_path.to_string(),
118 line_start: (start.row + 1) as u32,
119 line_end: (end.row + 1) as u32,
120 content,
121 });
122 }
123
124 if kind == NodeKind::Class {
126 if let Some(body) = node.child_by_field_name("body") {
127 let mut cursor = body.walk();
128 for child in body.children(&mut cursor) {
129 extract_definitions(child, source, file_path, results, true);
130 }
131 }
132 return; }
134 }
135
136 if node.kind() != "class_definition" {
138 let mut cursor = node.walk();
139 for child in node.children(&mut cursor) {
140 extract_definitions(child, source, file_path, results, inside_class);
141 }
142 }
143}
144
145fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
146 match kind {
147 NodeKind::Import => {
148 node.utf8_text(source.as_bytes())
150 .ok()
151 .map(|s| s.trim().to_string())
152 }
153 _ => {
154 node.child_by_field_name("name")
156 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
157 .map(|s| s.to_string())
158 }
159 }
160}
161
162fn extract_assignment_name(node: tree_sitter::Node, source: &str) -> Option<String> {
163 node.child_by_field_name("left")
165 .and_then(|n| {
166 if n.kind() == "identifier" {
167 n.utf8_text(source.as_bytes()).ok().map(|s| s.to_string())
168 } else {
169 None
170 }
171 })
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use std::path::PathBuf;
178
179 const SAMPLE_PYTHON: &str = r#"
180import os
181from pathlib import Path
182
183MAX_SIZE = 1024
184DEBUG = True
185
186class Config:
187 """Configuration holder."""
188
189 def __init__(self, name: str):
190 self.name = name
191 self.values = {}
192
193 def get(self, key: str) -> str:
194 return self.values.get(key, "")
195
196 @staticmethod
197 def default() -> "Config":
198 return Config("default")
199
200class Status:
201 ACTIVE = "active"
202 INACTIVE = "inactive"
203
204def process(data: list) -> dict:
205 result = {}
206 for item in data:
207 result[item] = True
208 return result
209
210def main():
211 config = Config("test")
212 print(config.get("key"))
213"#;
214
215 #[test]
216 fn test_parse_python_finds_all_definitions() {
217 let path = PathBuf::from("test.py");
218 let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
219
220 let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
221 assert!(kinds.contains(&NodeKind::Import), "should find import");
222 assert!(kinds.contains(&NodeKind::Constant), "should find constant");
223 assert!(kinds.contains(&NodeKind::Class), "should find class");
224 assert!(kinds.contains(&NodeKind::Function), "should find function");
225 }
226
227 #[test]
228 fn test_parse_python_finds_methods() {
229 let path = PathBuf::from("test.py");
230 let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
231
232 let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
233 assert!(methods.len() >= 3, "should find methods (__init__, get, default), got {}", methods.len());
234 assert!(methods.iter().any(|m| m.name == "__init__"));
235 assert!(methods.iter().any(|m| m.name == "get"));
236 assert!(methods.iter().any(|m| m.name == "default"));
237 }
238
239 #[test]
240 fn test_parse_python_class_name() {
241 let path = PathBuf::from("test.py");
242 let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
243
244 let classes: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
245 assert_eq!(classes.len(), 2);
246 assert!(classes.iter().any(|c| c.name == "Config"));
247 assert!(classes.iter().any(|c| c.name == "Status"));
248 assert_eq!(classes[0].language, Language::Python);
249 }
250
251 #[test]
252 fn test_parse_python_constants() {
253 let path = PathBuf::from("test.py");
254 let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
255
256 let constants: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Constant).collect();
257 assert!(constants.iter().any(|c| c.name == "MAX_SIZE"), "should find MAX_SIZE");
258 assert!(constants.iter().any(|c| c.name == "DEBUG"), "should find DEBUG");
259 }
260
261 #[test]
262 fn test_parse_python_line_numbers() {
263 let path = PathBuf::from("test.py");
264 let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
265
266 let main_fn = nodes.iter().find(|n| n.name == "main" && n.kind == NodeKind::Function);
267 assert!(main_fn.is_some(), "should find main function");
268 assert!(main_fn.unwrap().line_start > 0, "line numbers should be 1-indexed");
269 }
270
271 #[test]
272 fn test_parse_python_imports() {
273 let path = PathBuf::from("test.py");
274 let nodes = parse(&path, SAMPLE_PYTHON).unwrap();
275
276 let imports: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Import).collect();
277 assert_eq!(imports.len(), 2, "should find 2 import statements");
278 assert!(imports.iter().any(|i| i.name.contains("os")));
279 assert!(imports.iter().any(|i| i.name.contains("pathlib")));
280 }
281
282 #[test]
283 fn test_parse_python_edges() {
284 let path = PathBuf::from("test.py");
285 let result = parse_with_edges(&path, SAMPLE_PYTHON).unwrap();
286
287 assert!(!result.edges.is_empty(), "should have CONTAINS edges");
288 for &(from_idx, _, ref kind) in &result.edges {
290 assert_eq!(from_idx, 0);
291 assert_eq!(*kind, EdgeKind::Contains);
292 }
293 }
294
295 #[test]
296 fn test_parse_empty_python_file() {
297 let path = PathBuf::from("empty.py");
298 let nodes = parse(&path, "").unwrap();
299 assert!(nodes.len() <= 1);
300 }
301
302 #[test]
303 fn test_parse_python_decorated_function() {
304 let source = r#"
305import functools
306
307def decorator(f):
308 return f
309
310@decorator
311def decorated():
312 pass
313
314class MyClass:
315 @staticmethod
316 def static_method():
317 pass
318
319 @classmethod
320 def class_method(cls):
321 pass
322"#;
323 let path = PathBuf::from("deco.py");
324 let nodes = parse(&path, source).unwrap();
325
326 let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
327 assert!(fns.iter().any(|f| f.name == "decorator"));
328 assert!(fns.iter().any(|f| f.name == "decorated"));
329
330 let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
331 assert!(methods.iter().any(|m| m.name == "static_method"));
332 assert!(methods.iter().any(|m| m.name == "class_method"));
333 }
334
335 #[test]
336 fn test_parse_python_nested_class() {
337 let source = r#"
338class Outer:
339 class Inner:
340 def inner_method(self):
341 pass
342
343 def outer_method(self):
344 pass
345"#;
346 let path = PathBuf::from("nested.py");
347 let nodes = parse(&path, source).unwrap();
348
349 let classes: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
350 assert!(classes.iter().any(|c| c.name == "Outer"));
351 }
352
353 #[test]
354 fn test_parse_python_async_function() {
355 let source = r#"
356import asyncio
357
358async def fetch_data(url: str) -> dict:
359 return {}
360
361class Client:
362 async def connect(self):
363 pass
364"#;
365 let path = PathBuf::from("async.py");
366 let nodes = parse(&path, source).unwrap();
367
368 let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
369 assert!(fns.iter().any(|f| f.name == "fetch_data"), "should find async function");
370
371 let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
372 assert!(methods.iter().any(|m| m.name == "connect"), "should find async method");
373 }
374
375 #[test]
376 fn test_parse_python_lowercase_not_constant() {
377 let source = r#"
378MAX_SIZE = 100
379lowercase_var = "not a constant"
380_private = True
381"#;
382 let path = PathBuf::from("vars.py");
383 let nodes = parse(&path, source).unwrap();
384
385 let constants: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Constant).collect();
386 assert!(constants.iter().any(|c| c.name == "MAX_SIZE"));
387 assert!(!constants.iter().any(|c| c.name == "lowercase_var"));
389 assert!(!constants.iter().any(|c| c.name == "_private"));
390 }
391
392 #[test]
393 fn test_parse_python_multiple_imports() {
394 let source = r#"
395import os
396import sys
397from typing import Dict, List, Optional
398from pathlib import Path
399from collections import defaultdict
400"#;
401 let path = PathBuf::from("imports.py");
402 let nodes = parse(&path, source).unwrap();
403
404 let imports: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Import).collect();
405 assert_eq!(imports.len(), 5, "should find all 5 import statements");
406 }
407}