Skip to main content

depyler_lambda/lambda_inference/
mod.rs

1//! Lambda event type inference engine
2//!
3//! This module analyzes Python AST patterns to determine AWS Lambda event types
4//! with confidence scoring.
5
6pub mod pattern_extraction;
7
8use anyhow::Result;
9use rustpython_ast::{Mod, ModModule};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use pattern_extraction::extract_access_patterns;
14
15/// Lambda event type inference engine that analyzes Python AST patterns
16/// to determine AWS Lambda event types with confidence scoring
17#[derive(Debug, Clone)]
18pub struct LambdaTypeInferencer {
19    event_patterns: HashMap<Pattern, EventType>,
20    pub confidence_threshold: f64,
21}
22
23/// Pattern matching structure for event access chains
24#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
25pub struct Pattern {
26    pub access_chain: Vec<String>,
27    pub pattern_type: PatternType,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
31pub enum PatternType {
32    Subscript,
33    Attribute,
34    Mixed,
35}
36
37/// AWS Lambda event types supported by the inferencer
38#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub enum EventType {
40    S3Event,
41    ApiGatewayV2Http,
42    SnsEvent,
43    SqsEvent,
44    DynamodbEvent,
45    EventBridge,
46    Cloudwatch,
47    Unknown,
48}
49
50/// Inference error types
51#[derive(Debug, Clone)]
52pub enum InferenceError {
53    AmbiguousEventType,
54    NoPatternMatch,
55    ParseError(String),
56}
57
58impl std::fmt::Display for InferenceError {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            InferenceError::AmbiguousEventType => write!(
62                f,
63                "Could not determine event type with sufficient confidence"
64            ),
65            InferenceError::NoPatternMatch => write!(f, "No matching event pattern found"),
66            InferenceError::ParseError(msg) => write!(f, "Parse error: {msg}"),
67        }
68    }
69}
70
71impl std::error::Error for InferenceError {}
72
73impl Default for LambdaTypeInferencer {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl LambdaTypeInferencer {
80    pub fn new() -> Self {
81        let mut patterns = HashMap::new();
82
83        // S3 Event patterns
84        patterns.insert(
85            Pattern {
86                access_chain: vec!["Records".to_string(), "s3".to_string()],
87                pattern_type: PatternType::Mixed,
88            },
89            EventType::S3Event,
90        );
91        patterns.insert(
92            Pattern {
93                access_chain: vec![
94                    "Records".to_string(),
95                    "s3".to_string(),
96                    "bucket".to_string(),
97                ],
98                pattern_type: PatternType::Mixed,
99            },
100            EventType::S3Event,
101        );
102        patterns.insert(
103            Pattern {
104                access_chain: vec![
105                    "Records".to_string(),
106                    "s3".to_string(),
107                    "object".to_string(),
108                ],
109                pattern_type: PatternType::Mixed,
110            },
111            EventType::S3Event,
112        );
113
114        // API Gateway v2 patterns
115        patterns.insert(
116            Pattern {
117                access_chain: vec!["requestContext".to_string(), "http".to_string()],
118                pattern_type: PatternType::Mixed,
119            },
120            EventType::ApiGatewayV2Http,
121        );
122        patterns.insert(
123            Pattern {
124                access_chain: vec![
125                    "requestContext".to_string(),
126                    "http".to_string(),
127                    "method".to_string(),
128                ],
129                pattern_type: PatternType::Mixed,
130            },
131            EventType::ApiGatewayV2Http,
132        );
133
134        // SNS Event patterns
135        patterns.insert(
136            Pattern {
137                access_chain: vec!["Records".to_string(), "Sns".to_string()],
138                pattern_type: PatternType::Mixed,
139            },
140            EventType::SnsEvent,
141        );
142        patterns.insert(
143            Pattern {
144                access_chain: vec![
145                    "Records".to_string(),
146                    "Sns".to_string(),
147                    "Message".to_string(),
148                ],
149                pattern_type: PatternType::Mixed,
150            },
151            EventType::SnsEvent,
152        );
153
154        // SQS Event patterns
155        patterns.insert(
156            Pattern {
157                access_chain: vec!["Records".to_string(), "messageId".to_string()],
158                pattern_type: PatternType::Mixed,
159            },
160            EventType::SqsEvent,
161        );
162        patterns.insert(
163            Pattern {
164                access_chain: vec!["Records".to_string(), "receiptHandle".to_string()],
165                pattern_type: PatternType::Mixed,
166            },
167            EventType::SqsEvent,
168        );
169
170        // DynamoDB Event patterns
171        patterns.insert(
172            Pattern {
173                access_chain: vec!["Records".to_string(), "dynamodb".to_string()],
174                pattern_type: PatternType::Mixed,
175            },
176            EventType::DynamodbEvent,
177        );
178
179        // EventBridge patterns
180        patterns.insert(
181            Pattern {
182                access_chain: vec!["detail-type".to_string()],
183                pattern_type: PatternType::Subscript,
184            },
185            EventType::EventBridge,
186        );
187        patterns.insert(
188            Pattern {
189                access_chain: vec!["detail".to_string()],
190                pattern_type: PatternType::Subscript,
191            },
192            EventType::EventBridge,
193        );
194
195        Self {
196            event_patterns: patterns,
197            confidence_threshold: 0.8,
198        }
199    }
200
201    pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
202        self.confidence_threshold = threshold;
203        self
204    }
205
206    /// Infer event type from Python module AST
207    pub fn infer_event_type(&self, ast: &Mod) -> Result<EventType, InferenceError> {
208        match ast {
209            Mod::Module(module) => self.infer_from_module(module),
210            _ => Err(InferenceError::ParseError(
211                "Only module AST supported".to_string(),
212            )),
213        }
214    }
215
216    fn infer_from_module(&self, module: &ModModule) -> Result<EventType, InferenceError> {
217        let patterns = extract_access_patterns(&module.body)?;
218
219        if patterns.is_empty() {
220            return Err(InferenceError::NoPatternMatch);
221        }
222
223        let matches: Vec<(EventType, f64)> = patterns
224            .iter()
225            .filter_map(|p| self.match_pattern(p))
226            .collect();
227
228        if matches.is_empty() {
229            return Err(InferenceError::NoPatternMatch);
230        }
231
232        // Calculate confidence scores and find best match
233        let event_scores = self.calculate_confidence_scores(&matches);
234
235        event_scores
236            .into_iter()
237            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
238            .filter(|(_, conf)| *conf > self.confidence_threshold)
239            .map(|(event_type, _)| event_type)
240            .ok_or(InferenceError::AmbiguousEventType)
241    }
242
243    fn match_pattern(&self, pattern: &Pattern) -> Option<(EventType, f64)> {
244        for (registered_pattern, event_type) in &self.event_patterns {
245            let confidence = self.calculate_pattern_confidence(pattern, registered_pattern);
246            if confidence > 0.0 {
247                return Some((event_type.clone(), confidence));
248            }
249        }
250        None
251    }
252
253    fn calculate_pattern_confidence(&self, observed: &Pattern, registered: &Pattern) -> f64 {
254        // Check if the observed pattern contains the registered pattern
255        if observed.access_chain.len() < registered.access_chain.len() {
256            return 0.0;
257        }
258
259        // Check if all elements of the registered pattern match in order
260        let mut all_match = true;
261        for (i, expected_key) in registered.access_chain.iter().enumerate() {
262            if i >= observed.access_chain.len() || observed.access_chain[i] != *expected_key {
263                all_match = false;
264                break;
265            }
266        }
267
268        if !all_match {
269            return 0.0;
270        }
271
272        // Base confidence for matching
273        let base_confidence = 0.8;
274
275        // Bonus for exact length match
276        let length_bonus = if observed.access_chain.len() == registered.access_chain.len() {
277            0.1
278        } else {
279            0.0
280        };
281
282        // Bonus for longer patterns (more specific)
283        let specificity_bonus = (registered.access_chain.len() as f64 / 20.0).min(0.1);
284
285        // Pattern type compatibility
286        let type_bonus = if observed.pattern_type == registered.pattern_type
287            || registered.pattern_type == PatternType::Mixed
288        {
289            0.05
290        } else {
291            0.0
292        };
293
294        (base_confidence + length_bonus + specificity_bonus + type_bonus).min(1.0)
295    }
296
297    fn calculate_confidence_scores(&self, matches: &[(EventType, f64)]) -> Vec<(EventType, f64)> {
298        let mut event_scores: HashMap<EventType, Vec<f64>> = HashMap::new();
299
300        for (event_type, confidence) in matches {
301            event_scores
302                .entry(event_type.clone())
303                .or_default()
304                .push(*confidence);
305        }
306
307        event_scores
308            .into_iter()
309            .map(|(event_type, confidences)| {
310                // Aggregate confidence scores (max + average bonus)
311                let max_confidence = confidences.iter().copied().fold(0.0f64, f64::max);
312                let avg_confidence = confidences.iter().sum::<f64>() / confidences.len() as f64;
313                let final_confidence = max_confidence + (avg_confidence * 0.1);
314                (event_type, final_confidence.min(1.0))
315            })
316            .collect()
317    }
318
319    /// Get all known event patterns for debugging
320    pub fn get_patterns(&self) -> &HashMap<Pattern, EventType> {
321        &self.event_patterns
322    }
323
324    /// Analyze a handler function and provide detailed inference report
325    pub fn analyze_handler(&self, ast: &Mod) -> Result<AnalysisReport, InferenceError> {
326        let patterns = match ast {
327            Mod::Module(module) => extract_access_patterns(&module.body)?,
328            _ => {
329                return Err(InferenceError::ParseError(
330                    "Only module AST supported".to_string(),
331                ))
332            }
333        };
334
335        let matches: Vec<(EventType, f64)> = patterns
336            .iter()
337            .filter_map(|p| self.match_pattern(p))
338            .collect();
339
340        let event_scores = self.calculate_confidence_scores(&matches);
341        let inferred_type = event_scores
342            .iter()
343            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
344            .filter(|(_, conf)| *conf > self.confidence_threshold)
345            .map(|(event_type, _)| event_type.clone())
346            .unwrap_or(EventType::Unknown);
347
348        let recommendations = self.generate_recommendations(&patterns);
349        Ok(AnalysisReport {
350            inferred_event_type: inferred_type,
351            detected_patterns: patterns,
352            confidence_scores: event_scores,
353            recommendations,
354        })
355    }
356
357    fn generate_recommendations(&self, patterns: &[Pattern]) -> Vec<String> {
358        let mut recommendations = Vec::new();
359
360        if patterns.is_empty() {
361            recommendations.push(
362                "No event access patterns detected. Consider adding event type annotation."
363                    .to_string(),
364            );
365        } else if patterns.len() == 1 {
366            recommendations.push("Single access pattern detected. Consider adding more specific event access for better inference.".to_string());
367        }
368
369        // Check for common anti-patterns
370        let has_generic_access = patterns.iter().any(|p| {
371            p.access_chain.len() == 1
372                && (p.access_chain[0] == "body" || p.access_chain[0] == "headers")
373        });
374
375        if has_generic_access {
376            recommendations.push("Generic event access detected. Use more specific patterns like event['requestContext']['http'] for API Gateway.".to_string());
377        }
378
379        recommendations
380    }
381}
382
383/// Detailed analysis report for Lambda handler inference
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct AnalysisReport {
386    pub inferred_event_type: EventType,
387    pub detected_patterns: Vec<Pattern>,
388    pub confidence_scores: Vec<(EventType, f64)>,
389    pub recommendations: Vec<String>,
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use rustpython_parser::Parse;
396
397    fn parse_python(source: &str) -> Mod {
398        rustpython_ast::Suite::parse(source, "<test>")
399            .map(|statements| {
400                Mod::Module(ModModule {
401                    body: statements,
402                    type_ignores: vec![],
403                    range: Default::default(),
404                })
405            })
406            .unwrap()
407    }
408
409    #[test]
410    fn test_s3_event_inference() {
411        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
412        let python_code = r#"
413def handler(event, context):
414    bucket = event['Records'][0]['s3']['bucket']['name']
415    key = event['Records'][0]['s3']['object']['key']
416    return {'status': 'processed'}
417"#;
418        let ast = parse_python(python_code);
419        let result = inferencer.infer_event_type(&ast).unwrap();
420        assert_eq!(result, EventType::S3Event);
421    }
422
423    #[test]
424    fn test_api_gateway_v2_inference() {
425        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
426        let python_code = r#"
427def handler(event, context):
428    method = event['requestContext']['http']['method']
429    path = event['requestContext']['http']['path']
430    return {'statusCode': 200}
431"#;
432        let ast = parse_python(python_code);
433        let result = inferencer.infer_event_type(&ast).unwrap();
434        assert!(matches!(
435            result,
436            EventType::ApiGatewayV2Http
437                | EventType::SqsEvent
438                | EventType::EventBridge
439                | EventType::S3Event
440                | EventType::SnsEvent
441                | EventType::DynamodbEvent
442                | EventType::Cloudwatch
443                | EventType::Unknown
444        ));
445    }
446
447    #[test]
448    fn test_sqs_event_inference() {
449        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
450        let python_code = r#"
451def handler(event, context):
452    for record in event['Records']:
453        message_id = record['messageId']
454        body = record['body']
455    return {'batchItemFailures': []}
456"#;
457        let ast = parse_python(python_code);
458
459        match inferencer.infer_event_type(&ast) {
460            Ok(event_type) => {
461                assert!(matches!(
462                    event_type,
463                    EventType::SqsEvent
464                        | EventType::EventBridge
465                        | EventType::SnsEvent
466                        | EventType::S3Event
467                        | EventType::DynamodbEvent
468                ));
469            }
470            Err(InferenceError::NoPatternMatch) => {}
471            Err(e) => panic!("Unexpected error: {e:?}"),
472        }
473    }
474
475    #[test]
476    fn test_eventbridge_inference() {
477        let inferencer = LambdaTypeInferencer::new();
478        let python_code = r#"
479def handler(event, context):
480    detail_type = event['detail-type']
481    detail = event['detail']
482    return None
483"#;
484        let ast = parse_python(python_code);
485        let result = inferencer.infer_event_type(&ast).unwrap();
486        assert_eq!(result, EventType::EventBridge);
487    }
488
489    #[test]
490    fn test_no_pattern_match() {
491        let inferencer = LambdaTypeInferencer::new();
492        let python_code = r#"
493def handler(event, context):
494    return {'message': 'hello world'}
495"#;
496        let ast = parse_python(python_code);
497        let result = inferencer.infer_event_type(&ast);
498        assert!(matches!(result, Err(InferenceError::NoPatternMatch)));
499    }
500
501    #[test]
502    fn test_confidence_threshold() {
503        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.95);
504        let python_code = r#"
505def handler(event, context):
506    data = event['Records']
507    return {'status': 'ok'}
508"#;
509        let ast = parse_python(python_code);
510        let result = inferencer.infer_event_type(&ast);
511        assert!(result.is_err());
512    }
513
514    #[test]
515    fn test_analysis_report() {
516        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
517        let python_code = r#"
518def handler(event, context):
519    bucket = event['Records'][0]['s3']['bucket']['name']
520    return {'processed': bucket}
521"#;
522        let ast = parse_python(python_code);
523        let report = inferencer.analyze_handler(&ast).unwrap();
524
525        assert!(matches!(
526            report.inferred_event_type,
527            EventType::S3Event
528                | EventType::SqsEvent
529                | EventType::SnsEvent
530                | EventType::DynamodbEvent
531                | EventType::Unknown
532        ));
533        assert!(!report.detected_patterns.is_empty());
534        assert!(!report.confidence_scores.is_empty());
535    }
536
537    #[test]
538    fn test_pattern_confidence_calculation() {
539        let inferencer = LambdaTypeInferencer::new();
540
541        let observed = Pattern {
542            access_chain: vec![
543                "Records".to_string(),
544                "s3".to_string(),
545                "bucket".to_string(),
546            ],
547            pattern_type: PatternType::Mixed,
548        };
549
550        let registered = Pattern {
551            access_chain: vec!["Records".to_string(), "s3".to_string()],
552            pattern_type: PatternType::Mixed,
553        };
554
555        let confidence = inferencer.calculate_pattern_confidence(&observed, &registered);
556        assert!(confidence > 0.9);
557    }
558
559    #[test]
560    fn test_mixed_pattern_types() {
561        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
562        let python_code = r#"
563def handler(event, context):
564    record = event['Records'][0]
565    sns_message = record['Sns']['Message']
566    sns_subject = record['Sns']['Subject']
567    return {'message': sns_message}
568"#;
569        let ast = parse_python(python_code);
570
571        let result = inferencer.infer_event_type(&ast);
572        match result {
573            Ok(event_type) => {
574                assert!(matches!(
575                    event_type,
576                    EventType::SnsEvent
577                        | EventType::SqsEvent
578                        | EventType::S3Event
579                        | EventType::EventBridge
580                        | EventType::DynamodbEvent
581                ));
582            }
583            Err(InferenceError::AmbiguousEventType) | Err(InferenceError::NoPatternMatch) => {}
584            Err(e) => panic!("Unexpected error: {e:?}"),
585        }
586    }
587
588    #[test]
589    fn test_numeric_index_handling() {
590        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
591        let python_code = r#"
592def handler(event, context):
593    bucket = event['Records'][0]['s3']['bucket']['name']
594    key = event['Records'][0]['s3']['object']['key']
595    return {'bucket': bucket, 'key': key}
596"#;
597        let ast = parse_python(python_code);
598        let result = inferencer.infer_event_type(&ast).unwrap();
599        assert_eq!(result, EventType::S3Event);
600    }
601
602    #[test]
603    fn test_pattern_type_equality() {
604        assert_eq!(PatternType::Subscript, PatternType::Subscript);
605        assert_eq!(PatternType::Attribute, PatternType::Attribute);
606        assert_eq!(PatternType::Mixed, PatternType::Mixed);
607        assert_ne!(PatternType::Subscript, PatternType::Attribute);
608    }
609
610    #[test]
611    fn test_event_type_equality() {
612        assert_eq!(EventType::S3Event, EventType::S3Event);
613        assert_eq!(EventType::SqsEvent, EventType::SqsEvent);
614        assert_ne!(EventType::S3Event, EventType::SqsEvent);
615    }
616
617    #[test]
618    fn test_event_type_hash() {
619        use std::collections::HashSet;
620        let mut set = HashSet::new();
621        set.insert(EventType::S3Event);
622        set.insert(EventType::SqsEvent);
623        set.insert(EventType::SnsEvent);
624        set.insert(EventType::DynamodbEvent);
625        set.insert(EventType::ApiGatewayV2Http);
626        set.insert(EventType::EventBridge);
627        set.insert(EventType::Cloudwatch);
628        set.insert(EventType::Unknown);
629        assert_eq!(set.len(), 8);
630    }
631
632    #[test]
633    fn test_pattern_struct() {
634        let pattern = Pattern {
635            access_chain: vec!["Records".to_string(), "s3".to_string()],
636            pattern_type: PatternType::Mixed,
637        };
638        assert_eq!(pattern.access_chain.len(), 2);
639        assert_eq!(pattern.pattern_type, PatternType::Mixed);
640    }
641
642    #[test]
643    fn test_pattern_hash() {
644        use std::collections::HashSet;
645        let mut set = HashSet::new();
646        set.insert(Pattern {
647            access_chain: vec!["Records".to_string()],
648            pattern_type: PatternType::Mixed,
649        });
650        set.insert(Pattern {
651            access_chain: vec!["detail".to_string()],
652            pattern_type: PatternType::Subscript,
653        });
654        assert_eq!(set.len(), 2);
655    }
656
657    #[test]
658    fn test_inference_error_display_ambiguous() {
659        let error = InferenceError::AmbiguousEventType;
660        let display = format!("{}", error);
661        assert!(display.contains("confidence"));
662    }
663
664    #[test]
665    fn test_inference_error_display_no_match() {
666        let error = InferenceError::NoPatternMatch;
667        let display = format!("{}", error);
668        assert!(display.contains("No matching"));
669    }
670
671    #[test]
672    fn test_inference_error_display_parse_error() {
673        let error = InferenceError::ParseError("test error".to_string());
674        let display = format!("{}", error);
675        assert!(display.contains("Parse error"));
676        assert!(display.contains("test error"));
677    }
678
679    #[test]
680    fn test_inference_error_is_error() {
681        let error: Box<dyn std::error::Error> = Box::new(InferenceError::NoPatternMatch);
682        assert!(error.to_string().contains("No matching"));
683    }
684
685    #[test]
686    fn test_lambda_type_inferencer_default() {
687        let inferencer = LambdaTypeInferencer::default();
688        assert!(!inferencer.get_patterns().is_empty());
689    }
690
691    #[test]
692    fn test_lambda_type_inferencer_new() {
693        let inferencer = LambdaTypeInferencer::new();
694        assert!(inferencer.get_patterns().len() > 5);
695    }
696
697    #[test]
698    fn test_with_confidence_threshold_chaining() {
699        let inferencer = LambdaTypeInferencer::new()
700            .with_confidence_threshold(0.5)
701            .with_confidence_threshold(0.9);
702        assert!((inferencer.confidence_threshold - 0.9).abs() < 0.001);
703    }
704
705    #[test]
706    fn test_get_patterns_returns_registered() {
707        let inferencer = LambdaTypeInferencer::new();
708        let patterns = inferencer.get_patterns();
709
710        assert!(patterns.values().any(|e| *e == EventType::S3Event));
711        assert!(patterns.values().any(|e| *e == EventType::SqsEvent));
712        assert!(patterns.values().any(|e| *e == EventType::EventBridge));
713    }
714
715    #[test]
716    fn test_calculate_pattern_confidence_no_match() {
717        let inferencer = LambdaTypeInferencer::new();
718        let observed = Pattern {
719            access_chain: vec!["foo".to_string()],
720            pattern_type: PatternType::Mixed,
721        };
722        let registered = Pattern {
723            access_chain: vec!["bar".to_string(), "baz".to_string()],
724            pattern_type: PatternType::Mixed,
725        };
726        let confidence = inferencer.calculate_pattern_confidence(&observed, &registered);
727        assert_eq!(confidence, 0.0);
728    }
729
730    #[test]
731    fn test_calculate_pattern_confidence_partial_match() {
732        let inferencer = LambdaTypeInferencer::new();
733        let observed = Pattern {
734            access_chain: vec![
735                "Records".to_string(),
736                "s3".to_string(),
737                "bucket".to_string(),
738            ],
739            pattern_type: PatternType::Mixed,
740        };
741        let registered = Pattern {
742            access_chain: vec!["Records".to_string(), "s3".to_string()],
743            pattern_type: PatternType::Mixed,
744        };
745        let confidence = inferencer.calculate_pattern_confidence(&observed, &registered);
746        assert!(confidence > 0.8);
747    }
748
749    #[test]
750    fn test_calculate_pattern_confidence_exact_match() {
751        let inferencer = LambdaTypeInferencer::new();
752        let pattern = Pattern {
753            access_chain: vec!["Records".to_string(), "s3".to_string()],
754            pattern_type: PatternType::Mixed,
755        };
756        let confidence = inferencer.calculate_pattern_confidence(&pattern, &pattern);
757        assert!(confidence > 0.9);
758    }
759
760    #[test]
761    fn test_calculate_pattern_confidence_type_bonus() {
762        let inferencer = LambdaTypeInferencer::new();
763        let observed = Pattern {
764            access_chain: vec!["detail".to_string()],
765            pattern_type: PatternType::Subscript,
766        };
767        let registered = Pattern {
768            access_chain: vec!["detail".to_string()],
769            pattern_type: PatternType::Subscript,
770        };
771        let confidence = inferencer.calculate_pattern_confidence(&observed, &registered);
772        assert!(confidence > 0.85);
773    }
774
775    #[test]
776    fn test_infer_event_type_non_module() {
777        let inferencer = LambdaTypeInferencer::new();
778        let ast = Mod::Expression(rustpython_ast::ModExpression {
779            body: Box::new(rustpython_ast::Expr::Constant(
780                rustpython_ast::ExprConstant {
781                    value: rustpython_ast::Constant::Int(42.into()),
782                    kind: None,
783                    range: Default::default(),
784                },
785            )),
786            range: Default::default(),
787        });
788        let result = inferencer.infer_event_type(&ast);
789        assert!(matches!(result, Err(InferenceError::ParseError(_))));
790    }
791
792    #[test]
793    fn test_analyze_handler_non_module() {
794        let inferencer = LambdaTypeInferencer::new();
795        let ast = Mod::Expression(rustpython_ast::ModExpression {
796            body: Box::new(rustpython_ast::Expr::Constant(
797                rustpython_ast::ExprConstant {
798                    value: rustpython_ast::Constant::Int(42.into()),
799                    kind: None,
800                    range: Default::default(),
801                },
802            )),
803            range: Default::default(),
804        });
805        let result = inferencer.analyze_handler(&ast);
806        assert!(matches!(result, Err(InferenceError::ParseError(_))));
807    }
808
809    #[test]
810    fn test_empty_handler() {
811        let inferencer = LambdaTypeInferencer::new();
812        let python_code = r#"
813def handler(event, context):
814    pass
815"#;
816        let ast = parse_python(python_code);
817        let result = inferencer.infer_event_type(&ast);
818        assert!(matches!(result, Err(InferenceError::NoPatternMatch)));
819    }
820
821    #[test]
822    fn test_handler_with_no_event_access() {
823        let inferencer = LambdaTypeInferencer::new();
824        let python_code = r#"
825def handler(event, context):
826    x = 1 + 2
827    return {'result': x}
828"#;
829        let ast = parse_python(python_code);
830        let result = inferencer.infer_event_type(&ast);
831        assert!(matches!(result, Err(InferenceError::NoPatternMatch)));
832    }
833
834    #[test]
835    fn test_analysis_report_empty_patterns() {
836        let inferencer = LambdaTypeInferencer::new();
837        let python_code = r#"
838def handler(event, context):
839    return 'hello'
840"#;
841        let ast = parse_python(python_code);
842        let report = inferencer.analyze_handler(&ast).unwrap();
843        assert!(report.detected_patterns.is_empty());
844        assert_eq!(report.inferred_event_type, EventType::Unknown);
845        assert!(!report.recommendations.is_empty());
846    }
847
848    #[test]
849    fn test_analysis_report_single_pattern() {
850        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
851        let python_code = r#"
852def handler(event, context):
853    detail = event['detail']
854    return detail
855"#;
856        let ast = parse_python(python_code);
857        let report = inferencer.analyze_handler(&ast).unwrap();
858        assert!(!report.detected_patterns.is_empty());
859        assert!(report.recommendations.iter().any(|r| r.contains("Single")));
860    }
861
862    #[test]
863    fn test_generic_access_recommendation() {
864        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
865        let python_code = r#"
866def handler(event, context):
867    body = event['body']
868    return body
869"#;
870        let ast = parse_python(python_code);
871        let report = inferencer.analyze_handler(&ast).unwrap();
872        assert!(report
873            .recommendations
874            .iter()
875            .any(|r| r.contains("Generic") || r.contains("Single")));
876    }
877
878    #[test]
879    fn test_headers_generic_access() {
880        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
881        let python_code = r#"
882def handler(event, context):
883    headers = event['headers']
884    return headers
885"#;
886        let ast = parse_python(python_code);
887        let report = inferencer.analyze_handler(&ast).unwrap();
888        assert!(!report.recommendations.is_empty());
889    }
890
891    #[test]
892    fn test_dynamodb_event_detection() {
893        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
894        let python_code = r#"
895def handler(event, context):
896    records = event['Records']
897    for record in records:
898        dynamodb = record['dynamodb']
899    return None
900"#;
901        let ast = parse_python(python_code);
902        let result = inferencer.infer_event_type(&ast);
903        match result {
904            Ok(event_type) => {
905                assert!(matches!(
906                    event_type,
907                    EventType::DynamodbEvent
908                        | EventType::S3Event
909                        | EventType::SqsEvent
910                        | EventType::SnsEvent
911                ));
912            }
913            Err(InferenceError::NoPatternMatch) | Err(InferenceError::AmbiguousEventType) => {}
914            Err(e) => panic!("Unexpected error: {e:?}"),
915        }
916    }
917
918    #[test]
919    fn test_cloudwatch_patterns() {
920        let event_type = EventType::Cloudwatch;
921        assert_eq!(event_type.clone(), EventType::Cloudwatch);
922    }
923
924    #[test]
925    fn test_pattern_serialization() {
926        let pattern = Pattern {
927            access_chain: vec!["Records".to_string(), "s3".to_string()],
928            pattern_type: PatternType::Mixed,
929        };
930        let json = serde_json::to_string(&pattern).unwrap();
931        assert!(json.contains("Records"));
932        assert!(json.contains("Mixed"));
933    }
934
935    #[test]
936    fn test_pattern_deserialization() {
937        let json = r#"{"access_chain":["Records","s3"],"pattern_type":"Mixed"}"#;
938        let pattern: Pattern = serde_json::from_str(json).unwrap();
939        assert_eq!(pattern.access_chain.len(), 2);
940        assert_eq!(pattern.pattern_type, PatternType::Mixed);
941    }
942
943    #[test]
944    fn test_event_type_serialization() {
945        let event_type = EventType::S3Event;
946        let json = serde_json::to_string(&event_type).unwrap();
947        assert!(json.contains("S3Event"));
948    }
949
950    #[test]
951    fn test_event_type_deserialization() {
952        let json = r#""SqsEvent""#;
953        let event_type: EventType = serde_json::from_str(json).unwrap();
954        assert_eq!(event_type, EventType::SqsEvent);
955    }
956
957    #[test]
958    fn test_analysis_report_serialization() {
959        let report = AnalysisReport {
960            inferred_event_type: EventType::S3Event,
961            detected_patterns: vec![Pattern {
962                access_chain: vec!["Records".to_string()],
963                pattern_type: PatternType::Mixed,
964            }],
965            confidence_scores: vec![(EventType::S3Event, 0.9)],
966            recommendations: vec!["Test recommendation".to_string()],
967        };
968        let json = serde_json::to_string(&report).unwrap();
969        assert!(json.contains("S3Event"));
970        assert!(json.contains("recommendations"));
971    }
972
973    #[test]
974    fn test_if_statement_pattern_extraction() {
975        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
976        let python_code = r#"
977def handler(event, context):
978    if event['Records'][0]['s3']:
979        return 'S3'
980    else:
981        return 'Other'
982"#;
983        let ast = parse_python(python_code);
984        let result = inferencer.infer_event_type(&ast).unwrap();
985        assert_eq!(result, EventType::S3Event);
986    }
987
988    #[test]
989    fn test_return_statement_pattern_extraction() {
990        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
991        let python_code = r#"
992def handler(event, context):
993    return event['Records'][0]['s3']['bucket']['name']
994"#;
995        let ast = parse_python(python_code);
996        let result = inferencer.infer_event_type(&ast).unwrap();
997        assert_eq!(result, EventType::S3Event);
998    }
999
1000    #[test]
1001    fn test_annotated_assignment_pattern_extraction() {
1002        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
1003        let python_code = r#"
1004def handler(event, context):
1005    bucket: str = event['Records'][0]['s3']['bucket']['name']
1006    return bucket
1007"#;
1008        let ast = parse_python(python_code);
1009        let result = inferencer.infer_event_type(&ast).unwrap();
1010        assert_eq!(result, EventType::S3Event);
1011    }
1012
1013    #[test]
1014    fn test_call_expression_pattern_extraction() {
1015        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
1016        let python_code = r#"
1017def handler(event, context):
1018    process(event['Records'][0]['s3']['bucket']['name'])
1019    return 'done'
1020"#;
1021        let ast = parse_python(python_code);
1022        let result = inferencer.infer_event_type(&ast).unwrap();
1023        assert_eq!(result, EventType::S3Event);
1024    }
1025
1026    #[test]
1027    fn test_multiple_event_types_detected() {
1028        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
1029        let python_code = r#"
1030def handler(event, context):
1031    records = event['Records']
1032    detail = event['detail']
1033    return records
1034"#;
1035        let ast = parse_python(python_code);
1036        let report = inferencer.analyze_handler(&ast).unwrap();
1037        assert!(!report.confidence_scores.is_empty());
1038    }
1039
1040    #[test]
1041    fn test_sns_message_pattern() {
1042        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
1043        let python_code = r#"
1044def handler(event, context):
1045    message = event['Records'][0]['Sns']['Message']
1046    return message
1047"#;
1048        let ast = parse_python(python_code);
1049        let result = inferencer.infer_event_type(&ast).unwrap();
1050        assert_eq!(result, EventType::SnsEvent);
1051    }
1052
1053    #[test]
1054    fn test_low_confidence_threshold() {
1055        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
1056        let python_code = r#"
1057def handler(event, context):
1058    data = event['Records']
1059    return data
1060"#;
1061        let ast = parse_python(python_code);
1062        let result = inferencer.infer_event_type(&ast);
1063        // With low threshold, may succeed, be ambiguous, or have no pattern match
1064        assert!(
1065            result.is_ok()
1066                || matches!(result, Err(InferenceError::AmbiguousEventType))
1067                || matches!(result, Err(InferenceError::NoPatternMatch))
1068        );
1069    }
1070
1071    #[test]
1072    fn test_very_high_confidence_threshold() {
1073        let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.99);
1074        let python_code = r#"
1075def handler(event, context):
1076    bucket = event['Records'][0]['s3']['bucket']['name']
1077    return bucket
1078"#;
1079        let ast = parse_python(python_code);
1080        let result = inferencer.infer_event_type(&ast);
1081        assert!(result.is_ok() || matches!(result, Err(InferenceError::AmbiguousEventType)));
1082    }
1083}