depyler_lambda/lambda_inference/
pattern_extraction.rs1use crate::lambda_inference::{InferenceError, Pattern, PatternType};
7use rustpython_ast::{Expr, ExprAttribute, ExprSubscript, Stmt, StmtFunctionDef};
8
9pub fn extract_access_patterns(statements: &[Stmt]) -> Result<Vec<Pattern>, InferenceError> {
11 let mut patterns = Vec::new();
12
13 for stmt in statements {
14 if let Stmt::FunctionDef(func_def) = stmt {
15 patterns.extend(extract_patterns_from_function(func_def)?);
16 }
17 }
18
19 Ok(patterns)
20}
21
22pub fn extract_patterns_from_function(
24 func_def: &StmtFunctionDef,
25) -> Result<Vec<Pattern>, InferenceError> {
26 let mut patterns = Vec::new();
27
28 for stmt in &func_def.body {
29 patterns.extend(extract_patterns_from_stmt(stmt)?);
30 }
31
32 Ok(patterns)
33}
34
35pub fn extract_patterns_from_stmt(stmt: &Stmt) -> Result<Vec<Pattern>, InferenceError> {
37 let mut patterns = Vec::new();
38
39 match stmt {
40 Stmt::Assign(assign) => {
41 for target in &assign.targets {
42 patterns.extend(extract_patterns_from_expr(&assign.value)?);
43 patterns.extend(extract_patterns_from_expr(target)?);
44 }
45 }
46 Stmt::AnnAssign(ann_assign) => {
47 if let Some(ref value) = ann_assign.value {
48 patterns.extend(extract_patterns_from_expr(value)?);
49 } else {
50 patterns.extend(extract_patterns_from_expr(&ann_assign.target)?);
51 }
52 }
53 Stmt::Return(ret) => {
54 if let Some(value) = &ret.value {
55 patterns.extend(extract_patterns_from_expr(value)?);
56 }
57 }
58 Stmt::If(if_stmt) => {
59 patterns.extend(extract_patterns_from_expr(&if_stmt.test)?);
60 for stmt in &if_stmt.body {
61 patterns.extend(extract_patterns_from_stmt(stmt)?);
62 }
63 for stmt in &if_stmt.orelse {
64 patterns.extend(extract_patterns_from_stmt(stmt)?);
65 }
66 }
67 Stmt::For(for_stmt) => {
68 patterns.extend(extract_patterns_from_expr(&for_stmt.iter)?);
69 for stmt in &for_stmt.body {
70 patterns.extend(extract_patterns_from_stmt(stmt)?);
71 }
72 }
73 Stmt::While(while_stmt) => {
74 patterns.extend(extract_patterns_from_expr(&while_stmt.test)?);
75 for stmt in &while_stmt.body {
76 patterns.extend(extract_patterns_from_stmt(stmt)?);
77 }
78 }
79 Stmt::With(with_stmt) => {
80 for item in &with_stmt.items {
81 patterns.extend(extract_patterns_from_expr(&item.context_expr)?);
82 }
83 for stmt in &with_stmt.body {
84 patterns.extend(extract_patterns_from_stmt(stmt)?);
85 }
86 }
87 Stmt::Expr(expr_stmt) => {
88 patterns.extend(extract_patterns_from_expr(&expr_stmt.value)?);
89 }
90 _ => {}
91 }
92
93 Ok(patterns)
94}
95
96pub fn extract_patterns_from_expr(expr: &Expr) -> Result<Vec<Pattern>, InferenceError> {
98 let mut patterns = Vec::new();
99
100 match expr {
101 Expr::Subscript(subscript) => {
102 if let Some(pattern) = extract_subscript_pattern(subscript)? {
103 patterns.push(pattern);
104 }
105 patterns.extend(extract_patterns_from_expr(&subscript.value)?);
106 }
107 Expr::Attribute(attr) => {
108 if let Some(pattern) = extract_attribute_pattern(attr)? {
109 patterns.push(pattern);
110 }
111 patterns.extend(extract_patterns_from_expr(&attr.value)?);
112 }
113 Expr::Call(call) => {
114 patterns.extend(extract_patterns_from_expr(&call.func)?);
115 for arg in &call.args {
116 patterns.extend(extract_patterns_from_expr(arg)?);
117 }
118 for keyword in &call.keywords {
119 patterns.extend(extract_patterns_from_expr(&keyword.value)?);
120 }
121 }
122 Expr::BinOp(binop) => {
123 patterns.extend(extract_patterns_from_expr(&binop.left)?);
124 patterns.extend(extract_patterns_from_expr(&binop.right)?);
125 }
126 Expr::Compare(compare) => {
127 patterns.extend(extract_patterns_from_expr(&compare.left)?);
128 for comp in &compare.comparators {
129 patterns.extend(extract_patterns_from_expr(comp)?);
130 }
131 }
132 Expr::BoolOp(boolop) => {
133 for value in &boolop.values {
134 patterns.extend(extract_patterns_from_expr(value)?);
135 }
136 }
137 Expr::UnaryOp(unaryop) => {
138 patterns.extend(extract_patterns_from_expr(&unaryop.operand)?);
139 }
140 Expr::IfExp(ifexp) => {
141 patterns.extend(extract_patterns_from_expr(&ifexp.test)?);
142 patterns.extend(extract_patterns_from_expr(&ifexp.body)?);
143 patterns.extend(extract_patterns_from_expr(&ifexp.orelse)?);
144 }
145 Expr::Dict(dict) => {
146 for value in &dict.values {
147 patterns.extend(extract_patterns_from_expr(value)?);
148 }
149 }
150 Expr::List(list) => {
151 for elt in &list.elts {
152 patterns.extend(extract_patterns_from_expr(elt)?);
153 }
154 }
155 Expr::Tuple(tuple) => {
156 for elt in &tuple.elts {
157 patterns.extend(extract_patterns_from_expr(elt)?);
158 }
159 }
160 _ => {}
161 }
162
163 Ok(patterns)
164}
165
166pub fn extract_subscript_pattern(
168 subscript: &ExprSubscript,
169) -> Result<Option<Pattern>, InferenceError> {
170 let mut access_chain = Vec::new();
171 let mut current_expr = &subscript.value;
172
173 if let Expr::Constant(constant) = &*subscript.slice {
175 if let Some(key) = constant.value.as_str() {
176 access_chain.insert(0, key.to_string());
177 }
178 }
180
181 loop {
183 match &**current_expr {
184 Expr::Subscript(inner_subscript) => {
185 if let Expr::Constant(constant) = &*inner_subscript.slice {
186 if let Some(key) = constant.value.as_str() {
187 access_chain.insert(0, key.to_string());
188 }
189 }
190 current_expr = &inner_subscript.value;
191 }
192 Expr::Attribute(attr) => {
193 access_chain.insert(0, attr.attr.to_string());
194 current_expr = &attr.value;
195 }
196 Expr::Name(name) => {
197 if name.id.as_str() == "event" {
198 return Ok(Some(Pattern {
199 access_chain,
200 pattern_type: PatternType::Mixed,
201 }));
202 }
203 break;
204 }
205 _ => break,
206 }
207 }
208
209 Ok(None)
210}
211
212pub fn extract_attribute_pattern(attr: &ExprAttribute) -> Result<Option<Pattern>, InferenceError> {
214 let mut access_chain = vec![attr.attr.to_string()];
215 let mut current_expr = &attr.value;
216
217 loop {
219 match &**current_expr {
220 Expr::Attribute(inner_attr) => {
221 access_chain.insert(0, inner_attr.attr.to_string());
222 current_expr = &inner_attr.value;
223 }
224 Expr::Subscript(subscript) => {
225 if let Expr::Constant(constant) = &*subscript.slice {
226 if let Some(key) = constant.value.as_str() {
227 access_chain.insert(0, key.to_string());
228 }
229 }
230 current_expr = &subscript.value;
231 }
232 Expr::Name(name) => {
233 if name.id.as_str() == "event" {
234 return Ok(Some(Pattern {
235 access_chain,
236 pattern_type: PatternType::Attribute,
237 }));
238 }
239 break;
240 }
241 _ => break,
242 }
243 }
244
245 Ok(None)
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use rustpython_ast::{Mod, ModModule};
252 use rustpython_parser::Parse;
253
254 fn parse_python(source: &str) -> Mod {
255 rustpython_ast::Suite::parse(source, "<test>")
256 .map(|statements| {
257 Mod::Module(ModModule {
258 body: statements,
259 type_ignores: vec![],
260 range: Default::default(),
261 })
262 })
263 .unwrap()
264 }
265
266 fn get_patterns(code: &str) -> Vec<Pattern> {
267 let ast = parse_python(code);
268 match ast {
269 Mod::Module(module) => extract_access_patterns(&module.body).unwrap(),
270 _ => vec![],
271 }
272 }
273
274 #[test]
279 fn test_subscript_simple_event_access() {
280 let patterns = get_patterns(
281 r#"
282def handler(event, context):
283 x = event['Records']
284"#,
285 );
286 assert!(patterns.iter().any(|p| p.access_chain == vec!["Records"]));
287 }
288
289 #[test]
290 fn test_subscript_nested_event_access() {
291 let patterns = get_patterns(
292 r#"
293def handler(event, context):
294 x = event['Records']['s3']
295"#,
296 );
297 assert!(patterns
298 .iter()
299 .any(|p| p.access_chain == vec!["Records", "s3"]));
300 }
301
302 #[test]
303 fn test_subscript_deeply_nested_access() {
304 let patterns = get_patterns(
305 r#"
306def handler(event, context):
307 x = event['Records']['s3']['bucket']['name']
308"#,
309 );
310 assert!(patterns
311 .iter()
312 .any(|p| p.access_chain == vec!["Records", "s3", "bucket", "name"]));
313 }
314
315 #[test]
316 fn test_subscript_numeric_index_skipped() {
317 let patterns = get_patterns(
318 r#"
319def handler(event, context):
320 x = event['Records'][0]['s3']
321"#,
322 );
323 assert!(patterns
325 .iter()
326 .any(|p| p.access_chain == vec!["Records", "s3"]));
327 }
328
329 #[test]
330 fn test_subscript_non_event_ignored() {
331 let patterns = get_patterns(
332 r#"
333def handler(event, context):
334 data = {'foo': 'bar'}
335 x = data['foo']
336"#,
337 );
338 assert!(patterns.is_empty());
340 }
341
342 #[test]
343 fn test_subscript_mixed_with_attribute() {
344 let patterns = get_patterns(
345 r#"
346def handler(event, context):
347 x = event['Records'][0].data
348"#,
349 );
350 assert!(!patterns.is_empty());
352 }
353
354 #[test]
359 fn test_attribute_simple_access() {
360 let patterns = get_patterns(
361 r#"
362def handler(event, context):
363 x = event.body
364"#,
365 );
366 assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
367 assert!(patterns
368 .iter()
369 .any(|p| p.pattern_type == PatternType::Attribute));
370 }
371
372 #[test]
373 fn test_attribute_nested_access() {
374 let patterns = get_patterns(
375 r#"
376def handler(event, context):
377 x = event.body.data
378"#,
379 );
380 assert!(patterns
381 .iter()
382 .any(|p| p.access_chain == vec!["body", "data"]));
383 }
384
385 #[test]
386 fn test_attribute_non_event_ignored() {
387 let patterns = get_patterns(
388 r#"
389def handler(event, context):
390 obj = SomeClass()
391 x = obj.attribute
392"#,
393 );
394 assert!(patterns.is_empty());
395 }
396
397 #[test]
402 fn test_extract_from_assign() {
403 let patterns = get_patterns(
404 r#"
405def handler(event, context):
406 x = event['body']
407"#,
408 );
409 assert!(!patterns.is_empty());
410 }
411
412 #[test]
413 fn test_extract_from_annotated_assign() {
414 let patterns = get_patterns(
415 r#"
416def handler(event, context):
417 x: str = event['body']
418"#,
419 );
420 assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
421 }
422
423 #[test]
424 fn test_extract_from_annotated_assign_no_value() {
425 let patterns = get_patterns(
426 r#"
427def handler(event, context):
428 x: str
429"#,
430 );
431 assert!(patterns.is_empty());
433 }
434
435 #[test]
436 fn test_extract_from_return() {
437 let patterns = get_patterns(
438 r#"
439def handler(event, context):
440 return event['data']
441"#,
442 );
443 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
444 }
445
446 #[test]
447 fn test_extract_from_return_none() {
448 let patterns = get_patterns(
449 r#"
450def handler(event, context):
451 return
452"#,
453 );
454 assert!(patterns.is_empty());
455 }
456
457 #[test]
458 fn test_extract_from_if_test() {
459 let patterns = get_patterns(
460 r#"
461def handler(event, context):
462 if event['status']:
463 pass
464"#,
465 );
466 assert!(patterns.iter().any(|p| p.access_chain == vec!["status"]));
467 }
468
469 #[test]
470 fn test_extract_from_if_body() {
471 let patterns = get_patterns(
472 r#"
473def handler(event, context):
474 if True:
475 x = event['body']
476"#,
477 );
478 assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
479 }
480
481 #[test]
482 fn test_extract_from_if_else() {
483 let patterns = get_patterns(
484 r#"
485def handler(event, context):
486 if True:
487 pass
488 else:
489 x = event['data']
490"#,
491 );
492 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
493 }
494
495 #[test]
496 fn test_extract_from_for_iter() {
497 let patterns = get_patterns(
498 r#"
499def handler(event, context):
500 for record in event['Records']:
501 pass
502"#,
503 );
504 assert!(patterns.iter().any(|p| p.access_chain == vec!["Records"]));
505 }
506
507 #[test]
508 fn test_extract_from_for_body() {
509 let patterns = get_patterns(
510 r#"
511def handler(event, context):
512 for i in range(10):
513 x = event['body']
514"#,
515 );
516 assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
517 }
518
519 #[test]
520 fn test_extract_from_while_test() {
521 let patterns = get_patterns(
522 r#"
523def handler(event, context):
524 while event['status']:
525 pass
526"#,
527 );
528 assert!(patterns.iter().any(|p| p.access_chain == vec!["status"]));
529 }
530
531 #[test]
532 fn test_extract_from_while_body() {
533 let patterns = get_patterns(
534 r#"
535def handler(event, context):
536 while True:
537 x = event['data']
538 break
539"#,
540 );
541 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
542 }
543
544 #[test]
545 fn test_extract_from_with_context() {
546 let patterns = get_patterns(
547 r#"
548def handler(event, context):
549 with event['resource'] as r:
550 pass
551"#,
552 );
553 assert!(patterns.iter().any(|p| p.access_chain == vec!["resource"]));
554 }
555
556 #[test]
557 fn test_extract_from_with_body() {
558 let patterns = get_patterns(
559 r#"
560def handler(event, context):
561 with open('file') as f:
562 x = event['data']
563"#,
564 );
565 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
566 }
567
568 #[test]
569 fn test_extract_from_expr_stmt() {
570 let patterns = get_patterns(
571 r#"
572def handler(event, context):
573 event['action']
574"#,
575 );
576 assert!(patterns.iter().any(|p| p.access_chain == vec!["action"]));
577 }
578
579 #[test]
584 fn test_extract_from_call_func() {
585 let patterns = get_patterns(
586 r#"
587def handler(event, context):
588 event['handler']()
589"#,
590 );
591 assert!(patterns.iter().any(|p| p.access_chain == vec!["handler"]));
592 }
593
594 #[test]
595 fn test_extract_from_call_args() {
596 let patterns = get_patterns(
597 r#"
598def handler(event, context):
599 process(event['data'])
600"#,
601 );
602 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
603 }
604
605 #[test]
606 fn test_extract_from_call_kwargs() {
607 let patterns = get_patterns(
608 r#"
609def handler(event, context):
610 process(data=event['body'])
611"#,
612 );
613 assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
614 }
615
616 #[test]
617 fn test_extract_from_binop() {
618 let patterns = get_patterns(
619 r#"
620def handler(event, context):
621 x = event['a'] + event['b']
622"#,
623 );
624 assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
625 assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
626 }
627
628 #[test]
629 fn test_extract_from_compare() {
630 let patterns = get_patterns(
631 r#"
632def handler(event, context):
633 if event['a'] == event['b']:
634 pass
635"#,
636 );
637 assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
638 assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
639 }
640
641 #[test]
642 fn test_extract_from_boolop() {
643 let patterns = get_patterns(
644 r#"
645def handler(event, context):
646 if event['a'] and event['b']:
647 pass
648"#,
649 );
650 assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
651 assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
652 }
653
654 #[test]
655 fn test_extract_from_unaryop() {
656 let patterns = get_patterns(
657 r#"
658def handler(event, context):
659 if not event['flag']:
660 pass
661"#,
662 );
663 assert!(patterns.iter().any(|p| p.access_chain == vec!["flag"]));
664 }
665
666 #[test]
667 fn test_extract_from_ifexp() {
668 let patterns = get_patterns(
669 r#"
670def handler(event, context):
671 x = event['a'] if event['cond'] else event['b']
672"#,
673 );
674 assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
675 assert!(patterns.iter().any(|p| p.access_chain == vec!["cond"]));
676 assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
677 }
678
679 #[test]
680 fn test_extract_from_dict_values() {
681 let patterns = get_patterns(
682 r#"
683def handler(event, context):
684 return {'result': event['data']}
685"#,
686 );
687 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
688 }
689
690 #[test]
691 fn test_extract_from_list_elements() {
692 let patterns = get_patterns(
693 r#"
694def handler(event, context):
695 return [event['a'], event['b']]
696"#,
697 );
698 assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
699 assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
700 }
701
702 #[test]
703 fn test_extract_from_tuple_elements() {
704 let patterns = get_patterns(
705 r#"
706def handler(event, context):
707 return (event['a'], event['b'])
708"#,
709 );
710 assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
711 assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
712 }
713
714 #[test]
719 fn test_multiple_functions_in_module() {
720 let patterns = get_patterns(
721 r#"
722def helper():
723 pass
724
725def handler(event, context):
726 x = event['data']
727 return x
728
729def another_helper():
730 pass
731"#,
732 );
733 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
734 }
735
736 #[test]
737 fn test_empty_function() {
738 let patterns = get_patterns(
739 r#"
740def handler(event, context):
741 pass
742"#,
743 );
744 assert!(patterns.is_empty());
745 }
746
747 #[test]
748 fn test_no_function_def() {
749 let patterns = get_patterns(
750 r#"
751x = 1
752y = 2
753"#,
754 );
755 assert!(patterns.is_empty());
756 }
757
758 #[test]
759 fn test_nested_function() {
760 let patterns = get_patterns(
761 r#"
762def handler(event, context):
763 def inner():
764 return event['inner_data']
765 return event['outer_data']
766"#,
767 );
768 assert!(patterns
770 .iter()
771 .any(|p| p.access_chain == vec!["outer_data"]));
772 }
773
774 #[test]
775 fn test_complex_s3_pattern() {
776 let patterns = get_patterns(
777 r#"
778def handler(event, context):
779 bucket = event['Records'][0]['s3']['bucket']['name']
780 key = event['Records'][0]['s3']['object']['key']
781 return {'bucket': bucket, 'key': key}
782"#,
783 );
784 assert!(patterns.len() >= 2);
785 assert!(patterns
786 .iter()
787 .any(|p| p.access_chain.contains(&"bucket".to_string())));
788 assert!(patterns
789 .iter()
790 .any(|p| p.access_chain.contains(&"object".to_string())));
791 }
792
793 #[test]
794 fn test_api_gateway_pattern() {
795 let patterns = get_patterns(
796 r#"
797def handler(event, context):
798 method = event['requestContext']['http']['method']
799 path = event['requestContext']['http']['path']
800 body = event['body']
801 return {'method': method, 'path': path}
802"#,
803 );
804 assert!(patterns.len() >= 3);
805 assert!(patterns
806 .iter()
807 .any(|p| p.access_chain.contains(&"requestContext".to_string())));
808 assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
809 }
810
811 #[test]
812 fn test_eventbridge_pattern() {
813 let patterns = get_patterns(
814 r#"
815def handler(event, context):
816 detail_type = event['detail-type']
817 detail = event['detail']
818 source = event['source']
819 return None
820"#,
821 );
822 assert!(patterns
823 .iter()
824 .any(|p| p.access_chain == vec!["detail-type"]));
825 assert!(patterns.iter().any(|p| p.access_chain == vec!["detail"]));
826 assert!(patterns.iter().any(|p| p.access_chain == vec!["source"]));
827 }
828
829 #[test]
830 fn test_pattern_type_is_mixed_for_subscript() {
831 let patterns = get_patterns(
832 r#"
833def handler(event, context):
834 x = event['data']
835"#,
836 );
837 assert!(patterns
838 .iter()
839 .all(|p| p.pattern_type == PatternType::Mixed));
840 }
841
842 #[test]
843 fn test_pattern_type_is_attribute_for_dot_access() {
844 let patterns = get_patterns(
845 r#"
846def handler(event, context):
847 x = event.data
848"#,
849 );
850 assert!(patterns
851 .iter()
852 .any(|p| p.pattern_type == PatternType::Attribute));
853 }
854
855 #[test]
856 fn test_multiple_targets_in_assign() {
857 let patterns = get_patterns(
858 r#"
859def handler(event, context):
860 x = y = event['data']
861"#,
862 );
863 assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
864 }
865}