1use crate::{
2 Error, FileUnit, FunctionUnit, LanguageParser, ModuleUnit, PythonParser, Result, StructUnit,
3 Visibility,
4};
5use std::fs;
6use std::ops::{Deref, DerefMut};
7use std::path::Path;
8use tree_sitter::{Node, Parser};
9
10fn get_node_text(node: Node, source_code: &str) -> Option<String> {
12 node.utf8_text(source_code.as_bytes())
13 .ok()
14 .map(String::from)
15}
16
17fn get_child_node_text<'a>(node: Node<'a>, kind: &str, source_code: &'a str) -> Option<String> {
19 node.children(&mut node.walk())
20 .find(|child| child.kind() == kind)
21 .and_then(|child| child.utf8_text(source_code.as_bytes()).ok())
22 .map(String::from)
23}
24
25impl PythonParser {
26 pub fn try_new() -> Result<Self> {
27 let mut parser = Parser::new();
28 let language = tree_sitter_python::LANGUAGE;
29 parser
30 .set_language(&language.into())
31 .map_err(|e| Error::TreeSitter(e.to_string()))?;
32 Ok(Self { parser })
33 }
34
35 fn extract_documentation(&self, node: Node, source_code: &str) -> Option<String> {
37 let mut cursor = node.walk();
38 let mut children = node.children(&mut cursor);
39
40 if node.kind() == "function_definition" || node.kind() == "class_definition" {
42 children.next(); }
44
45 for child in children {
47 match child.kind() {
48 "block" => {
49 let mut body_cursor = child.walk();
51 let mut body_children = child.children(&mut body_cursor);
52 if let Some(first_expr) = body_children.next() {
53 if first_expr.kind() == "expression_statement" {
54 if let Some(string) = first_expr
55 .children(&mut first_expr.walk())
56 .find(|c| c.kind() == "string")
57 {
58 return self.clean_docstring(string, source_code);
59 }
60 }
61 }
62 }
63 "expression_statement" => {
64 if let Some(string) = child
66 .children(&mut child.walk())
67 .find(|c| c.kind() == "string")
68 {
69 return self.clean_docstring(string, source_code);
70 }
71 }
72 "ERROR" => {
73 let mut error_cursor = child.walk();
75 let error_children = child.children(&mut error_cursor);
76 for error_child in error_children {
77 if error_child.kind() == "string" {
78 if let Some(string_content) = error_child
79 .children(&mut error_child.walk())
80 .find(|c| c.kind() == "string_content")
81 {
82 if let Some(content) = get_node_text(string_content, source_code) {
83 return Some(content.trim().to_string());
84 }
85 }
86 }
87 }
88 }
89 _ => continue,
90 }
91 }
92 None
93 }
94
95 fn clean_docstring(&self, node: Node, source_code: &str) -> Option<String> {
97 let doc = get_node_text(node, source_code)?;
98 let doc = if doc.starts_with("\"\"\"") && doc.ends_with("\"\"\"") {
100 doc[3..doc.len() - 3].trim()
102 } else if doc.starts_with("'''") && doc.ends_with("'''") {
103 doc[3..doc.len() - 3].trim()
105 } else {
106 doc.trim_matches('"').trim_matches('\'').trim()
108 };
109 Some(doc.to_string())
110 }
111
112 fn extract_decorators(&self, node: Node, source_code: &str) -> Vec<String> {
114 let mut decorators = Vec::new();
115 let mut cursor = node.walk();
116
117 for child in node.children(&mut cursor) {
119 if child.kind() == "decorator" {
120 if let Some(text) = get_node_text(child, source_code) {
121 decorators.push(text);
122 }
123 }
124 }
125 decorators
126 }
127
128 fn parse_function(&self, node: Node, source_code: &str) -> Result<FunctionUnit> {
130 let function_node = if node.kind() == "decorated_definition" {
132 node.children(&mut node.walk())
133 .find(|child| child.kind() == "function_definition")
134 .unwrap_or(node)
135 } else {
136 node
137 };
138
139 let name = get_child_node_text(function_node, "identifier", source_code)
140 .unwrap_or_else(|| "unknown".to_string());
141 let documentation = self.extract_documentation(function_node, source_code);
142 let attributes = self.extract_decorators(node, source_code);
143 let source = get_node_text(function_node, source_code);
144 let visibility = if name.starts_with('_') {
145 Visibility::Private
146 } else {
147 Visibility::Public
148 };
149
150 let mut signature = None;
151 let mut body = None;
152
153 if let Some(src) = &source {
154 if let Some(body_start_idx) = src.find(':') {
155 signature = Some(src[0..body_start_idx].trim().to_string());
156 body = Some(src[body_start_idx + 1..].trim().to_string());
157 }
158 }
159
160 Ok(FunctionUnit {
161 name,
162 visibility,
163 documentation,
164 source,
165 signature,
166 body,
167 attributes,
168 })
169 }
170
171 fn parse_class(&self, node: Node, source_code: &str) -> Result<StructUnit> {
173 let class_node = if node.kind() == "decorated_definition" {
175 node.children(&mut node.walk())
176 .find(|child| child.kind() == "class_definition")
177 .unwrap_or(node)
178 } else {
179 node
180 };
181
182 let name = get_child_node_text(class_node, "identifier", source_code)
183 .unwrap_or_else(|| "unknown".to_string());
184 let documentation = self.extract_documentation(class_node, source_code);
185 let attributes = self.extract_decorators(node, source_code);
186 let source = get_node_text(class_node, source_code);
187 let visibility = if name.starts_with('_') {
188 Visibility::Private
189 } else {
190 Visibility::Public
191 };
192
193 let head = format!("class {}", name);
195
196 let mut methods = Vec::new();
198 let mut cursor = class_node.walk();
199 for child in class_node.children(&mut cursor) {
200 if child.kind() == "block" {
201 let mut block_cursor = child.walk();
202 for method_node in child.children(&mut block_cursor) {
203 match method_node.kind() {
204 "function_definition" | "decorated_definition" => {
205 if let Ok(method) = self.parse_function(method_node, source_code) {
206 methods.push(method);
207 }
208 }
209 _ => continue,
210 }
211 }
212 }
213 }
214
215 Ok(StructUnit {
216 name,
217 head,
218 visibility,
219 documentation,
220 source,
221 attributes,
222 methods,
223 })
224 }
225
226 #[allow(dead_code)]
227 fn parse_module(&self, node: Node, source_code: &str) -> Result<ModuleUnit> {
229 let name = get_child_node_text(node, "identifier", source_code)
230 .unwrap_or_else(|| "unknown".to_string());
231 let document = self.extract_documentation(node, source_code);
232 let source = get_node_text(node, source_code);
233 let visibility = if name.starts_with('_') {
234 Visibility::Private
235 } else {
236 Visibility::Public
237 };
238
239 Ok(ModuleUnit {
240 name,
241 visibility,
242 document,
243 source,
244 attributes: Vec::new(),
245 declares: Vec::new(),
246 functions: Vec::new(),
247 structs: Vec::new(),
248 traits: Vec::new(),
249 impls: Vec::new(),
250 submodules: Vec::new(),
251 })
252 }
253}
254
255impl LanguageParser for PythonParser {
256 fn parse_file(&mut self, file_path: &Path) -> Result<FileUnit> {
257 let source_code = fs::read_to_string(file_path).map_err(Error::Io)?;
258 let tree = self
259 .parse(source_code.as_bytes(), None)
260 .ok_or_else(|| Error::TreeSitter("Failed to parse Python file".to_string()))?;
261
262 let mut file_unit = FileUnit {
263 path: file_path.to_path_buf(),
264 source: Some(source_code.clone()),
265 document: None,
266 declares: Vec::new(),
267 modules: Vec::new(),
268 functions: Vec::new(),
269 structs: Vec::new(),
270 traits: Vec::new(),
271 impls: Vec::new(),
272 };
273
274 let root_node = tree.root_node();
275
276 {
278 let mut cursor = root_node.walk();
279 let mut children = root_node.children(&mut cursor);
280
281 if let Some(first_expr) = children.next() {
282 if first_expr.kind() == "expression_statement" {
283 if let Some(string) = first_expr
284 .children(&mut first_expr.walk())
285 .find(|c| c.kind() == "string")
286 {
287 if let Some(doc) = get_node_text(string, &source_code) {
288 let doc = doc
290 .trim_start_matches(r#"""""#)
291 .trim_end_matches(r#"""""#)
292 .trim_start_matches(r#"'''"#)
293 .trim_end_matches(r#"'''"#)
294 .trim_start_matches('"')
295 .trim_end_matches('"')
296 .trim_start_matches('\'')
297 .trim_end_matches('\'')
298 .trim();
299 file_unit.document = Some(doc.to_string());
300 }
301 }
302 }
303 }
304 }
305
306 {
308 let mut cursor = root_node.walk();
309 for node in root_node.children(&mut cursor) {
310 if node.kind() == "import_statement" || node.kind() == "import_from_statement" {
311 if let Some(import_text) = get_node_text(node, &source_code) {
312 file_unit.declares.push(crate::DeclareStatements {
313 source: import_text,
314 kind: crate::DeclareKind::Import,
315 });
316 }
317 }
318 }
319 }
320
321 let mut cursor = root_node.walk();
323 for node in root_node.children(&mut cursor) {
324 match node.kind() {
325 "function_definition" => {
326 let func = self.parse_function(node, &source_code)?;
327 file_unit.functions.push(func);
328 }
329 "class_definition" => {
330 let class = self.parse_class(node, &source_code)?;
331 file_unit.structs.push(class);
332 }
333 "decorated_definition" => {
334 let mut node_cursor = node.walk();
335 let children: Vec<_> = node.children(&mut node_cursor).collect();
336 if let Some(def_node) = children.iter().find(|n| {
337 n.kind() == "function_definition" || n.kind() == "class_definition"
338 }) {
339 match def_node.kind() {
340 "function_definition" => {
341 let func = self.parse_function(node, &source_code)?;
342 file_unit.functions.push(func);
343 }
344 "class_definition" => {
345 let class = self.parse_class(node, &source_code)?;
346 file_unit.structs.push(class);
347 }
348 _ => {}
349 }
350 }
351 }
352 _ => continue,
353 }
354 }
355
356 Ok(file_unit)
357 }
358}
359
360impl Deref for PythonParser {
361 type Target = Parser;
362
363 fn deref(&self) -> &Self::Target {
364 &self.parser
365 }
366}
367
368impl DerefMut for PythonParser {
369 fn deref_mut(&mut self) -> &mut Self::Target {
370 &mut self.parser
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use std::path::PathBuf;
378
379 fn create_test_file(content: &str) -> Result<(tempfile::TempDir, PathBuf)> {
380 let dir = tempfile::tempdir().map_err(Error::Io)?;
381 let file_path = dir.path().join("test.py");
382 fs::write(&file_path, content).map_err(Error::Io)?;
383 Ok((dir, file_path))
384 }
385
386 #[test]
387 fn test_parse_function() -> Result<()> {
388 let content = r#"
389def hello_world():
390 """This is a docstring."""
391 print("Hello, World!")
392"#;
393 let (_dir, file_path) = create_test_file(content)?;
394 let mut parser = PythonParser::try_new()?;
395 let file_unit = parser.parse_file(&file_path)?;
396
397 assert_eq!(file_unit.functions.len(), 1);
398 let func = &file_unit.functions[0];
399 assert_eq!(func.name, "hello_world");
400 assert_eq!(func.visibility, Visibility::Public);
401 assert_eq!(func.documentation, Some("This is a docstring.".to_string()));
402 Ok(())
403 }
404
405 #[test]
406 fn test_parse_class() -> Result<()> {
407 let content = r#"
408@dataclass
409class Person:
410 """A person class."""
411 def __init__(self, name: str):
412 self.name = name
413"#;
414 let (_dir, file_path) = create_test_file(content)?;
415 let mut parser = PythonParser::try_new()?;
416 let file_unit = parser.parse_file(&file_path)?;
417
418 assert_eq!(file_unit.structs.len(), 1);
419 let class = &file_unit.structs[0];
420 assert_eq!(class.name, "Person");
421 assert_eq!(class.visibility, Visibility::Public);
422 assert_eq!(class.documentation, Some("A person class.".to_string()));
423 assert_eq!(class.attributes.len(), 1);
424 assert_eq!(class.attributes[0], "@dataclass");
425 Ok(())
426 }
427
428 #[test]
429 fn test_parse_private_members() -> Result<()> {
430 let content = r#"
431def _private_function():
432 """A private function."""
433 pass
434
435class _PrivateClass:
436 """A private class."""
437 pass
438"#;
439 let (_dir, file_path) = create_test_file(content)?;
440 let mut parser = PythonParser::try_new()?;
441 let file_unit = parser.parse_file(&file_path)?;
442
443 assert_eq!(file_unit.functions[0].visibility, Visibility::Private);
444 assert_eq!(file_unit.structs[0].visibility, Visibility::Private);
445 Ok(())
446 }
447
448 #[test]
449 fn test_parse_module_docstring() -> Result<()> {
450 let content = r#"'''This is a module docstring.'''
451
452def hello_world():
453 pass
454"#;
455 let (_dir, file_path) = create_test_file(content)?;
456 let mut parser = PythonParser::try_new()?;
457 let file_unit = parser.parse_file(&file_path)?;
458
459 assert_eq!(
460 file_unit.document,
461 Some("This is a module docstring.".to_string())
462 );
463 Ok(())
464 }
465
466 #[test]
467 fn test_parse_module_docstring_with_triple_quotes() -> Result<()> {
468 let content = r#"'''This is a module docstring with triple quotes.'''
469
470def hello_world():
471 pass
472"#;
473 let (_dir, file_path) = create_test_file(content)?;
474 let mut parser = PythonParser::try_new()?;
475 let file_unit = parser.parse_file(&file_path)?;
476
477 assert_eq!(
478 file_unit.document,
479 Some("This is a module docstring with triple quotes.".to_string())
480 );
481 Ok(())
482 }
483}