1use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct StdlibApiMapping {
15 pub module: &'static str,
17 pub class: &'static str,
19 pub python_attr: &'static str,
21 pub rust_pattern: RustPattern,
23}
24
25#[derive(Debug, Clone)]
27pub enum RustPattern {
28 MethodCall {
30 method: &'static str,
31 extra_args: Vec<&'static str>,
33 propagate_error: bool,
35 },
36
37 PropertyToMethod {
39 method: &'static str,
40 propagate_error: bool,
41 },
42
43 IterationPattern {
45 iter_method: &'static str,
47 element_type: Option<&'static str>,
49 yields_results: bool,
51 },
52
53 CustomTemplate { template: &'static str },
55}
56
57pub trait StdlibPlugin {
87 fn register_mappings(&self, registry: &mut StdlibMappings);
89
90 fn name(&self) -> &str;
92
93 fn version(&self) -> &str {
95 "0.1.0"
96 }
97}
98
99pub struct StdlibMappings {
101 mappings: HashMap<(String, String, String), RustPattern>,
103}
104
105impl StdlibMappings {
106 pub fn new() -> Self {
108 let mut mappings = HashMap::new();
109
110 Self::register_csv_mappings(&mut mappings);
112
113 Self::register_file_mappings(&mut mappings);
115
116 Self { mappings }
117 }
118
119 fn register_csv_mappings(mappings: &mut HashMap<(String, String, String), RustPattern>) {
121 mappings.insert(
123 (
124 "csv".to_string(),
125 "DictReader".to_string(),
126 "fieldnames".to_string(),
127 ),
128 RustPattern::PropertyToMethod {
129 method: "headers",
130 propagate_error: true,
131 },
132 );
133
134 mappings.insert(
136 (
137 "csv".to_string(),
138 "DictReader".to_string(),
139 "__iter__".to_string(),
140 ),
141 RustPattern::IterationPattern {
142 iter_method: "deserialize",
143 element_type: Some("HashMap<String, String>"),
144 yields_results: true,
145 },
146 );
147
148 mappings.insert(
150 (
151 "csv".to_string(),
152 "Reader".to_string(),
153 "fieldnames".to_string(),
154 ),
155 RustPattern::PropertyToMethod {
156 method: "headers",
157 propagate_error: true,
158 },
159 );
160 }
161
162 fn register_file_mappings(mappings: &mut HashMap<(String, String, String), RustPattern>) {
164 mappings.insert(
166 (
167 "builtins".to_string(),
168 "file".to_string(),
169 "__iter__".to_string(),
170 ),
171 RustPattern::CustomTemplate {
172 template: "BufReader::new({var}).lines()",
173 },
174 );
175
176 mappings.insert(
178 (
179 "io".to_string(),
180 "TextIOWrapper".to_string(),
181 "__iter__".to_string(),
182 ),
183 RustPattern::CustomTemplate {
184 template: "BufReader::new({var}).lines()",
185 },
186 );
187 }
188
189 pub fn lookup(&self, module: &str, class: &str, attribute: &str) -> Option<&RustPattern> {
191 self.mappings
192 .get(&(module.to_string(), class.to_string(), attribute.to_string()))
193 }
194
195 pub fn has_iteration_mapping(&self, module: &str, class: &str) -> bool {
197 self.lookup(module, class, "__iter__").is_some()
198 }
199
200 pub fn get_iteration_pattern(&self, module: &str, class: &str) -> Option<&RustPattern> {
202 self.lookup(module, class, "__iter__")
203 }
204
205 pub fn register(&mut self, mapping: StdlibApiMapping) {
224 let key = (
225 mapping.module.to_string(),
226 mapping.class.to_string(),
227 mapping.python_attr.to_string(),
228 );
229 self.mappings.insert(key, mapping.rust_pattern);
230 }
231
232 pub fn register_batch(&mut self, mappings: Vec<StdlibApiMapping>) {
234 for mapping in mappings {
235 self.register(mapping);
236 }
237 }
238
239 pub fn load_plugin(&mut self, plugin: &dyn StdlibPlugin) {
257 plugin.register_mappings(self);
258 }
259
260 pub fn load_plugins(&mut self, plugins: &[&dyn StdlibPlugin]) {
262 for plugin in plugins {
263 self.load_plugin(*plugin);
264 }
265 }
266}
267
268impl Default for StdlibMappings {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274impl RustPattern {
276 pub fn generate_rust_code(&self, base_expr: &str, original_args: &[String]) -> String {
282 match self {
283 RustPattern::MethodCall {
284 method,
285 extra_args,
286 propagate_error,
287 } => {
288 let mut all_args = original_args.to_vec();
289 all_args.extend(extra_args.iter().map(|s| s.to_string()));
290 let args_str = all_args.join(", ");
291 let call = if args_str.is_empty() {
292 format!("{}.{}()", base_expr, method)
293 } else {
294 format!("{}.{}({})", base_expr, method, args_str)
295 };
296
297 if *propagate_error {
298 format!("{}?", call)
299 } else {
300 call
301 }
302 }
303
304 RustPattern::PropertyToMethod {
305 method,
306 propagate_error,
307 } => {
308 let call = format!("{}.{}()", base_expr, method);
309 if *propagate_error {
310 format!("{}?", call)
311 } else {
312 call
313 }
314 }
315
316 RustPattern::IterationPattern {
317 iter_method,
318 element_type,
319 yields_results: _,
320 } => {
321 if let Some(elem_type) = element_type {
322 format!("{}.{}::<{}>()", base_expr, iter_method, elem_type)
323 } else {
324 format!("{}.{}()", base_expr, iter_method)
325 }
326 }
327
328 RustPattern::CustomTemplate { template } => template.replace("{var}", base_expr),
329 }
330 }
331
332 pub fn yields_results(&self) -> bool {
334 matches!(
335 self,
336 RustPattern::IterationPattern {
337 yields_results: true,
338 ..
339 }
340 )
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_csv_fieldnames_mapping() {
350 let mappings = StdlibMappings::new();
351 let pattern = mappings.lookup("csv", "DictReader", "fieldnames");
352 assert!(pattern.is_some());
353
354 let rust_code = pattern.unwrap().generate_rust_code("reader", &[]);
355 assert_eq!(rust_code, "reader.headers()?");
356 }
357
358 #[test]
359 fn test_csv_iteration_mapping() {
360 let mappings = StdlibMappings::new();
361 let pattern = mappings.get_iteration_pattern("csv", "DictReader");
362 assert!(pattern.is_some());
363
364 let rust_code = pattern.unwrap().generate_rust_code("reader", &[]);
365 assert_eq!(rust_code, "reader.deserialize::<HashMap<String, String>>()");
366 }
367
368 #[test]
369 fn test_file_iteration_mapping() {
370 let mappings = StdlibMappings::new();
371 let pattern = mappings.lookup("builtins", "file", "__iter__");
372 assert!(pattern.is_some());
373
374 let rust_code = pattern.unwrap().generate_rust_code("f", &[]);
375 assert_eq!(rust_code, "BufReader::new(f).lines()");
376 }
377
378 #[test]
381 fn test_register_custom_mapping() {
382 let mut mappings = StdlibMappings::new();
383
384 mappings.register(StdlibApiMapping {
386 module: "requests",
387 class: "Session",
388 python_attr: "get",
389 rust_pattern: RustPattern::MethodCall {
390 method: "get",
391 extra_args: vec![],
392 propagate_error: true,
393 },
394 });
395
396 let pattern = mappings.lookup("requests", "Session", "get");
398 assert!(pattern.is_some());
399
400 let rust_code = pattern.unwrap().generate_rust_code("session", &[]);
401 assert_eq!(rust_code, "session.get()?");
402 }
403
404 #[test]
405 fn test_register_batch() {
406 let mut mappings = StdlibMappings::new();
407
408 let batch = vec![
409 StdlibApiMapping {
410 module: "numpy",
411 class: "ndarray",
412 python_attr: "shape",
413 rust_pattern: RustPattern::PropertyToMethod {
414 method: "shape",
415 propagate_error: false,
416 },
417 },
418 StdlibApiMapping {
419 module: "numpy",
420 class: "ndarray",
421 python_attr: "dtype",
422 rust_pattern: RustPattern::PropertyToMethod {
423 method: "dtype",
424 propagate_error: false,
425 },
426 },
427 ];
428
429 mappings.register_batch(batch);
430
431 assert!(mappings.lookup("numpy", "ndarray", "shape").is_some());
432 assert!(mappings.lookup("numpy", "ndarray", "dtype").is_some());
433 }
434
435 struct TestRequestsPlugin;
437
438 impl StdlibPlugin for TestRequestsPlugin {
439 fn register_mappings(&self, registry: &mut StdlibMappings) {
440 registry.register(StdlibApiMapping {
441 module: "requests",
442 class: "Session",
443 python_attr: "get",
444 rust_pattern: RustPattern::MethodCall {
445 method: "get",
446 extra_args: vec![],
447 propagate_error: true,
448 },
449 });
450
451 registry.register(StdlibApiMapping {
452 module: "requests",
453 class: "Session",
454 python_attr: "post",
455 rust_pattern: RustPattern::MethodCall {
456 method: "post",
457 extra_args: vec![],
458 propagate_error: true,
459 },
460 });
461 }
462
463 fn name(&self) -> &str {
464 "requests"
465 }
466
467 fn version(&self) -> &str {
468 "1.0.0"
469 }
470 }
471
472 #[test]
473 fn test_load_plugin() {
474 let mut mappings = StdlibMappings::new();
475 let plugin = TestRequestsPlugin;
476
477 mappings.load_plugin(&plugin);
478
479 assert!(mappings.lookup("requests", "Session", "get").is_some());
481 assert!(mappings.lookup("requests", "Session", "post").is_some());
482
483 let get_pattern = mappings.lookup("requests", "Session", "get").unwrap();
485 assert_eq!(
486 get_pattern.generate_rust_code("session", &[]),
487 "session.get()?"
488 );
489 }
490
491 struct TestNumpyPlugin;
492
493 impl StdlibPlugin for TestNumpyPlugin {
494 fn register_mappings(&self, registry: &mut StdlibMappings) {
495 registry.register(StdlibApiMapping {
496 module: "numpy",
497 class: "ndarray",
498 python_attr: "reshape",
499 rust_pattern: RustPattern::MethodCall {
500 method: "reshape",
501 extra_args: vec![],
502 propagate_error: false,
503 },
504 });
505 }
506
507 fn name(&self) -> &str {
508 "numpy"
509 }
510 }
511
512 #[test]
513 fn test_load_multiple_plugins() {
514 let mut mappings = StdlibMappings::new();
515 let requests_plugin = TestRequestsPlugin;
516 let numpy_plugin = TestNumpyPlugin;
517
518 mappings.load_plugins(&[&requests_plugin, &numpy_plugin]);
519
520 assert!(mappings.lookup("requests", "Session", "get").is_some());
522 assert!(mappings.lookup("numpy", "ndarray", "reshape").is_some());
523 }
524
525 #[test]
526 fn test_plugin_override_builtin() {
527 let mut mappings = StdlibMappings::new();
528
529 assert!(mappings.lookup("csv", "DictReader", "fieldnames").is_some());
531
532 struct OverridePlugin;
534 impl StdlibPlugin for OverridePlugin {
535 fn register_mappings(&self, registry: &mut StdlibMappings) {
536 registry.register(StdlibApiMapping {
537 module: "csv",
538 class: "DictReader",
539 python_attr: "fieldnames",
540 rust_pattern: RustPattern::PropertyToMethod {
541 method: "get_headers", propagate_error: true,
543 },
544 });
545 }
546 fn name(&self) -> &str {
547 "csv_override"
548 }
549 }
550
551 mappings.load_plugin(&OverridePlugin);
552
553 let pattern = mappings.lookup("csv", "DictReader", "fieldnames").unwrap();
555 let code = pattern.generate_rust_code("reader", &[]);
556 assert_eq!(code, "reader.get_headers()?");
557 }
558
559 #[test]
564 fn test_stdlib_mappings_default() {
565 let mappings = StdlibMappings::default();
566 assert!(mappings.lookup("csv", "DictReader", "fieldnames").is_some());
568 assert!(mappings.lookup("csv", "DictReader", "__iter__").is_some());
569 }
570
571 #[test]
572 fn test_stdlib_api_mapping_clone() {
573 let mapping = StdlibApiMapping {
574 module: "csv",
575 class: "Reader",
576 python_attr: "test",
577 rust_pattern: RustPattern::PropertyToMethod {
578 method: "test",
579 propagate_error: false,
580 },
581 };
582 let cloned = mapping.clone();
583 assert_eq!(cloned.module, "csv");
584 assert_eq!(cloned.class, "Reader");
585 }
586
587 #[test]
588 fn test_stdlib_api_mapping_debug() {
589 let mapping = StdlibApiMapping {
590 module: "csv",
591 class: "Reader",
592 python_attr: "test",
593 rust_pattern: RustPattern::PropertyToMethod {
594 method: "test",
595 propagate_error: false,
596 },
597 };
598 let debug = format!("{:?}", mapping);
599 assert!(debug.contains("csv"));
600 assert!(debug.contains("Reader"));
601 }
602
603 #[test]
604 fn test_rust_pattern_debug() {
605 let pattern = RustPattern::MethodCall {
606 method: "test",
607 extra_args: vec![],
608 propagate_error: false,
609 };
610 let debug = format!("{:?}", pattern);
611 assert!(debug.contains("MethodCall"));
612 }
613
614 #[test]
615 fn test_rust_pattern_clone() {
616 let pattern = RustPattern::CustomTemplate {
617 template: "test({var})",
618 };
619 let cloned = pattern.clone();
620 if let RustPattern::CustomTemplate { template } = cloned {
621 assert_eq!(template, "test({var})");
622 } else {
623 panic!("Clone should preserve variant");
624 }
625 }
626
627 #[test]
628 fn test_method_call_with_args() {
629 let pattern = RustPattern::MethodCall {
630 method: "fetch",
631 extra_args: vec![],
632 propagate_error: false,
633 };
634 let code = pattern.generate_rust_code("client", &["url".to_string()]);
635 assert_eq!(code, "client.fetch(url)");
636 }
637
638 #[test]
639 fn test_method_call_with_extra_args() {
640 let pattern = RustPattern::MethodCall {
641 method: "fetch",
642 extra_args: vec!["timeout"],
643 propagate_error: false,
644 };
645 let code = pattern.generate_rust_code("client", &["url".to_string()]);
646 assert_eq!(code, "client.fetch(url, timeout)");
647 }
648
649 #[test]
650 fn test_method_call_no_propagate_error() {
651 let pattern = RustPattern::MethodCall {
652 method: "get",
653 extra_args: vec![],
654 propagate_error: false,
655 };
656 let code = pattern.generate_rust_code("obj", &[]);
657 assert_eq!(code, "obj.get()");
658 }
659
660 #[test]
661 fn test_method_call_propagate_error() {
662 let pattern = RustPattern::MethodCall {
663 method: "get",
664 extra_args: vec![],
665 propagate_error: true,
666 };
667 let code = pattern.generate_rust_code("obj", &[]);
668 assert_eq!(code, "obj.get()?");
669 }
670
671 #[test]
672 fn test_property_to_method_no_error() {
673 let pattern = RustPattern::PropertyToMethod {
674 method: "len",
675 propagate_error: false,
676 };
677 let code = pattern.generate_rust_code("list", &[]);
678 assert_eq!(code, "list.len()");
679 }
680
681 #[test]
682 fn test_property_to_method_with_error() {
683 let pattern = RustPattern::PropertyToMethod {
684 method: "headers",
685 propagate_error: true,
686 };
687 let code = pattern.generate_rust_code("reader", &[]);
688 assert_eq!(code, "reader.headers()?");
689 }
690
691 #[test]
692 fn test_iteration_pattern_no_element_type() {
693 let pattern = RustPattern::IterationPattern {
694 iter_method: "iter",
695 element_type: None,
696 yields_results: false,
697 };
698 let code = pattern.generate_rust_code("collection", &[]);
699 assert_eq!(code, "collection.iter()");
700 }
701
702 #[test]
703 fn test_iteration_pattern_with_element_type() {
704 let pattern = RustPattern::IterationPattern {
705 iter_method: "deserialize",
706 element_type: Some("Record"),
707 yields_results: true,
708 };
709 let code = pattern.generate_rust_code("reader", &[]);
710 assert_eq!(code, "reader.deserialize::<Record>()");
711 }
712
713 #[test]
714 fn test_custom_template_with_var() {
715 let pattern = RustPattern::CustomTemplate {
716 template: "Box::new({var})",
717 };
718 let code = pattern.generate_rust_code("value", &[]);
719 assert_eq!(code, "Box::new(value)");
720 }
721
722 #[test]
723 fn test_custom_template_multiple_vars() {
724 let pattern = RustPattern::CustomTemplate {
725 template: "process({var}).map(|x| x + {var})",
726 };
727 let code = pattern.generate_rust_code("n", &[]);
728 assert_eq!(code, "process(n).map(|x| x + n)");
729 }
730
731 #[test]
732 fn test_yields_results_true() {
733 let pattern = RustPattern::IterationPattern {
734 iter_method: "deserialize",
735 element_type: Some("Row"),
736 yields_results: true,
737 };
738 assert!(pattern.yields_results());
739 }
740
741 #[test]
742 fn test_yields_results_false() {
743 let pattern = RustPattern::IterationPattern {
744 iter_method: "iter",
745 element_type: None,
746 yields_results: false,
747 };
748 assert!(!pattern.yields_results());
749 }
750
751 #[test]
752 fn test_yields_results_method_call() {
753 let pattern = RustPattern::MethodCall {
754 method: "test",
755 extra_args: vec![],
756 propagate_error: true,
757 };
758 assert!(!pattern.yields_results());
759 }
760
761 #[test]
762 fn test_yields_results_property_to_method() {
763 let pattern = RustPattern::PropertyToMethod {
764 method: "test",
765 propagate_error: true,
766 };
767 assert!(!pattern.yields_results());
768 }
769
770 #[test]
771 fn test_yields_results_custom_template() {
772 let pattern = RustPattern::CustomTemplate {
773 template: "{var}.iter()",
774 };
775 assert!(!pattern.yields_results());
776 }
777
778 #[test]
779 fn test_has_iteration_mapping_true() {
780 let mappings = StdlibMappings::new();
781 assert!(mappings.has_iteration_mapping("csv", "DictReader"));
782 }
783
784 #[test]
785 fn test_has_iteration_mapping_false() {
786 let mappings = StdlibMappings::new();
787 assert!(!mappings.has_iteration_mapping("unknown", "Unknown"));
788 }
789
790 #[test]
791 fn test_lookup_nonexistent() {
792 let mappings = StdlibMappings::new();
793 assert!(mappings.lookup("nonexistent", "Foo", "bar").is_none());
794 }
795
796 #[test]
797 fn test_csv_reader_fieldnames() {
798 let mappings = StdlibMappings::new();
799 let pattern = mappings.lookup("csv", "Reader", "fieldnames");
800 assert!(pattern.is_some());
801 let code = pattern.unwrap().generate_rust_code("reader", &[]);
802 assert_eq!(code, "reader.headers()?");
803 }
804
805 #[test]
806 fn test_io_text_wrapper_iteration() {
807 let mappings = StdlibMappings::new();
808 let pattern = mappings.lookup("io", "TextIOWrapper", "__iter__");
809 assert!(pattern.is_some());
810 let code = pattern.unwrap().generate_rust_code("file", &[]);
811 assert_eq!(code, "BufReader::new(file).lines()");
812 }
813
814 #[test]
815 fn test_plugin_default_version() {
816 struct MinimalPlugin;
817 impl StdlibPlugin for MinimalPlugin {
818 fn register_mappings(&self, _registry: &mut StdlibMappings) {}
819 fn name(&self) -> &str {
820 "minimal"
821 }
822 }
823 let plugin = MinimalPlugin;
824 assert_eq!(plugin.version(), "0.1.0");
825 assert_eq!(plugin.name(), "minimal");
826 }
827
828 #[test]
829 fn test_plugin_custom_version() {
830 assert_eq!(TestRequestsPlugin.version(), "1.0.0");
831 }
832
833 #[test]
834 fn test_get_iteration_pattern_nonexistent() {
835 let mappings = StdlibMappings::new();
836 assert!(mappings
837 .get_iteration_pattern("unknown", "Unknown")
838 .is_none());
839 }
840
841 #[test]
842 fn test_method_call_with_multiple_extra_args() {
843 let pattern = RustPattern::MethodCall {
844 method: "request",
845 extra_args: vec!["headers", "timeout"],
846 propagate_error: true,
847 };
848 let code = pattern.generate_rust_code("client", &["url".to_string()]);
849 assert_eq!(code, "client.request(url, headers, timeout)?");
850 }
851
852 #[test]
853 fn test_empty_module_lookup() {
854 let mappings = StdlibMappings::new();
855 assert!(mappings.lookup("", "DictReader", "fieldnames").is_none());
856 }
857
858 #[test]
859 fn test_empty_class_lookup() {
860 let mappings = StdlibMappings::new();
861 assert!(mappings.lookup("csv", "", "fieldnames").is_none());
862 }
863
864 #[test]
865 fn test_empty_attribute_lookup() {
866 let mappings = StdlibMappings::new();
867 assert!(mappings.lookup("csv", "DictReader", "").is_none());
868 }
869}