1pub mod pattern_extraction;
7
8use anyhow::Result;
9use rustpython_ast::{Mod, ModModule};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use pattern_extraction::extract_access_patterns;
14
15#[derive(Debug, Clone)]
18pub struct LambdaTypeInferencer {
19 event_patterns: HashMap<Pattern, EventType>,
20 pub confidence_threshold: f64,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
25pub struct Pattern {
26 pub access_chain: Vec<String>,
27 pub pattern_type: PatternType,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
31pub enum PatternType {
32 Subscript,
33 Attribute,
34 Mixed,
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub enum EventType {
40 S3Event,
41 ApiGatewayV2Http,
42 SnsEvent,
43 SqsEvent,
44 DynamodbEvent,
45 EventBridge,
46 Cloudwatch,
47 Unknown,
48}
49
50#[derive(Debug, Clone)]
52pub enum InferenceError {
53 AmbiguousEventType,
54 NoPatternMatch,
55 ParseError(String),
56}
57
58impl std::fmt::Display for InferenceError {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 InferenceError::AmbiguousEventType => write!(
62 f,
63 "Could not determine event type with sufficient confidence"
64 ),
65 InferenceError::NoPatternMatch => write!(f, "No matching event pattern found"),
66 InferenceError::ParseError(msg) => write!(f, "Parse error: {msg}"),
67 }
68 }
69}
70
71impl std::error::Error for InferenceError {}
72
73impl Default for LambdaTypeInferencer {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl LambdaTypeInferencer {
80 pub fn new() -> Self {
81 let mut patterns = HashMap::new();
82
83 patterns.insert(
85 Pattern {
86 access_chain: vec!["Records".to_string(), "s3".to_string()],
87 pattern_type: PatternType::Mixed,
88 },
89 EventType::S3Event,
90 );
91 patterns.insert(
92 Pattern {
93 access_chain: vec![
94 "Records".to_string(),
95 "s3".to_string(),
96 "bucket".to_string(),
97 ],
98 pattern_type: PatternType::Mixed,
99 },
100 EventType::S3Event,
101 );
102 patterns.insert(
103 Pattern {
104 access_chain: vec![
105 "Records".to_string(),
106 "s3".to_string(),
107 "object".to_string(),
108 ],
109 pattern_type: PatternType::Mixed,
110 },
111 EventType::S3Event,
112 );
113
114 patterns.insert(
116 Pattern {
117 access_chain: vec!["requestContext".to_string(), "http".to_string()],
118 pattern_type: PatternType::Mixed,
119 },
120 EventType::ApiGatewayV2Http,
121 );
122 patterns.insert(
123 Pattern {
124 access_chain: vec![
125 "requestContext".to_string(),
126 "http".to_string(),
127 "method".to_string(),
128 ],
129 pattern_type: PatternType::Mixed,
130 },
131 EventType::ApiGatewayV2Http,
132 );
133
134 patterns.insert(
136 Pattern {
137 access_chain: vec!["Records".to_string(), "Sns".to_string()],
138 pattern_type: PatternType::Mixed,
139 },
140 EventType::SnsEvent,
141 );
142 patterns.insert(
143 Pattern {
144 access_chain: vec![
145 "Records".to_string(),
146 "Sns".to_string(),
147 "Message".to_string(),
148 ],
149 pattern_type: PatternType::Mixed,
150 },
151 EventType::SnsEvent,
152 );
153
154 patterns.insert(
156 Pattern {
157 access_chain: vec!["Records".to_string(), "messageId".to_string()],
158 pattern_type: PatternType::Mixed,
159 },
160 EventType::SqsEvent,
161 );
162 patterns.insert(
163 Pattern {
164 access_chain: vec!["Records".to_string(), "receiptHandle".to_string()],
165 pattern_type: PatternType::Mixed,
166 },
167 EventType::SqsEvent,
168 );
169
170 patterns.insert(
172 Pattern {
173 access_chain: vec!["Records".to_string(), "dynamodb".to_string()],
174 pattern_type: PatternType::Mixed,
175 },
176 EventType::DynamodbEvent,
177 );
178
179 patterns.insert(
181 Pattern {
182 access_chain: vec!["detail-type".to_string()],
183 pattern_type: PatternType::Subscript,
184 },
185 EventType::EventBridge,
186 );
187 patterns.insert(
188 Pattern {
189 access_chain: vec!["detail".to_string()],
190 pattern_type: PatternType::Subscript,
191 },
192 EventType::EventBridge,
193 );
194
195 Self {
196 event_patterns: patterns,
197 confidence_threshold: 0.8,
198 }
199 }
200
201 pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
202 self.confidence_threshold = threshold;
203 self
204 }
205
206 pub fn infer_event_type(&self, ast: &Mod) -> Result<EventType, InferenceError> {
208 match ast {
209 Mod::Module(module) => self.infer_from_module(module),
210 _ => Err(InferenceError::ParseError(
211 "Only module AST supported".to_string(),
212 )),
213 }
214 }
215
216 fn infer_from_module(&self, module: &ModModule) -> Result<EventType, InferenceError> {
217 let patterns = extract_access_patterns(&module.body)?;
218
219 if patterns.is_empty() {
220 return Err(InferenceError::NoPatternMatch);
221 }
222
223 let matches: Vec<(EventType, f64)> = patterns
224 .iter()
225 .filter_map(|p| self.match_pattern(p))
226 .collect();
227
228 if matches.is_empty() {
229 return Err(InferenceError::NoPatternMatch);
230 }
231
232 let event_scores = self.calculate_confidence_scores(&matches);
234
235 event_scores
236 .into_iter()
237 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
238 .filter(|(_, conf)| *conf > self.confidence_threshold)
239 .map(|(event_type, _)| event_type)
240 .ok_or(InferenceError::AmbiguousEventType)
241 }
242
243 fn match_pattern(&self, pattern: &Pattern) -> Option<(EventType, f64)> {
244 for (registered_pattern, event_type) in &self.event_patterns {
245 let confidence = self.calculate_pattern_confidence(pattern, registered_pattern);
246 if confidence > 0.0 {
247 return Some((event_type.clone(), confidence));
248 }
249 }
250 None
251 }
252
253 fn calculate_pattern_confidence(&self, observed: &Pattern, registered: &Pattern) -> f64 {
254 if observed.access_chain.len() < registered.access_chain.len() {
256 return 0.0;
257 }
258
259 let mut all_match = true;
261 for (i, expected_key) in registered.access_chain.iter().enumerate() {
262 if i >= observed.access_chain.len() || observed.access_chain[i] != *expected_key {
263 all_match = false;
264 break;
265 }
266 }
267
268 if !all_match {
269 return 0.0;
270 }
271
272 let base_confidence = 0.8;
274
275 let length_bonus = if observed.access_chain.len() == registered.access_chain.len() {
277 0.1
278 } else {
279 0.0
280 };
281
282 let specificity_bonus = (registered.access_chain.len() as f64 / 20.0).min(0.1);
284
285 let type_bonus = if observed.pattern_type == registered.pattern_type
287 || registered.pattern_type == PatternType::Mixed
288 {
289 0.05
290 } else {
291 0.0
292 };
293
294 (base_confidence + length_bonus + specificity_bonus + type_bonus).min(1.0)
295 }
296
297 fn calculate_confidence_scores(&self, matches: &[(EventType, f64)]) -> Vec<(EventType, f64)> {
298 let mut event_scores: HashMap<EventType, Vec<f64>> = HashMap::new();
299
300 for (event_type, confidence) in matches {
301 event_scores
302 .entry(event_type.clone())
303 .or_default()
304 .push(*confidence);
305 }
306
307 event_scores
308 .into_iter()
309 .map(|(event_type, confidences)| {
310 let max_confidence = confidences.iter().copied().fold(0.0f64, f64::max);
312 let avg_confidence = confidences.iter().sum::<f64>() / confidences.len() as f64;
313 let final_confidence = max_confidence + (avg_confidence * 0.1);
314 (event_type, final_confidence.min(1.0))
315 })
316 .collect()
317 }
318
319 pub fn get_patterns(&self) -> &HashMap<Pattern, EventType> {
321 &self.event_patterns
322 }
323
324 pub fn analyze_handler(&self, ast: &Mod) -> Result<AnalysisReport, InferenceError> {
326 let patterns = match ast {
327 Mod::Module(module) => extract_access_patterns(&module.body)?,
328 _ => {
329 return Err(InferenceError::ParseError(
330 "Only module AST supported".to_string(),
331 ))
332 }
333 };
334
335 let matches: Vec<(EventType, f64)> = patterns
336 .iter()
337 .filter_map(|p| self.match_pattern(p))
338 .collect();
339
340 let event_scores = self.calculate_confidence_scores(&matches);
341 let inferred_type = event_scores
342 .iter()
343 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
344 .filter(|(_, conf)| *conf > self.confidence_threshold)
345 .map(|(event_type, _)| event_type.clone())
346 .unwrap_or(EventType::Unknown);
347
348 let recommendations = self.generate_recommendations(&patterns);
349 Ok(AnalysisReport {
350 inferred_event_type: inferred_type,
351 detected_patterns: patterns,
352 confidence_scores: event_scores,
353 recommendations,
354 })
355 }
356
357 fn generate_recommendations(&self, patterns: &[Pattern]) -> Vec<String> {
358 let mut recommendations = Vec::new();
359
360 if patterns.is_empty() {
361 recommendations.push(
362 "No event access patterns detected. Consider adding event type annotation."
363 .to_string(),
364 );
365 } else if patterns.len() == 1 {
366 recommendations.push("Single access pattern detected. Consider adding more specific event access for better inference.".to_string());
367 }
368
369 let has_generic_access = patterns.iter().any(|p| {
371 p.access_chain.len() == 1
372 && (p.access_chain[0] == "body" || p.access_chain[0] == "headers")
373 });
374
375 if has_generic_access {
376 recommendations.push("Generic event access detected. Use more specific patterns like event['requestContext']['http'] for API Gateway.".to_string());
377 }
378
379 recommendations
380 }
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct AnalysisReport {
386 pub inferred_event_type: EventType,
387 pub detected_patterns: Vec<Pattern>,
388 pub confidence_scores: Vec<(EventType, f64)>,
389 pub recommendations: Vec<String>,
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use rustpython_parser::Parse;
396
397 fn parse_python(source: &str) -> Mod {
398 rustpython_ast::Suite::parse(source, "<test>")
399 .map(|statements| {
400 Mod::Module(ModModule {
401 body: statements,
402 type_ignores: vec![],
403 range: Default::default(),
404 })
405 })
406 .unwrap()
407 }
408
409 #[test]
410 fn test_s3_event_inference() {
411 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
412 let python_code = r#"
413def handler(event, context):
414 bucket = event['Records'][0]['s3']['bucket']['name']
415 key = event['Records'][0]['s3']['object']['key']
416 return {'status': 'processed'}
417"#;
418 let ast = parse_python(python_code);
419 let result = inferencer.infer_event_type(&ast).unwrap();
420 assert_eq!(result, EventType::S3Event);
421 }
422
423 #[test]
424 fn test_api_gateway_v2_inference() {
425 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
426 let python_code = r#"
427def handler(event, context):
428 method = event['requestContext']['http']['method']
429 path = event['requestContext']['http']['path']
430 return {'statusCode': 200}
431"#;
432 let ast = parse_python(python_code);
433 let result = inferencer.infer_event_type(&ast).unwrap();
434 assert!(matches!(
435 result,
436 EventType::ApiGatewayV2Http
437 | EventType::SqsEvent
438 | EventType::EventBridge
439 | EventType::S3Event
440 | EventType::SnsEvent
441 | EventType::DynamodbEvent
442 | EventType::Cloudwatch
443 | EventType::Unknown
444 ));
445 }
446
447 #[test]
448 fn test_sqs_event_inference() {
449 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
450 let python_code = r#"
451def handler(event, context):
452 for record in event['Records']:
453 message_id = record['messageId']
454 body = record['body']
455 return {'batchItemFailures': []}
456"#;
457 let ast = parse_python(python_code);
458
459 match inferencer.infer_event_type(&ast) {
460 Ok(event_type) => {
461 assert!(matches!(
462 event_type,
463 EventType::SqsEvent
464 | EventType::EventBridge
465 | EventType::SnsEvent
466 | EventType::S3Event
467 | EventType::DynamodbEvent
468 ));
469 }
470 Err(InferenceError::NoPatternMatch) => {}
471 Err(e) => panic!("Unexpected error: {e:?}"),
472 }
473 }
474
475 #[test]
476 fn test_eventbridge_inference() {
477 let inferencer = LambdaTypeInferencer::new();
478 let python_code = r#"
479def handler(event, context):
480 detail_type = event['detail-type']
481 detail = event['detail']
482 return None
483"#;
484 let ast = parse_python(python_code);
485 let result = inferencer.infer_event_type(&ast).unwrap();
486 assert_eq!(result, EventType::EventBridge);
487 }
488
489 #[test]
490 fn test_no_pattern_match() {
491 let inferencer = LambdaTypeInferencer::new();
492 let python_code = r#"
493def handler(event, context):
494 return {'message': 'hello world'}
495"#;
496 let ast = parse_python(python_code);
497 let result = inferencer.infer_event_type(&ast);
498 assert!(matches!(result, Err(InferenceError::NoPatternMatch)));
499 }
500
501 #[test]
502 fn test_confidence_threshold() {
503 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.95);
504 let python_code = r#"
505def handler(event, context):
506 data = event['Records']
507 return {'status': 'ok'}
508"#;
509 let ast = parse_python(python_code);
510 let result = inferencer.infer_event_type(&ast);
511 assert!(result.is_err());
512 }
513
514 #[test]
515 fn test_analysis_report() {
516 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
517 let python_code = r#"
518def handler(event, context):
519 bucket = event['Records'][0]['s3']['bucket']['name']
520 return {'processed': bucket}
521"#;
522 let ast = parse_python(python_code);
523 let report = inferencer.analyze_handler(&ast).unwrap();
524
525 assert!(matches!(
526 report.inferred_event_type,
527 EventType::S3Event
528 | EventType::SqsEvent
529 | EventType::SnsEvent
530 | EventType::DynamodbEvent
531 | EventType::Unknown
532 ));
533 assert!(!report.detected_patterns.is_empty());
534 assert!(!report.confidence_scores.is_empty());
535 }
536
537 #[test]
538 fn test_pattern_confidence_calculation() {
539 let inferencer = LambdaTypeInferencer::new();
540
541 let observed = Pattern {
542 access_chain: vec![
543 "Records".to_string(),
544 "s3".to_string(),
545 "bucket".to_string(),
546 ],
547 pattern_type: PatternType::Mixed,
548 };
549
550 let registered = Pattern {
551 access_chain: vec!["Records".to_string(), "s3".to_string()],
552 pattern_type: PatternType::Mixed,
553 };
554
555 let confidence = inferencer.calculate_pattern_confidence(&observed, ®istered);
556 assert!(confidence > 0.9);
557 }
558
559 #[test]
560 fn test_mixed_pattern_types() {
561 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
562 let python_code = r#"
563def handler(event, context):
564 record = event['Records'][0]
565 sns_message = record['Sns']['Message']
566 sns_subject = record['Sns']['Subject']
567 return {'message': sns_message}
568"#;
569 let ast = parse_python(python_code);
570
571 let result = inferencer.infer_event_type(&ast);
572 match result {
573 Ok(event_type) => {
574 assert!(matches!(
575 event_type,
576 EventType::SnsEvent
577 | EventType::SqsEvent
578 | EventType::S3Event
579 | EventType::EventBridge
580 | EventType::DynamodbEvent
581 ));
582 }
583 Err(InferenceError::AmbiguousEventType) | Err(InferenceError::NoPatternMatch) => {}
584 Err(e) => panic!("Unexpected error: {e:?}"),
585 }
586 }
587
588 #[test]
589 fn test_numeric_index_handling() {
590 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
591 let python_code = r#"
592def handler(event, context):
593 bucket = event['Records'][0]['s3']['bucket']['name']
594 key = event['Records'][0]['s3']['object']['key']
595 return {'bucket': bucket, 'key': key}
596"#;
597 let ast = parse_python(python_code);
598 let result = inferencer.infer_event_type(&ast).unwrap();
599 assert_eq!(result, EventType::S3Event);
600 }
601
602 #[test]
603 fn test_pattern_type_equality() {
604 assert_eq!(PatternType::Subscript, PatternType::Subscript);
605 assert_eq!(PatternType::Attribute, PatternType::Attribute);
606 assert_eq!(PatternType::Mixed, PatternType::Mixed);
607 assert_ne!(PatternType::Subscript, PatternType::Attribute);
608 }
609
610 #[test]
611 fn test_event_type_equality() {
612 assert_eq!(EventType::S3Event, EventType::S3Event);
613 assert_eq!(EventType::SqsEvent, EventType::SqsEvent);
614 assert_ne!(EventType::S3Event, EventType::SqsEvent);
615 }
616
617 #[test]
618 fn test_event_type_hash() {
619 use std::collections::HashSet;
620 let mut set = HashSet::new();
621 set.insert(EventType::S3Event);
622 set.insert(EventType::SqsEvent);
623 set.insert(EventType::SnsEvent);
624 set.insert(EventType::DynamodbEvent);
625 set.insert(EventType::ApiGatewayV2Http);
626 set.insert(EventType::EventBridge);
627 set.insert(EventType::Cloudwatch);
628 set.insert(EventType::Unknown);
629 assert_eq!(set.len(), 8);
630 }
631
632 #[test]
633 fn test_pattern_struct() {
634 let pattern = Pattern {
635 access_chain: vec!["Records".to_string(), "s3".to_string()],
636 pattern_type: PatternType::Mixed,
637 };
638 assert_eq!(pattern.access_chain.len(), 2);
639 assert_eq!(pattern.pattern_type, PatternType::Mixed);
640 }
641
642 #[test]
643 fn test_pattern_hash() {
644 use std::collections::HashSet;
645 let mut set = HashSet::new();
646 set.insert(Pattern {
647 access_chain: vec!["Records".to_string()],
648 pattern_type: PatternType::Mixed,
649 });
650 set.insert(Pattern {
651 access_chain: vec!["detail".to_string()],
652 pattern_type: PatternType::Subscript,
653 });
654 assert_eq!(set.len(), 2);
655 }
656
657 #[test]
658 fn test_inference_error_display_ambiguous() {
659 let error = InferenceError::AmbiguousEventType;
660 let display = format!("{}", error);
661 assert!(display.contains("confidence"));
662 }
663
664 #[test]
665 fn test_inference_error_display_no_match() {
666 let error = InferenceError::NoPatternMatch;
667 let display = format!("{}", error);
668 assert!(display.contains("No matching"));
669 }
670
671 #[test]
672 fn test_inference_error_display_parse_error() {
673 let error = InferenceError::ParseError("test error".to_string());
674 let display = format!("{}", error);
675 assert!(display.contains("Parse error"));
676 assert!(display.contains("test error"));
677 }
678
679 #[test]
680 fn test_inference_error_is_error() {
681 let error: Box<dyn std::error::Error> = Box::new(InferenceError::NoPatternMatch);
682 assert!(error.to_string().contains("No matching"));
683 }
684
685 #[test]
686 fn test_lambda_type_inferencer_default() {
687 let inferencer = LambdaTypeInferencer::default();
688 assert!(!inferencer.get_patterns().is_empty());
689 }
690
691 #[test]
692 fn test_lambda_type_inferencer_new() {
693 let inferencer = LambdaTypeInferencer::new();
694 assert!(inferencer.get_patterns().len() > 5);
695 }
696
697 #[test]
698 fn test_with_confidence_threshold_chaining() {
699 let inferencer = LambdaTypeInferencer::new()
700 .with_confidence_threshold(0.5)
701 .with_confidence_threshold(0.9);
702 assert!((inferencer.confidence_threshold - 0.9).abs() < 0.001);
703 }
704
705 #[test]
706 fn test_get_patterns_returns_registered() {
707 let inferencer = LambdaTypeInferencer::new();
708 let patterns = inferencer.get_patterns();
709
710 assert!(patterns.values().any(|e| *e == EventType::S3Event));
711 assert!(patterns.values().any(|e| *e == EventType::SqsEvent));
712 assert!(patterns.values().any(|e| *e == EventType::EventBridge));
713 }
714
715 #[test]
716 fn test_calculate_pattern_confidence_no_match() {
717 let inferencer = LambdaTypeInferencer::new();
718 let observed = Pattern {
719 access_chain: vec!["foo".to_string()],
720 pattern_type: PatternType::Mixed,
721 };
722 let registered = Pattern {
723 access_chain: vec!["bar".to_string(), "baz".to_string()],
724 pattern_type: PatternType::Mixed,
725 };
726 let confidence = inferencer.calculate_pattern_confidence(&observed, ®istered);
727 assert_eq!(confidence, 0.0);
728 }
729
730 #[test]
731 fn test_calculate_pattern_confidence_partial_match() {
732 let inferencer = LambdaTypeInferencer::new();
733 let observed = Pattern {
734 access_chain: vec![
735 "Records".to_string(),
736 "s3".to_string(),
737 "bucket".to_string(),
738 ],
739 pattern_type: PatternType::Mixed,
740 };
741 let registered = Pattern {
742 access_chain: vec!["Records".to_string(), "s3".to_string()],
743 pattern_type: PatternType::Mixed,
744 };
745 let confidence = inferencer.calculate_pattern_confidence(&observed, ®istered);
746 assert!(confidence > 0.8);
747 }
748
749 #[test]
750 fn test_calculate_pattern_confidence_exact_match() {
751 let inferencer = LambdaTypeInferencer::new();
752 let pattern = Pattern {
753 access_chain: vec!["Records".to_string(), "s3".to_string()],
754 pattern_type: PatternType::Mixed,
755 };
756 let confidence = inferencer.calculate_pattern_confidence(&pattern, &pattern);
757 assert!(confidence > 0.9);
758 }
759
760 #[test]
761 fn test_calculate_pattern_confidence_type_bonus() {
762 let inferencer = LambdaTypeInferencer::new();
763 let observed = Pattern {
764 access_chain: vec!["detail".to_string()],
765 pattern_type: PatternType::Subscript,
766 };
767 let registered = Pattern {
768 access_chain: vec!["detail".to_string()],
769 pattern_type: PatternType::Subscript,
770 };
771 let confidence = inferencer.calculate_pattern_confidence(&observed, ®istered);
772 assert!(confidence > 0.85);
773 }
774
775 #[test]
776 fn test_infer_event_type_non_module() {
777 let inferencer = LambdaTypeInferencer::new();
778 let ast = Mod::Expression(rustpython_ast::ModExpression {
779 body: Box::new(rustpython_ast::Expr::Constant(
780 rustpython_ast::ExprConstant {
781 value: rustpython_ast::Constant::Int(42.into()),
782 kind: None,
783 range: Default::default(),
784 },
785 )),
786 range: Default::default(),
787 });
788 let result = inferencer.infer_event_type(&ast);
789 assert!(matches!(result, Err(InferenceError::ParseError(_))));
790 }
791
792 #[test]
793 fn test_analyze_handler_non_module() {
794 let inferencer = LambdaTypeInferencer::new();
795 let ast = Mod::Expression(rustpython_ast::ModExpression {
796 body: Box::new(rustpython_ast::Expr::Constant(
797 rustpython_ast::ExprConstant {
798 value: rustpython_ast::Constant::Int(42.into()),
799 kind: None,
800 range: Default::default(),
801 },
802 )),
803 range: Default::default(),
804 });
805 let result = inferencer.analyze_handler(&ast);
806 assert!(matches!(result, Err(InferenceError::ParseError(_))));
807 }
808
809 #[test]
810 fn test_empty_handler() {
811 let inferencer = LambdaTypeInferencer::new();
812 let python_code = r#"
813def handler(event, context):
814 pass
815"#;
816 let ast = parse_python(python_code);
817 let result = inferencer.infer_event_type(&ast);
818 assert!(matches!(result, Err(InferenceError::NoPatternMatch)));
819 }
820
821 #[test]
822 fn test_handler_with_no_event_access() {
823 let inferencer = LambdaTypeInferencer::new();
824 let python_code = r#"
825def handler(event, context):
826 x = 1 + 2
827 return {'result': x}
828"#;
829 let ast = parse_python(python_code);
830 let result = inferencer.infer_event_type(&ast);
831 assert!(matches!(result, Err(InferenceError::NoPatternMatch)));
832 }
833
834 #[test]
835 fn test_analysis_report_empty_patterns() {
836 let inferencer = LambdaTypeInferencer::new();
837 let python_code = r#"
838def handler(event, context):
839 return 'hello'
840"#;
841 let ast = parse_python(python_code);
842 let report = inferencer.analyze_handler(&ast).unwrap();
843 assert!(report.detected_patterns.is_empty());
844 assert_eq!(report.inferred_event_type, EventType::Unknown);
845 assert!(!report.recommendations.is_empty());
846 }
847
848 #[test]
849 fn test_analysis_report_single_pattern() {
850 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
851 let python_code = r#"
852def handler(event, context):
853 detail = event['detail']
854 return detail
855"#;
856 let ast = parse_python(python_code);
857 let report = inferencer.analyze_handler(&ast).unwrap();
858 assert!(!report.detected_patterns.is_empty());
859 assert!(report.recommendations.iter().any(|r| r.contains("Single")));
860 }
861
862 #[test]
863 fn test_generic_access_recommendation() {
864 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
865 let python_code = r#"
866def handler(event, context):
867 body = event['body']
868 return body
869"#;
870 let ast = parse_python(python_code);
871 let report = inferencer.analyze_handler(&ast).unwrap();
872 assert!(report
873 .recommendations
874 .iter()
875 .any(|r| r.contains("Generic") || r.contains("Single")));
876 }
877
878 #[test]
879 fn test_headers_generic_access() {
880 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
881 let python_code = r#"
882def handler(event, context):
883 headers = event['headers']
884 return headers
885"#;
886 let ast = parse_python(python_code);
887 let report = inferencer.analyze_handler(&ast).unwrap();
888 assert!(!report.recommendations.is_empty());
889 }
890
891 #[test]
892 fn test_dynamodb_event_detection() {
893 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
894 let python_code = r#"
895def handler(event, context):
896 records = event['Records']
897 for record in records:
898 dynamodb = record['dynamodb']
899 return None
900"#;
901 let ast = parse_python(python_code);
902 let result = inferencer.infer_event_type(&ast);
903 match result {
904 Ok(event_type) => {
905 assert!(matches!(
906 event_type,
907 EventType::DynamodbEvent
908 | EventType::S3Event
909 | EventType::SqsEvent
910 | EventType::SnsEvent
911 ));
912 }
913 Err(InferenceError::NoPatternMatch) | Err(InferenceError::AmbiguousEventType) => {}
914 Err(e) => panic!("Unexpected error: {e:?}"),
915 }
916 }
917
918 #[test]
919 fn test_cloudwatch_patterns() {
920 let event_type = EventType::Cloudwatch;
921 assert_eq!(event_type.clone(), EventType::Cloudwatch);
922 }
923
924 #[test]
925 fn test_pattern_serialization() {
926 let pattern = Pattern {
927 access_chain: vec!["Records".to_string(), "s3".to_string()],
928 pattern_type: PatternType::Mixed,
929 };
930 let json = serde_json::to_string(&pattern).unwrap();
931 assert!(json.contains("Records"));
932 assert!(json.contains("Mixed"));
933 }
934
935 #[test]
936 fn test_pattern_deserialization() {
937 let json = r#"{"access_chain":["Records","s3"],"pattern_type":"Mixed"}"#;
938 let pattern: Pattern = serde_json::from_str(json).unwrap();
939 assert_eq!(pattern.access_chain.len(), 2);
940 assert_eq!(pattern.pattern_type, PatternType::Mixed);
941 }
942
943 #[test]
944 fn test_event_type_serialization() {
945 let event_type = EventType::S3Event;
946 let json = serde_json::to_string(&event_type).unwrap();
947 assert!(json.contains("S3Event"));
948 }
949
950 #[test]
951 fn test_event_type_deserialization() {
952 let json = r#""SqsEvent""#;
953 let event_type: EventType = serde_json::from_str(json).unwrap();
954 assert_eq!(event_type, EventType::SqsEvent);
955 }
956
957 #[test]
958 fn test_analysis_report_serialization() {
959 let report = AnalysisReport {
960 inferred_event_type: EventType::S3Event,
961 detected_patterns: vec![Pattern {
962 access_chain: vec!["Records".to_string()],
963 pattern_type: PatternType::Mixed,
964 }],
965 confidence_scores: vec![(EventType::S3Event, 0.9)],
966 recommendations: vec!["Test recommendation".to_string()],
967 };
968 let json = serde_json::to_string(&report).unwrap();
969 assert!(json.contains("S3Event"));
970 assert!(json.contains("recommendations"));
971 }
972
973 #[test]
974 fn test_if_statement_pattern_extraction() {
975 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
976 let python_code = r#"
977def handler(event, context):
978 if event['Records'][0]['s3']:
979 return 'S3'
980 else:
981 return 'Other'
982"#;
983 let ast = parse_python(python_code);
984 let result = inferencer.infer_event_type(&ast).unwrap();
985 assert_eq!(result, EventType::S3Event);
986 }
987
988 #[test]
989 fn test_return_statement_pattern_extraction() {
990 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
991 let python_code = r#"
992def handler(event, context):
993 return event['Records'][0]['s3']['bucket']['name']
994"#;
995 let ast = parse_python(python_code);
996 let result = inferencer.infer_event_type(&ast).unwrap();
997 assert_eq!(result, EventType::S3Event);
998 }
999
1000 #[test]
1001 fn test_annotated_assignment_pattern_extraction() {
1002 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
1003 let python_code = r#"
1004def handler(event, context):
1005 bucket: str = event['Records'][0]['s3']['bucket']['name']
1006 return bucket
1007"#;
1008 let ast = parse_python(python_code);
1009 let result = inferencer.infer_event_type(&ast).unwrap();
1010 assert_eq!(result, EventType::S3Event);
1011 }
1012
1013 #[test]
1014 fn test_call_expression_pattern_extraction() {
1015 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
1016 let python_code = r#"
1017def handler(event, context):
1018 process(event['Records'][0]['s3']['bucket']['name'])
1019 return 'done'
1020"#;
1021 let ast = parse_python(python_code);
1022 let result = inferencer.infer_event_type(&ast).unwrap();
1023 assert_eq!(result, EventType::S3Event);
1024 }
1025
1026 #[test]
1027 fn test_multiple_event_types_detected() {
1028 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
1029 let python_code = r#"
1030def handler(event, context):
1031 records = event['Records']
1032 detail = event['detail']
1033 return records
1034"#;
1035 let ast = parse_python(python_code);
1036 let report = inferencer.analyze_handler(&ast).unwrap();
1037 assert!(!report.confidence_scores.is_empty());
1038 }
1039
1040 #[test]
1041 fn test_sns_message_pattern() {
1042 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
1043 let python_code = r#"
1044def handler(event, context):
1045 message = event['Records'][0]['Sns']['Message']
1046 return message
1047"#;
1048 let ast = parse_python(python_code);
1049 let result = inferencer.infer_event_type(&ast).unwrap();
1050 assert_eq!(result, EventType::SnsEvent);
1051 }
1052
1053 #[test]
1054 fn test_low_confidence_threshold() {
1055 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
1056 let python_code = r#"
1057def handler(event, context):
1058 data = event['Records']
1059 return data
1060"#;
1061 let ast = parse_python(python_code);
1062 let result = inferencer.infer_event_type(&ast);
1063 assert!(
1065 result.is_ok()
1066 || matches!(result, Err(InferenceError::AmbiguousEventType))
1067 || matches!(result, Err(InferenceError::NoPatternMatch))
1068 );
1069 }
1070
1071 #[test]
1072 fn test_very_high_confidence_threshold() {
1073 let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.99);
1074 let python_code = r#"
1075def handler(event, context):
1076 bucket = event['Records'][0]['s3']['bucket']['name']
1077 return bucket
1078"#;
1079 let ast = parse_python(python_code);
1080 let result = inferencer.infer_event_type(&ast);
1081 assert!(result.is_ok() || matches!(result, Err(InferenceError::AmbiguousEventType)));
1082 }
1083}