1use crate::{KnowledgeError, Result, TypeFact, TypeFactKind};
7use rustpython_ast::{self as ast, Stmt};
8use rustpython_parser::{parse, Mode};
9use std::path::Path;
10use tracing::{debug, warn};
11
12pub struct Extractor {
14 include_private: bool,
16}
17
18impl Default for Extractor {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Extractor {
25 pub fn new() -> Self {
27 Self {
28 include_private: false,
29 }
30 }
31
32 pub fn with_private(mut self) -> Self {
34 self.include_private = true;
35 self
36 }
37
38 pub fn extract_file(&self, path: &Path, module: &str) -> Result<Vec<TypeFact>> {
40 let source = std::fs::read_to_string(path)?;
41 self.extract_source(&source, module, path.to_string_lossy().as_ref())
42 }
43
44 pub fn extract_source(&self, source: &str, module: &str, filename: &str) -> Result<Vec<TypeFact>> {
46 let parsed = parse(source, Mode::Module, filename).map_err(|e| KnowledgeError::StubParseError {
47 file: filename.to_string(),
48 message: e.to_string(),
49 })?;
50
51 let mut facts = Vec::new();
52
53 if let ast::Mod::Module(module_ast) = parsed {
55 for stmt in module_ast.body {
56 self.extract_stmt(&stmt, module, &mut facts);
57 }
58 }
59
60 debug!(
61 module = %module,
62 facts = facts.len(),
63 "Extracted type facts"
64 );
65
66 Ok(facts)
67 }
68
69 fn extract_stmt(&self, stmt: &Stmt, module: &str, facts: &mut Vec<TypeFact>) {
71 match stmt {
72 Stmt::FunctionDef(func) => {
73 if self.should_include(&func.name) {
74 if let Some(fact) = self.extract_function(func, module) {
75 facts.push(fact);
76 }
77 }
78 }
79 Stmt::AsyncFunctionDef(func) => {
80 if self.should_include(&func.name) {
81 if let Some(fact) = self.extract_async_function(func, module) {
82 facts.push(fact);
83 }
84 }
85 }
86 Stmt::ClassDef(class) => {
87 if self.should_include(&class.name) {
88 self.extract_class(class, module, facts);
89 }
90 }
91 Stmt::AnnAssign(assign) => {
92 if let Some(fact) = self.extract_annotated_assign(assign, module) {
93 facts.push(fact);
94 }
95 }
96 _ => {}
97 }
98 }
99
100 fn should_include(&self, name: &str) -> bool {
102 self.include_private || !name.starts_with('_')
103 }
104
105 fn extract_function(&self, func: &ast::StmtFunctionDef, module: &str) -> Option<TypeFact> {
107 let signature = self.build_signature(&func.args, &func.returns);
108 let return_type = self.type_to_string(&func.returns);
109
110 Some(TypeFact {
111 module: module.to_string(),
112 symbol: func.name.to_string(),
113 kind: TypeFactKind::Function,
114 signature,
115 return_type,
116 })
117 }
118
119 fn extract_async_function(&self, func: &ast::StmtAsyncFunctionDef, module: &str) -> Option<TypeFact> {
121 let signature = self.build_signature(&func.args, &func.returns);
122 let return_type = self.type_to_string(&func.returns);
123
124 Some(TypeFact {
125 module: module.to_string(),
126 symbol: func.name.to_string(),
127 kind: TypeFactKind::Function,
128 signature: format!("async {signature}"),
129 return_type,
130 })
131 }
132
133 fn extract_class(&self, class: &ast::StmtClassDef, module: &str, facts: &mut Vec<TypeFact>) {
135 facts.push(TypeFact::class(module, &class.name));
137
138 for stmt in &class.body {
140 match stmt {
141 Stmt::FunctionDef(method) => {
142 if self.should_include(&method.name) {
143 if let Some(fact) = self.extract_method(method, module, &class.name) {
144 facts.push(fact);
145 }
146 }
147 }
148 Stmt::AsyncFunctionDef(method) => {
149 if self.should_include(&method.name) {
150 if let Some(fact) = self.extract_async_method(method, module, &class.name) {
151 facts.push(fact);
152 }
153 }
154 }
155 Stmt::AnnAssign(assign) => {
156 if let Some(fact) = self.extract_class_attribute(assign, module, &class.name) {
157 facts.push(fact);
158 }
159 }
160 _ => {}
161 }
162 }
163 }
164
165 fn extract_method(
167 &self,
168 method: &ast::StmtFunctionDef,
169 module: &str,
170 class_name: &str,
171 ) -> Option<TypeFact> {
172 let signature = self.build_signature(&method.args, &method.returns);
173 let return_type = self.type_to_string(&method.returns);
174
175 Some(TypeFact::method(
176 module,
177 class_name,
178 &method.name,
179 &signature,
180 &return_type,
181 ))
182 }
183
184 fn extract_async_method(
186 &self,
187 method: &ast::StmtAsyncFunctionDef,
188 module: &str,
189 class_name: &str,
190 ) -> Option<TypeFact> {
191 let signature = self.build_signature(&method.args, &method.returns);
192 let return_type = self.type_to_string(&method.returns);
193
194 Some(TypeFact::method(
195 module,
196 class_name,
197 &method.name,
198 &format!("async {signature}"),
199 &return_type,
200 ))
201 }
202
203 fn extract_annotated_assign(&self, assign: &ast::StmtAnnAssign, module: &str) -> Option<TypeFact> {
205 let target = match assign.target.as_ref() {
206 ast::Expr::Name(name) => name.id.to_string(),
207 _ => return None,
208 };
209
210 if !self.should_include(&target) {
211 return None;
212 }
213
214 let type_str = self.expr_to_string(&assign.annotation);
215
216 Some(TypeFact {
217 module: module.to_string(),
218 symbol: target,
219 kind: TypeFactKind::Attribute,
220 signature: String::new(),
221 return_type: type_str,
222 })
223 }
224
225 fn extract_class_attribute(
227 &self,
228 assign: &ast::StmtAnnAssign,
229 module: &str,
230 class_name: &str,
231 ) -> Option<TypeFact> {
232 let target = match assign.target.as_ref() {
233 ast::Expr::Name(name) => name.id.to_string(),
234 _ => return None,
235 };
236
237 if !self.should_include(&target) {
238 return None;
239 }
240
241 let type_str = self.expr_to_string(&assign.annotation);
242
243 Some(TypeFact {
244 module: module.to_string(),
245 symbol: format!("{class_name}.{target}"),
246 kind: TypeFactKind::Attribute,
247 signature: String::new(),
248 return_type: type_str,
249 })
250 }
251
252 fn build_signature(
254 &self,
255 args: &ast::Arguments,
256 returns: &Option<Box<ast::Expr>>,
257 ) -> String {
258 let mut parts = Vec::new();
259
260 for param in &args.posonlyargs {
262 parts.push(self.arg_with_default_to_string(param));
263 }
264
265 if !args.posonlyargs.is_empty() && !args.args.is_empty() {
266 parts.push("/".to_string());
267 }
268
269 for param in &args.args {
271 parts.push(self.arg_with_default_to_string(param));
272 }
273
274 if let Some(vararg) = &args.vararg {
276 parts.push(format!("*{}", self.arg_to_string(vararg)));
277 }
278
279 for param in &args.kwonlyargs {
281 parts.push(self.arg_with_default_to_string(param));
282 }
283
284 if let Some(kwarg) = &args.kwarg {
286 parts.push(format!("**{}", self.arg_to_string(kwarg)));
287 }
288
289 let params_str = parts.join(", ");
290 let return_str = self.type_to_string(returns);
291
292 format!("({params_str}) -> {return_str}")
293 }
294
295 fn arg_with_default_to_string(&self, arg: &ast::ArgWithDefault) -> String {
297 let name = &arg.def.arg;
298 let type_str = arg
299 .def
300 .annotation
301 .as_ref()
302 .map(|a| self.expr_to_string(a))
303 .unwrap_or_default();
304
305 if type_str.is_empty() {
306 if arg.default.is_some() {
307 format!("{name} = ...")
308 } else {
309 name.to_string()
310 }
311 } else if arg.default.is_some() {
312 format!("{name}: {type_str} = ...")
313 } else {
314 format!("{name}: {type_str}")
315 }
316 }
317
318 fn arg_to_string(&self, arg: &ast::Arg) -> String {
320 let name = &arg.arg;
321 let type_str = arg
322 .annotation
323 .as_ref()
324 .map(|a| self.expr_to_string(a))
325 .unwrap_or_default();
326
327 if type_str.is_empty() {
328 name.to_string()
329 } else {
330 format!("{name}: {type_str}")
331 }
332 }
333
334 fn type_to_string(&self, returns: &Option<Box<ast::Expr>>) -> String {
336 match returns {
337 Some(expr) => self.expr_to_string(expr),
338 None => "None".to_string(),
339 }
340 }
341
342 fn expr_to_string(&self, expr: &ast::Expr) -> String {
344 match expr {
345 ast::Expr::Name(name) => name.id.to_string(),
346 ast::Expr::Attribute(attr) => {
347 let value = self.expr_to_string(&attr.value);
348 format!("{value}.{}", attr.attr)
349 }
350 ast::Expr::Subscript(sub) => {
351 let value = self.expr_to_string(&sub.value);
352 let slice = self.expr_to_string(&sub.slice);
353 format!("{value}[{slice}]")
354 }
355 ast::Expr::Tuple(tuple) => {
356 let elts: Vec<_> = tuple.elts.iter().map(|e| self.expr_to_string(e)).collect();
357 elts.join(", ")
358 }
359 ast::Expr::BinOp(binop) => {
360 if matches!(binop.op, ast::Operator::BitOr) {
362 let left = self.expr_to_string(&binop.left);
363 let right = self.expr_to_string(&binop.right);
364 format!("{left} | {right}")
365 } else {
366 "Unknown".to_string()
367 }
368 }
369 ast::Expr::Constant(c) => match &c.value {
370 ast::Constant::None => "None".to_string(),
371 ast::Constant::Str(s) => format!("\"{s}\""),
372 ast::Constant::Int(i) => i.to_string(),
373 ast::Constant::Float(f) => f.to_string(),
374 ast::Constant::Bool(b) => b.to_string(),
375 ast::Constant::Ellipsis => "...".to_string(),
376 _ => "Unknown".to_string(),
377 },
378 ast::Expr::List(list) => {
379 let elts: Vec<_> = list.elts.iter().map(|e| self.expr_to_string(e)).collect();
380 format!("[{}]", elts.join(", "))
381 }
382 _ => {
383 warn!("Unknown expression type in type annotation");
384 "Unknown".to_string()
385 }
386 }
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_extract_simple_function() {
396 let source = r#"
397def get(url: str) -> Response: ...
398"#;
399 let extractor = Extractor::new();
400 let facts = extractor.extract_source(source, "requests", "test.pyi").unwrap();
401
402 assert_eq!(facts.len(), 1);
403 assert_eq!(facts[0].symbol, "get");
404 assert_eq!(facts[0].kind, TypeFactKind::Function);
405 assert!(facts[0].signature.contains("url: str"));
406 assert_eq!(facts[0].return_type, "Response");
407 }
408
409 #[test]
410 fn test_extract_function_with_optional() {
411 let source = r#"
412def get(url: str, params: dict | None = ...) -> Response: ...
413"#;
414 let extractor = Extractor::new();
415 let facts = extractor.extract_source(source, "requests", "test.pyi").unwrap();
416
417 assert_eq!(facts.len(), 1);
418 assert!(facts[0].signature.contains("params: dict | None"));
419 }
420
421 #[test]
422 fn test_extract_class_with_methods() {
423 let source = r#"
424class Response:
425 status_code: int
426 def json(self) -> dict: ...
427 def text(self) -> str: ...
428"#;
429 let extractor = Extractor::new();
430 let facts = extractor.extract_source(source, "requests.models", "test.pyi").unwrap();
431
432 assert_eq!(facts.len(), 4);
434
435 let class_fact = facts.iter().find(|f| f.symbol == "Response").unwrap();
436 assert_eq!(class_fact.kind, TypeFactKind::Class);
437
438 let json_fact = facts.iter().find(|f| f.symbol == "Response.json").unwrap();
439 assert_eq!(json_fact.kind, TypeFactKind::Method);
440 assert_eq!(json_fact.return_type, "dict");
441 }
442
443 #[test]
444 fn test_excludes_private_by_default() {
445 let source = r#"
446def _private(): ...
447def public(): ...
448"#;
449 let extractor = Extractor::new();
450 let facts = extractor.extract_source(source, "test", "test.pyi").unwrap();
451
452 assert_eq!(facts.len(), 1);
453 assert_eq!(facts[0].symbol, "public");
454 }
455
456 #[test]
457 fn test_includes_private_when_enabled() {
458 let source = r#"
459def _private(): ...
460def public(): ...
461"#;
462 let extractor = Extractor::new().with_private();
463 let facts = extractor.extract_source(source, "test", "test.pyi").unwrap();
464
465 assert_eq!(facts.len(), 2);
466 }
467
468 #[test]
469 fn test_extract_kwargs() {
470 let source = r#"
471def get(url: str, **kwargs) -> Response: ...
472"#;
473 let extractor = Extractor::new();
474 let facts = extractor.extract_source(source, "requests", "test.pyi").unwrap();
475
476 assert!(facts[0].signature.contains("**kwargs"));
477 }
478}