1use crate::{KnowledgeError, Result, TypeFact, TypeFactKind};
7use rustpython_ast::{self as ast, Stmt};
8use rustpython_parser::{parse, Mode};
9use std::path::Path;
10use tracing::{debug, warn};
11
12pub struct Extractor {
14 include_private: bool,
16}
17
18impl Default for Extractor {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Extractor {
25 pub fn new() -> Self {
27 Self {
28 include_private: false,
29 }
30 }
31
32 pub fn with_private(mut self) -> Self {
34 self.include_private = true;
35 self
36 }
37
38 pub fn extract_file(&self, path: &Path, module: &str) -> Result<Vec<TypeFact>> {
40 let source = std::fs::read_to_string(path)?;
41 self.extract_source(&source, module, path.to_string_lossy().as_ref())
42 }
43
44 pub fn extract_source(
46 &self,
47 source: &str,
48 module: &str,
49 filename: &str,
50 ) -> Result<Vec<TypeFact>> {
51 let parsed =
52 parse(source, Mode::Module, filename).map_err(|e| KnowledgeError::StubParseError {
53 file: filename.to_string(),
54 message: e.to_string(),
55 })?;
56
57 let mut facts = Vec::new();
58
59 if let ast::Mod::Module(module_ast) = parsed {
61 for stmt in module_ast.body {
62 self.extract_stmt(&stmt, module, &mut facts);
63 }
64 }
65
66 debug!(
67 module = %module,
68 facts = facts.len(),
69 "Extracted type facts"
70 );
71
72 Ok(facts)
73 }
74
75 fn extract_stmt(&self, stmt: &Stmt, module: &str, facts: &mut Vec<TypeFact>) {
77 match stmt {
78 Stmt::FunctionDef(func) => {
79 if self.should_include(&func.name) {
80 if let Some(fact) = self.extract_function(func, module) {
81 facts.push(fact);
82 }
83 }
84 }
85 Stmt::AsyncFunctionDef(func) => {
86 if self.should_include(&func.name) {
87 if let Some(fact) = self.extract_async_function(func, module) {
88 facts.push(fact);
89 }
90 }
91 }
92 Stmt::ClassDef(class) => {
93 if self.should_include(&class.name) {
94 self.extract_class(class, module, facts);
95 }
96 }
97 Stmt::AnnAssign(assign) => {
98 if let Some(fact) = self.extract_annotated_assign(assign, module) {
99 facts.push(fact);
100 }
101 }
102 _ => {}
103 }
104 }
105
106 fn should_include(&self, name: &str) -> bool {
108 self.include_private || !name.starts_with('_')
109 }
110
111 fn extract_function(&self, func: &ast::StmtFunctionDef, module: &str) -> Option<TypeFact> {
113 let signature = self.build_signature(&func.args, &func.returns);
114 let return_type = self.type_to_string(&func.returns);
115
116 Some(TypeFact {
117 module: module.to_string(),
118 symbol: func.name.to_string(),
119 kind: TypeFactKind::Function,
120 signature,
121 return_type,
122 })
123 }
124
125 fn extract_async_function(
127 &self,
128 func: &ast::StmtAsyncFunctionDef,
129 module: &str,
130 ) -> Option<TypeFact> {
131 let signature = self.build_signature(&func.args, &func.returns);
132 let return_type = self.type_to_string(&func.returns);
133
134 Some(TypeFact {
135 module: module.to_string(),
136 symbol: func.name.to_string(),
137 kind: TypeFactKind::Function,
138 signature: format!("async {signature}"),
139 return_type,
140 })
141 }
142
143 fn extract_class(&self, class: &ast::StmtClassDef, module: &str, facts: &mut Vec<TypeFact>) {
145 facts.push(TypeFact::class(module, &class.name));
147
148 for stmt in &class.body {
150 match stmt {
151 Stmt::FunctionDef(method) => {
152 if self.should_include(&method.name) {
153 if let Some(fact) = self.extract_method(method, module, &class.name) {
154 facts.push(fact);
155 }
156 }
157 }
158 Stmt::AsyncFunctionDef(method) => {
159 if self.should_include(&method.name) {
160 if let Some(fact) = self.extract_async_method(method, module, &class.name) {
161 facts.push(fact);
162 }
163 }
164 }
165 Stmt::AnnAssign(assign) => {
166 if let Some(fact) = self.extract_class_attribute(assign, module, &class.name) {
167 facts.push(fact);
168 }
169 }
170 _ => {}
171 }
172 }
173 }
174
175 fn extract_method(
177 &self,
178 method: &ast::StmtFunctionDef,
179 module: &str,
180 class_name: &str,
181 ) -> Option<TypeFact> {
182 let signature = self.build_signature(&method.args, &method.returns);
183 let return_type = self.type_to_string(&method.returns);
184
185 Some(TypeFact::method(
186 module,
187 class_name,
188 &method.name,
189 &signature,
190 &return_type,
191 ))
192 }
193
194 fn extract_async_method(
196 &self,
197 method: &ast::StmtAsyncFunctionDef,
198 module: &str,
199 class_name: &str,
200 ) -> Option<TypeFact> {
201 let signature = self.build_signature(&method.args, &method.returns);
202 let return_type = self.type_to_string(&method.returns);
203
204 Some(TypeFact::method(
205 module,
206 class_name,
207 &method.name,
208 &format!("async {signature}"),
209 &return_type,
210 ))
211 }
212
213 fn extract_annotated_assign(
215 &self,
216 assign: &ast::StmtAnnAssign,
217 module: &str,
218 ) -> Option<TypeFact> {
219 let target = match assign.target.as_ref() {
220 ast::Expr::Name(name) => name.id.to_string(),
221 _ => return None,
222 };
223
224 if !self.should_include(&target) {
225 return None;
226 }
227
228 let type_str = self.expr_to_string(&assign.annotation);
229
230 Some(TypeFact {
231 module: module.to_string(),
232 symbol: target,
233 kind: TypeFactKind::Attribute,
234 signature: String::new(),
235 return_type: type_str,
236 })
237 }
238
239 fn extract_class_attribute(
241 &self,
242 assign: &ast::StmtAnnAssign,
243 module: &str,
244 class_name: &str,
245 ) -> Option<TypeFact> {
246 let target = match assign.target.as_ref() {
247 ast::Expr::Name(name) => name.id.to_string(),
248 _ => return None,
249 };
250
251 if !self.should_include(&target) {
252 return None;
253 }
254
255 let type_str = self.expr_to_string(&assign.annotation);
256
257 Some(TypeFact {
258 module: module.to_string(),
259 symbol: format!("{class_name}.{target}"),
260 kind: TypeFactKind::Attribute,
261 signature: String::new(),
262 return_type: type_str,
263 })
264 }
265
266 fn build_signature(&self, args: &ast::Arguments, returns: &Option<Box<ast::Expr>>) -> String {
268 let mut parts = Vec::new();
269
270 for param in &args.posonlyargs {
272 parts.push(self.arg_with_default_to_string(param));
273 }
274
275 if !args.posonlyargs.is_empty() && !args.args.is_empty() {
276 parts.push("/".to_string());
277 }
278
279 for param in &args.args {
281 parts.push(self.arg_with_default_to_string(param));
282 }
283
284 if let Some(vararg) = &args.vararg {
286 parts.push(format!("*{}", self.arg_to_string(vararg)));
287 }
288
289 for param in &args.kwonlyargs {
291 parts.push(self.arg_with_default_to_string(param));
292 }
293
294 if let Some(kwarg) = &args.kwarg {
296 parts.push(format!("**{}", self.arg_to_string(kwarg)));
297 }
298
299 let params_str = parts.join(", ");
300 let return_str = self.type_to_string(returns);
301
302 format!("({params_str}) -> {return_str}")
303 }
304
305 fn arg_with_default_to_string(&self, arg: &ast::ArgWithDefault) -> String {
307 let name = &arg.def.arg;
308 let type_str = arg
309 .def
310 .annotation
311 .as_ref()
312 .map(|a| self.expr_to_string(a))
313 .unwrap_or_default();
314
315 if type_str.is_empty() {
316 if arg.default.is_some() {
317 format!("{name} = ...")
318 } else {
319 name.to_string()
320 }
321 } else if arg.default.is_some() {
322 format!("{name}: {type_str} = ...")
323 } else {
324 format!("{name}: {type_str}")
325 }
326 }
327
328 fn arg_to_string(&self, arg: &ast::Arg) -> String {
330 let name = &arg.arg;
331 let type_str = arg
332 .annotation
333 .as_ref()
334 .map(|a| self.expr_to_string(a))
335 .unwrap_or_default();
336
337 if type_str.is_empty() {
338 name.to_string()
339 } else {
340 format!("{name}: {type_str}")
341 }
342 }
343
344 fn type_to_string(&self, returns: &Option<Box<ast::Expr>>) -> String {
346 match returns {
347 Some(expr) => self.expr_to_string(expr),
348 None => "None".to_string(),
349 }
350 }
351
352 fn expr_to_string(&self, expr: &ast::Expr) -> String {
354 match expr {
355 ast::Expr::Name(name) => name.id.to_string(),
356 ast::Expr::Attribute(attr) => {
357 let value = self.expr_to_string(&attr.value);
358 format!("{value}.{}", attr.attr)
359 }
360 ast::Expr::Subscript(sub) => {
361 let value = self.expr_to_string(&sub.value);
362 let slice = self.expr_to_string(&sub.slice);
363 format!("{value}[{slice}]")
364 }
365 ast::Expr::Tuple(tuple) => {
366 let elts: Vec<_> = tuple.elts.iter().map(|e| self.expr_to_string(e)).collect();
367 elts.join(", ")
368 }
369 ast::Expr::BinOp(binop) => {
370 if matches!(binop.op, ast::Operator::BitOr) {
372 let left = self.expr_to_string(&binop.left);
373 let right = self.expr_to_string(&binop.right);
374 format!("{left} | {right}")
375 } else {
376 "Unknown".to_string()
377 }
378 }
379 ast::Expr::Constant(c) => match &c.value {
380 ast::Constant::None => "None".to_string(),
381 ast::Constant::Str(s) => format!("\"{s}\""),
382 ast::Constant::Int(i) => i.to_string(),
383 ast::Constant::Float(f) => f.to_string(),
384 ast::Constant::Bool(b) => b.to_string(),
385 ast::Constant::Ellipsis => "...".to_string(),
386 _ => "Unknown".to_string(),
387 },
388 ast::Expr::List(list) => {
389 let elts: Vec<_> = list.elts.iter().map(|e| self.expr_to_string(e)).collect();
390 format!("[{}]", elts.join(", "))
391 }
392 _ => {
393 warn!("Unknown expression type in type annotation");
394 "Unknown".to_string()
395 }
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_extract_simple_function() {
406 let source = r#"
407def get(url: str) -> Response: ...
408"#;
409 let extractor = Extractor::new();
410 let facts = extractor
411 .extract_source(source, "requests", "test.pyi")
412 .unwrap();
413
414 assert_eq!(facts.len(), 1);
415 assert_eq!(facts[0].symbol, "get");
416 assert_eq!(facts[0].kind, TypeFactKind::Function);
417 assert!(facts[0].signature.contains("url: str"));
418 assert_eq!(facts[0].return_type, "Response");
419 }
420
421 #[test]
422 fn test_extract_function_with_optional() {
423 let source = r#"
424def get(url: str, params: dict | None = ...) -> Response: ...
425"#;
426 let extractor = Extractor::new();
427 let facts = extractor
428 .extract_source(source, "requests", "test.pyi")
429 .unwrap();
430
431 assert_eq!(facts.len(), 1);
432 assert!(facts[0].signature.contains("params: dict | None"));
433 }
434
435 #[test]
436 fn test_extract_class_with_methods() {
437 let source = r#"
438class Response:
439 status_code: int
440 def json(self) -> dict: ...
441 def text(self) -> str: ...
442"#;
443 let extractor = Extractor::new();
444 let facts = extractor
445 .extract_source(source, "requests.models", "test.pyi")
446 .unwrap();
447
448 assert_eq!(facts.len(), 4);
450
451 let class_fact = facts.iter().find(|f| f.symbol == "Response").unwrap();
452 assert_eq!(class_fact.kind, TypeFactKind::Class);
453
454 let json_fact = facts.iter().find(|f| f.symbol == "Response.json").unwrap();
455 assert_eq!(json_fact.kind, TypeFactKind::Method);
456 assert_eq!(json_fact.return_type, "dict");
457 }
458
459 #[test]
460 fn test_excludes_private_by_default() {
461 let source = r#"
462def _private(): ...
463def public(): ...
464"#;
465 let extractor = Extractor::new();
466 let facts = extractor
467 .extract_source(source, "test", "test.pyi")
468 .unwrap();
469
470 assert_eq!(facts.len(), 1);
471 assert_eq!(facts[0].symbol, "public");
472 }
473
474 #[test]
475 fn test_includes_private_when_enabled() {
476 let source = r#"
477def _private(): ...
478def public(): ...
479"#;
480 let extractor = Extractor::new().with_private();
481 let facts = extractor
482 .extract_source(source, "test", "test.pyi")
483 .unwrap();
484
485 assert_eq!(facts.len(), 2);
486 }
487
488 #[test]
489 fn test_extract_kwargs() {
490 let source = r#"
491def get(url: str, **kwargs) -> Response: ...
492"#;
493 let extractor = Extractor::new();
494 let facts = extractor
495 .extract_source(source, "requests", "test.pyi")
496 .unwrap();
497
498 assert!(facts[0].signature.contains("**kwargs"));
499 }
500
501 #[test]
502 fn test_extractor_default() {
503 let extractor = Extractor::default();
504 let source = "def _hidden(): ...\ndef visible(): ...";
506 let facts = extractor
507 .extract_source(source, "test", "test.pyi")
508 .unwrap();
509 assert_eq!(facts.len(), 1);
510 assert_eq!(facts[0].symbol, "visible");
511 }
512
513 #[test]
514 fn test_extract_async_function() {
515 let source = r#"
516async def fetch(url: str) -> bytes: ...
517"#;
518 let extractor = Extractor::new();
519 let facts = extractor
520 .extract_source(source, "aiohttp", "test.pyi")
521 .unwrap();
522
523 assert_eq!(facts.len(), 1);
524 assert_eq!(facts[0].symbol, "fetch");
525 assert_eq!(facts[0].kind, TypeFactKind::Function);
526 assert!(facts[0].signature.starts_with("async "));
527 assert_eq!(facts[0].return_type, "bytes");
528 }
529
530 #[test]
531 fn test_extract_function_no_return_type() {
532 let source = r#"
533def setup(): ...
534"#;
535 let extractor = Extractor::new();
536 let facts = extractor
537 .extract_source(source, "test", "test.pyi")
538 .unwrap();
539
540 assert_eq!(facts.len(), 1);
541 assert_eq!(facts[0].return_type, "None");
542 }
543
544 #[test]
545 fn test_extract_annotated_assign() {
546 let source = r#"
547VERSION: str
548DEBUG: bool
549"#;
550 let extractor = Extractor::new();
551 let facts = extractor
552 .extract_source(source, "config", "test.pyi")
553 .unwrap();
554
555 assert_eq!(facts.len(), 2);
556 assert_eq!(facts[0].symbol, "VERSION");
557 assert_eq!(facts[0].kind, TypeFactKind::Attribute);
558 assert_eq!(facts[0].return_type, "str");
559 assert_eq!(facts[1].symbol, "DEBUG");
560 assert_eq!(facts[1].return_type, "bool");
561 }
562
563 #[test]
564 fn test_extract_private_annotated_assign_excluded() {
565 let source = r#"
566_internal: int
567public: str
568"#;
569 let extractor = Extractor::new();
570 let facts = extractor
571 .extract_source(source, "test", "test.pyi")
572 .unwrap();
573
574 assert_eq!(facts.len(), 1);
575 assert_eq!(facts[0].symbol, "public");
576 }
577
578 #[test]
579 fn test_extract_generic_type() {
580 let source = r#"
581def values() -> List[int]: ...
582"#;
583 let extractor = Extractor::new();
584 let facts = extractor
585 .extract_source(source, "test", "test.pyi")
586 .unwrap();
587
588 assert_eq!(facts.len(), 1);
589 assert_eq!(facts[0].return_type, "List[int]");
590 }
591
592 #[test]
593 fn test_extract_nested_generic_type() {
594 let source = r#"
595def nested() -> Dict[str, List[int]]: ...
596"#;
597 let extractor = Extractor::new();
598 let facts = extractor
599 .extract_source(source, "test", "test.pyi")
600 .unwrap();
601
602 assert_eq!(facts.len(), 1);
603 assert_eq!(facts[0].return_type, "Dict[str, List[int]]");
604 }
605
606 #[test]
607 fn test_extract_union_type_pipe() {
608 let source = r#"
609def maybe(x: int) -> str | None: ...
610"#;
611 let extractor = Extractor::new();
612 let facts = extractor
613 .extract_source(source, "test", "test.pyi")
614 .unwrap();
615
616 assert_eq!(facts.len(), 1);
617 assert_eq!(facts[0].return_type, "str | None");
618 }
619
620 #[test]
621 fn test_extract_varargs() {
622 let source = r#"
623def concat(*args: str) -> str: ...
624"#;
625 let extractor = Extractor::new();
626 let facts = extractor
627 .extract_source(source, "test", "test.pyi")
628 .unwrap();
629
630 assert_eq!(facts.len(), 1);
631 assert!(facts[0].signature.contains("*args: str"));
632 }
633
634 #[test]
635 fn test_extract_class_attribute() {
636 let source = r#"
637class Config:
638 timeout: int
639 name: str
640"#;
641 let extractor = Extractor::new();
642 let facts = extractor
643 .extract_source(source, "app", "test.pyi")
644 .unwrap();
645
646 assert_eq!(facts.len(), 3);
648
649 let timeout = facts
650 .iter()
651 .find(|f| f.symbol == "Config.timeout")
652 .unwrap();
653 assert_eq!(timeout.kind, TypeFactKind::Attribute);
654 assert_eq!(timeout.return_type, "int");
655 }
656
657 #[test]
658 fn test_extract_class_private_method_excluded() {
659 let source = r#"
660class MyClass:
661 def _private(self) -> None: ...
662 def public(self) -> int: ...
663"#;
664 let extractor = Extractor::new();
665 let facts = extractor
666 .extract_source(source, "test", "test.pyi")
667 .unwrap();
668
669 assert_eq!(facts.len(), 2);
671 let method = facts
672 .iter()
673 .find(|f| f.kind == TypeFactKind::Method)
674 .unwrap();
675 assert_eq!(method.symbol, "MyClass.public");
676 }
677
678 #[test]
679 fn test_extract_class_private_method_included() {
680 let source = r#"
681class MyClass:
682 def _private(self) -> None: ...
683 def public(self) -> int: ...
684"#;
685 let extractor = Extractor::new().with_private();
686 let facts = extractor
687 .extract_source(source, "test", "test.pyi")
688 .unwrap();
689
690 assert_eq!(facts.len(), 3);
692 }
693
694 #[test]
695 fn test_extract_async_method() {
696 let source = r#"
697class Client:
698 async def fetch(self, url: str) -> bytes: ...
699"#;
700 let extractor = Extractor::new();
701 let facts = extractor
702 .extract_source(source, "http", "test.pyi")
703 .unwrap();
704
705 assert_eq!(facts.len(), 2);
707 let method = facts
708 .iter()
709 .find(|f| f.kind == TypeFactKind::Method)
710 .unwrap();
711 assert_eq!(method.symbol, "Client.fetch");
712 assert!(method.signature.starts_with("async "));
713 }
714
715 #[test]
716 fn test_extract_invalid_syntax() {
717 let source = "def invalid syntax here %%%: ...";
718 let extractor = Extractor::new();
719 let result = extractor.extract_source(source, "test", "test.pyi");
720 assert!(result.is_err());
721 }
722
723 #[test]
724 fn test_extract_empty_source() {
725 let source = "";
726 let extractor = Extractor::new();
727 let facts = extractor
728 .extract_source(source, "empty", "test.pyi")
729 .unwrap();
730 assert!(facts.is_empty());
731 }
732
733 #[test]
734 fn test_extract_multiple_functions() {
735 let source = r#"
736def add(a: int, b: int) -> int: ...
737def sub(a: int, b: int) -> int: ...
738def mul(a: int, b: int) -> int: ...
739"#;
740 let extractor = Extractor::new();
741 let facts = extractor
742 .extract_source(source, "math_ops", "test.pyi")
743 .unwrap();
744
745 assert_eq!(facts.len(), 3);
746 let symbols: Vec<&str> = facts.iter().map(|f| f.symbol.as_str()).collect();
747 assert!(symbols.contains(&"add"));
748 assert!(symbols.contains(&"sub"));
749 assert!(symbols.contains(&"mul"));
750 }
751
752 #[test]
753 fn test_extract_function_with_default_no_type() {
754 let source = r#"
755def func(x, y = ...): ...
756"#;
757 let extractor = Extractor::new();
758 let facts = extractor
759 .extract_source(source, "test", "test.pyi")
760 .unwrap();
761
762 assert_eq!(facts.len(), 1);
763 assert!(facts[0].signature.contains("y = ..."));
764 }
765
766 #[test]
767 fn test_extract_dotted_return_type() {
768 let source = r#"
769def connect() -> http.client.HTTPConnection: ...
770"#;
771 let extractor = Extractor::new();
772 let facts = extractor
773 .extract_source(source, "test", "test.pyi")
774 .unwrap();
775
776 assert_eq!(facts.len(), 1);
777 assert_eq!(facts[0].return_type, "http.client.HTTPConnection");
778 }
779
780 #[test]
785 fn test_s9b7_extract_posonly_args() {
786 let source = r#"
787def func(a: int, b: int, /, c: int) -> int: ...
788"#;
789 let extractor = Extractor::new();
790 let facts = extractor
791 .extract_source(source, "test", "test.pyi")
792 .unwrap();
793 assert_eq!(facts.len(), 1);
794 assert!(facts[0].signature.contains("/"));
796 }
797
798 #[test]
799 fn test_s9b7_extract_kwonly_args() {
800 let source = r#"
801def func(*, key: str) -> None: ...
802"#;
803 let extractor = Extractor::new();
804 let facts = extractor
805 .extract_source(source, "test", "test.pyi")
806 .unwrap();
807 assert_eq!(facts.len(), 1);
808 assert!(facts[0].signature.contains("key: str"));
809 }
810
811 #[test]
812 fn test_s9b7_extract_constant_type_in_annotation() {
813 let source = r#"
814def func() -> None: ...
815"#;
816 let extractor = Extractor::new();
817 let facts = extractor
818 .extract_source(source, "test", "test.pyi")
819 .unwrap();
820 assert_eq!(facts[0].return_type, "None");
821 }
822
823 #[test]
824 fn test_s9b7_extract_list_type_annotation() {
825 let source = r#"
826def func() -> [int, str]: ...
827"#;
828 let extractor = Extractor::new();
829 let facts = extractor
830 .extract_source(source, "test", "test.pyi")
831 .unwrap();
832 assert_eq!(facts.len(), 1);
833 assert_eq!(facts[0].return_type, "[int, str]");
834 }
835
836 #[test]
837 fn test_s9b7_extract_class_private_attribute_excluded() {
838 let source = r#"
839class MyClass:
840 _private: int
841 public: str
842"#;
843 let extractor = Extractor::new();
844 let facts = extractor
845 .extract_source(source, "test", "test.pyi")
846 .unwrap();
847 assert_eq!(facts.len(), 2);
849 assert!(!facts.iter().any(|f| f.symbol.contains("_private")));
850 }
851
852 #[test]
853 fn test_s9b7_extract_class_private_attribute_included() {
854 let source = r#"
855class MyClass:
856 _private: int
857 public: str
858"#;
859 let extractor = Extractor::new().with_private();
860 let facts = extractor
861 .extract_source(source, "test", "test.pyi")
862 .unwrap();
863 assert_eq!(facts.len(), 3);
865 }
866
867 #[test]
868 fn test_s9b7_extract_async_method_in_class() {
869 let source = r#"
870class Service:
871 async def process(self, data: bytes) -> str: ...
872"#;
873 let extractor = Extractor::new();
874 let facts = extractor
875 .extract_source(source, "svc", "test.pyi")
876 .unwrap();
877 assert_eq!(facts.len(), 2); let method = facts.iter().find(|f| f.kind == TypeFactKind::Method).unwrap();
879 assert!(method.signature.starts_with("async "));
880 assert_eq!(method.return_type, "str");
881 }
882
883 #[test]
884 fn test_s9b7_extract_function_with_typed_default() {
885 let source = r#"
886def func(x: int = ..., y: str = ...) -> None: ...
887"#;
888 let extractor = Extractor::new();
889 let facts = extractor
890 .extract_source(source, "test", "test.pyi")
891 .unwrap();
892 assert_eq!(facts.len(), 1);
893 assert!(facts[0].signature.contains("x: int = ..."));
894 assert!(facts[0].signature.contains("y: str = ..."));
895 }
896
897 #[test]
898 fn test_s9b7_extractor_with_private_flag() {
899 let extractor = Extractor::new().with_private();
900 let source = "_hidden_func = 1\n";
901 let _ = extractor.extract_source(source, "test", "test.pyi");
903 }
904
905 #[test]
906 fn test_s9b7_extract_multiple_classes() {
907 let source = r#"
908class A:
909 def method_a(self) -> int: ...
910
911class B:
912 def method_b(self) -> str: ...
913"#;
914 let extractor = Extractor::new();
915 let facts = extractor
916 .extract_source(source, "test", "test.pyi")
917 .unwrap();
918 assert_eq!(facts.len(), 4);
920 }
921
922 #[test]
923 fn test_extract_private_class_excluded() {
924 let source = r#"
925class _Internal: ...
926class Public: ...
927"#;
928 let extractor = Extractor::new();
929 let facts = extractor
930 .extract_source(source, "test", "test.pyi")
931 .unwrap();
932
933 assert_eq!(facts.len(), 1);
934 assert_eq!(facts[0].symbol, "Public");
935 }
936
937 #[test]
943 fn test_s11_expr_to_string_constant_int_in_annotation() {
944 let source = r#"
946x: 42
947"#;
948 let extractor = Extractor::new();
949 let facts = extractor
950 .extract_source(source, "test", "test.pyi")
951 .unwrap();
952 assert_eq!(facts.len(), 1);
953 assert_eq!(facts[0].return_type, "42");
954 }
955
956 #[test]
957 fn test_s11_expr_to_string_constant_float_in_annotation() {
958 let source = r#"
960x: 3.5
961"#;
962 let extractor = Extractor::new();
963 let facts = extractor
964 .extract_source(source, "test", "test.pyi")
965 .unwrap();
966 assert_eq!(facts.len(), 1);
967 assert_eq!(facts[0].return_type, "3.5");
968 }
969
970 #[test]
971 fn test_s11_expr_to_string_constant_bool_true() {
972 let source = r#"
973x: True
974"#;
975 let extractor = Extractor::new();
976 let facts = extractor
977 .extract_source(source, "test", "test.pyi")
978 .unwrap();
979 assert_eq!(facts.len(), 1);
980 assert_eq!(facts[0].return_type, "true");
981 }
982
983 #[test]
984 fn test_s11_expr_to_string_constant_bool_false() {
985 let source = r#"
986x: False
987"#;
988 let extractor = Extractor::new();
989 let facts = extractor
990 .extract_source(source, "test", "test.pyi")
991 .unwrap();
992 assert_eq!(facts.len(), 1);
993 assert_eq!(facts[0].return_type, "false");
994 }
995
996 #[test]
997 fn test_s11_expr_to_string_constant_ellipsis() {
998 let source = r#"
999x: ...
1000"#;
1001 let extractor = Extractor::new();
1002 let facts = extractor
1003 .extract_source(source, "test", "test.pyi")
1004 .unwrap();
1005 assert_eq!(facts.len(), 1);
1006 assert_eq!(facts[0].return_type, "...");
1007 }
1008
1009 #[test]
1010 fn test_s11_expr_to_string_constant_string_literal() {
1011 let source = "x: \"hello\"\n";
1012 let extractor = Extractor::new();
1013 let facts = extractor
1014 .extract_source(source, "test", "test.pyi")
1015 .unwrap();
1016 assert_eq!(facts.len(), 1);
1017 assert_eq!(facts[0].return_type, "\"hello\"");
1018 }
1019
1020 #[test]
1021 fn test_s11_expr_to_string_constant_none() {
1022 let source = r#"
1023x: None
1024"#;
1025 let extractor = Extractor::new();
1026 let facts = extractor
1027 .extract_source(source, "test", "test.pyi")
1028 .unwrap();
1029 assert_eq!(facts.len(), 1);
1030 assert_eq!(facts[0].return_type, "None");
1031 }
1032
1033 #[test]
1034 fn test_s11_expr_to_string_non_bitor_binop() {
1035 let source = r#"
1037x: int + str
1038"#;
1039 let extractor = Extractor::new();
1040 let facts = extractor
1041 .extract_source(source, "test", "test.pyi")
1042 .unwrap();
1043 assert_eq!(facts.len(), 1);
1044 assert_eq!(facts[0].return_type, "Unknown");
1045 }
1046
1047 #[test]
1048 fn test_s11_extract_tuple_return_type() {
1049 let source = r#"
1050def func() -> (int, str, float): ...
1051"#;
1052 let extractor = Extractor::new();
1053 let facts = extractor
1054 .extract_source(source, "test", "test.pyi")
1055 .unwrap();
1056 assert_eq!(facts.len(), 1);
1057 assert_eq!(facts[0].return_type, "int, str, float");
1058 }
1059
1060 #[test]
1061 fn test_s11_extract_deeply_nested_subscript() {
1062 let source = r#"
1063def func() -> Dict[str, List[Optional[int]]]: ...
1064"#;
1065 let extractor = Extractor::new();
1066 let facts = extractor
1067 .extract_source(source, "test", "test.pyi")
1068 .unwrap();
1069 assert_eq!(facts.len(), 1);
1070 assert_eq!(facts[0].return_type, "Dict[str, List[Optional[int]]]");
1071 }
1072
1073 #[test]
1074 fn test_s11_extract_attribute_type_nested() {
1075 let source = r#"
1076x: collections.abc.Mapping
1077"#;
1078 let extractor = Extractor::new();
1079 let facts = extractor
1080 .extract_source(source, "test", "test.pyi")
1081 .unwrap();
1082 assert_eq!(facts.len(), 1);
1083 assert_eq!(facts[0].return_type, "collections.abc.Mapping");
1084 }
1085
1086 #[test]
1087 fn test_s11_extract_union_pipe_chained() {
1088 let source = r#"
1089def func() -> int | str | None: ...
1090"#;
1091 let extractor = Extractor::new();
1092 let facts = extractor
1093 .extract_source(source, "test", "test.pyi")
1094 .unwrap();
1095 assert_eq!(facts.len(), 1);
1096 assert!(facts[0].return_type.contains("int | str | None"));
1098 }
1099
1100 #[test]
1101 fn test_s11_extract_posonly_with_kwonly_combined() {
1102 let source = r#"
1103def func(a: int, b: int, /, c: int, *, d: str) -> None: ...
1104"#;
1105 let extractor = Extractor::new();
1106 let facts = extractor
1107 .extract_source(source, "test", "test.pyi")
1108 .unwrap();
1109 assert_eq!(facts.len(), 1);
1110 let sig = &facts[0].signature;
1111 assert!(sig.contains("a: int"));
1112 assert!(sig.contains("/"));
1113 assert!(sig.contains("c: int"));
1114 assert!(sig.contains("d: str"));
1115 }
1116
1117 #[test]
1118 fn test_s11_extract_class_with_async_private_method_included() {
1119 let source = r#"
1120class Service:
1121 async def _internal(self) -> None: ...
1122 async def public(self) -> str: ...
1123"#;
1124 let extractor = Extractor::new().with_private();
1125 let facts = extractor
1126 .extract_source(source, "svc", "test.pyi")
1127 .unwrap();
1128 assert_eq!(facts.len(), 3);
1130 assert!(facts.iter().any(|f| f.symbol == "Service._internal"));
1131 assert!(facts.iter().any(|f| f.symbol == "Service.public"));
1132 }
1133
1134 #[test]
1135 fn test_s11_extract_annotated_assign_private_excluded() {
1136 let source = r#"
1137_PRIVATE_CONST: int
1138PUBLIC_CONST: str
1139"#;
1140 let extractor = Extractor::new();
1141 let facts = extractor
1142 .extract_source(source, "config", "test.pyi")
1143 .unwrap();
1144 assert_eq!(facts.len(), 1);
1145 assert_eq!(facts[0].symbol, "PUBLIC_CONST");
1146 }
1147
1148 #[test]
1149 fn test_s11_extract_annotated_assign_private_included() {
1150 let source = r#"
1151_PRIVATE_CONST: int
1152PUBLIC_CONST: str
1153"#;
1154 let extractor = Extractor::new().with_private();
1155 let facts = extractor
1156 .extract_source(source, "config", "test.pyi")
1157 .unwrap();
1158 assert_eq!(facts.len(), 2);
1159 }
1160
1161 #[test]
1162 fn test_s11_extract_function_only_varargs() {
1163 let source = r#"
1164def func(*args: int, **kwargs: str) -> None: ...
1165"#;
1166 let extractor = Extractor::new();
1167 let facts = extractor
1168 .extract_source(source, "test", "test.pyi")
1169 .unwrap();
1170 assert_eq!(facts.len(), 1);
1171 let sig = &facts[0].signature;
1172 assert!(sig.contains("*args: int"));
1173 assert!(sig.contains("**kwargs: str"));
1174 }
1175
1176 #[test]
1177 fn test_s11_extract_function_untyped_vararg() {
1178 let source = r#"
1179def func(*args) -> None: ...
1180"#;
1181 let extractor = Extractor::new();
1182 let facts = extractor
1183 .extract_source(source, "test", "test.pyi")
1184 .unwrap();
1185 assert_eq!(facts.len(), 1);
1186 assert!(facts[0].signature.contains("*args"));
1187 }
1188
1189 #[test]
1190 fn test_s11_extract_function_untyped_kwarg() {
1191 let source = r#"
1192def func(**kw) -> None: ...
1193"#;
1194 let extractor = Extractor::new();
1195 let facts = extractor
1196 .extract_source(source, "test", "test.pyi")
1197 .unwrap();
1198 assert_eq!(facts.len(), 1);
1199 assert!(facts[0].signature.contains("**kw"));
1200 }
1201
1202 #[test]
1203 fn test_s11_extract_class_with_mixed_statement_types() {
1204 let source = r#"
1206class MyClass:
1207 name: str
1208 async def fetch(self) -> bytes: ...
1209 def process(self) -> int: ...
1210 x = 42
1211"#;
1212 let extractor = Extractor::new();
1213 let facts = extractor
1214 .extract_source(source, "test", "test.pyi")
1215 .unwrap();
1216 assert_eq!(facts.len(), 4);
1219 }
1220
1221 #[test]
1222 fn test_s11_extract_source_with_ignored_statement_types() {
1223 let source = r#"
1225import os
1226from typing import List
1227def real_func() -> int: ...
1228"#;
1229 let extractor = Extractor::new();
1230 let facts = extractor
1231 .extract_source(source, "test", "test.pyi")
1232 .unwrap();
1233 assert_eq!(facts.len(), 1);
1234 assert_eq!(facts[0].symbol, "real_func");
1235 }
1236
1237 #[test]
1238 fn test_s11_extract_file_nonexistent_path() {
1239 let extractor = Extractor::new();
1240 let result = extractor.extract_file(
1241 Path::new("/nonexistent/path/to/file.pyi"),
1242 "nonexistent",
1243 );
1244 assert!(result.is_err());
1245 }
1246
1247 #[test]
1248 fn test_s11_extract_list_type_with_nested_elements() {
1249 let source = r#"
1250def func() -> [List[int], Dict[str, str]]: ...
1251"#;
1252 let extractor = Extractor::new();
1253 let facts = extractor
1254 .extract_source(source, "test", "test.pyi")
1255 .unwrap();
1256 assert_eq!(facts.len(), 1);
1257 assert!(facts[0].return_type.starts_with('['));
1258 assert!(facts[0].return_type.contains("List[int]"));
1259 }
1260
1261 #[test]
1266 fn test_s12_extract_function_no_params() {
1267 let source = r#"
1269def no_args() -> int: ...
1270"#;
1271 let extractor = Extractor::new();
1272 let facts = extractor
1273 .extract_source(source, "test", "test.pyi")
1274 .unwrap();
1275 assert_eq!(facts.len(), 1);
1276 assert!(facts[0].signature.contains("()"), "Expected empty params, got: {}", facts[0].signature);
1277 }
1278
1279 #[test]
1280 fn test_s12_extract_function_no_return_type() {
1281 let source = r#"
1283def no_return(x: int): ...
1284"#;
1285 let extractor = Extractor::new();
1286 let facts = extractor
1287 .extract_source(source, "test", "test.pyi")
1288 .unwrap();
1289 assert_eq!(facts.len(), 1);
1290 assert_eq!(facts[0].return_type, "None");
1291 }
1292
1293 #[test]
1294 fn test_s12_extract_method_self_excluded_from_sig() {
1295 let source = r#"
1297class Foo:
1298 def bar(self, x: int) -> str: ...
1299"#;
1300 let extractor = Extractor::new();
1301 let facts = extractor
1302 .extract_source(source, "test", "test.pyi")
1303 .unwrap();
1304 let method = facts.iter().find(|f| f.kind == TypeFactKind::Method).unwrap();
1305 assert!(method.signature.contains("x: int"));
1307 }
1308
1309 #[test]
1310 fn test_s12_extract_class_attribute_fqn() {
1311 let source = r#"
1312class Config:
1313 debug: bool
1314"#;
1315 let extractor = Extractor::new();
1316 let facts = extractor
1317 .extract_source(source, "test", "test.pyi")
1318 .unwrap();
1319 let attr = facts.iter().find(|f| f.kind == TypeFactKind::Attribute).unwrap();
1320 assert_eq!(attr.symbol, "Config.debug");
1321 assert_eq!(attr.return_type, "bool");
1322 }
1323
1324 #[test]
1325 fn test_s12_extract_module_attribute() {
1326 let source = r#"
1328VERSION: str
1329"#;
1330 let extractor = Extractor::new();
1331 let facts = extractor
1332 .extract_source(source, "test", "test.pyi")
1333 .unwrap();
1334 assert_eq!(facts.len(), 1);
1335 assert_eq!(facts[0].symbol, "VERSION");
1336 assert_eq!(facts[0].return_type, "str");
1337 }
1338
1339 #[test]
1340 fn test_s12_extract_generic_subscript() {
1341 let source = r#"
1342def func() -> Callable[[int, str], bool]: ...
1343"#;
1344 let extractor = Extractor::new();
1345 let facts = extractor
1346 .extract_source(source, "test", "test.pyi")
1347 .unwrap();
1348 assert_eq!(facts.len(), 1);
1349 assert!(facts[0].return_type.contains("Callable"));
1350 }
1351
1352 #[test]
1353 fn test_s12_extractor_default_no_private() {
1354 let extractor = Extractor::new();
1355 let source = r#"
1356def _private() -> None: ...
1357def public() -> None: ...
1358"#;
1359 let facts = extractor
1360 .extract_source(source, "test", "test.pyi")
1361 .unwrap();
1362 assert_eq!(facts.len(), 1);
1363 assert_eq!(facts[0].symbol, "public");
1364 }
1365
1366 #[test]
1367 fn test_s12_extract_multiple_classes_and_functions() {
1368 let source = r#"
1369class A:
1370 def m(self) -> int: ...
1371
1372class B:
1373 x: str
1374 def n(self) -> str: ...
1375
1376def standalone(a: int, b: int) -> int: ...
1377"#;
1378 let extractor = Extractor::new();
1379 let facts = extractor
1380 .extract_source(source, "test", "test.pyi")
1381 .unwrap();
1382 assert_eq!(facts.len(), 6);
1384 }
1385
1386 #[test]
1389 fn test_s12_extract_file_valid() {
1390 let dir = tempfile::tempdir().unwrap();
1391 let stub_path = dir.path().join("test.pyi");
1392 std::fs::write(&stub_path, "def greet(name: str) -> str: ...\n").unwrap();
1393
1394 let extractor = Extractor::new();
1395 let facts = extractor.extract_file(&stub_path, "test").unwrap();
1396 assert_eq!(facts.len(), 1);
1397 assert_eq!(facts[0].symbol, "greet");
1398 }
1399
1400 #[test]
1401 fn test_s12_extract_file_missing() {
1402 let extractor = Extractor::new();
1403 let result = extractor.extract_file(Path::new("/nonexistent/file.pyi"), "test");
1404 assert!(result.is_err());
1405 }
1406
1407 #[test]
1408 fn test_s12_extract_file_with_class() {
1409 let dir = tempfile::tempdir().unwrap();
1410 let stub_path = dir.path().join("mymodule.pyi");
1411 std::fs::write(
1412 &stub_path,
1413 r#"class MyClass:
1414 def method(self, x: int) -> str: ...
1415 name: str
1416"#,
1417 )
1418 .unwrap();
1419
1420 let extractor = Extractor::new();
1421 let facts = extractor.extract_file(&stub_path, "mymodule").unwrap();
1422 assert_eq!(facts.len(), 3);
1424 }
1425
1426 #[test]
1427 fn test_s12_extract_source_empty() {
1428 let extractor = Extractor::new();
1429 let facts = extractor.extract_source("", "empty", "empty.pyi").unwrap();
1430 assert!(facts.is_empty());
1431 }
1432
1433 #[test]
1434 fn test_s12_extract_source_comments_only() {
1435 let extractor = Extractor::new();
1436 let facts = extractor
1437 .extract_source("# Just a comment\n", "comments", "comments.pyi")
1438 .unwrap();
1439 assert!(facts.is_empty());
1440 }
1441
1442 #[test]
1443 fn test_s12_extract_source_private_excluded() {
1444 let extractor = Extractor::new();
1445 let source = "def _private() -> None: ...\ndef public() -> int: ...\n";
1446 let facts = extractor
1447 .extract_source(source, "test", "test.pyi")
1448 .unwrap();
1449 assert_eq!(facts.len(), 1);
1450 assert_eq!(facts[0].symbol, "public");
1451 }
1452
1453 #[test]
1454 fn test_s12_extract_source_private_included() {
1455 let extractor = Extractor::new().with_private();
1456 let source = "def _private() -> None: ...\ndef public() -> int: ...\n";
1457 let facts = extractor
1458 .extract_source(source, "test", "test.pyi")
1459 .unwrap();
1460 assert_eq!(facts.len(), 2);
1461 }
1462
1463 #[test]
1464 fn test_s12_extract_source_nested_class() {
1465 let extractor = Extractor::new();
1466 let source = r#"class Outer:
1467 class Inner:
1468 def inner_method(self) -> int: ...
1469 def outer_method(self) -> str: ...
1470"#;
1471 let facts = extractor
1472 .extract_source(source, "test", "test.pyi")
1473 .unwrap();
1474 assert!(facts.len() >= 2);
1475 }
1476
1477 #[test]
1478 fn test_s12_extract_source_complex_signatures() {
1479 let extractor = Extractor::new();
1480 let source = r#"
1481def foo(a: int, b: str, c: float = 0.0) -> bool: ...
1482def bar(items: list, key: str) -> dict: ...
1483def baz(*args, **kwargs) -> None: ...
1484"#;
1485 let facts = extractor
1486 .extract_source(source, "sigs", "sigs.pyi")
1487 .unwrap();
1488 assert_eq!(facts.len(), 3);
1489 }
1490
1491 #[test]
1492 fn test_s12_extract_source_module_path() {
1493 let extractor = Extractor::new();
1494 let source = "def greet(name: str) -> str: ...\n";
1495 let facts = extractor
1496 .extract_source(source, "my.nested.module", "module.pyi")
1497 .unwrap();
1498 assert_eq!(facts.len(), 1);
1499 let fqn = facts[0].fqn();
1501 assert!(fqn.contains("my.nested.module"));
1502 }
1503}