1use super::{node_text, LanguageExtractor};
7use crate::ast::{
8 ExtractedSymbol, FunctionCall, Import, ImportedName, Parameter, SymbolKind, Visibility,
9};
10use crate::error::Result;
11use tree_sitter::{Language, Node, Tree};
12
13pub struct PythonExtractor;
15
16impl LanguageExtractor for PythonExtractor {
17 fn language(&self) -> Language {
18 tree_sitter_python::LANGUAGE.into()
19 }
20
21 fn name(&self) -> &'static str {
22 "python"
23 }
24
25 fn extensions(&self) -> &'static [&'static str] {
26 &["py", "pyi"]
27 }
28
29 fn extract_symbols(&self, tree: &Tree, source: &str) -> Result<Vec<ExtractedSymbol>> {
30 let mut symbols = Vec::new();
31 let root = tree.root_node();
32 self.extract_symbols_recursive(&root, source, &mut symbols, None);
33 Ok(symbols)
34 }
35
36 fn extract_imports(&self, tree: &Tree, source: &str) -> Result<Vec<Import>> {
37 let mut imports = Vec::new();
38 let root = tree.root_node();
39 self.extract_imports_recursive(&root, source, &mut imports);
40 Ok(imports)
41 }
42
43 fn extract_calls(
44 &self,
45 tree: &Tree,
46 source: &str,
47 current_function: Option<&str>,
48 ) -> Result<Vec<FunctionCall>> {
49 let mut calls = Vec::new();
50 let root = tree.root_node();
51 self.extract_calls_recursive(&root, source, &mut calls, current_function);
52 Ok(calls)
53 }
54
55 fn extract_doc_comment(&self, node: &Node, source: &str) -> Option<String> {
56 let body = node.child_by_field_name("body")?;
58 let mut cursor = body.walk();
59 let first_child = body.children(&mut cursor).next()?;
61 if first_child.kind() == "expression_statement" {
62 if let Some(string) = first_child.child(0) {
63 if string.kind() == "string" {
64 let text = node_text(&string, source);
65 return Some(Self::clean_docstring(text));
66 }
67 }
68 }
69 None
70 }
71}
72
73impl PythonExtractor {
74 fn extract_symbols_recursive(
75 &self,
76 node: &Node,
77 source: &str,
78 symbols: &mut Vec<ExtractedSymbol>,
79 parent: Option<&str>,
80 ) {
81 match node.kind() {
82 "function_definition" => {
83 if let Some(sym) = self.extract_function(node, source, parent) {
84 symbols.push(sym);
85 }
88 }
89
90 "class_definition" => {
91 if let Some(sym) = self.extract_class(node, source, parent) {
92 let class_name = sym.name.clone();
93 symbols.push(sym);
94
95 if let Some(body) = node.child_by_field_name("body") {
97 self.extract_class_members(&body, source, symbols, Some(&class_name));
98 }
99 return; }
101 }
102
103 "decorated_definition" => {
104 let mut cursor = node.walk();
106 for child in node.children(&mut cursor) {
107 if child.kind() == "function_definition" || child.kind() == "class_definition" {
108 self.extract_symbols_recursive(&child, source, symbols, parent);
109 }
110 }
111 return;
112 }
113
114 _ => {}
115 }
116
117 let mut cursor = node.walk();
119 for child in node.children(&mut cursor) {
120 self.extract_symbols_recursive(&child, source, symbols, parent);
121 }
122 }
123
124 fn extract_function(
125 &self,
126 node: &Node,
127 source: &str,
128 parent: Option<&str>,
129 ) -> Option<ExtractedSymbol> {
130 let name_node = node.child_by_field_name("name")?;
131 let name = node_text(&name_node, source).to_string();
132
133 let mut sym = ExtractedSymbol::new(
134 name.clone(),
135 SymbolKind::Function,
136 node.start_position().row + 1,
137 node.end_position().row + 1,
138 )
139 .with_columns(node.start_position().column, node.end_position().column);
140
141 let text = node_text(node, source);
143 if text.starts_with("async") {
144 sym = sym.async_fn();
145 }
146
147 if name.starts_with("__") && !name.ends_with("__") {
149 sym.visibility = Visibility::Private;
150 } else if name.starts_with('_') {
151 sym.visibility = Visibility::Protected;
152 } else {
153 sym = sym.exported();
154 }
155
156 if let Some(params) = node.child_by_field_name("parameters") {
158 self.extract_parameters(¶ms, source, &mut sym);
159 }
160
161 if let Some(ret_type) = node.child_by_field_name("return_type") {
163 sym.return_type = Some(
164 node_text(&ret_type, source)
165 .trim_start_matches("->")
166 .trim()
167 .to_string(),
168 );
169 }
170
171 sym.doc_comment = self.extract_doc_comment(node, source);
173
174 if let Some(p) = parent {
175 sym = sym.with_parent(p);
176 sym.kind = SymbolKind::Method;
177 }
178
179 sym.signature = Some(self.build_function_signature(node, source));
180
181 Some(sym)
182 }
183
184 fn extract_class(
185 &self,
186 node: &Node,
187 source: &str,
188 parent: Option<&str>,
189 ) -> Option<ExtractedSymbol> {
190 let name_node = node.child_by_field_name("name")?;
191 let name = node_text(&name_node, source).to_string();
192
193 let mut sym = ExtractedSymbol::new(
194 name.clone(),
195 SymbolKind::Class,
196 node.start_position().row + 1,
197 node.end_position().row + 1,
198 )
199 .with_columns(node.start_position().column, node.end_position().column);
200
201 if name.starts_with('_') {
203 sym.visibility = Visibility::Protected;
204 } else {
205 sym = sym.exported();
206 }
207
208 sym.doc_comment = self.extract_doc_comment(node, source);
210
211 if let Some(p) = parent {
212 sym = sym.with_parent(p);
213 }
214
215 Some(sym)
216 }
217
218 fn extract_class_members(
219 &self,
220 body: &Node,
221 source: &str,
222 symbols: &mut Vec<ExtractedSymbol>,
223 class_name: Option<&str>,
224 ) {
225 let mut cursor = body.walk();
226 for child in body.children(&mut cursor) {
227 match child.kind() {
228 "function_definition" => {
229 if let Some(sym) = self.extract_function(&child, source, class_name) {
230 symbols.push(sym);
231 }
232 }
233 "decorated_definition" => {
234 let mut inner_cursor = child.walk();
235 for inner in child.children(&mut inner_cursor) {
236 if inner.kind() == "function_definition" {
237 if let Some(mut sym) = self.extract_function(&inner, source, class_name)
238 {
239 let deco_text = node_text(&child, source);
241 if deco_text.contains("@staticmethod") {
242 sym = sym.static_fn();
243 }
244 symbols.push(sym);
245 }
246 }
247 }
248 }
249 _ => {}
250 }
251 }
252 }
253
254 fn extract_parameters(&self, params: &Node, source: &str, sym: &mut ExtractedSymbol) {
255 let mut cursor = params.walk();
256 for child in params.children(&mut cursor) {
257 match child.kind() {
258 "identifier" => {
259 let name = node_text(&child, source);
260 if name != "self" && name != "cls" {
262 sym.add_parameter(Parameter {
263 name: name.to_string(),
264 type_info: None,
265 default_value: None,
266 is_rest: false,
267 is_optional: false,
268 });
269 }
270 }
271 "typed_parameter" => {
272 let name = child
273 .child_by_field_name("name")
274 .map(|n| node_text(&n, source).to_string())
275 .unwrap_or_default();
276
277 if name != "self" && name != "cls" {
278 let type_info = child
279 .child_by_field_name("type")
280 .map(|n| node_text(&n, source).to_string());
281
282 sym.add_parameter(Parameter {
283 name,
284 type_info,
285 default_value: None,
286 is_rest: false,
287 is_optional: false,
288 });
289 }
290 }
291 "default_parameter" | "typed_default_parameter" => {
292 let name = child
293 .child_by_field_name("name")
294 .map(|n| node_text(&n, source).to_string())
295 .unwrap_or_default();
296
297 if name != "self" && name != "cls" {
298 let type_info = child
299 .child_by_field_name("type")
300 .map(|n| node_text(&n, source).to_string());
301 let default_value = child
302 .child_by_field_name("value")
303 .map(|n| node_text(&n, source).to_string());
304
305 sym.add_parameter(Parameter {
306 name,
307 type_info,
308 default_value,
309 is_rest: false,
310 is_optional: true,
311 });
312 }
313 }
314 "list_splat_pattern" | "dictionary_splat_pattern" => {
315 let text = node_text(&child, source);
316 let name = text.trim_start_matches('*').to_string();
317 let is_kwargs = text.starts_with("**");
318
319 sym.add_parameter(Parameter {
320 name,
321 type_info: None,
322 default_value: None,
323 is_rest: !is_kwargs,
324 is_optional: true,
325 });
326 }
327 _ => {}
328 }
329 }
330 }
331
332 fn extract_imports_recursive(&self, node: &Node, source: &str, imports: &mut Vec<Import>) {
333 match node.kind() {
334 "import_statement" => {
335 if let Some(import) = self.parse_import(node, source) {
336 imports.push(import);
337 }
338 }
339 "import_from_statement" => {
340 if let Some(import) = self.parse_from_import(node, source) {
341 imports.push(import);
342 }
343 }
344 _ => {}
345 }
346
347 let mut cursor = node.walk();
348 for child in node.children(&mut cursor) {
349 self.extract_imports_recursive(&child, source, imports);
350 }
351 }
352
353 fn parse_import(&self, node: &Node, source: &str) -> Option<Import> {
354 let mut import = Import {
355 source: String::new(),
356 names: Vec::new(),
357 is_default: false,
358 is_namespace: false,
359 line: node.start_position().row + 1,
360 };
361
362 let mut cursor = node.walk();
363 for child in node.children(&mut cursor) {
364 match child.kind() {
365 "dotted_name" => {
366 let name = node_text(&child, source).to_string();
367 import.source = name.clone();
368 import.names.push(ImportedName { name, alias: None });
369 }
370 "aliased_import" => {
371 let name = child
372 .child_by_field_name("name")
373 .map(|n| node_text(&n, source).to_string())
374 .unwrap_or_default();
375 let alias = child
376 .child_by_field_name("alias")
377 .map(|n| node_text(&n, source).to_string());
378
379 import.source = name.clone();
380 import.names.push(ImportedName { name, alias });
381 }
382 _ => {}
383 }
384 }
385
386 Some(import)
387 }
388
389 fn parse_from_import(&self, node: &Node, source: &str) -> Option<Import> {
390 let module = node
391 .child_by_field_name("module_name")
392 .map(|n| node_text(&n, source).to_string())
393 .unwrap_or_default();
394
395 let mut import = Import {
396 source: module,
397 names: Vec::new(),
398 is_default: false,
399 is_namespace: false,
400 line: node.start_position().row + 1,
401 };
402
403 let mut cursor = node.walk();
404 for child in node.children(&mut cursor) {
405 match child.kind() {
406 "wildcard_import" => {
407 import.is_namespace = true;
408 import.names.push(ImportedName {
409 name: "*".to_string(),
410 alias: None,
411 });
412 }
413 "dotted_name" | "identifier" => {
414 import.names.push(ImportedName {
415 name: node_text(&child, source).to_string(),
416 alias: None,
417 });
418 }
419 "aliased_import" => {
420 let name = child
421 .child_by_field_name("name")
422 .map(|n| node_text(&n, source).to_string())
423 .unwrap_or_default();
424 let alias = child
425 .child_by_field_name("alias")
426 .map(|n| node_text(&n, source).to_string());
427
428 import.names.push(ImportedName { name, alias });
429 }
430 _ => {}
431 }
432 }
433
434 Some(import)
435 }
436
437 fn extract_calls_recursive(
438 &self,
439 node: &Node,
440 source: &str,
441 calls: &mut Vec<FunctionCall>,
442 current_function: Option<&str>,
443 ) {
444 if node.kind() == "call" {
445 if let Some(call) = self.parse_call(node, source, current_function) {
446 calls.push(call);
447 }
448 }
449
450 let func_name = if node.kind() == "function_definition" {
451 node.child_by_field_name("name")
452 .map(|n| node_text(&n, source))
453 } else {
454 None
455 };
456
457 let current = func_name
458 .map(String::from)
459 .or_else(|| current_function.map(String::from));
460
461 let mut cursor = node.walk();
462 for child in node.children(&mut cursor) {
463 self.extract_calls_recursive(&child, source, calls, current.as_deref());
464 }
465 }
466
467 fn parse_call(
468 &self,
469 node: &Node,
470 source: &str,
471 current_function: Option<&str>,
472 ) -> Option<FunctionCall> {
473 let function = node.child_by_field_name("function")?;
474
475 let (callee, is_method, receiver) = match function.kind() {
476 "attribute" => {
477 let object = function
478 .child_by_field_name("object")
479 .map(|n| node_text(&n, source).to_string());
480 let attr = function
481 .child_by_field_name("attribute")
482 .map(|n| node_text(&n, source).to_string())?;
483 (attr, true, object)
484 }
485 "identifier" => (node_text(&function, source).to_string(), false, None),
486 _ => return None,
487 };
488
489 Some(FunctionCall {
490 caller: current_function.unwrap_or("<module>").to_string(),
491 callee,
492 line: node.start_position().row + 1,
493 is_method,
494 receiver,
495 })
496 }
497
498 fn build_function_signature(&self, node: &Node, source: &str) -> String {
499 let async_kw = if node_text(node, source).starts_with("async") {
500 "async "
501 } else {
502 ""
503 };
504
505 let name = node
506 .child_by_field_name("name")
507 .map(|n| node_text(&n, source))
508 .unwrap_or("unknown");
509
510 let params = node
511 .child_by_field_name("parameters")
512 .map(|n| node_text(&n, source))
513 .unwrap_or("()");
514
515 let return_type = node
516 .child_by_field_name("return_type")
517 .map(|n| format!(" {}", node_text(&n, source)))
518 .unwrap_or_default();
519
520 format!("{}def {}{}{}", async_kw, name, params, return_type)
521 }
522
523 fn clean_docstring(text: &str) -> String {
524 let text = text
526 .trim_start_matches("\"\"\"")
527 .trim_start_matches("'''")
528 .trim_end_matches("\"\"\"")
529 .trim_end_matches("'''")
530 .trim();
531
532 text.to_string()
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 fn parse_py(source: &str) -> (Tree, String) {
541 let mut parser = tree_sitter::Parser::new();
542 parser
543 .set_language(&tree_sitter_python::LANGUAGE.into())
544 .unwrap();
545 let tree = parser.parse(source, None).unwrap();
546 (tree, source.to_string())
547 }
548
549 #[test]
550 fn test_extract_function() {
551 let source = r#"
552def greet(name: str) -> str:
553 """Greet someone."""
554 return f"Hello, {name}!"
555"#;
556 let (tree, src) = parse_py(source);
557 let extractor = PythonExtractor;
558 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
559
560 assert_eq!(symbols.len(), 1);
561 assert_eq!(symbols[0].name, "greet");
562 assert_eq!(symbols[0].kind, SymbolKind::Function);
563 }
564
565 #[test]
566 fn test_extract_class() {
567 let source = r#"
568class UserService:
569 """A service for managing users."""
570
571 def __init__(self, name: str):
572 self.name = name
573
574 def greet(self) -> str:
575 return f"Hello, {self.name}!"
576
577 @staticmethod
578 def create():
579 return UserService("default")
580"#;
581 let (tree, src) = parse_py(source);
582 let extractor = PythonExtractor;
583 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
584
585 assert!(symbols
586 .iter()
587 .any(|s| s.name == "UserService" && s.kind == SymbolKind::Class));
588 assert!(symbols
589 .iter()
590 .any(|s| s.name == "__init__" && s.kind == SymbolKind::Method));
591 assert!(symbols
592 .iter()
593 .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
594 assert!(symbols
595 .iter()
596 .any(|s| s.name == "create" && s.kind == SymbolKind::Method));
597 }
598
599 #[test]
600 fn test_extract_async_function() {
601 let source = r#"
602async def fetch_data(url: str) -> dict:
603 """Fetch data from URL."""
604 pass
605"#;
606 let (tree, src) = parse_py(source);
607 let extractor = PythonExtractor;
608 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
609
610 assert_eq!(symbols.len(), 1);
611 assert_eq!(symbols[0].name, "fetch_data");
612 assert!(symbols[0].is_async);
613 }
614}