Skip to main content

depyler_lambda/lambda_inference/
pattern_extraction.rs

1//! Pattern extraction from Python AST for Lambda event type inference
2//!
3//! This module extracts access patterns from Python AST expressions to enable
4//! AWS Lambda event type inference.
5
6use crate::lambda_inference::{InferenceError, Pattern, PatternType};
7use rustpython_ast::{Expr, ExprAttribute, ExprSubscript, Stmt, StmtFunctionDef};
8
9/// Extract access patterns from a list of statements
10pub fn extract_access_patterns(statements: &[Stmt]) -> Result<Vec<Pattern>, InferenceError> {
11    let mut patterns = Vec::new();
12
13    for stmt in statements {
14        if let Stmt::FunctionDef(func_def) = stmt {
15            patterns.extend(extract_patterns_from_function(func_def)?);
16        }
17    }
18
19    Ok(patterns)
20}
21
22/// Extract patterns from a function definition
23pub fn extract_patterns_from_function(
24    func_def: &StmtFunctionDef,
25) -> Result<Vec<Pattern>, InferenceError> {
26    let mut patterns = Vec::new();
27
28    for stmt in &func_def.body {
29        patterns.extend(extract_patterns_from_stmt(stmt)?);
30    }
31
32    Ok(patterns)
33}
34
35/// Extract patterns from a single statement
36pub fn extract_patterns_from_stmt(stmt: &Stmt) -> Result<Vec<Pattern>, InferenceError> {
37    let mut patterns = Vec::new();
38
39    match stmt {
40        Stmt::Assign(assign) => {
41            for target in &assign.targets {
42                patterns.extend(extract_patterns_from_expr(&assign.value)?);
43                patterns.extend(extract_patterns_from_expr(target)?);
44            }
45        }
46        Stmt::AnnAssign(ann_assign) => {
47            if let Some(ref value) = ann_assign.value {
48                patterns.extend(extract_patterns_from_expr(value)?);
49            } else {
50                patterns.extend(extract_patterns_from_expr(&ann_assign.target)?);
51            }
52        }
53        Stmt::Return(ret) => {
54            if let Some(value) = &ret.value {
55                patterns.extend(extract_patterns_from_expr(value)?);
56            }
57        }
58        Stmt::If(if_stmt) => {
59            patterns.extend(extract_patterns_from_expr(&if_stmt.test)?);
60            for stmt in &if_stmt.body {
61                patterns.extend(extract_patterns_from_stmt(stmt)?);
62            }
63            for stmt in &if_stmt.orelse {
64                patterns.extend(extract_patterns_from_stmt(stmt)?);
65            }
66        }
67        Stmt::For(for_stmt) => {
68            patterns.extend(extract_patterns_from_expr(&for_stmt.iter)?);
69            for stmt in &for_stmt.body {
70                patterns.extend(extract_patterns_from_stmt(stmt)?);
71            }
72        }
73        Stmt::While(while_stmt) => {
74            patterns.extend(extract_patterns_from_expr(&while_stmt.test)?);
75            for stmt in &while_stmt.body {
76                patterns.extend(extract_patterns_from_stmt(stmt)?);
77            }
78        }
79        Stmt::With(with_stmt) => {
80            for item in &with_stmt.items {
81                patterns.extend(extract_patterns_from_expr(&item.context_expr)?);
82            }
83            for stmt in &with_stmt.body {
84                patterns.extend(extract_patterns_from_stmt(stmt)?);
85            }
86        }
87        Stmt::Expr(expr_stmt) => {
88            patterns.extend(extract_patterns_from_expr(&expr_stmt.value)?);
89        }
90        _ => {}
91    }
92
93    Ok(patterns)
94}
95
96/// Extract patterns from an expression
97pub fn extract_patterns_from_expr(expr: &Expr) -> Result<Vec<Pattern>, InferenceError> {
98    let mut patterns = Vec::new();
99
100    match expr {
101        Expr::Subscript(subscript) => {
102            if let Some(pattern) = extract_subscript_pattern(subscript)? {
103                patterns.push(pattern);
104            }
105            patterns.extend(extract_patterns_from_expr(&subscript.value)?);
106        }
107        Expr::Attribute(attr) => {
108            if let Some(pattern) = extract_attribute_pattern(attr)? {
109                patterns.push(pattern);
110            }
111            patterns.extend(extract_patterns_from_expr(&attr.value)?);
112        }
113        Expr::Call(call) => {
114            patterns.extend(extract_patterns_from_expr(&call.func)?);
115            for arg in &call.args {
116                patterns.extend(extract_patterns_from_expr(arg)?);
117            }
118            for keyword in &call.keywords {
119                patterns.extend(extract_patterns_from_expr(&keyword.value)?);
120            }
121        }
122        Expr::BinOp(binop) => {
123            patterns.extend(extract_patterns_from_expr(&binop.left)?);
124            patterns.extend(extract_patterns_from_expr(&binop.right)?);
125        }
126        Expr::Compare(compare) => {
127            patterns.extend(extract_patterns_from_expr(&compare.left)?);
128            for comp in &compare.comparators {
129                patterns.extend(extract_patterns_from_expr(comp)?);
130            }
131        }
132        Expr::BoolOp(boolop) => {
133            for value in &boolop.values {
134                patterns.extend(extract_patterns_from_expr(value)?);
135            }
136        }
137        Expr::UnaryOp(unaryop) => {
138            patterns.extend(extract_patterns_from_expr(&unaryop.operand)?);
139        }
140        Expr::IfExp(ifexp) => {
141            patterns.extend(extract_patterns_from_expr(&ifexp.test)?);
142            patterns.extend(extract_patterns_from_expr(&ifexp.body)?);
143            patterns.extend(extract_patterns_from_expr(&ifexp.orelse)?);
144        }
145        Expr::Dict(dict) => {
146            for value in &dict.values {
147                patterns.extend(extract_patterns_from_expr(value)?);
148            }
149        }
150        Expr::List(list) => {
151            for elt in &list.elts {
152                patterns.extend(extract_patterns_from_expr(elt)?);
153            }
154        }
155        Expr::Tuple(tuple) => {
156            for elt in &tuple.elts {
157                patterns.extend(extract_patterns_from_expr(elt)?);
158            }
159        }
160        _ => {}
161    }
162
163    Ok(patterns)
164}
165
166/// Extract a pattern from a subscript expression like event['Records']
167pub fn extract_subscript_pattern(
168    subscript: &ExprSubscript,
169) -> Result<Option<Pattern>, InferenceError> {
170    let mut access_chain = Vec::new();
171    let mut current_expr = &subscript.value;
172
173    // Extract the subscript key
174    if let Expr::Constant(constant) = &*subscript.slice {
175        if let Some(key) = constant.value.as_str() {
176            access_chain.insert(0, key.to_string());
177        }
178        // Skip numeric indices - they don't contribute to pattern matching
179    }
180
181    // Walk up the access chain
182    loop {
183        match &**current_expr {
184            Expr::Subscript(inner_subscript) => {
185                if let Expr::Constant(constant) = &*inner_subscript.slice {
186                    if let Some(key) = constant.value.as_str() {
187                        access_chain.insert(0, key.to_string());
188                    }
189                }
190                current_expr = &inner_subscript.value;
191            }
192            Expr::Attribute(attr) => {
193                access_chain.insert(0, attr.attr.to_string());
194                current_expr = &attr.value;
195            }
196            Expr::Name(name) => {
197                if name.id.as_str() == "event" {
198                    return Ok(Some(Pattern {
199                        access_chain,
200                        pattern_type: PatternType::Mixed,
201                    }));
202                }
203                break;
204            }
205            _ => break,
206        }
207    }
208
209    Ok(None)
210}
211
212/// Extract a pattern from an attribute expression like event.body
213pub fn extract_attribute_pattern(attr: &ExprAttribute) -> Result<Option<Pattern>, InferenceError> {
214    let mut access_chain = vec![attr.attr.to_string()];
215    let mut current_expr = &attr.value;
216
217    // Walk up the access chain
218    loop {
219        match &**current_expr {
220            Expr::Attribute(inner_attr) => {
221                access_chain.insert(0, inner_attr.attr.to_string());
222                current_expr = &inner_attr.value;
223            }
224            Expr::Subscript(subscript) => {
225                if let Expr::Constant(constant) = &*subscript.slice {
226                    if let Some(key) = constant.value.as_str() {
227                        access_chain.insert(0, key.to_string());
228                    }
229                }
230                current_expr = &subscript.value;
231            }
232            Expr::Name(name) => {
233                if name.id.as_str() == "event" {
234                    return Ok(Some(Pattern {
235                        access_chain,
236                        pattern_type: PatternType::Attribute,
237                    }));
238                }
239                break;
240            }
241            _ => break,
242        }
243    }
244
245    Ok(None)
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use rustpython_ast::{Mod, ModModule};
252    use rustpython_parser::Parse;
253
254    fn parse_python(source: &str) -> Mod {
255        rustpython_ast::Suite::parse(source, "<test>")
256            .map(|statements| {
257                Mod::Module(ModModule {
258                    body: statements,
259                    type_ignores: vec![],
260                    range: Default::default(),
261                })
262            })
263            .unwrap()
264    }
265
266    fn get_patterns(code: &str) -> Vec<Pattern> {
267        let ast = parse_python(code);
268        match ast {
269            Mod::Module(module) => extract_access_patterns(&module.body).unwrap(),
270            _ => vec![],
271        }
272    }
273
274    // ========================================
275    // extract_subscript_pattern tests
276    // ========================================
277
278    #[test]
279    fn test_subscript_simple_event_access() {
280        let patterns = get_patterns(
281            r#"
282def handler(event, context):
283    x = event['Records']
284"#,
285        );
286        assert!(patterns.iter().any(|p| p.access_chain == vec!["Records"]));
287    }
288
289    #[test]
290    fn test_subscript_nested_event_access() {
291        let patterns = get_patterns(
292            r#"
293def handler(event, context):
294    x = event['Records']['s3']
295"#,
296        );
297        assert!(patterns
298            .iter()
299            .any(|p| p.access_chain == vec!["Records", "s3"]));
300    }
301
302    #[test]
303    fn test_subscript_deeply_nested_access() {
304        let patterns = get_patterns(
305            r#"
306def handler(event, context):
307    x = event['Records']['s3']['bucket']['name']
308"#,
309        );
310        assert!(patterns
311            .iter()
312            .any(|p| p.access_chain == vec!["Records", "s3", "bucket", "name"]));
313    }
314
315    #[test]
316    fn test_subscript_numeric_index_skipped() {
317        let patterns = get_patterns(
318            r#"
319def handler(event, context):
320    x = event['Records'][0]['s3']
321"#,
322        );
323        // Numeric indices should be skipped
324        assert!(patterns
325            .iter()
326            .any(|p| p.access_chain == vec!["Records", "s3"]));
327    }
328
329    #[test]
330    fn test_subscript_non_event_ignored() {
331        let patterns = get_patterns(
332            r#"
333def handler(event, context):
334    data = {'foo': 'bar'}
335    x = data['foo']
336"#,
337        );
338        // Non-event subscripts should not produce patterns
339        assert!(patterns.is_empty());
340    }
341
342    #[test]
343    fn test_subscript_mixed_with_attribute() {
344        let patterns = get_patterns(
345            r#"
346def handler(event, context):
347    x = event['Records'][0].data
348"#,
349        );
350        // Should capture mixed patterns
351        assert!(!patterns.is_empty());
352    }
353
354    // ========================================
355    // extract_attribute_pattern tests
356    // ========================================
357
358    #[test]
359    fn test_attribute_simple_access() {
360        let patterns = get_patterns(
361            r#"
362def handler(event, context):
363    x = event.body
364"#,
365        );
366        assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
367        assert!(patterns
368            .iter()
369            .any(|p| p.pattern_type == PatternType::Attribute));
370    }
371
372    #[test]
373    fn test_attribute_nested_access() {
374        let patterns = get_patterns(
375            r#"
376def handler(event, context):
377    x = event.body.data
378"#,
379        );
380        assert!(patterns
381            .iter()
382            .any(|p| p.access_chain == vec!["body", "data"]));
383    }
384
385    #[test]
386    fn test_attribute_non_event_ignored() {
387        let patterns = get_patterns(
388            r#"
389def handler(event, context):
390    obj = SomeClass()
391    x = obj.attribute
392"#,
393        );
394        assert!(patterns.is_empty());
395    }
396
397    // ========================================
398    // Statement extraction tests
399    // ========================================
400
401    #[test]
402    fn test_extract_from_assign() {
403        let patterns = get_patterns(
404            r#"
405def handler(event, context):
406    x = event['body']
407"#,
408        );
409        assert!(!patterns.is_empty());
410    }
411
412    #[test]
413    fn test_extract_from_annotated_assign() {
414        let patterns = get_patterns(
415            r#"
416def handler(event, context):
417    x: str = event['body']
418"#,
419        );
420        assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
421    }
422
423    #[test]
424    fn test_extract_from_annotated_assign_no_value() {
425        let patterns = get_patterns(
426            r#"
427def handler(event, context):
428    x: str
429"#,
430        );
431        // No event access, should be empty
432        assert!(patterns.is_empty());
433    }
434
435    #[test]
436    fn test_extract_from_return() {
437        let patterns = get_patterns(
438            r#"
439def handler(event, context):
440    return event['data']
441"#,
442        );
443        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
444    }
445
446    #[test]
447    fn test_extract_from_return_none() {
448        let patterns = get_patterns(
449            r#"
450def handler(event, context):
451    return
452"#,
453        );
454        assert!(patterns.is_empty());
455    }
456
457    #[test]
458    fn test_extract_from_if_test() {
459        let patterns = get_patterns(
460            r#"
461def handler(event, context):
462    if event['status']:
463        pass
464"#,
465        );
466        assert!(patterns.iter().any(|p| p.access_chain == vec!["status"]));
467    }
468
469    #[test]
470    fn test_extract_from_if_body() {
471        let patterns = get_patterns(
472            r#"
473def handler(event, context):
474    if True:
475        x = event['body']
476"#,
477        );
478        assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
479    }
480
481    #[test]
482    fn test_extract_from_if_else() {
483        let patterns = get_patterns(
484            r#"
485def handler(event, context):
486    if True:
487        pass
488    else:
489        x = event['data']
490"#,
491        );
492        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
493    }
494
495    #[test]
496    fn test_extract_from_for_iter() {
497        let patterns = get_patterns(
498            r#"
499def handler(event, context):
500    for record in event['Records']:
501        pass
502"#,
503        );
504        assert!(patterns.iter().any(|p| p.access_chain == vec!["Records"]));
505    }
506
507    #[test]
508    fn test_extract_from_for_body() {
509        let patterns = get_patterns(
510            r#"
511def handler(event, context):
512    for i in range(10):
513        x = event['body']
514"#,
515        );
516        assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
517    }
518
519    #[test]
520    fn test_extract_from_while_test() {
521        let patterns = get_patterns(
522            r#"
523def handler(event, context):
524    while event['status']:
525        pass
526"#,
527        );
528        assert!(patterns.iter().any(|p| p.access_chain == vec!["status"]));
529    }
530
531    #[test]
532    fn test_extract_from_while_body() {
533        let patterns = get_patterns(
534            r#"
535def handler(event, context):
536    while True:
537        x = event['data']
538        break
539"#,
540        );
541        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
542    }
543
544    #[test]
545    fn test_extract_from_with_context() {
546        let patterns = get_patterns(
547            r#"
548def handler(event, context):
549    with event['resource'] as r:
550        pass
551"#,
552        );
553        assert!(patterns.iter().any(|p| p.access_chain == vec!["resource"]));
554    }
555
556    #[test]
557    fn test_extract_from_with_body() {
558        let patterns = get_patterns(
559            r#"
560def handler(event, context):
561    with open('file') as f:
562        x = event['data']
563"#,
564        );
565        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
566    }
567
568    #[test]
569    fn test_extract_from_expr_stmt() {
570        let patterns = get_patterns(
571            r#"
572def handler(event, context):
573    event['action']
574"#,
575        );
576        assert!(patterns.iter().any(|p| p.access_chain == vec!["action"]));
577    }
578
579    // ========================================
580    // Expression extraction tests
581    // ========================================
582
583    #[test]
584    fn test_extract_from_call_func() {
585        let patterns = get_patterns(
586            r#"
587def handler(event, context):
588    event['handler']()
589"#,
590        );
591        assert!(patterns.iter().any(|p| p.access_chain == vec!["handler"]));
592    }
593
594    #[test]
595    fn test_extract_from_call_args() {
596        let patterns = get_patterns(
597            r#"
598def handler(event, context):
599    process(event['data'])
600"#,
601        );
602        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
603    }
604
605    #[test]
606    fn test_extract_from_call_kwargs() {
607        let patterns = get_patterns(
608            r#"
609def handler(event, context):
610    process(data=event['body'])
611"#,
612        );
613        assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
614    }
615
616    #[test]
617    fn test_extract_from_binop() {
618        let patterns = get_patterns(
619            r#"
620def handler(event, context):
621    x = event['a'] + event['b']
622"#,
623        );
624        assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
625        assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
626    }
627
628    #[test]
629    fn test_extract_from_compare() {
630        let patterns = get_patterns(
631            r#"
632def handler(event, context):
633    if event['a'] == event['b']:
634        pass
635"#,
636        );
637        assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
638        assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
639    }
640
641    #[test]
642    fn test_extract_from_boolop() {
643        let patterns = get_patterns(
644            r#"
645def handler(event, context):
646    if event['a'] and event['b']:
647        pass
648"#,
649        );
650        assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
651        assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
652    }
653
654    #[test]
655    fn test_extract_from_unaryop() {
656        let patterns = get_patterns(
657            r#"
658def handler(event, context):
659    if not event['flag']:
660        pass
661"#,
662        );
663        assert!(patterns.iter().any(|p| p.access_chain == vec!["flag"]));
664    }
665
666    #[test]
667    fn test_extract_from_ifexp() {
668        let patterns = get_patterns(
669            r#"
670def handler(event, context):
671    x = event['a'] if event['cond'] else event['b']
672"#,
673        );
674        assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
675        assert!(patterns.iter().any(|p| p.access_chain == vec!["cond"]));
676        assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
677    }
678
679    #[test]
680    fn test_extract_from_dict_values() {
681        let patterns = get_patterns(
682            r#"
683def handler(event, context):
684    return {'result': event['data']}
685"#,
686        );
687        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
688    }
689
690    #[test]
691    fn test_extract_from_list_elements() {
692        let patterns = get_patterns(
693            r#"
694def handler(event, context):
695    return [event['a'], event['b']]
696"#,
697        );
698        assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
699        assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
700    }
701
702    #[test]
703    fn test_extract_from_tuple_elements() {
704        let patterns = get_patterns(
705            r#"
706def handler(event, context):
707    return (event['a'], event['b'])
708"#,
709        );
710        assert!(patterns.iter().any(|p| p.access_chain == vec!["a"]));
711        assert!(patterns.iter().any(|p| p.access_chain == vec!["b"]));
712    }
713
714    // ========================================
715    // Edge cases and complex scenarios
716    // ========================================
717
718    #[test]
719    fn test_multiple_functions_in_module() {
720        let patterns = get_patterns(
721            r#"
722def helper():
723    pass
724
725def handler(event, context):
726    x = event['data']
727    return x
728
729def another_helper():
730    pass
731"#,
732        );
733        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
734    }
735
736    #[test]
737    fn test_empty_function() {
738        let patterns = get_patterns(
739            r#"
740def handler(event, context):
741    pass
742"#,
743        );
744        assert!(patterns.is_empty());
745    }
746
747    #[test]
748    fn test_no_function_def() {
749        let patterns = get_patterns(
750            r#"
751x = 1
752y = 2
753"#,
754        );
755        assert!(patterns.is_empty());
756    }
757
758    #[test]
759    fn test_nested_function() {
760        let patterns = get_patterns(
761            r#"
762def handler(event, context):
763    def inner():
764        return event['inner_data']
765    return event['outer_data']
766"#,
767        );
768        // Only outer patterns are extracted
769        assert!(patterns
770            .iter()
771            .any(|p| p.access_chain == vec!["outer_data"]));
772    }
773
774    #[test]
775    fn test_complex_s3_pattern() {
776        let patterns = get_patterns(
777            r#"
778def handler(event, context):
779    bucket = event['Records'][0]['s3']['bucket']['name']
780    key = event['Records'][0]['s3']['object']['key']
781    return {'bucket': bucket, 'key': key}
782"#,
783        );
784        assert!(patterns.len() >= 2);
785        assert!(patterns
786            .iter()
787            .any(|p| p.access_chain.contains(&"bucket".to_string())));
788        assert!(patterns
789            .iter()
790            .any(|p| p.access_chain.contains(&"object".to_string())));
791    }
792
793    #[test]
794    fn test_api_gateway_pattern() {
795        let patterns = get_patterns(
796            r#"
797def handler(event, context):
798    method = event['requestContext']['http']['method']
799    path = event['requestContext']['http']['path']
800    body = event['body']
801    return {'method': method, 'path': path}
802"#,
803        );
804        assert!(patterns.len() >= 3);
805        assert!(patterns
806            .iter()
807            .any(|p| p.access_chain.contains(&"requestContext".to_string())));
808        assert!(patterns.iter().any(|p| p.access_chain == vec!["body"]));
809    }
810
811    #[test]
812    fn test_eventbridge_pattern() {
813        let patterns = get_patterns(
814            r#"
815def handler(event, context):
816    detail_type = event['detail-type']
817    detail = event['detail']
818    source = event['source']
819    return None
820"#,
821        );
822        assert!(patterns
823            .iter()
824            .any(|p| p.access_chain == vec!["detail-type"]));
825        assert!(patterns.iter().any(|p| p.access_chain == vec!["detail"]));
826        assert!(patterns.iter().any(|p| p.access_chain == vec!["source"]));
827    }
828
829    #[test]
830    fn test_pattern_type_is_mixed_for_subscript() {
831        let patterns = get_patterns(
832            r#"
833def handler(event, context):
834    x = event['data']
835"#,
836        );
837        assert!(patterns
838            .iter()
839            .all(|p| p.pattern_type == PatternType::Mixed));
840    }
841
842    #[test]
843    fn test_pattern_type_is_attribute_for_dot_access() {
844        let patterns = get_patterns(
845            r#"
846def handler(event, context):
847    x = event.data
848"#,
849        );
850        assert!(patterns
851            .iter()
852            .any(|p| p.pattern_type == PatternType::Attribute));
853    }
854
855    #[test]
856    fn test_multiple_targets_in_assign() {
857        let patterns = get_patterns(
858            r#"
859def handler(event, context):
860    x = y = event['data']
861"#,
862        );
863        assert!(patterns.iter().any(|p| p.access_chain == vec!["data"]));
864    }
865}