use decy_codegen::CodeGenerator;
use decy_hir::{BinaryOperator, HirExpression, HirFunction, HirParameter, HirStatement, HirType};
fn create_function(
name: &str,
params: Vec<HirParameter>,
return_type: HirType,
body: Vec<HirStatement>,
) -> HirFunction {
HirFunction::new_with_body(name.to_string(), return_type, params, body)
}
#[test]
fn test_recursive_multiply_correct() {
let func = create_function(
"factorial",
vec![HirParameter::new("n".to_string(), HirType::Int)],
HirType::Int,
vec![HirStatement::Return(Some(HirExpression::BinaryOp {
op: BinaryOperator::Multiply,
left: Box::new(HirExpression::Variable("n".to_string())),
right: Box::new(HirExpression::FunctionCall {
function: "factorial".to_string(),
arguments: vec![HirExpression::BinaryOp {
op: BinaryOperator::Subtract,
left: Box::new(HirExpression::Variable("n".to_string())),
right: Box::new(HirExpression::IntLiteral(1)),
}],
}),
}))],
);
let generator = CodeGenerator::new();
let code = generator.generate_function(&func);
assert!(
code.contains("n * factorial"),
"Should have n * factorial(...):\n{}",
code
);
assert!(
!code.contains("n - factorial"),
"Should NOT have n - factorial (bug):\n{}",
code
);
}
#[test]
fn test_simple_multiply_correct() {
let func = create_function(
"mul",
vec![
HirParameter::new("a".to_string(), HirType::Int),
HirParameter::new("b".to_string(), HirType::Int),
],
HirType::Int,
vec![HirStatement::Return(Some(HirExpression::BinaryOp {
op: BinaryOperator::Multiply,
left: Box::new(HirExpression::Variable("a".to_string())),
right: Box::new(HirExpression::Variable("b".to_string())),
}))],
);
let generator = CodeGenerator::new();
let code = generator.generate_function(&func);
assert!(code.contains("a * b"), "Should generate a * b:\n{}", code);
}
#[test]
fn test_nested_multiply_correct() {
let func = create_function(
"triple",
vec![
HirParameter::new("a".to_string(), HirType::Int),
HirParameter::new("b".to_string(), HirType::Int),
HirParameter::new("c".to_string(), HirType::Int),
],
HirType::Int,
vec![HirStatement::Return(Some(HirExpression::BinaryOp {
op: BinaryOperator::Multiply,
left: Box::new(HirExpression::BinaryOp {
op: BinaryOperator::Multiply,
left: Box::new(HirExpression::Variable("a".to_string())),
right: Box::new(HirExpression::Variable("b".to_string())),
}),
right: Box::new(HirExpression::Variable("c".to_string())),
}))],
);
let generator = CodeGenerator::new();
let code = generator.generate_function(&func);
assert!(
code.contains("*") && code.matches('*').count() >= 2,
"Should have two multiply operators:\n{}",
code
);
}
#[test]
fn test_multiply_func_result() {
let func = create_function(
"compute",
vec![HirParameter::new("x".to_string(), HirType::Int)],
HirType::Int,
vec![HirStatement::Return(Some(HirExpression::BinaryOp {
op: BinaryOperator::Multiply,
left: Box::new(HirExpression::Variable("x".to_string())),
right: Box::new(HirExpression::FunctionCall {
function: "get_value".to_string(),
arguments: vec![],
}),
}))],
);
let generator = CodeGenerator::new();
let code = generator.generate_function(&func);
assert!(
code.contains("x * get_value"),
"Should have x * get_value():\n{}",
code
);
}