1use super::{node_text, LanguageExtractor};
7use crate::ast::{
8 ExtractedSymbol, FunctionCall, Import, ImportedName, Parameter, SymbolKind, Visibility,
9};
10use crate::error::Result;
11use tree_sitter::{Language, Node, Tree};
12
13pub struct PythonExtractor;
15
16impl LanguageExtractor for PythonExtractor {
17 fn language(&self) -> Language {
18 tree_sitter_python::LANGUAGE.into()
19 }
20
21 fn name(&self) -> &'static str {
22 "python"
23 }
24
25 fn extensions(&self) -> &'static [&'static str] {
26 &["py", "pyi"]
27 }
28
29 fn extract_symbols(&self, tree: &Tree, source: &str) -> Result<Vec<ExtractedSymbol>> {
30 let mut symbols = Vec::new();
31 let root = tree.root_node();
32 self.extract_symbols_recursive(&root, source, &mut symbols, None);
33 Ok(symbols)
34 }
35
36 fn extract_imports(&self, tree: &Tree, source: &str) -> Result<Vec<Import>> {
37 let mut imports = Vec::new();
38 let root = tree.root_node();
39 self.extract_imports_recursive(&root, source, &mut imports);
40 Ok(imports)
41 }
42
43 fn extract_calls(
44 &self,
45 tree: &Tree,
46 source: &str,
47 current_function: Option<&str>,
48 ) -> Result<Vec<FunctionCall>> {
49 let mut calls = Vec::new();
50 let root = tree.root_node();
51 self.extract_calls_recursive(&root, source, &mut calls, current_function);
52 Ok(calls)
53 }
54
55 fn extract_doc_comment(&self, node: &Node, source: &str) -> Option<String> {
56 let body = node.child_by_field_name("body")?;
58 let mut cursor = body.walk();
59 let first_child = body.children(&mut cursor).next()?;
61 if first_child.kind() == "expression_statement" {
62 if let Some(string) = first_child.child(0) {
63 if string.kind() == "string" {
64 let text = node_text(&string, source);
65 return Some(Self::clean_docstring(text));
66 }
67 }
68 }
69 None
70 }
71}
72
73impl PythonExtractor {
74 fn extract_symbols_recursive(
75 &self,
76 node: &Node,
77 source: &str,
78 symbols: &mut Vec<ExtractedSymbol>,
79 parent: Option<&str>,
80 ) {
81 match node.kind() {
82 "function_definition" => {
83 if let Some(sym) = self.extract_function(node, source, parent) {
84 symbols.push(sym);
85 }
88 }
89
90 "class_definition" => {
91 if let Some(sym) = self.extract_class(node, source, parent) {
92 let class_name = sym.name.clone();
93 symbols.push(sym);
94
95 if let Some(body) = node.child_by_field_name("body") {
97 self.extract_class_members(&body, source, symbols, Some(&class_name));
98 }
99 return; }
101 }
102
103 "decorated_definition" => {
104 let decorator_start = node.start_position().row + 1;
107
108 let mut cursor = node.walk();
109 for child in node.children(&mut cursor) {
110 if child.kind() == "function_definition" {
111 if let Some(mut sym) = self.extract_function(&child, source, parent) {
112 sym.definition_start_line = Some(decorator_start);
114 symbols.push(sym);
115 }
116 } else if child.kind() == "class_definition" {
117 if let Some(mut sym) = self.extract_class(&child, source, parent) {
118 let class_name = sym.name.clone();
119 sym.definition_start_line = Some(decorator_start);
120 symbols.push(sym);
121
122 if let Some(body) = child.child_by_field_name("body") {
124 self.extract_class_members(
125 &body,
126 source,
127 symbols,
128 Some(&class_name),
129 );
130 }
131 }
132 }
133 }
134 return;
135 }
136
137 _ => {}
138 }
139
140 let mut cursor = node.walk();
142 for child in node.children(&mut cursor) {
143 self.extract_symbols_recursive(&child, source, symbols, parent);
144 }
145 }
146
147 fn extract_function(
148 &self,
149 node: &Node,
150 source: &str,
151 parent: Option<&str>,
152 ) -> Option<ExtractedSymbol> {
153 let name_node = node.child_by_field_name("name")?;
154 let name = node_text(&name_node, source).to_string();
155
156 let mut sym = ExtractedSymbol::new(
157 name.clone(),
158 SymbolKind::Function,
159 node.start_position().row + 1,
160 node.end_position().row + 1,
161 )
162 .with_columns(node.start_position().column, node.end_position().column);
163
164 let text = node_text(node, source);
166 if text.starts_with("async") {
167 sym = sym.async_fn();
168 }
169
170 if name.starts_with("__") && !name.ends_with("__") {
172 sym.visibility = Visibility::Private;
173 } else if name.starts_with('_') {
174 sym.visibility = Visibility::Protected;
175 } else {
176 sym = sym.exported();
177 }
178
179 if let Some(params) = node.child_by_field_name("parameters") {
181 self.extract_parameters(¶ms, source, &mut sym);
182 }
183
184 if let Some(ret_type) = node.child_by_field_name("return_type") {
186 sym.return_type = Some(
187 node_text(&ret_type, source)
188 .trim_start_matches("->")
189 .trim()
190 .to_string(),
191 );
192 }
193
194 sym.doc_comment = self.extract_doc_comment(node, source);
196
197 if let Some(p) = parent {
198 sym = sym.with_parent(p);
199 sym.kind = SymbolKind::Method;
200 }
201
202 sym.signature = Some(self.build_function_signature(node, source));
203
204 if sym.definition_start_line.is_none() {
206 sym.definition_start_line = Some(node.start_position().row + 1);
207 }
208
209 Some(sym)
210 }
211
212 fn extract_class(
213 &self,
214 node: &Node,
215 source: &str,
216 parent: Option<&str>,
217 ) -> Option<ExtractedSymbol> {
218 let name_node = node.child_by_field_name("name")?;
219 let name = node_text(&name_node, source).to_string();
220
221 let mut sym = ExtractedSymbol::new(
222 name.clone(),
223 SymbolKind::Class,
224 node.start_position().row + 1,
225 node.end_position().row + 1,
226 )
227 .with_columns(node.start_position().column, node.end_position().column);
228
229 if name.starts_with('_') {
231 sym.visibility = Visibility::Protected;
232 } else {
233 sym = sym.exported();
234 }
235
236 sym.doc_comment = self.extract_doc_comment(node, source);
238
239 if let Some(p) = parent {
240 sym = sym.with_parent(p);
241 }
242
243 if sym.definition_start_line.is_none() {
245 sym.definition_start_line = Some(node.start_position().row + 1);
246 }
247
248 Some(sym)
249 }
250
251 fn extract_class_members(
252 &self,
253 body: &Node,
254 source: &str,
255 symbols: &mut Vec<ExtractedSymbol>,
256 class_name: Option<&str>,
257 ) {
258 let mut cursor = body.walk();
259 for child in body.children(&mut cursor) {
260 match child.kind() {
261 "function_definition" => {
262 if let Some(sym) = self.extract_function(&child, source, class_name) {
263 symbols.push(sym);
264 }
265 }
266 "decorated_definition" => {
267 let decorator_start = child.start_position().row + 1;
269
270 let mut inner_cursor = child.walk();
271 for inner in child.children(&mut inner_cursor) {
272 if inner.kind() == "function_definition" {
273 if let Some(mut sym) = self.extract_function(&inner, source, class_name)
274 {
275 sym.definition_start_line = Some(decorator_start);
277
278 let deco_text = node_text(&child, source);
280 if deco_text.contains("@staticmethod") {
281 sym = sym.static_fn();
282 }
283 symbols.push(sym);
284 }
285 }
286 }
287 }
288 _ => {}
289 }
290 }
291 }
292
293 fn extract_parameters(&self, params: &Node, source: &str, sym: &mut ExtractedSymbol) {
294 let mut cursor = params.walk();
295 for child in params.children(&mut cursor) {
296 match child.kind() {
297 "identifier" => {
298 let name = node_text(&child, source);
299 if name != "self" && name != "cls" {
301 sym.add_parameter(Parameter {
302 name: name.to_string(),
303 type_info: None,
304 default_value: None,
305 is_rest: false,
306 is_optional: false,
307 });
308 }
309 }
310 "typed_parameter" => {
311 let name = child
312 .child_by_field_name("name")
313 .map(|n| node_text(&n, source).to_string())
314 .unwrap_or_default();
315
316 if name != "self" && name != "cls" {
317 let type_info = child
318 .child_by_field_name("type")
319 .map(|n| node_text(&n, source).to_string());
320
321 sym.add_parameter(Parameter {
322 name,
323 type_info,
324 default_value: None,
325 is_rest: false,
326 is_optional: false,
327 });
328 }
329 }
330 "default_parameter" | "typed_default_parameter" => {
331 let name = child
332 .child_by_field_name("name")
333 .map(|n| node_text(&n, source).to_string())
334 .unwrap_or_default();
335
336 if name != "self" && name != "cls" {
337 let type_info = child
338 .child_by_field_name("type")
339 .map(|n| node_text(&n, source).to_string());
340 let default_value = child
341 .child_by_field_name("value")
342 .map(|n| node_text(&n, source).to_string());
343
344 sym.add_parameter(Parameter {
345 name,
346 type_info,
347 default_value,
348 is_rest: false,
349 is_optional: true,
350 });
351 }
352 }
353 "list_splat_pattern" | "dictionary_splat_pattern" => {
354 let text = node_text(&child, source);
355 let name = text.trim_start_matches('*').to_string();
356 let is_kwargs = text.starts_with("**");
357
358 sym.add_parameter(Parameter {
359 name,
360 type_info: None,
361 default_value: None,
362 is_rest: !is_kwargs,
363 is_optional: true,
364 });
365 }
366 _ => {}
367 }
368 }
369 }
370
371 fn extract_imports_recursive(&self, node: &Node, source: &str, imports: &mut Vec<Import>) {
372 match node.kind() {
373 "import_statement" => {
374 if let Some(import) = self.parse_import(node, source) {
375 imports.push(import);
376 }
377 }
378 "import_from_statement" => {
379 if let Some(import) = self.parse_from_import(node, source) {
380 imports.push(import);
381 }
382 }
383 _ => {}
384 }
385
386 let mut cursor = node.walk();
387 for child in node.children(&mut cursor) {
388 self.extract_imports_recursive(&child, source, imports);
389 }
390 }
391
392 fn parse_import(&self, node: &Node, source: &str) -> Option<Import> {
393 let mut import = Import {
394 source: String::new(),
395 names: Vec::new(),
396 is_default: false,
397 is_namespace: false,
398 line: node.start_position().row + 1,
399 };
400
401 let mut cursor = node.walk();
402 for child in node.children(&mut cursor) {
403 match child.kind() {
404 "dotted_name" => {
405 let name = node_text(&child, source).to_string();
406 import.source = name.clone();
407 import.names.push(ImportedName { name, alias: None });
408 }
409 "aliased_import" => {
410 let name = child
411 .child_by_field_name("name")
412 .map(|n| node_text(&n, source).to_string())
413 .unwrap_or_default();
414 let alias = child
415 .child_by_field_name("alias")
416 .map(|n| node_text(&n, source).to_string());
417
418 import.source = name.clone();
419 import.names.push(ImportedName { name, alias });
420 }
421 _ => {}
422 }
423 }
424
425 Some(import)
426 }
427
428 fn parse_from_import(&self, node: &Node, source: &str) -> Option<Import> {
429 let module = node
430 .child_by_field_name("module_name")
431 .map(|n| node_text(&n, source).to_string())
432 .unwrap_or_default();
433
434 let mut import = Import {
435 source: module,
436 names: Vec::new(),
437 is_default: false,
438 is_namespace: false,
439 line: node.start_position().row + 1,
440 };
441
442 let mut cursor = node.walk();
443 for child in node.children(&mut cursor) {
444 match child.kind() {
445 "wildcard_import" => {
446 import.is_namespace = true;
447 import.names.push(ImportedName {
448 name: "*".to_string(),
449 alias: None,
450 });
451 }
452 "dotted_name" | "identifier" => {
453 import.names.push(ImportedName {
454 name: node_text(&child, source).to_string(),
455 alias: None,
456 });
457 }
458 "aliased_import" => {
459 let name = child
460 .child_by_field_name("name")
461 .map(|n| node_text(&n, source).to_string())
462 .unwrap_or_default();
463 let alias = child
464 .child_by_field_name("alias")
465 .map(|n| node_text(&n, source).to_string());
466
467 import.names.push(ImportedName { name, alias });
468 }
469 _ => {}
470 }
471 }
472
473 Some(import)
474 }
475
476 fn extract_calls_recursive(
477 &self,
478 node: &Node,
479 source: &str,
480 calls: &mut Vec<FunctionCall>,
481 current_function: Option<&str>,
482 ) {
483 if node.kind() == "call" {
484 if let Some(call) = self.parse_call(node, source, current_function) {
485 calls.push(call);
486 }
487 }
488
489 let func_name = if node.kind() == "function_definition" {
490 node.child_by_field_name("name")
491 .map(|n| node_text(&n, source))
492 } else {
493 None
494 };
495
496 let current = func_name
497 .map(String::from)
498 .or_else(|| current_function.map(String::from));
499
500 let mut cursor = node.walk();
501 for child in node.children(&mut cursor) {
502 self.extract_calls_recursive(&child, source, calls, current.as_deref());
503 }
504 }
505
506 fn parse_call(
507 &self,
508 node: &Node,
509 source: &str,
510 current_function: Option<&str>,
511 ) -> Option<FunctionCall> {
512 let function = node.child_by_field_name("function")?;
513
514 let (callee, is_method, receiver) = match function.kind() {
515 "attribute" => {
516 let object = function
517 .child_by_field_name("object")
518 .map(|n| node_text(&n, source).to_string());
519 let attr = function
520 .child_by_field_name("attribute")
521 .map(|n| node_text(&n, source).to_string())?;
522 (attr, true, object)
523 }
524 "identifier" => (node_text(&function, source).to_string(), false, None),
525 _ => return None,
526 };
527
528 Some(FunctionCall {
529 caller: current_function.unwrap_or("<module>").to_string(),
530 callee,
531 line: node.start_position().row + 1,
532 is_method,
533 receiver,
534 })
535 }
536
537 fn build_function_signature(&self, node: &Node, source: &str) -> String {
538 let async_kw = if node_text(node, source).starts_with("async") {
539 "async "
540 } else {
541 ""
542 };
543
544 let name = node
545 .child_by_field_name("name")
546 .map(|n| node_text(&n, source))
547 .unwrap_or("unknown");
548
549 let params = node
550 .child_by_field_name("parameters")
551 .map(|n| node_text(&n, source))
552 .unwrap_or("()");
553
554 let return_type = node
555 .child_by_field_name("return_type")
556 .map(|n| format!(" {}", node_text(&n, source)))
557 .unwrap_or_default();
558
559 format!("{}def {}{}{}", async_kw, name, params, return_type)
560 }
561
562 fn clean_docstring(text: &str) -> String {
563 let text = text
565 .trim_start_matches("\"\"\"")
566 .trim_start_matches("'''")
567 .trim_end_matches("\"\"\"")
568 .trim_end_matches("'''")
569 .trim();
570
571 text.to_string()
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578
579 fn parse_py(source: &str) -> (Tree, String) {
580 let mut parser = tree_sitter::Parser::new();
581 parser
582 .set_language(&tree_sitter_python::LANGUAGE.into())
583 .unwrap();
584 let tree = parser.parse(source, None).unwrap();
585 (tree, source.to_string())
586 }
587
588 #[test]
589 fn test_extract_function() {
590 let source = r#"
591def greet(name: str) -> str:
592 """Greet someone."""
593 return f"Hello, {name}!"
594"#;
595 let (tree, src) = parse_py(source);
596 let extractor = PythonExtractor;
597 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
598
599 assert_eq!(symbols.len(), 1);
600 assert_eq!(symbols[0].name, "greet");
601 assert_eq!(symbols[0].kind, SymbolKind::Function);
602 }
603
604 #[test]
605 fn test_extract_class() {
606 let source = r#"
607class UserService:
608 """A service for managing users."""
609
610 def __init__(self, name: str):
611 self.name = name
612
613 def greet(self) -> str:
614 return f"Hello, {self.name}!"
615
616 @staticmethod
617 def create():
618 return UserService("default")
619"#;
620 let (tree, src) = parse_py(source);
621 let extractor = PythonExtractor;
622 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
623
624 assert!(symbols
625 .iter()
626 .any(|s| s.name == "UserService" && s.kind == SymbolKind::Class));
627 assert!(symbols
628 .iter()
629 .any(|s| s.name == "__init__" && s.kind == SymbolKind::Method));
630 assert!(symbols
631 .iter()
632 .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
633 assert!(symbols
634 .iter()
635 .any(|s| s.name == "create" && s.kind == SymbolKind::Method));
636 }
637
638 #[test]
639 fn test_extract_async_function() {
640 let source = r#"
641async def fetch_data(url: str) -> dict:
642 """Fetch data from URL."""
643 pass
644"#;
645 let (tree, src) = parse_py(source);
646 let extractor = PythonExtractor;
647 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
648
649 assert_eq!(symbols.len(), 1);
650 assert_eq!(symbols[0].name, "fetch_data");
651 assert!(symbols[0].is_async);
652 }
653
654 #[test]
655 fn test_decorated_function_definition_start_line() {
656 let source = r#"
657@decorator1
658@decorator2
659def my_function():
660 pass
661"#;
662 let (tree, src) = parse_py(source);
663 let extractor = PythonExtractor;
664 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
665
666 assert_eq!(symbols.len(), 1);
667 assert_eq!(symbols[0].name, "my_function");
668 assert_eq!(symbols[0].definition_start_line, Some(2));
670 assert_eq!(symbols[0].start_line, 4);
672 }
673
674 #[test]
675 fn test_non_decorated_function_definition_start_line() {
676 let source = r#"
677def simple_function():
678 pass
679"#;
680 let (tree, src) = parse_py(source);
681 let extractor = PythonExtractor;
682 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
683
684 assert_eq!(symbols.len(), 1);
685 assert_eq!(symbols[0].definition_start_line, Some(2));
687 assert_eq!(symbols[0].start_line, 2);
688 }
689}