use crate::frontend::ast::{Expr, ExprKind, Param, Pattern};
use std::collections::HashMap;
pub fn inline_small_functions(expr: Expr) -> (Expr, std::collections::HashSet<String>) {
use std::collections::HashSet;
let mut functions = HashMap::new();
collect_inline_candidates(&expr, &mut functions);
let inline_candidates: HashSet<String> = functions.keys().cloned().collect();
let optimized = inline_function_calls(expr, &functions);
(optimized, inline_candidates)
}
fn collect_inline_candidates(expr: &Expr, functions: &mut HashMap<String, FunctionDef>) {
match &expr.kind {
ExprKind::Function {
name,
params,
body,
is_pub,
..
} => {
let body_size = estimate_body_size(body);
let is_recursive = check_recursion(name, body);
let accesses_globals = accesses_global_variables(params, body);
if body_size <= 10 && !is_recursive && !accesses_globals && !is_pub {
functions.insert(
name.clone(),
FunctionDef {
params: params.clone(),
body: body.clone(),
},
);
}
}
ExprKind::Block(exprs) => {
for e in exprs {
collect_inline_candidates(e, functions);
}
}
_ => {}
}
}
fn try_inline_call(
func: &Expr,
args: &[Expr],
functions: &HashMap<String, FunctionDef>,
) -> Option<Expr> {
if let ExprKind::Identifier(func_name) = &func.kind {
if let Some(func_def) = functions.get(func_name) {
return Some(substitute_params(&func_def.body, &func_def.params, args));
}
}
None
}
fn inline_function_calls(expr: Expr, functions: &HashMap<String, FunctionDef>) -> Expr {
match expr.kind {
ExprKind::Call { func, args } => {
if let Some(inlined) = try_inline_call(&func, &args, functions) {
return inlined;
}
Expr::new(
ExprKind::Call {
func: Box::new(inline_function_calls(*func, functions)),
args: args
.into_iter()
.map(|a| inline_function_calls(a, functions))
.collect(),
},
expr.span,
)
}
ExprKind::Block(exprs) => {
let inlined = exprs
.into_iter()
.map(|e| inline_function_calls(e, functions))
.collect();
Expr::new(ExprKind::Block(inlined), expr.span)
}
ExprKind::Let {
name,
type_annotation,
value,
body,
is_mutable,
else_block,
} => Expr::new(
ExprKind::Let {
name,
type_annotation,
value: Box::new(inline_function_calls(*value, functions)),
body: Box::new(inline_function_calls(*body, functions)),
is_mutable,
else_block: else_block.map(|e| Box::new(inline_function_calls(*e, functions))),
},
expr.span,
),
ExprKind::Function {
name,
type_params,
params,
return_type,
body,
is_async,
is_pub,
} => Expr::new(
ExprKind::Function {
name,
type_params,
params,
return_type,
body: Box::new(inline_function_calls(*body, functions)),
is_async,
is_pub,
},
expr.span,
),
ExprKind::Binary { left, op, right } => Expr::new(
ExprKind::Binary {
left: Box::new(inline_function_calls(*left, functions)),
op,
right: Box::new(inline_function_calls(*right, functions)),
},
expr.span,
),
ExprKind::If {
condition,
then_branch,
else_branch,
} => Expr::new(
ExprKind::If {
condition: Box::new(inline_function_calls(*condition, functions)),
then_branch: Box::new(inline_function_calls(*then_branch, functions)),
else_branch: else_branch.map(|e| Box::new(inline_function_calls(*e, functions))),
},
expr.span,
),
_ => expr,
}
}
fn substitute_params(body: &Expr, params: &[Param], args: &[Expr]) -> Expr {
let mut substitutions = HashMap::new();
for (param, arg) in params.iter().zip(args.iter()) {
if let crate::frontend::ast::Pattern::Identifier(param_name) = ¶m.pattern {
substitutions.insert(param_name.clone(), arg.clone());
}
}
substitute_identifiers(body.clone(), &substitutions)
}
fn substitute_identifiers(expr: Expr, subs: &HashMap<String, Expr>) -> Expr {
match &expr.kind {
ExprKind::Identifier(name) => {
subs.get(name).cloned().unwrap_or(expr)
}
ExprKind::Binary { left, op, right } => Expr::new(
ExprKind::Binary {
left: Box::new(substitute_identifiers((**left).clone(), subs)),
op: *op,
right: Box::new(substitute_identifiers((**right).clone(), subs)),
},
expr.span,
),
ExprKind::Call { func, args } => Expr::new(
ExprKind::Call {
func: Box::new(substitute_identifiers((**func).clone(), subs)),
args: args
.iter()
.map(|a| substitute_identifiers(a.clone(), subs))
.collect(),
},
expr.span,
),
ExprKind::Block(exprs) => Expr::new(
ExprKind::Block(
exprs
.iter()
.map(|e| substitute_identifiers(e.clone(), subs))
.collect(),
),
expr.span,
),
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
Expr::new(
ExprKind::If {
condition: Box::new(substitute_identifiers((**condition).clone(), subs)),
then_branch: Box::new(substitute_identifiers((**then_branch).clone(), subs)),
else_branch: else_branch
.as_ref()
.map(|e| Box::new(substitute_identifiers((**e).clone(), subs))),
},
expr.span,
)
}
ExprKind::Return { value } => {
Expr::new(
ExprKind::Return {
value: value
.as_ref()
.map(|v| Box::new(substitute_identifiers((**v).clone(), subs))),
},
expr.span,
)
}
ExprKind::While {
label,
condition,
body,
} => {
Expr::new(
ExprKind::While {
label: label.clone(),
condition: Box::new(substitute_identifiers((**condition).clone(), subs)),
body: Box::new(substitute_identifiers((**body).clone(), subs)),
},
expr.span,
)
}
_ => expr, }
}
fn estimate_body_size(body: &Expr) -> usize {
match &body.kind {
ExprKind::Block(exprs) => {
if exprs.is_empty() {
1 } else {
exprs.iter().map(estimate_body_size).sum()
}
}
ExprKind::Let { body, .. } => 1 + estimate_body_size(body),
ExprKind::If {
then_branch,
else_branch,
..
} => {
1 + estimate_body_size(then_branch)
+ else_branch.as_ref().map_or(0, |e| estimate_body_size(e))
}
ExprKind::For { body, .. } => 1 + estimate_body_size(body),
_ => 1, }
}
fn is_direct_recursion(func_name: &str, func_expr: &Expr) -> bool {
matches!(&func_expr.kind, ExprKind::Identifier(name) if name == func_name)
}
fn check_recursion(func_name: &str, body: &Expr) -> bool {
match &body.kind {
ExprKind::Call { func, .. } => is_direct_recursion(func_name, func),
ExprKind::Binary { left, right, .. } => {
check_recursion(func_name, left) || check_recursion(func_name, right)
}
ExprKind::Block(exprs) => exprs.iter().any(|e| check_recursion(func_name, e)),
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
check_recursion(func_name, condition)
|| check_recursion(func_name, then_branch)
|| else_branch
.as_ref()
.is_some_and(|e| check_recursion(func_name, e))
}
ExprKind::Let { value, body, .. } => {
check_recursion(func_name, value) || check_recursion(func_name, body)
}
ExprKind::Return { value } => value
.as_ref()
.is_some_and(|v| check_recursion(func_name, v)),
ExprKind::Match { expr, arms } => {
check_recursion(func_name, expr)
|| arms.iter().any(|arm| check_recursion(func_name, &arm.body))
}
_ => false,
}
}
fn accesses_global_variables(params: &[Param], body: &Expr) -> bool {
use std::collections::HashSet;
let param_names: HashSet<String> = params
.iter()
.map(|p| match &p.pattern {
Pattern::Identifier(name) => name.clone(),
_ => String::new(),
})
.collect();
check_for_external_refs(body, ¶m_names)
}
fn check_pair_external_refs(
left: &Expr,
right: &Expr,
allowed: &std::collections::HashSet<String>,
) -> bool {
check_for_external_refs(left, allowed) || check_for_external_refs(right, allowed)
}
fn check_let_external_refs(
name: &str,
value: &Expr,
body: &Expr,
allowed: &std::collections::HashSet<String>,
) -> bool {
if check_for_external_refs(value, allowed) {
return true;
}
let mut allowed_with_binding = allowed.clone();
allowed_with_binding.insert(name.to_string());
check_for_external_refs(body, &allowed_with_binding)
}
fn check_for_external_refs(expr: &Expr, allowed: &std::collections::HashSet<String>) -> bool {
match &expr.kind {
ExprKind::Identifier(name) => !allowed.contains(name),
ExprKind::Assign { target, value } => check_pair_external_refs(target, value, allowed),
ExprKind::CompoundAssign { target, value, .. } => {
check_pair_external_refs(target, value, allowed)
}
ExprKind::Binary { left, right, .. } => check_pair_external_refs(left, right, allowed),
ExprKind::Block(exprs) => exprs.iter().any(|e| check_for_external_refs(e, allowed)),
ExprKind::If {
condition,
then_branch,
else_branch,
} => {
check_for_external_refs(condition, allowed)
|| check_for_external_refs(then_branch, allowed)
|| else_branch
.as_ref()
.is_some_and(|e| check_for_external_refs(e, allowed))
}
ExprKind::Let {
name, value, body, ..
} => check_let_external_refs(name, value, body, allowed),
ExprKind::Return { value } => value
.as_ref()
.is_some_and(|v| check_for_external_refs(v, allowed)),
ExprKind::Call { func, args } => {
check_for_external_refs(func, allowed)
|| args.iter().any(|a| check_for_external_refs(a, allowed))
}
_ => false,
}
}
#[derive(Debug, Clone)]
struct FunctionDef {
params: Vec<Param>,
body: Box<Expr>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frontend::ast::{BinaryOp, Literal, Pattern, Span, Type, TypeKind};
#[test]
fn test_inline_simple_function() {
let func_def = Expr::new(
ExprKind::Function {
name: "add_one".to_string(),
type_params: vec![],
params: vec![Param {
pattern: Pattern::Identifier("x".to_string()),
ty: Type {
kind: TypeKind::Named("i32".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
}],
return_type: None,
body: Box::new(Expr::new(
ExprKind::Binary {
left: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::default(),
)),
op: BinaryOp::Add,
right: Box::new(Expr::new(
ExprKind::Literal(Literal::Integer(1, None)),
Span::default(),
)),
},
Span::default(),
)),
is_async: false,
is_pub: false,
},
Span::default(),
);
let call = Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("add_one".to_string()),
Span::default(),
)),
args: vec![Expr::new(
ExprKind::Literal(Literal::Integer(5, None)),
Span::default(),
)],
},
Span::default(),
);
let block = Expr::new(ExprKind::Block(vec![func_def, call]), Span::default());
let (result, _inlined) = inline_small_functions(block);
if let ExprKind::Block(exprs) = result.kind {
assert_eq!(exprs.len(), 2);
if let ExprKind::Binary { .. } = exprs[1].kind {
} else {
panic!("Expected inlined binary expression");
}
}
}
#[test]
fn test_inline_inside_function_body() {
let add_one_func = Expr::new(
ExprKind::Function {
name: "add_one".to_string(),
type_params: vec![],
params: vec![Param {
pattern: Pattern::Identifier("x".to_string()),
ty: Type {
kind: TypeKind::Named("i32".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
}],
return_type: None,
body: Box::new(Expr::new(
ExprKind::Binary {
left: Box::new(Expr::new(
ExprKind::Identifier("x".to_string()),
Span::default(),
)),
op: BinaryOp::Add,
right: Box::new(Expr::new(
ExprKind::Literal(Literal::Integer(1, None)),
Span::default(),
)),
},
Span::default(),
)),
is_async: false,
is_pub: false,
},
Span::default(),
);
let main_func = Expr::new(
ExprKind::Function {
name: "main".to_string(),
type_params: vec![],
params: vec![],
return_type: None,
body: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("add_one".to_string()),
Span::default(),
)),
args: vec![Expr::new(
ExprKind::Literal(Literal::Integer(5, None)),
Span::default(),
)],
},
Span::default(),
)),
is_async: false,
is_pub: false,
},
Span::default(),
);
let block = Expr::new(
ExprKind::Block(vec![add_one_func, main_func]),
Span::default(),
);
let (result, _inlined) = inline_small_functions(block);
if let ExprKind::Block(exprs) = &result.kind {
assert_eq!(exprs.len(), 2);
if let ExprKind::Function { name, body, .. } = &exprs[1].kind {
assert_eq!(name, "main");
if let ExprKind::Binary { .. } = body.kind {
} else {
panic!(
"Expected main's body to have inlined binary expression, got: {:?}",
body.kind
);
}
} else {
panic!("Expected second expression to be main function");
}
} else {
panic!("Expected block result");
}
}
#[test]
fn test_no_inline_recursive() {
let recursive_func = Expr::new(
ExprKind::Function {
name: "factorial".to_string(),
type_params: vec![],
params: vec![],
return_type: None,
body: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::Identifier("factorial".to_string()),
Span::default(),
)),
args: vec![],
},
Span::default(),
)),
is_async: false,
is_pub: false,
},
Span::default(),
);
let mut functions = HashMap::new();
collect_inline_candidates(&recursive_func, &mut functions);
assert!(!functions.contains_key("factorial"));
}
fn int_lit(n: i64) -> Expr {
Expr::new(
ExprKind::Literal(Literal::Integer(n, None)),
Span::default(),
)
}
fn ident(name: &str) -> Expr {
Expr::new(ExprKind::Identifier(name.to_string()), Span::default())
}
fn binary(left: Expr, op: BinaryOp, right: Expr) -> Expr {
Expr::new(
ExprKind::Binary {
left: Box::new(left),
op,
right: Box::new(right),
},
Span::default(),
)
}
#[test]
fn test_estimate_body_size_single() {
let expr = int_lit(42);
assert_eq!(estimate_body_size(&expr), 1);
}
#[test]
fn test_estimate_body_size_block() {
let block = Expr::new(
ExprKind::Block(vec![int_lit(1), int_lit(2), int_lit(3)]),
Span::default(),
);
assert_eq!(estimate_body_size(&block), 3);
}
#[test]
fn test_estimate_body_size_let() {
let let_expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(5)),
body: Box::new(int_lit(10)),
is_mutable: false,
else_block: None,
},
Span::default(),
);
assert_eq!(estimate_body_size(&let_expr), 2); }
#[test]
fn test_estimate_body_size_if_else() {
let if_expr = Expr::new(
ExprKind::If {
condition: Box::new(ident("x")),
then_branch: Box::new(int_lit(1)),
else_branch: Some(Box::new(int_lit(2))),
},
Span::default(),
);
assert_eq!(estimate_body_size(&if_expr), 3); }
#[test]
fn test_check_recursion_direct() {
let call = Expr::new(
ExprKind::Call {
func: Box::new(ident("factorial")),
args: vec![],
},
Span::default(),
);
assert!(check_recursion("factorial", &call));
}
#[test]
fn test_check_recursion_binary() {
let expr = binary(
Expr::new(
ExprKind::Call {
func: Box::new(ident("fib")),
args: vec![],
},
Span::default(),
),
BinaryOp::Add,
int_lit(1),
);
assert!(check_recursion("fib", &expr));
}
#[test]
fn test_check_recursion_if() {
let if_expr = Expr::new(
ExprKind::If {
condition: Box::new(ident("x")),
then_branch: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(ident("recurse")),
args: vec![],
},
Span::default(),
)),
else_branch: None,
},
Span::default(),
);
assert!(check_recursion("recurse", &if_expr));
}
#[test]
fn test_check_recursion_none() {
let expr = binary(int_lit(1), BinaryOp::Add, int_lit(2));
assert!(!check_recursion("foo", &expr));
}
#[test]
fn test_accesses_global_variables_none() {
let params = vec![Param {
pattern: Pattern::Identifier("x".to_string()),
ty: Type {
kind: TypeKind::Named("i32".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
}];
let body = ident("x");
assert!(!accesses_global_variables(¶ms, &body));
}
#[test]
fn test_accesses_global_variables_global() {
let params = vec![Param {
pattern: Pattern::Identifier("x".to_string()),
ty: Type {
kind: TypeKind::Named("i32".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
}];
let body = binary(ident("x"), BinaryOp::Add, ident("global_var"));
assert!(accesses_global_variables(¶ms, &body));
}
#[test]
fn test_accesses_global_variables_let_binding() {
let params = vec![];
let body = Expr::new(
ExprKind::Let {
name: "y".to_string(),
type_annotation: None,
value: Box::new(int_lit(5)),
body: Box::new(ident("y")), is_mutable: false,
else_block: None,
},
Span::default(),
);
assert!(!accesses_global_variables(¶ms, &body));
}
#[test]
fn test_check_for_external_refs_allowed() {
use std::collections::HashSet;
let mut allowed = HashSet::new();
allowed.insert("x".to_string());
assert!(!check_for_external_refs(&ident("x"), &allowed));
}
#[test]
fn test_check_for_external_refs_external() {
use std::collections::HashSet;
let allowed = HashSet::new();
assert!(check_for_external_refs(&ident("external"), &allowed));
}
#[test]
fn test_check_for_external_refs_compound_assign() {
use std::collections::HashSet;
let mut allowed = HashSet::new();
allowed.insert("total".to_string());
let expr = Expr::new(
ExprKind::CompoundAssign {
target: Box::new(ident("total")),
op: BinaryOp::Add,
value: Box::new(ident("x")), },
Span::default(),
);
assert!(check_for_external_refs(&expr, &allowed));
}
#[test]
fn test_substitute_identifiers_identifier() {
let mut subs = HashMap::new();
subs.insert("x".to_string(), int_lit(42));
let result = substitute_identifiers(ident("x"), &subs);
assert!(matches!(
result.kind,
ExprKind::Literal(Literal::Integer(42, None))
));
}
#[test]
fn test_substitute_identifiers_binary() {
let mut subs = HashMap::new();
subs.insert("x".to_string(), int_lit(5));
let expr = binary(ident("x"), BinaryOp::Add, int_lit(1));
let result = substitute_identifiers(expr, &subs);
if let ExprKind::Binary { left, .. } = result.kind {
assert!(matches!(
left.kind,
ExprKind::Literal(Literal::Integer(5, None))
));
} else {
panic!("Expected binary expression");
}
}
#[test]
fn test_substitute_identifiers_call() {
let mut subs = HashMap::new();
subs.insert("x".to_string(), int_lit(10));
let expr = Expr::new(
ExprKind::Call {
func: Box::new(ident("foo")),
args: vec![ident("x")],
},
Span::default(),
);
let result = substitute_identifiers(expr, &subs);
if let ExprKind::Call { args, .. } = result.kind {
assert!(matches!(
args[0].kind,
ExprKind::Literal(Literal::Integer(10, None))
));
} else {
panic!("Expected call expression");
}
}
#[test]
fn test_substitute_identifiers_block() {
let mut subs = HashMap::new();
subs.insert("x".to_string(), int_lit(7));
let block = Expr::new(
ExprKind::Block(vec![ident("x"), ident("x")]),
Span::default(),
);
let result = substitute_identifiers(block, &subs);
if let ExprKind::Block(exprs) = result.kind {
assert_eq!(exprs.len(), 2);
assert!(matches!(
exprs[0].kind,
ExprKind::Literal(Literal::Integer(7, None))
));
assert!(matches!(
exprs[1].kind,
ExprKind::Literal(Literal::Integer(7, None))
));
} else {
panic!("Expected block");
}
}
#[test]
fn test_substitute_identifiers_if() {
let mut subs = HashMap::new();
subs.insert("x".to_string(), int_lit(1));
let if_expr = Expr::new(
ExprKind::If {
condition: Box::new(ident("x")),
then_branch: Box::new(ident("x")),
else_branch: Some(Box::new(int_lit(0))),
},
Span::default(),
);
let result = substitute_identifiers(if_expr, &subs);
if let ExprKind::If {
condition,
then_branch,
..
} = result.kind
{
assert!(matches!(
condition.kind,
ExprKind::Literal(Literal::Integer(1, None))
));
assert!(matches!(
then_branch.kind,
ExprKind::Literal(Literal::Integer(1, None))
));
} else {
panic!("Expected if expression");
}
}
#[test]
fn test_substitute_identifiers_return() {
let mut subs = HashMap::new();
subs.insert("result".to_string(), int_lit(99));
let ret = Expr::new(
ExprKind::Return {
value: Some(Box::new(ident("result"))),
},
Span::default(),
);
let result = substitute_identifiers(ret, &subs);
if let ExprKind::Return { value } = result.kind {
let val = value.unwrap();
assert!(matches!(
val.kind,
ExprKind::Literal(Literal::Integer(99, None))
));
} else {
panic!("Expected return expression");
}
}
#[test]
fn test_collect_inline_candidates_too_large() {
let mut functions = HashMap::new();
let large_body = Expr::new(
ExprKind::Block(vec![int_lit(1); 15]), Span::default(),
);
let func = Expr::new(
ExprKind::Function {
name: "large".to_string(),
type_params: vec![],
params: vec![],
body: Box::new(large_body),
return_type: None,
is_pub: false,
is_async: false,
},
Span::default(),
);
collect_inline_candidates(&func, &mut functions);
assert!(functions.is_empty()); }
#[test]
fn test_collect_inline_candidates_pub_function() {
let mut functions = HashMap::new();
let func = Expr::new(
ExprKind::Function {
name: "pub_fn".to_string(),
type_params: vec![],
params: vec![],
body: Box::new(int_lit(42)),
return_type: None,
is_pub: true, is_async: false,
},
Span::default(),
);
collect_inline_candidates(&func, &mut functions);
assert!(functions.is_empty()); }
#[test]
fn test_collect_inline_candidates_nested_blocks() {
let mut functions = HashMap::new();
let inner_func = Expr::new(
ExprKind::Function {
name: "inner".to_string(),
type_params: vec![],
params: vec![],
body: Box::new(int_lit(10)),
return_type: None,
is_pub: false,
is_async: false,
},
Span::default(),
);
let block = Expr::new(ExprKind::Block(vec![inner_func]), Span::default());
collect_inline_candidates(&block, &mut functions);
assert_eq!(functions.len(), 1);
assert!(functions.contains_key("inner"));
}
#[test]
fn test_inline_function_calls_method_call() {
let functions = HashMap::new();
let method_call = Expr::new(
ExprKind::Call {
func: Box::new(Expr::new(
ExprKind::FieldAccess {
object: Box::new(ident("obj")),
field: "method".to_string(),
},
Span::default(),
)),
args: vec![],
},
Span::default(),
);
let result = inline_function_calls(method_call, &functions);
assert!(matches!(result.kind, ExprKind::Call { .. }));
}
#[test]
fn test_inline_function_calls_undefined() {
let functions = HashMap::new(); let call = Expr::new(
ExprKind::Call {
func: Box::new(ident("undefined")),
args: vec![int_lit(1)],
},
Span::default(),
);
let result = inline_function_calls(call, &functions);
if let ExprKind::Call { func, .. } = result.kind {
assert!(matches!(func.kind, ExprKind::Identifier(ref name) if name == "undefined"));
} else {
panic!("Expected Call expression");
}
}
#[test]
fn test_substitute_params_empty() {
let body = int_lit(42);
let params = vec![];
let args = vec![];
let result = substitute_params(&body, ¶ms, &args);
assert!(matches!(
result.kind,
ExprKind::Literal(Literal::Integer(42, None))
));
}
#[test]
fn test_substitute_params_multiple() {
use crate::frontend::ast::{Param, Pattern};
let body = binary(ident("a"), BinaryOp::Add, ident("b"));
let params = vec![
Param {
pattern: Pattern::Identifier("a".to_string()),
ty: crate::frontend::ast::Type {
kind: crate::frontend::ast::TypeKind::Named("i32".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
},
Param {
pattern: Pattern::Identifier("b".to_string()),
ty: crate::frontend::ast::Type {
kind: crate::frontend::ast::TypeKind::Named("i32".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
},
];
let args = vec![int_lit(10), int_lit(20)];
let result = substitute_params(&body, ¶ms, &args);
if let ExprKind::Binary { left, right, .. } = result.kind {
assert!(matches!(
left.kind,
ExprKind::Literal(Literal::Integer(10, None))
));
assert!(matches!(
right.kind,
ExprKind::Literal(Literal::Integer(20, None))
));
} else {
panic!("Expected binary expression");
}
}
#[test]
fn test_inline_small_functions_returns_inlined_set() {
use crate::frontend::ast::{Param, Pattern, Type, TypeKind};
let x_param = Param {
pattern: Pattern::Identifier("x".to_string()),
ty: Type {
kind: TypeKind::Named("_".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
};
let add_fn = Expr::new(
ExprKind::Function {
name: "add_one".to_string(),
type_params: vec![],
params: vec![x_param],
body: Box::new(binary(ident("x"), BinaryOp::Add, int_lit(1))),
return_type: None,
is_pub: false,
is_async: false,
},
Span::default(),
);
let call = Expr::new(
ExprKind::Call {
func: Box::new(ident("add_one")),
args: vec![int_lit(5)],
},
Span::default(),
);
let program = Expr::new(ExprKind::Block(vec![add_fn, call]), Span::default());
let (result, inlined) = inline_small_functions(program);
assert!(inlined.contains("add_one"));
assert!(matches!(result.kind, ExprKind::Block(_)));
}
#[test]
fn test_check_for_external_refs_let() {
let mut allowed = std::collections::HashSet::new();
allowed.insert("x".to_string());
let let_expr = Expr::new(
ExprKind::Let {
name: "y".to_string(),
type_annotation: None,
value: Box::new(ident("x")),
body: Box::new(ident("y")),
is_mutable: false,
else_block: None,
},
Span::default(),
);
let result = check_for_external_refs(&let_expr, &allowed);
assert!(!result); }
#[test]
fn test_check_for_external_refs_match() {
use crate::frontend::ast::{MatchArm, Pattern};
let mut allowed = std::collections::HashSet::new();
allowed.insert("value".to_string());
let match_expr = Expr::new(
ExprKind::Match {
expr: Box::new(ident("value")),
arms: vec![MatchArm {
pattern: Pattern::Literal(Literal::Integer(1, None)),
guard: None,
body: Box::new(ident("value")),
span: Span::default(),
}],
},
Span::default(),
);
let result = check_for_external_refs(&match_expr, &allowed);
assert!(!result); }
#[test]
fn test_estimate_body_size_nested() {
let nested = Expr::new(
ExprKind::Block(vec![
Expr::new(
ExprKind::Block(vec![int_lit(1), int_lit(2)]),
Span::default(),
),
int_lit(3),
]),
Span::default(),
);
let size = estimate_body_size(&nested);
assert_eq!(size, 3); }
#[test]
fn test_estimate_body_size_for_loop() {
let for_loop = Expr::new(
ExprKind::For {
label: None,
var: "i".to_string(),
pattern: None,
iter: Box::new(ident("items")),
body: Box::new(int_lit(1)),
},
Span::default(),
);
let size = estimate_body_size(&for_loop);
assert_eq!(size, 2); }
#[test]
fn test_check_recursion_in_match() {
use crate::frontend::ast::{MatchArm, Pattern};
let match_expr = Expr::new(
ExprKind::Match {
expr: Box::new(int_lit(1)),
arms: vec![MatchArm {
pattern: Pattern::Literal(Literal::Integer(1, None)),
guard: None,
body: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(ident("factorial")),
args: vec![int_lit(1)],
},
Span::default(),
)),
span: Span::default(),
}],
},
Span::default(),
);
assert!(check_recursion("factorial", &match_expr));
}
#[test]
fn test_accesses_global_variables_nested_block() {
let params = vec![];
let body = Expr::new(
ExprKind::Block(vec![Expr::new(
ExprKind::Let {
name: "local".to_string(),
type_annotation: None,
value: Box::new(ident("global")),
body: Box::new(ident("local")),
is_mutable: false,
else_block: None,
},
Span::default(),
)]),
Span::default(),
);
assert!(accesses_global_variables(¶ms, &body));
}
#[test]
fn test_substitute_identifiers_while() {
let mut subs = HashMap::new();
subs.insert("n".to_string(), int_lit(10));
let while_loop = Expr::new(
ExprKind::While {
label: None,
condition: Box::new(binary(ident("n"), BinaryOp::Gt, int_lit(0))),
body: Box::new(ident("n")),
},
Span::default(),
);
let result = substitute_identifiers(while_loop, &subs);
if let ExprKind::While {
label: _,
condition,
body,
} = result.kind
{
if let ExprKind::Binary { left, .. } = condition.kind {
assert!(matches!(
left.kind,
ExprKind::Literal(Literal::Integer(10, None))
));
} else {
panic!("Expected binary in condition");
}
assert!(matches!(
body.kind,
ExprKind::Literal(Literal::Integer(10, None))
));
} else {
panic!("Expected while expression");
}
}
#[test]
fn test_estimate_body_size_empty_block() {
let block = Expr::new(ExprKind::Block(vec![]), Span::default());
assert_eq!(estimate_body_size(&block), 1);
}
#[test]
fn test_estimate_body_size_if_without_else() {
let if_expr = Expr::new(
ExprKind::If {
condition: Box::new(ident("cond")),
then_branch: Box::new(int_lit(1)),
else_branch: None,
},
Span::default(),
);
assert_eq!(estimate_body_size(&if_expr), 2); }
#[test]
fn test_check_recursion_in_block() {
let block = Expr::new(
ExprKind::Block(vec![Expr::new(
ExprKind::Call {
func: Box::new(ident("recurse")),
args: vec![],
},
Span::default(),
)]),
Span::default(),
);
assert!(check_recursion("recurse", &block));
}
#[test]
fn test_check_recursion_in_let_value() {
let let_expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(Expr::new(
ExprKind::Call {
func: Box::new(ident("recurse")),
args: vec![],
},
Span::default(),
)),
body: Box::new(ident("x")),
is_mutable: false,
else_block: None,
},
Span::default(),
);
assert!(check_recursion("recurse", &let_expr));
}
#[test]
fn test_check_recursion_in_return() {
let ret = Expr::new(
ExprKind::Return {
value: Some(Box::new(Expr::new(
ExprKind::Call {
func: Box::new(ident("factorial")),
args: vec![int_lit(5)],
},
Span::default(),
))),
},
Span::default(),
);
assert!(check_recursion("factorial", &ret));
}
#[test]
fn test_check_for_external_refs_call() {
use std::collections::HashSet;
let mut allowed = HashSet::new();
allowed.insert("x".to_string());
let call_expr = Expr::new(
ExprKind::Call {
func: Box::new(ident("foo")),
args: vec![ident("y")], },
Span::default(),
);
assert!(check_for_external_refs(&call_expr, &allowed));
}
#[test]
fn test_check_for_external_refs_call_all_allowed() {
use std::collections::HashSet;
let mut allowed = HashSet::new();
allowed.insert("foo".to_string());
allowed.insert("x".to_string());
let call_expr = Expr::new(
ExprKind::Call {
func: Box::new(ident("foo")),
args: vec![ident("x")],
},
Span::default(),
);
assert!(!check_for_external_refs(&call_expr, &allowed));
}
#[test]
fn test_check_for_external_refs_if() {
use std::collections::HashSet;
let mut allowed = HashSet::new();
allowed.insert("cond".to_string());
let if_expr = Expr::new(
ExprKind::If {
condition: Box::new(ident("cond")),
then_branch: Box::new(ident("external")),
else_branch: None,
},
Span::default(),
);
assert!(check_for_external_refs(&if_expr, &allowed));
}
#[test]
fn test_check_for_external_refs_return() {
use std::collections::HashSet;
let mut allowed = HashSet::new();
allowed.insert("x".to_string());
let ret_expr = Expr::new(
ExprKind::Return {
value: Some(Box::new(ident("external"))),
},
Span::default(),
);
assert!(check_for_external_refs(&ret_expr, &allowed));
}
#[test]
fn test_inline_function_calls_let_else() {
let functions = HashMap::new();
let let_expr = Expr::new(
ExprKind::Let {
name: "x".to_string(),
type_annotation: None,
value: Box::new(int_lit(5)),
body: Box::new(ident("x")),
is_mutable: false,
else_block: Some(Box::new(int_lit(0))),
},
Span::default(),
);
let result = inline_function_calls(let_expr, &functions);
assert!(matches!(result.kind, ExprKind::Let { .. }));
}
#[test]
fn test_inline_function_calls_if_with_else() {
let functions = HashMap::new();
let if_expr = Expr::new(
ExprKind::If {
condition: Box::new(ident("cond")),
then_branch: Box::new(int_lit(1)),
else_branch: Some(Box::new(int_lit(0))),
},
Span::default(),
);
let result = inline_function_calls(if_expr, &functions);
assert!(matches!(result.kind, ExprKind::If { .. }));
}
#[test]
fn test_collect_inline_candidates_function_with_global() {
let mut functions = HashMap::new();
let func = Expr::new(
ExprKind::Function {
name: "uses_global".to_string(),
type_params: vec![],
params: vec![],
body: Box::new(ident("global_var")), return_type: None,
is_pub: false,
is_async: false,
},
Span::default(),
);
collect_inline_candidates(&func, &mut functions);
assert!(functions.is_empty()); }
#[test]
fn test_accesses_global_variables_if_condition() {
let params = vec![Param {
pattern: Pattern::Identifier("x".to_string()),
ty: Type {
kind: TypeKind::Named("i32".to_string()),
span: Span::default(),
},
span: Span::default(),
is_mutable: false,
default_value: None,
}];
let body = Expr::new(
ExprKind::If {
condition: Box::new(ident("global")), then_branch: Box::new(ident("x")),
else_branch: None,
},
Span::default(),
);
assert!(accesses_global_variables(¶ms, &body));
}
#[test]
fn test_substitute_identifiers_non_matching() {
let subs = HashMap::new();
let expr = ident("x");
let result = substitute_identifiers(expr, &subs);
assert!(matches!(result.kind, ExprKind::Identifier(name) if name == "x"));
}
#[test]
fn test_try_inline_call_not_identifier_func() {
let functions = HashMap::new();
let method_expr = Expr::new(
ExprKind::FieldAccess {
object: Box::new(ident("obj")),
field: "method".to_string(),
},
Span::default(),
);
let result = try_inline_call(&method_expr, &[], &functions);
assert!(result.is_none());
}
#[test]
fn test_is_direct_recursion_true() {
let func_expr = ident("factorial");
assert!(is_direct_recursion("factorial", &func_expr));
}
#[test]
fn test_is_direct_recursion_false() {
let func_expr = ident("other_func");
assert!(!is_direct_recursion("factorial", &func_expr));
}
}