use crate::hir::{
AssignTarget, ConstGeneric, HirExpr, HirFunction, HirModule, HirStmt, Literal, Type,
};
use anyhow::Result;
use std::collections::{HashMap, HashSet};
pub struct ConstGenericInferencer {
const_values: HashMap<String, usize>,
const_params: HashSet<String>,
}
impl ConstGenericInferencer {
pub fn new() -> Self {
Self {
const_values: HashMap::new(),
const_params: HashSet::new(),
}
}
pub fn analyze_module(&mut self, module: &mut HirModule) -> Result<()> {
for function in &mut module.functions {
self.analyze_function(function)?;
}
Ok(())
}
pub fn analyze_function(&mut self, function: &mut HirFunction) -> Result<()> {
self.collect_const_values(function)?;
self.transform_function_types(function)?;
for stmt in &mut function.body {
self.transform_statement(stmt)?;
}
Ok(())
}
fn collect_const_values(&mut self, function: &HirFunction) -> Result<()> {
for param in &function.params {
if let Type::Int = param.ty {
}
}
for stmt in &function.body {
self.scan_statement_for_consts(stmt)?;
}
Ok(())
}
fn scan_statement_for_consts(&mut self, stmt: &HirStmt) -> Result<()> {
match stmt {
HirStmt::Assign {
target: AssignTarget::Symbol(symbol),
value,
..
} => self.scan_assign_for_const(symbol, value),
HirStmt::If {
then_body,
else_body,
..
} => self.scan_if_branches(then_body, else_body),
HirStmt::While { body, .. } | HirStmt::For { body, .. } => self.scan_stmt_block(body),
_ => Ok(()),
}
}
fn scan_assign_for_const(&mut self, symbol: &str, value: &HirExpr) -> Result<()> {
if let Some(size) = self.detect_fixed_size_pattern(value) {
self.const_values.insert(symbol.to_string(), size);
}
Ok(())
}
fn scan_if_branches(
&mut self,
then_body: &[HirStmt],
else_body: &Option<Vec<HirStmt>>,
) -> Result<()> {
self.scan_stmt_block(then_body)?;
if let Some(else_stmts) = else_body {
self.scan_stmt_block(else_stmts)?;
}
Ok(())
}
fn scan_stmt_block(&mut self, stmts: &[HirStmt]) -> Result<()> {
for stmt in stmts {
self.scan_statement_for_consts(stmt)?;
}
Ok(())
}
fn detect_fixed_size_pattern(&self, expr: &HirExpr) -> Option<usize> {
match expr {
HirExpr::Binary {
op: crate::hir::BinOp::Mul,
left,
right,
} => self.detect_multiply_pattern(left, right),
HirExpr::List(elements) => self.detect_literal_list_size(elements),
HirExpr::Call { func, args, .. } => self.detect_array_func_call(func, args),
_ => None,
}
}
fn detect_multiply_pattern(&self, left: &HirExpr, right: &HirExpr) -> Option<usize> {
self.check_list_times_int(left, right)
.or_else(|| self.check_list_times_int(right, left))
}
fn check_list_times_int(&self, list_side: &HirExpr, int_side: &HirExpr) -> Option<usize> {
if let (HirExpr::List(elements), HirExpr::Literal(Literal::Int(size))) =
(list_side, int_side)
{
if elements.len() == 1 && *size > 0 {
return Some(*size as usize);
}
}
None
}
fn detect_literal_list_size(&self, elements: &[HirExpr]) -> Option<usize> {
if !elements.is_empty() && elements.len() < 1000 {
Some(elements.len())
} else {
None
}
}
fn detect_array_func_call(&self, func: &str, args: &[HirExpr]) -> Option<usize> {
match func {
"zeros" | "ones" | "full" => {
if let Some(HirExpr::Literal(Literal::Int(size))) = args.first() {
if *size > 0 && *size < 1000 {
return Some(*size as usize);
}
}
None
}
_ => None,
}
}
fn transform_function_types(&mut self, _function: &mut HirFunction) -> Result<()> {
Ok(())
}
#[allow(dead_code)] fn infer_const_size_for_param(
&self,
param_name: &str,
function: &HirFunction,
) -> Option<usize> {
for stmt in &function.body {
if let Some(size) = self.find_const_usage_in_stmt(param_name, stmt) {
return Some(size);
}
}
None
}
#[allow(dead_code)] fn infer_const_size_for_return(&self, function: &HirFunction) -> Option<usize> {
let mut var_sizes = HashMap::new();
for stmt in &function.body {
if let HirStmt::Assign {
target: AssignTarget::Symbol(symbol),
value,
..
} = stmt
{
if let Some(size) = self.detect_fixed_size_pattern(value) {
var_sizes.insert(symbol.clone(), size);
}
}
}
let mutated_vars = self.detect_mutated_variables(function);
for stmt in &function.body {
if let HirStmt::Return(Some(expr)) = stmt {
if let Some(size) = self.detect_fixed_size_pattern(expr) {
return Some(size);
}
if let HirExpr::Var(var_name) = expr {
if mutated_vars.contains(var_name) {
return None;
}
if let Some(size) = var_sizes.get(var_name) {
return Some(*size);
}
}
}
}
None
}
#[allow(dead_code)] fn detect_mutated_variables(&self, function: &HirFunction) -> HashSet<String> {
let mut mutated = HashSet::new();
for stmt in &function.body {
self.scan_stmt_for_mutations(stmt, &mut mutated);
}
mutated
}
#[allow(dead_code)] fn scan_stmt_for_mutations(&self, stmt: &HirStmt, mutated: &mut HashSet<String>) {
match stmt {
HirStmt::Expr(expr) => {
self.scan_expr_for_mutations(expr, mutated);
}
HirStmt::Assign { value, .. } => {
self.scan_expr_for_mutations(value, mutated);
}
HirStmt::If {
condition,
then_body,
else_body,
} => {
self.scan_expr_for_mutations(condition, mutated);
for s in then_body {
self.scan_stmt_for_mutations(s, mutated);
}
if let Some(else_stmts) = else_body {
for s in else_stmts {
self.scan_stmt_for_mutations(s, mutated);
}
}
}
HirStmt::While { condition, body } => {
self.scan_expr_for_mutations(condition, mutated);
for s in body {
self.scan_stmt_for_mutations(s, mutated);
}
}
HirStmt::For { iter, body, .. } => {
self.scan_expr_for_mutations(iter, mutated);
for s in body {
self.scan_stmt_for_mutations(s, mutated);
}
}
HirStmt::Return(Some(expr)) => {
self.scan_expr_for_mutations(expr, mutated);
}
_ => {}
}
}
#[allow(clippy::only_used_in_recursion)]
#[allow(dead_code)] fn scan_expr_for_mutations(&self, expr: &HirExpr, mutated: &mut HashSet<String>) {
match expr {
HirExpr::MethodCall {
object,
method,
args,
..
} => {
if matches!(
method.as_str(),
"append"
| "extend"
| "insert"
| "remove"
| "pop"
| "clear"
| "reverse"
| "sort"
) {
if let HirExpr::Var(var_name) = &**object {
mutated.insert(var_name.clone());
}
}
self.scan_expr_for_mutations(object, mutated);
for arg in args {
self.scan_expr_for_mutations(arg, mutated);
}
}
HirExpr::Binary { left, right, .. } => {
self.scan_expr_for_mutations(left, mutated);
self.scan_expr_for_mutations(right, mutated);
}
HirExpr::Unary { operand, .. } => {
self.scan_expr_for_mutations(operand, mutated);
}
HirExpr::Call { args, .. } => {
for arg in args {
self.scan_expr_for_mutations(arg, mutated);
}
}
HirExpr::Index { base, index } => {
self.scan_expr_for_mutations(base, mutated);
self.scan_expr_for_mutations(index, mutated);
}
HirExpr::List(elements) => {
for elem in elements {
self.scan_expr_for_mutations(elem, mutated);
}
}
_ => {}
}
}
#[allow(dead_code)] fn find_const_usage_in_stmt(&self, param_name: &str, stmt: &HirStmt) -> Option<usize> {
match stmt {
HirStmt::Assign { value, .. } => self.find_const_usage_in_expr(param_name, value),
HirStmt::If {
condition: _,
then_body,
else_body,
} => {
for s in then_body {
if let Some(size) = self.find_const_usage_in_stmt(param_name, s) {
return Some(size);
}
}
if let Some(else_stmts) = else_body {
for s in else_stmts {
if let Some(size) = self.find_const_usage_in_stmt(param_name, s) {
return Some(size);
}
}
}
None
}
_ => None,
}
}
#[allow(dead_code)] fn find_const_usage_in_expr(&self, param_name: &str, expr: &HirExpr) -> Option<usize> {
match expr {
HirExpr::Binary {
op: crate::hir::BinOp::Eq,
left,
right,
} => self.find_len_equality_pattern(param_name, left, right),
HirExpr::Index { base, index } => self.find_index_pattern(param_name, base, index),
_ => None,
}
}
#[allow(dead_code)] fn find_len_equality_pattern(
&self,
param_name: &str,
left: &HirExpr,
right: &HirExpr,
) -> Option<usize> {
self.check_len_eq_side(param_name, left, right)
.or_else(|| self.check_len_eq_side(param_name, right, left))
}
#[allow(dead_code)] fn check_len_eq_side(
&self,
param_name: &str,
call_side: &HirExpr,
size_side: &HirExpr,
) -> Option<usize> {
if let (HirExpr::Call { func, args, .. }, HirExpr::Literal(Literal::Int(size))) =
(call_side, size_side)
{
if func == "len" && args.len() == 1 {
if let HirExpr::Var(var_name) = &args[0] {
if var_name == param_name && *size > 0 {
return Some(*size as usize);
}
}
}
}
None
}
#[allow(dead_code)] fn find_index_pattern(
&self,
param_name: &str,
base: &HirExpr,
index: &HirExpr,
) -> Option<usize> {
if let HirExpr::Var(var_name) = base {
if var_name == param_name {
if let HirExpr::Literal(Literal::Int(idx)) = index {
if *idx >= 0 {
return Some((*idx + 1) as usize);
}
}
}
}
None
}
fn transform_statement(&mut self, stmt: &mut HirStmt) -> Result<()> {
match stmt {
HirStmt::Assign { value, .. } => self.transform_expression(value),
HirStmt::Return(Some(expr)) => self.transform_expression(expr),
HirStmt::If {
condition,
then_body,
else_body,
} => self.transform_if_stmt(condition, then_body, else_body),
HirStmt::While { condition, body } => self.transform_while_stmt(condition, body),
HirStmt::For { iter, body, .. } => self.transform_for_stmt(iter, body),
_ => Ok(()),
}
}
fn transform_if_stmt(
&mut self,
condition: &mut HirExpr,
then_body: &mut [HirStmt],
else_body: &mut Option<Vec<HirStmt>>,
) -> Result<()> {
self.transform_expression(condition)?;
self.transform_stmt_block(then_body)?;
if let Some(else_stmts) = else_body {
self.transform_stmt_block(else_stmts)?;
}
Ok(())
}
fn transform_while_stmt(
&mut self,
condition: &mut HirExpr,
body: &mut [HirStmt],
) -> Result<()> {
self.transform_expression(condition)?;
self.transform_stmt_block(body)
}
fn transform_for_stmt(&mut self, iter: &mut HirExpr, body: &mut [HirStmt]) -> Result<()> {
self.transform_expression(iter)?;
self.transform_stmt_block(body)
}
fn transform_stmt_block(&mut self, stmts: &mut [HirStmt]) -> Result<()> {
for stmt in stmts {
self.transform_statement(stmt)?;
}
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn transform_expression(&mut self, expr: &mut HirExpr) -> Result<()> {
match expr {
HirExpr::List(elements) => self.transform_list_expr(elements),
HirExpr::Binary { left, right, .. } => self.transform_binary_expr(left, right),
HirExpr::Unary { operand, .. } => self.transform_expression(operand),
HirExpr::Call { args, .. } => self.transform_call_args(args),
HirExpr::MethodCall { object, args, .. } => self.transform_method_call(object, args),
HirExpr::Index { base, index } => self.transform_index_expr(base, index),
HirExpr::Slice {
base,
start,
stop,
step,
} => self.transform_slice_expr(base, start, stop, step),
HirExpr::Dict(pairs) => self.transform_dict_expr(pairs),
HirExpr::Tuple(elements) => self.transform_tuple_expr(elements),
HirExpr::Borrow { expr, .. } => self.transform_expression(expr),
HirExpr::ListComp {
element,
generators,
} => self.transform_list_comp(element, generators),
_ => Ok(()),
}
}
fn transform_list_expr(&mut self, elements: &mut [HirExpr]) -> Result<()> {
for elem in elements {
self.transform_expression(elem)?;
}
Ok(())
}
fn transform_binary_expr(&mut self, left: &mut HirExpr, right: &mut HirExpr) -> Result<()> {
self.transform_expression(left)?;
self.transform_expression(right)
}
fn transform_call_args(&mut self, args: &mut [HirExpr]) -> Result<()> {
for arg in args {
self.transform_expression(arg)?;
}
Ok(())
}
fn transform_method_call(&mut self, object: &mut HirExpr, args: &mut [HirExpr]) -> Result<()> {
self.transform_expression(object)?;
self.transform_call_args(args)
}
fn transform_index_expr(&mut self, base: &mut HirExpr, index: &mut HirExpr) -> Result<()> {
self.transform_expression(base)?;
self.transform_expression(index)
}
fn transform_slice_expr(
&mut self,
base: &mut HirExpr,
start: &mut Option<Box<HirExpr>>,
stop: &mut Option<Box<HirExpr>>,
step: &mut Option<Box<HirExpr>>,
) -> Result<()> {
self.transform_expression(base)?;
if let Some(start_expr) = start {
self.transform_expression(start_expr)?;
}
if let Some(stop_expr) = stop {
self.transform_expression(stop_expr)?;
}
if let Some(step_expr) = step {
self.transform_expression(step_expr)?;
}
Ok(())
}
fn transform_dict_expr(&mut self, pairs: &mut [(HirExpr, HirExpr)]) -> Result<()> {
for (k, v) in pairs {
self.transform_expression(k)?;
self.transform_expression(v)?;
}
Ok(())
}
fn transform_tuple_expr(&mut self, elements: &mut [HirExpr]) -> Result<()> {
for elem in elements {
self.transform_expression(elem)?;
}
Ok(())
}
fn transform_list_comp(
&mut self,
element: &mut HirExpr,
generators: &mut [crate::hir::HirComprehension],
) -> Result<()> {
self.transform_expression(element)?;
for gen in generators {
self.transform_expression(&mut gen.iter)?;
for cond in &mut gen.conditions {
self.transform_expression(cond)?;
}
}
Ok(())
}
pub fn get_const_params(&self) -> &HashSet<String> {
&self.const_params
}
pub fn should_convert_to_array(&self, _list_type: &Type) -> Option<(Type, ConstGeneric)> {
None }
}
impl Default for ConstGenericInferencer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hir::{
BinOp, FunctionProperties, HirComprehension, HirExpr, HirFunction, HirModule, HirParam,
HirStmt, UnaryOp,
};
use depyler_annotations::TranspilationAnnotations;
use smallvec::smallvec;
#[test]
fn test_new() {
let inferencer = ConstGenericInferencer::new();
assert!(inferencer.const_values.is_empty());
assert!(inferencer.const_params.is_empty());
}
#[test]
fn test_default() {
let inferencer = ConstGenericInferencer::default();
assert!(inferencer.const_values.is_empty());
assert!(inferencer.const_params.is_empty());
}
#[test]
fn test_detect_fixed_size_list() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::List(vec![
HirExpr::Literal(Literal::Int(1)),
HirExpr::Literal(Literal::Int(2)),
HirExpr::Literal(Literal::Int(3)),
]);
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), Some(3));
}
#[test]
fn test_detect_empty_list() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::List(vec![]);
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), None);
}
#[test]
fn test_detect_multiply_pattern() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Binary {
op: BinOp::Mul,
left: Box::new(HirExpr::List(vec![HirExpr::Literal(Literal::Int(0))])),
right: Box::new(HirExpr::Literal(Literal::Int(5))),
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), Some(5));
}
#[test]
fn test_detect_multiply_pattern_reverse() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Binary {
op: BinOp::Mul,
left: Box::new(HirExpr::Literal(Literal::Int(5))),
right: Box::new(HirExpr::List(vec![HirExpr::Literal(Literal::Int(0))])),
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), Some(5));
}
#[test]
fn test_detect_multiply_invalid() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Binary {
op: BinOp::Mul,
left: Box::new(HirExpr::List(vec![
HirExpr::Literal(Literal::Int(0)),
HirExpr::Literal(Literal::Int(1)),
])),
right: Box::new(HirExpr::Literal(Literal::Int(5))),
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), None);
}
#[test]
fn test_detect_multiply_zero_size() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Binary {
op: BinOp::Mul,
left: Box::new(HirExpr::List(vec![HirExpr::Literal(Literal::Int(0))])),
right: Box::new(HirExpr::Literal(Literal::Int(0))),
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), None);
}
#[test]
fn test_detect_zeros_call() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Call {
func: "zeros".to_string(),
args: vec![HirExpr::Literal(Literal::Int(10))],
kwargs: vec![],
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), Some(10));
}
#[test]
fn test_detect_ones_call() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Call {
func: "ones".to_string(),
args: vec![HirExpr::Literal(Literal::Int(8))],
kwargs: vec![],
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), Some(8));
}
#[test]
fn test_detect_full_call() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Call {
func: "full".to_string(),
args: vec![HirExpr::Literal(Literal::Int(15))],
kwargs: vec![],
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), Some(15));
}
#[test]
fn test_detect_array_call_too_large() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Call {
func: "zeros".to_string(),
args: vec![HirExpr::Literal(Literal::Int(1000))],
kwargs: vec![],
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), None);
}
#[test]
fn test_detect_unknown_func_call() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Call {
func: "unknown_func".to_string(),
args: vec![HirExpr::Literal(Literal::Int(10))],
kwargs: vec![],
};
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), None);
}
#[test]
fn test_detect_non_matching_pattern() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Var("x".to_string());
assert_eq!(inferencer.detect_fixed_size_pattern(&expr), None);
}
#[test]
fn test_len_equality_detection() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Binary {
op: BinOp::Eq,
left: Box::new(HirExpr::Call {
func: "len".to_string(),
args: vec![HirExpr::Var("arr".to_string())],
kwargs: vec![],
}),
right: Box::new(HirExpr::Literal(Literal::Int(5))),
};
assert_eq!(inferencer.find_const_usage_in_expr("arr", &expr), Some(5));
}
#[test]
fn test_len_equality_detection_reverse() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Binary {
op: BinOp::Eq,
left: Box::new(HirExpr::Literal(Literal::Int(5))),
right: Box::new(HirExpr::Call {
func: "len".to_string(),
args: vec![HirExpr::Var("arr".to_string())],
kwargs: vec![],
}),
};
assert_eq!(inferencer.find_const_usage_in_expr("arr", &expr), Some(5));
}
#[test]
fn test_len_equality_wrong_param() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Binary {
op: BinOp::Eq,
left: Box::new(HirExpr::Call {
func: "len".to_string(),
args: vec![HirExpr::Var("other".to_string())],
kwargs: vec![],
}),
right: Box::new(HirExpr::Literal(Literal::Int(5))),
};
assert_eq!(inferencer.find_const_usage_in_expr("arr", &expr), None);
}
#[test]
fn test_index_access_detection() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Index {
base: Box::new(HirExpr::Var("arr".to_string())),
index: Box::new(HirExpr::Literal(Literal::Int(4))),
};
assert_eq!(inferencer.find_const_usage_in_expr("arr", &expr), Some(5));
}
#[test]
fn test_index_access_negative() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Index {
base: Box::new(HirExpr::Var("arr".to_string())),
index: Box::new(HirExpr::Literal(Literal::Int(-1))),
};
assert_eq!(inferencer.find_const_usage_in_expr("arr", &expr), None);
}
#[test]
fn test_index_access_wrong_param() {
let inferencer = ConstGenericInferencer::new();
let expr = HirExpr::Index {
base: Box::new(HirExpr::Var("other".to_string())),
index: Box::new(HirExpr::Literal(Literal::Int(4))),
};
assert_eq!(inferencer.find_const_usage_in_expr("arr", &expr), None);
}
#[test]
fn test_analyze_empty_module() {
let mut inferencer = ConstGenericInferencer::new();
let mut module = HirModule {
functions: vec![],
classes: vec![],
imports: vec![],
type_aliases: vec![],
protocols: vec![],
constants: vec![],
top_level_stmts: vec![],
};
assert!(inferencer.analyze_module(&mut module).is_ok());
}
#[test]
fn test_analyze_simple_function() {
let mut inferencer = ConstGenericInferencer::new();
let mut function = HirFunction {
name: "test_fn".to_string(),
params: smallvec![],
ret_type: Type::None,
body: vec![HirStmt::Return(None)],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
assert!(inferencer.analyze_function(&mut function).is_ok());
}
#[test]
fn test_analyze_function_with_list_assign() {
let mut inferencer = ConstGenericInferencer::new();
let mut function = HirFunction {
name: "test_fn".to_string(),
params: smallvec![],
ret_type: Type::None,
body: vec![HirStmt::Assign {
target: AssignTarget::Symbol("x".to_string()),
value: HirExpr::List(vec![
HirExpr::Literal(Literal::Int(1)),
HirExpr::Literal(Literal::Int(2)),
]),
type_annotation: None,
}],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
assert!(inferencer.analyze_function(&mut function).is_ok());
assert!(inferencer.const_values.contains_key("x"));
assert_eq!(inferencer.const_values["x"], 2);
}
#[test]
fn test_scan_if_statement() {
let mut inferencer = ConstGenericInferencer::new();
let stmt = HirStmt::If {
condition: HirExpr::Literal(Literal::Bool(true)),
then_body: vec![HirStmt::Assign {
target: AssignTarget::Symbol("x".to_string()),
value: HirExpr::List(vec![HirExpr::Literal(Literal::Int(1))]),
type_annotation: None,
}],
else_body: Some(vec![HirStmt::Assign {
target: AssignTarget::Symbol("y".to_string()),
value: HirExpr::List(vec![
HirExpr::Literal(Literal::Int(1)),
HirExpr::Literal(Literal::Int(2)),
]),
type_annotation: None,
}]),
};
assert!(inferencer.scan_statement_for_consts(&stmt).is_ok());
assert!(inferencer.const_values.contains_key("x"));
assert!(inferencer.const_values.contains_key("y"));
}
#[test]
fn test_scan_while_statement() {
let mut inferencer = ConstGenericInferencer::new();
let stmt = HirStmt::While {
condition: HirExpr::Literal(Literal::Bool(true)),
body: vec![HirStmt::Assign {
target: AssignTarget::Symbol("x".to_string()),
value: HirExpr::List(vec![HirExpr::Literal(Literal::Int(1))]),
type_annotation: None,
}],
};
assert!(inferencer.scan_statement_for_consts(&stmt).is_ok());
assert!(inferencer.const_values.contains_key("x"));
}
#[test]
fn test_scan_for_statement() {
let mut inferencer = ConstGenericInferencer::new();
let stmt = HirStmt::For {
target: AssignTarget::Symbol("i".to_string()),
iter: HirExpr::Var("items".to_string()),
body: vec![HirStmt::Assign {
target: AssignTarget::Symbol("x".to_string()),
value: HirExpr::List(vec![HirExpr::Literal(Literal::Int(1))]),
type_annotation: None,
}],
};
assert!(inferencer.scan_statement_for_consts(&stmt).is_ok());
assert!(inferencer.const_values.contains_key("x"));
}
#[test]
fn test_scan_return_statement() {
let mut inferencer = ConstGenericInferencer::new();
let stmt = HirStmt::Return(Some(HirExpr::Literal(Literal::Int(42))));
assert!(inferencer.scan_statement_for_consts(&stmt).is_ok());
}
#[test]
fn test_transform_assign_statement() {
let mut inferencer = ConstGenericInferencer::new();
let mut stmt = HirStmt::Assign {
target: AssignTarget::Symbol("x".to_string()),
value: HirExpr::List(vec![HirExpr::Literal(Literal::Int(1))]),
type_annotation: None,
};
assert!(inferencer.transform_statement(&mut stmt).is_ok());
}
#[test]
fn test_transform_return_statement() {
let mut inferencer = ConstGenericInferencer::new();
let mut stmt = HirStmt::Return(Some(HirExpr::Literal(Literal::Int(42))));
assert!(inferencer.transform_statement(&mut stmt).is_ok());
}
#[test]
fn test_transform_if_statement() {
let mut inferencer = ConstGenericInferencer::new();
let mut stmt = HirStmt::If {
condition: HirExpr::Literal(Literal::Bool(true)),
then_body: vec![HirStmt::Return(Some(HirExpr::Literal(Literal::Int(1))))],
else_body: Some(vec![HirStmt::Return(Some(HirExpr::Literal(Literal::Int(
2,
))))]),
};
assert!(inferencer.transform_statement(&mut stmt).is_ok());
}
#[test]
fn test_transform_while_statement() {
let mut inferencer = ConstGenericInferencer::new();
let mut stmt = HirStmt::While {
condition: HirExpr::Literal(Literal::Bool(true)),
body: vec![HirStmt::Return(None)],
};
assert!(inferencer.transform_statement(&mut stmt).is_ok());
}
#[test]
fn test_transform_for_statement() {
let mut inferencer = ConstGenericInferencer::new();
let mut stmt = HirStmt::For {
target: AssignTarget::Symbol("i".to_string()),
iter: HirExpr::Var("items".to_string()),
body: vec![HirStmt::Return(None)],
};
assert!(inferencer.transform_statement(&mut stmt).is_ok());
}
#[test]
fn test_transform_list_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::List(vec![
HirExpr::Literal(Literal::Int(1)),
HirExpr::Literal(Literal::Int(2)),
]);
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_binary_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Binary {
op: BinOp::Add,
left: Box::new(HirExpr::Literal(Literal::Int(1))),
right: Box::new(HirExpr::Literal(Literal::Int(2))),
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_unary_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Unary {
op: UnaryOp::Neg,
operand: Box::new(HirExpr::Literal(Literal::Int(1))),
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_call_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Call {
func: "print".to_string(),
args: vec![HirExpr::Literal(Literal::String("hello".to_string()))],
kwargs: vec![],
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_method_call_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::MethodCall {
object: Box::new(HirExpr::Var("lst".to_string())),
method: "append".to_string(),
args: vec![HirExpr::Literal(Literal::Int(1))],
kwargs: vec![],
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_index_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Index {
base: Box::new(HirExpr::Var("arr".to_string())),
index: Box::new(HirExpr::Literal(Literal::Int(0))),
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_slice_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Slice {
base: Box::new(HirExpr::Var("arr".to_string())),
start: Some(Box::new(HirExpr::Literal(Literal::Int(0)))),
stop: Some(Box::new(HirExpr::Literal(Literal::Int(5)))),
step: None,
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_dict_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Dict(vec![(
HirExpr::Literal(Literal::String("key".to_string())),
HirExpr::Literal(Literal::Int(1)),
)]);
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_tuple_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Tuple(vec![
HirExpr::Literal(Literal::Int(1)),
HirExpr::Literal(Literal::String("a".to_string())),
]);
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_borrow_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::Borrow {
expr: Box::new(HirExpr::Var("x".to_string())),
mutable: false,
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_transform_list_comp_expr() {
let mut inferencer = ConstGenericInferencer::new();
let mut expr = HirExpr::ListComp {
element: Box::new(HirExpr::Var("x".to_string())),
generators: vec![HirComprehension {
target: "x".to_string(),
iter: Box::new(HirExpr::Var("items".to_string())),
conditions: vec![],
}],
};
assert!(inferencer.transform_expression(&mut expr).is_ok());
}
#[test]
fn test_get_const_params() {
let inferencer = ConstGenericInferencer::new();
let params = inferencer.get_const_params();
assert!(params.is_empty());
}
#[test]
fn test_should_convert_to_array() {
let inferencer = ConstGenericInferencer::new();
let list_type = Type::List(Box::new(Type::Int));
assert!(inferencer.should_convert_to_array(&list_type).is_none());
}
#[test]
fn test_find_const_usage_in_assign_stmt() {
let inferencer = ConstGenericInferencer::new();
let stmt = HirStmt::Assign {
target: AssignTarget::Symbol("x".to_string()),
value: HirExpr::Binary {
op: BinOp::Eq,
left: Box::new(HirExpr::Call {
func: "len".to_string(),
args: vec![HirExpr::Var("arr".to_string())],
kwargs: vec![],
}),
right: Box::new(HirExpr::Literal(Literal::Int(10))),
},
type_annotation: None,
};
assert_eq!(inferencer.find_const_usage_in_stmt("arr", &stmt), Some(10));
}
#[test]
fn test_find_const_usage_in_if_stmt() {
let inferencer = ConstGenericInferencer::new();
let stmt = HirStmt::If {
condition: HirExpr::Literal(Literal::Bool(true)),
then_body: vec![HirStmt::Assign {
target: AssignTarget::Symbol("x".to_string()),
value: HirExpr::Index {
base: Box::new(HirExpr::Var("arr".to_string())),
index: Box::new(HirExpr::Literal(Literal::Int(9))),
},
type_annotation: None,
}],
else_body: None,
};
assert_eq!(inferencer.find_const_usage_in_stmt("arr", &stmt), Some(10));
}
#[test]
#[ignore = "Incomplete feature: Const generic array inference not yet implemented"]
fn test_function_analysis() {
let mut inferencer = ConstGenericInferencer::new();
let mut function = HirFunction {
name: "process_array".to_string(),
params: smallvec![HirParam::new(
"arr".to_string(),
Type::List(Box::new(Type::Int))
)],
ret_type: Type::List(Box::new(Type::Int)),
body: vec![
HirStmt::Assign {
target: AssignTarget::Symbol("result".to_string()),
value: HirExpr::List(vec![
HirExpr::Literal(Literal::Int(0)),
HirExpr::Literal(Literal::Int(1)),
HirExpr::Literal(Literal::Int(2)),
]),
type_annotation: None,
},
HirStmt::Return(Some(HirExpr::Var("result".to_string()))),
],
properties: FunctionProperties::default(),
annotations: TranspilationAnnotations::default(),
docstring: None,
};
inferencer.analyze_function(&mut function).unwrap();
assert!(matches!(function.ret_type, Type::Array { .. }));
}
}