depyler_knowledge/
extractor.rs

1//! Extractor: Parse .pyi stub files to extract type facts.
2//!
3//! Uses `rustpython_parser` to parse Python stub files and extract
4//! function signatures, class definitions, and method types.
5
6use crate::{KnowledgeError, Result, TypeFact, TypeFactKind};
7use rustpython_ast::{self as ast, Stmt};
8use rustpython_parser::{parse, Mode};
9use std::path::Path;
10use tracing::{debug, warn};
11
12/// Extractor for parsing Python stub files.
13pub struct Extractor {
14    /// Whether to include private symbols (starting with _)
15    include_private: bool,
16}
17
18impl Default for Extractor {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl Extractor {
25    /// Create a new Extractor.
26    pub fn new() -> Self {
27        Self {
28            include_private: false,
29        }
30    }
31
32    /// Include private symbols (starting with _).
33    pub fn with_private(mut self) -> Self {
34        self.include_private = true;
35        self
36    }
37
38    /// Extract type facts from a single .pyi or .py file.
39    pub fn extract_file(&self, path: &Path, module: &str) -> Result<Vec<TypeFact>> {
40        let source = std::fs::read_to_string(path)?;
41        self.extract_source(&source, module, path.to_string_lossy().as_ref())
42    }
43
44    /// Extract type facts from source code.
45    pub fn extract_source(&self, source: &str, module: &str, filename: &str) -> Result<Vec<TypeFact>> {
46        let parsed = parse(source, Mode::Module, filename).map_err(|e| KnowledgeError::StubParseError {
47            file: filename.to_string(),
48            message: e.to_string(),
49        })?;
50
51        let mut facts = Vec::new();
52
53        // The parsed result is a Module containing statements
54        if let ast::Mod::Module(module_ast) = parsed {
55            for stmt in module_ast.body {
56                self.extract_stmt(&stmt, module, &mut facts);
57            }
58        }
59
60        debug!(
61            module = %module,
62            facts = facts.len(),
63            "Extracted type facts"
64        );
65
66        Ok(facts)
67    }
68
69    /// Extract type facts from a statement.
70    fn extract_stmt(&self, stmt: &Stmt, module: &str, facts: &mut Vec<TypeFact>) {
71        match stmt {
72            Stmt::FunctionDef(func) => {
73                if self.should_include(&func.name) {
74                    if let Some(fact) = self.extract_function(func, module) {
75                        facts.push(fact);
76                    }
77                }
78            }
79            Stmt::AsyncFunctionDef(func) => {
80                if self.should_include(&func.name) {
81                    if let Some(fact) = self.extract_async_function(func, module) {
82                        facts.push(fact);
83                    }
84                }
85            }
86            Stmt::ClassDef(class) => {
87                if self.should_include(&class.name) {
88                    self.extract_class(class, module, facts);
89                }
90            }
91            Stmt::AnnAssign(assign) => {
92                if let Some(fact) = self.extract_annotated_assign(assign, module) {
93                    facts.push(fact);
94                }
95            }
96            _ => {}
97        }
98    }
99
100    /// Check if a symbol should be included based on privacy settings.
101    fn should_include(&self, name: &str) -> bool {
102        self.include_private || !name.starts_with('_')
103    }
104
105    /// Extract a function definition.
106    fn extract_function(&self, func: &ast::StmtFunctionDef, module: &str) -> Option<TypeFact> {
107        let signature = self.build_signature(&func.args, &func.returns);
108        let return_type = self.type_to_string(&func.returns);
109
110        Some(TypeFact {
111            module: module.to_string(),
112            symbol: func.name.to_string(),
113            kind: TypeFactKind::Function,
114            signature,
115            return_type,
116        })
117    }
118
119    /// Extract an async function definition.
120    fn extract_async_function(&self, func: &ast::StmtAsyncFunctionDef, module: &str) -> Option<TypeFact> {
121        let signature = self.build_signature(&func.args, &func.returns);
122        let return_type = self.type_to_string(&func.returns);
123
124        Some(TypeFact {
125            module: module.to_string(),
126            symbol: func.name.to_string(),
127            kind: TypeFactKind::Function,
128            signature: format!("async {signature}"),
129            return_type,
130        })
131    }
132
133    /// Extract a class and its methods.
134    fn extract_class(&self, class: &ast::StmtClassDef, module: &str, facts: &mut Vec<TypeFact>) {
135        // Add the class itself
136        facts.push(TypeFact::class(module, &class.name));
137
138        // Extract methods
139        for stmt in &class.body {
140            match stmt {
141                Stmt::FunctionDef(method) => {
142                    if self.should_include(&method.name) {
143                        if let Some(fact) = self.extract_method(method, module, &class.name) {
144                            facts.push(fact);
145                        }
146                    }
147                }
148                Stmt::AsyncFunctionDef(method) => {
149                    if self.should_include(&method.name) {
150                        if let Some(fact) = self.extract_async_method(method, module, &class.name) {
151                            facts.push(fact);
152                        }
153                    }
154                }
155                Stmt::AnnAssign(assign) => {
156                    if let Some(fact) = self.extract_class_attribute(assign, module, &class.name) {
157                        facts.push(fact);
158                    }
159                }
160                _ => {}
161            }
162        }
163    }
164
165    /// Extract a method from a class.
166    fn extract_method(
167        &self,
168        method: &ast::StmtFunctionDef,
169        module: &str,
170        class_name: &str,
171    ) -> Option<TypeFact> {
172        let signature = self.build_signature(&method.args, &method.returns);
173        let return_type = self.type_to_string(&method.returns);
174
175        Some(TypeFact::method(
176            module,
177            class_name,
178            &method.name,
179            &signature,
180            &return_type,
181        ))
182    }
183
184    /// Extract an async method from a class.
185    fn extract_async_method(
186        &self,
187        method: &ast::StmtAsyncFunctionDef,
188        module: &str,
189        class_name: &str,
190    ) -> Option<TypeFact> {
191        let signature = self.build_signature(&method.args, &method.returns);
192        let return_type = self.type_to_string(&method.returns);
193
194        Some(TypeFact::method(
195            module,
196            class_name,
197            &method.name,
198            &format!("async {signature}"),
199            &return_type,
200        ))
201    }
202
203    /// Extract an annotated assignment (module-level attribute).
204    fn extract_annotated_assign(&self, assign: &ast::StmtAnnAssign, module: &str) -> Option<TypeFact> {
205        let target = match assign.target.as_ref() {
206            ast::Expr::Name(name) => name.id.to_string(),
207            _ => return None,
208        };
209
210        if !self.should_include(&target) {
211            return None;
212        }
213
214        let type_str = self.expr_to_string(&assign.annotation);
215
216        Some(TypeFact {
217            module: module.to_string(),
218            symbol: target,
219            kind: TypeFactKind::Attribute,
220            signature: String::new(),
221            return_type: type_str,
222        })
223    }
224
225    /// Extract a class attribute.
226    fn extract_class_attribute(
227        &self,
228        assign: &ast::StmtAnnAssign,
229        module: &str,
230        class_name: &str,
231    ) -> Option<TypeFact> {
232        let target = match assign.target.as_ref() {
233            ast::Expr::Name(name) => name.id.to_string(),
234            _ => return None,
235        };
236
237        if !self.should_include(&target) {
238            return None;
239        }
240
241        let type_str = self.expr_to_string(&assign.annotation);
242
243        Some(TypeFact {
244            module: module.to_string(),
245            symbol: format!("{class_name}.{target}"),
246            kind: TypeFactKind::Attribute,
247            signature: String::new(),
248            return_type: type_str,
249        })
250    }
251
252    /// Build a signature string from arguments and return type.
253    fn build_signature(
254        &self,
255        args: &ast::Arguments,
256        returns: &Option<Box<ast::Expr>>,
257    ) -> String {
258        let mut parts = Vec::new();
259
260        // Positional-only params
261        for param in &args.posonlyargs {
262            parts.push(self.arg_with_default_to_string(param));
263        }
264
265        if !args.posonlyargs.is_empty() && !args.args.is_empty() {
266            parts.push("/".to_string());
267        }
268
269        // Regular args
270        for param in &args.args {
271            parts.push(self.arg_with_default_to_string(param));
272        }
273
274        // *args
275        if let Some(vararg) = &args.vararg {
276            parts.push(format!("*{}", self.arg_to_string(vararg)));
277        }
278
279        // Keyword-only args
280        for param in &args.kwonlyargs {
281            parts.push(self.arg_with_default_to_string(param));
282        }
283
284        // **kwargs
285        if let Some(kwarg) = &args.kwarg {
286            parts.push(format!("**{}", self.arg_to_string(kwarg)));
287        }
288
289        let params_str = parts.join(", ");
290        let return_str = self.type_to_string(returns);
291
292        format!("({params_str}) -> {return_str}")
293    }
294
295    /// Convert an ArgWithDefault to string.
296    fn arg_with_default_to_string(&self, arg: &ast::ArgWithDefault) -> String {
297        let name = &arg.def.arg;
298        let type_str = arg
299            .def
300            .annotation
301            .as_ref()
302            .map(|a| self.expr_to_string(a))
303            .unwrap_or_default();
304
305        if type_str.is_empty() {
306            if arg.default.is_some() {
307                format!("{name} = ...")
308            } else {
309                name.to_string()
310            }
311        } else if arg.default.is_some() {
312            format!("{name}: {type_str} = ...")
313        } else {
314            format!("{name}: {type_str}")
315        }
316    }
317
318    /// Convert an Arg to string (for vararg/kwarg).
319    fn arg_to_string(&self, arg: &ast::Arg) -> String {
320        let name = &arg.arg;
321        let type_str = arg
322            .annotation
323            .as_ref()
324            .map(|a| self.expr_to_string(a))
325            .unwrap_or_default();
326
327        if type_str.is_empty() {
328            name.to_string()
329        } else {
330            format!("{name}: {type_str}")
331        }
332    }
333
334    /// Convert return type to string.
335    fn type_to_string(&self, returns: &Option<Box<ast::Expr>>) -> String {
336        match returns {
337            Some(expr) => self.expr_to_string(expr),
338            None => "None".to_string(),
339        }
340    }
341
342    /// Convert an expression to a type string.
343    fn expr_to_string(&self, expr: &ast::Expr) -> String {
344        match expr {
345            ast::Expr::Name(name) => name.id.to_string(),
346            ast::Expr::Attribute(attr) => {
347                let value = self.expr_to_string(&attr.value);
348                format!("{value}.{}", attr.attr)
349            }
350            ast::Expr::Subscript(sub) => {
351                let value = self.expr_to_string(&sub.value);
352                let slice = self.expr_to_string(&sub.slice);
353                format!("{value}[{slice}]")
354            }
355            ast::Expr::Tuple(tuple) => {
356                let elts: Vec<_> = tuple.elts.iter().map(|e| self.expr_to_string(e)).collect();
357                elts.join(", ")
358            }
359            ast::Expr::BinOp(binop) => {
360                // Handle Union types written as X | Y
361                if matches!(binop.op, ast::Operator::BitOr) {
362                    let left = self.expr_to_string(&binop.left);
363                    let right = self.expr_to_string(&binop.right);
364                    format!("{left} | {right}")
365                } else {
366                    "Unknown".to_string()
367                }
368            }
369            ast::Expr::Constant(c) => match &c.value {
370                ast::Constant::None => "None".to_string(),
371                ast::Constant::Str(s) => format!("\"{s}\""),
372                ast::Constant::Int(i) => i.to_string(),
373                ast::Constant::Float(f) => f.to_string(),
374                ast::Constant::Bool(b) => b.to_string(),
375                ast::Constant::Ellipsis => "...".to_string(),
376                _ => "Unknown".to_string(),
377            },
378            ast::Expr::List(list) => {
379                let elts: Vec<_> = list.elts.iter().map(|e| self.expr_to_string(e)).collect();
380                format!("[{}]", elts.join(", "))
381            }
382            _ => {
383                warn!("Unknown expression type in type annotation");
384                "Unknown".to_string()
385            }
386        }
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_extract_simple_function() {
396        let source = r#"
397def get(url: str) -> Response: ...
398"#;
399        let extractor = Extractor::new();
400        let facts = extractor.extract_source(source, "requests", "test.pyi").unwrap();
401
402        assert_eq!(facts.len(), 1);
403        assert_eq!(facts[0].symbol, "get");
404        assert_eq!(facts[0].kind, TypeFactKind::Function);
405        assert!(facts[0].signature.contains("url: str"));
406        assert_eq!(facts[0].return_type, "Response");
407    }
408
409    #[test]
410    fn test_extract_function_with_optional() {
411        let source = r#"
412def get(url: str, params: dict | None = ...) -> Response: ...
413"#;
414        let extractor = Extractor::new();
415        let facts = extractor.extract_source(source, "requests", "test.pyi").unwrap();
416
417        assert_eq!(facts.len(), 1);
418        assert!(facts[0].signature.contains("params: dict | None"));
419    }
420
421    #[test]
422    fn test_extract_class_with_methods() {
423        let source = r#"
424class Response:
425    status_code: int
426    def json(self) -> dict: ...
427    def text(self) -> str: ...
428"#;
429        let extractor = Extractor::new();
430        let facts = extractor.extract_source(source, "requests.models", "test.pyi").unwrap();
431
432        // Should have: class, status_code attribute, json method, text method
433        assert_eq!(facts.len(), 4);
434
435        let class_fact = facts.iter().find(|f| f.symbol == "Response").unwrap();
436        assert_eq!(class_fact.kind, TypeFactKind::Class);
437
438        let json_fact = facts.iter().find(|f| f.symbol == "Response.json").unwrap();
439        assert_eq!(json_fact.kind, TypeFactKind::Method);
440        assert_eq!(json_fact.return_type, "dict");
441    }
442
443    #[test]
444    fn test_excludes_private_by_default() {
445        let source = r#"
446def _private(): ...
447def public(): ...
448"#;
449        let extractor = Extractor::new();
450        let facts = extractor.extract_source(source, "test", "test.pyi").unwrap();
451
452        assert_eq!(facts.len(), 1);
453        assert_eq!(facts[0].symbol, "public");
454    }
455
456    #[test]
457    fn test_includes_private_when_enabled() {
458        let source = r#"
459def _private(): ...
460def public(): ...
461"#;
462        let extractor = Extractor::new().with_private();
463        let facts = extractor.extract_source(source, "test", "test.pyi").unwrap();
464
465        assert_eq!(facts.len(), 2);
466    }
467
468    #[test]
469    fn test_extract_kwargs() {
470        let source = r#"
471def get(url: str, **kwargs) -> Response: ...
472"#;
473        let extractor = Extractor::new();
474        let facts = extractor.extract_source(source, "requests", "test.pyi").unwrap();
475
476        assert!(facts[0].signature.contains("**kwargs"));
477    }
478}