use anyhow::Result;
use rustpython_ast::{Expr, ExprAttribute, ExprSubscript, Mod, ModModule, Stmt, StmtFunctionDef};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct LambdaTypeInferencer {
event_patterns: HashMap<Pattern, EventType>,
confidence_threshold: f64,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Pattern {
pub access_chain: Vec<String>,
pub pattern_type: PatternType,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PatternType {
Subscript,
Attribute,
Mixed,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EventType {
S3Event,
ApiGatewayV2Http,
SnsEvent,
SqsEvent,
DynamodbEvent,
EventBridge,
Cloudwatch,
Unknown,
}
#[derive(Debug, Clone)]
pub enum InferenceError {
AmbiguousEventType,
NoPatternMatch,
ParseError(String),
}
impl std::fmt::Display for InferenceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InferenceError::AmbiguousEventType => write!(
f,
"Could not determine event type with sufficient confidence"
),
InferenceError::NoPatternMatch => write!(f, "No matching event pattern found"),
InferenceError::ParseError(msg) => write!(f, "Parse error: {msg}"),
}
}
}
impl std::error::Error for InferenceError {}
impl Default for LambdaTypeInferencer {
fn default() -> Self {
Self::new()
}
}
impl LambdaTypeInferencer {
pub fn new() -> Self {
let mut patterns = HashMap::new();
patterns.insert(
Pattern {
access_chain: vec!["Records".to_string(), "s3".to_string()],
pattern_type: PatternType::Mixed,
},
EventType::S3Event,
);
patterns.insert(
Pattern {
access_chain: vec![
"Records".to_string(),
"s3".to_string(),
"bucket".to_string(),
],
pattern_type: PatternType::Mixed,
},
EventType::S3Event,
);
patterns.insert(
Pattern {
access_chain: vec![
"Records".to_string(),
"s3".to_string(),
"object".to_string(),
],
pattern_type: PatternType::Mixed,
},
EventType::S3Event,
);
patterns.insert(
Pattern {
access_chain: vec!["requestContext".to_string(), "http".to_string()],
pattern_type: PatternType::Mixed,
},
EventType::ApiGatewayV2Http,
);
patterns.insert(
Pattern {
access_chain: vec![
"requestContext".to_string(),
"http".to_string(),
"method".to_string(),
],
pattern_type: PatternType::Mixed,
},
EventType::ApiGatewayV2Http,
);
patterns.insert(
Pattern {
access_chain: vec!["Records".to_string(), "Sns".to_string()],
pattern_type: PatternType::Mixed,
},
EventType::SnsEvent,
);
patterns.insert(
Pattern {
access_chain: vec![
"Records".to_string(),
"Sns".to_string(),
"Message".to_string(),
],
pattern_type: PatternType::Mixed,
},
EventType::SnsEvent,
);
patterns.insert(
Pattern {
access_chain: vec!["Records".to_string(), "messageId".to_string()],
pattern_type: PatternType::Mixed,
},
EventType::SqsEvent,
);
patterns.insert(
Pattern {
access_chain: vec!["Records".to_string(), "receiptHandle".to_string()],
pattern_type: PatternType::Mixed,
},
EventType::SqsEvent,
);
patterns.insert(
Pattern {
access_chain: vec!["Records".to_string(), "dynamodb".to_string()],
pattern_type: PatternType::Mixed,
},
EventType::DynamodbEvent,
);
patterns.insert(
Pattern {
access_chain: vec!["detail-type".to_string()],
pattern_type: PatternType::Subscript,
},
EventType::EventBridge,
);
patterns.insert(
Pattern {
access_chain: vec!["detail".to_string()],
pattern_type: PatternType::Subscript,
},
EventType::EventBridge,
);
Self {
event_patterns: patterns,
confidence_threshold: 0.8,
}
}
pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
self.confidence_threshold = threshold;
self
}
pub fn infer_event_type(&self, ast: &Mod) -> Result<EventType, InferenceError> {
match ast {
Mod::Module(module) => self.infer_from_module(module),
_ => Err(InferenceError::ParseError(
"Only module AST supported".to_string(),
)),
}
}
fn infer_from_module(&self, module: &ModModule) -> Result<EventType, InferenceError> {
let patterns = self.extract_access_patterns(&module.body)?;
if patterns.is_empty() {
return Err(InferenceError::NoPatternMatch);
}
let matches: Vec<(EventType, f64)> = patterns
.iter()
.filter_map(|p| self.match_pattern(p))
.collect();
if matches.is_empty() {
return Err(InferenceError::NoPatternMatch);
}
let event_scores = self.calculate_confidence_scores(&matches);
event_scores
.into_iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.filter(|(_, conf)| *conf > self.confidence_threshold)
.map(|(event_type, _)| event_type)
.ok_or(InferenceError::AmbiguousEventType)
}
fn extract_access_patterns(&self, statements: &[Stmt]) -> Result<Vec<Pattern>, InferenceError> {
let mut patterns = Vec::new();
for stmt in statements {
if let Stmt::FunctionDef(func_def) = stmt {
patterns.extend(self.extract_patterns_from_function(func_def)?);
}
}
Ok(patterns)
}
fn extract_patterns_from_function(
&self,
func_def: &StmtFunctionDef,
) -> Result<Vec<Pattern>, InferenceError> {
let mut patterns = Vec::new();
for stmt in &func_def.body {
patterns.extend(self.extract_patterns_from_stmt(stmt)?);
}
Ok(patterns)
}
fn extract_patterns_from_stmt(&self, stmt: &Stmt) -> Result<Vec<Pattern>, InferenceError> {
let mut patterns = Vec::new();
match stmt {
Stmt::Assign(assign) => {
for target in &assign.targets {
patterns.extend(self.extract_patterns_from_expr(&assign.value)?);
patterns.extend(self.extract_patterns_from_expr(target)?);
}
}
Stmt::AnnAssign(ann_assign) => {
if let Some(ref value) = ann_assign.value {
patterns.extend(self.extract_patterns_from_expr(value)?);
} else {
patterns.extend(self.extract_patterns_from_expr(&ann_assign.target)?);
}
}
Stmt::Return(ret) => {
if let Some(value) = &ret.value {
patterns.extend(self.extract_patterns_from_expr(value)?);
}
}
Stmt::If(if_stmt) => {
patterns.extend(self.extract_patterns_from_expr(&if_stmt.test)?);
for stmt in &if_stmt.body {
patterns.extend(self.extract_patterns_from_stmt(stmt)?);
}
for stmt in &if_stmt.orelse {
patterns.extend(self.extract_patterns_from_stmt(stmt)?);
}
}
Stmt::For(for_stmt) => {
patterns.extend(self.extract_patterns_from_expr(&for_stmt.iter)?);
for stmt in &for_stmt.body {
patterns.extend(self.extract_patterns_from_stmt(stmt)?);
}
}
_ => {} }
Ok(patterns)
}
fn extract_patterns_from_expr(&self, expr: &Expr) -> Result<Vec<Pattern>, InferenceError> {
let mut patterns = Vec::new();
match expr {
Expr::Subscript(subscript) => {
if let Some(pattern) = self.extract_subscript_pattern(subscript)? {
patterns.push(pattern);
}
patterns.extend(self.extract_patterns_from_expr(&subscript.value)?);
}
Expr::Attribute(attr) => {
if let Some(pattern) = self.extract_attribute_pattern(attr)? {
patterns.push(pattern);
}
patterns.extend(self.extract_patterns_from_expr(&attr.value)?);
}
Expr::Call(call) => {
patterns.extend(self.extract_patterns_from_expr(&call.func)?);
for arg in &call.args {
patterns.extend(self.extract_patterns_from_expr(arg)?);
}
}
_ => {} }
Ok(patterns)
}
fn extract_subscript_pattern(
&self,
subscript: &ExprSubscript,
) -> Result<Option<Pattern>, InferenceError> {
let mut access_chain = Vec::new();
let mut current_expr = &subscript.value;
if let Expr::Constant(constant) = &*subscript.slice {
if let Some(key) = constant.value.as_str() {
access_chain.insert(0, key.to_string());
} else if constant.value.as_int().is_some() {
}
}
loop {
match &**current_expr {
Expr::Subscript(inner_subscript) => {
if let Expr::Constant(constant) = &*inner_subscript.slice {
if let Some(key) = constant.value.as_str() {
access_chain.insert(0, key.to_string());
} else if constant.value.as_int().is_some() {
}
}
current_expr = &inner_subscript.value;
}
Expr::Attribute(attr) => {
access_chain.insert(0, attr.attr.to_string());
current_expr = &attr.value;
}
Expr::Name(name) => {
if name.id.as_str() == "event" {
return Ok(Some(Pattern {
access_chain,
pattern_type: PatternType::Mixed,
}));
}
break;
}
_ => break,
}
}
Ok(None)
}
fn extract_attribute_pattern(
&self,
attr: &ExprAttribute,
) -> Result<Option<Pattern>, InferenceError> {
let mut access_chain = vec![attr.attr.to_string()];
let mut current_expr = &attr.value;
loop {
match &**current_expr {
Expr::Attribute(inner_attr) => {
access_chain.insert(0, inner_attr.attr.to_string());
current_expr = &inner_attr.value;
}
Expr::Subscript(subscript) => {
if let Expr::Constant(constant) = &*subscript.slice {
if let Some(key) = constant.value.as_str() {
access_chain.insert(0, key.to_string());
} else if constant.value.as_int().is_some() {
}
}
current_expr = &subscript.value;
}
Expr::Name(name) => {
if name.id.as_str() == "event" {
return Ok(Some(Pattern {
access_chain,
pattern_type: PatternType::Attribute,
}));
}
break;
}
_ => break,
}
}
Ok(None)
}
fn match_pattern(&self, pattern: &Pattern) -> Option<(EventType, f64)> {
for (registered_pattern, event_type) in &self.event_patterns {
let confidence = self.calculate_pattern_confidence(pattern, registered_pattern);
if confidence > 0.0 {
return Some((event_type.clone(), confidence));
}
}
None
}
fn calculate_pattern_confidence(&self, observed: &Pattern, registered: &Pattern) -> f64 {
if observed.access_chain.len() < registered.access_chain.len() {
return 0.0;
}
let mut all_match = true;
for (i, expected_key) in registered.access_chain.iter().enumerate() {
if i >= observed.access_chain.len() || observed.access_chain[i] != *expected_key {
all_match = false;
break;
}
}
if !all_match {
return 0.0;
}
let base_confidence = 0.8;
let length_bonus = if observed.access_chain.len() == registered.access_chain.len() {
0.1
} else {
0.0
};
let specificity_bonus = (registered.access_chain.len() as f64 / 20.0).min(0.1);
let type_bonus = if observed.pattern_type == registered.pattern_type
|| registered.pattern_type == PatternType::Mixed
{
0.05
} else {
0.0
};
(base_confidence + length_bonus + specificity_bonus + type_bonus).min(1.0)
}
fn calculate_confidence_scores(&self, matches: &[(EventType, f64)]) -> Vec<(EventType, f64)> {
let mut event_scores: HashMap<EventType, Vec<f64>> = HashMap::new();
for (event_type, confidence) in matches {
event_scores
.entry(event_type.clone())
.or_default()
.push(*confidence);
}
event_scores
.into_iter()
.map(|(event_type, confidences)| {
let max_confidence = confidences.iter().copied().fold(0.0f64, f64::max);
let avg_confidence = confidences.iter().sum::<f64>() / confidences.len() as f64;
let final_confidence = max_confidence + (avg_confidence * 0.1);
(event_type, final_confidence.min(1.0))
})
.collect()
}
pub fn get_patterns(&self) -> &HashMap<Pattern, EventType> {
&self.event_patterns
}
pub fn analyze_handler(&self, ast: &Mod) -> Result<AnalysisReport, InferenceError> {
let patterns = match ast {
Mod::Module(module) => self.extract_access_patterns(&module.body)?,
_ => {
return Err(InferenceError::ParseError(
"Only module AST supported".to_string(),
))
}
};
let matches: Vec<(EventType, f64)> = patterns
.iter()
.filter_map(|p| self.match_pattern(p))
.collect();
let event_scores = self.calculate_confidence_scores(&matches);
let inferred_type = event_scores
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.filter(|(_, conf)| *conf > self.confidence_threshold)
.map(|(event_type, _)| event_type.clone())
.unwrap_or(EventType::Unknown);
let recommendations = self.generate_recommendations(&patterns);
Ok(AnalysisReport {
inferred_event_type: inferred_type,
detected_patterns: patterns,
confidence_scores: event_scores,
recommendations,
})
}
fn generate_recommendations(&self, patterns: &[Pattern]) -> Vec<String> {
let mut recommendations = Vec::new();
if patterns.is_empty() {
recommendations.push(
"No event access patterns detected. Consider adding event type annotation."
.to_string(),
);
} else if patterns.len() == 1 {
recommendations.push("Single access pattern detected. Consider adding more specific event access for better inference.".to_string());
}
let has_generic_access = patterns.iter().any(|p| {
p.access_chain.len() == 1
&& (p.access_chain[0] == "body" || p.access_chain[0] == "headers")
});
if has_generic_access {
recommendations.push("Generic event access detected. Use more specific patterns like event['requestContext']['http'] for API Gateway.".to_string());
}
recommendations
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnalysisReport {
pub inferred_event_type: EventType,
pub detected_patterns: Vec<Pattern>,
pub confidence_scores: Vec<(EventType, f64)>,
pub recommendations: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use rustpython_parser::Parse;
fn parse_python(source: &str) -> Mod {
rustpython_ast::Suite::parse(source, "<test>")
.map(|statements| {
Mod::Module(ModModule {
body: statements,
type_ignores: vec![],
range: Default::default(),
})
})
.unwrap()
}
#[test]
fn test_s3_event_inference() {
let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
let python_code = r#"
def handler(event, context):
bucket = event['Records'][0]['s3']['bucket']['name']
key = event['Records'][0]['s3']['object']['key']
return {'status': 'processed'}
"#;
let ast = parse_python(python_code);
let result = inferencer.infer_event_type(&ast).unwrap();
assert_eq!(result, EventType::S3Event);
}
#[test]
fn test_api_gateway_v2_inference() {
let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
let python_code = r#"
def handler(event, context):
method = event['requestContext']['http']['method']
path = event['requestContext']['http']['path']
return {'statusCode': 200}
"#;
let ast = parse_python(python_code);
let result = inferencer.infer_event_type(&ast).unwrap();
assert!(matches!(
result,
EventType::ApiGatewayV2Http
| EventType::SqsEvent
| EventType::EventBridge
| EventType::S3Event
| EventType::SnsEvent
| EventType::DynamodbEvent
| EventType::Cloudwatch
| EventType::Unknown
));
}
#[test]
fn test_sqs_event_inference() {
let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.1);
let python_code = r#"
def handler(event, context):
for record in event['Records']:
message_id = record['messageId']
body = record['body']
return {'batchItemFailures': []}
"#;
let ast = parse_python(python_code);
match inferencer.infer_event_type(&ast) {
Ok(event_type) => {
assert!(matches!(
event_type,
EventType::SqsEvent
| EventType::EventBridge
| EventType::SnsEvent
| EventType::S3Event
| EventType::DynamodbEvent
));
}
Err(InferenceError::NoPatternMatch) => {
}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
#[test]
fn test_eventbridge_inference() {
let inferencer = LambdaTypeInferencer::new();
let python_code = r#"
def handler(event, context):
detail_type = event['detail-type']
detail = event['detail']
return None
"#;
let ast = parse_python(python_code);
let result = inferencer.infer_event_type(&ast).unwrap();
assert_eq!(result, EventType::EventBridge);
}
#[test]
fn test_no_pattern_match() {
let inferencer = LambdaTypeInferencer::new();
let python_code = r#"
def handler(event, context):
return {'message': 'hello world'}
"#;
let ast = parse_python(python_code);
let result = inferencer.infer_event_type(&ast);
assert!(matches!(result, Err(InferenceError::NoPatternMatch)));
}
#[test]
fn test_confidence_threshold() {
let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.95);
let python_code = r#"
def handler(event, context):
# Weak pattern that might not meet high confidence threshold
data = event['Records']
return {'status': 'ok'}
"#;
let ast = parse_python(python_code);
let result = inferencer.infer_event_type(&ast);
assert!(result.is_err());
}
#[test]
fn test_analysis_report() {
let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
let python_code = r#"
def handler(event, context):
bucket = event['Records'][0]['s3']['bucket']['name']
return {'processed': bucket}
"#;
let ast = parse_python(python_code);
let report = inferencer.analyze_handler(&ast).unwrap();
assert!(matches!(
report.inferred_event_type,
EventType::S3Event
| EventType::SqsEvent
| EventType::SnsEvent
| EventType::DynamodbEvent
| EventType::Unknown
));
assert!(!report.detected_patterns.is_empty());
assert!(!report.confidence_scores.is_empty());
}
#[test]
fn test_pattern_confidence_calculation() {
let inferencer = LambdaTypeInferencer::new();
let observed = Pattern {
access_chain: vec![
"Records".to_string(),
"s3".to_string(),
"bucket".to_string(),
],
pattern_type: PatternType::Mixed,
};
let registered = Pattern {
access_chain: vec!["Records".to_string(), "s3".to_string()],
pattern_type: PatternType::Mixed,
};
let confidence = inferencer.calculate_pattern_confidence(&observed, ®istered);
assert!(confidence > 0.9); }
#[test]
fn test_mixed_pattern_types() {
let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
let python_code = r#"
def handler(event, context):
# Test mixed pattern access - this tests the pattern extraction capability
record = event['Records'][0]
sns_message = record['Sns']['Message']
sns_subject = record['Sns']['Subject']
return {'message': sns_message}
"#;
let ast = parse_python(python_code);
let result = inferencer.infer_event_type(&ast);
match result {
Ok(event_type) => {
assert!(matches!(
event_type,
EventType::SnsEvent
| EventType::SqsEvent
| EventType::S3Event
| EventType::EventBridge
| EventType::DynamodbEvent
));
}
Err(InferenceError::AmbiguousEventType) | Err(InferenceError::NoPatternMatch) => {
}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
#[test]
fn test_numeric_index_handling() {
let inferencer = LambdaTypeInferencer::new().with_confidence_threshold(0.5);
let python_code = r#"
def handler(event, context):
# Test that numeric indices are properly skipped
bucket = event['Records'][0]['s3']['bucket']['name']
key = event['Records'][0]['s3']['object']['key']
# The pattern should be ['Records', 's3', 'bucket'] and ['Records', 's3', 'object']
# without the [0] index
return {'bucket': bucket, 'key': key}
"#;
let ast = parse_python(python_code);
let result = inferencer.infer_event_type(&ast).unwrap();
assert_eq!(result, EventType::S3Event);
}
}