use tl_compiler::{Vm, VmValue, compile, compile_with_source};
use tl_parser::parse;
fn run(src: &str) -> Result<VmValue, String> {
let program = parse(src).map_err(|e| format!("Parse error: {e}"))?;
let proto = compile(&program).map_err(|e| format!("Compile error: {e}"))?;
let mut vm = Vm::new();
vm.execute(&proto)
.map_err(|e| format!("Runtime error: {e}"))
}
fn disasm(src: &str) -> String {
let program = parse(src).unwrap();
let proto = compile_with_source(&program, src).unwrap();
proto.disassemble()
}
fn assert_int(val: &VmValue, expected: i64) {
if let VmValue::Int(n) = val {
assert_eq!(*n, expected, "Expected Int({expected}), got Int({n})");
} else {
panic!("Expected Int({expected}), got {val}");
}
}
fn assert_bool(val: &VmValue, expected: bool) {
if let VmValue::Bool(b) = val {
assert_eq!(*b, expected, "Expected Bool({expected}), got Bool({b})");
} else {
panic!("Expected Bool({expected}), got {val}");
}
}
#[test]
fn test_constant_folding_arithmetic() {
let bytecode = disasm("let x = 2 + 3 * 4\nprint(x)");
assert!(bytecode.contains("14"), "Should fold 2 + 3 * 4 to 14");
assert!(!bytecode.contains("Mul"), "Should not have Mul instruction");
assert!(
!bytecode.contains(" Add"),
"Should not have Add instruction"
);
let val = run("2 + 3 * 4").unwrap();
assert_int(&val, 14);
}
#[test]
fn test_constant_folding_string_concat() {
let bytecode = disasm("let x = \"hello\" + \" world\"\nprint(x)");
assert!(
bytecode.contains("hello world"),
"Should fold string concatenation"
);
}
#[test]
fn test_constant_folding_boolean() {
let val = run("not true").unwrap();
assert_bool(&val, false);
let val = run("true and false").unwrap();
assert_bool(&val, false);
}
#[test]
fn test_constant_folding_does_not_fold_variables() {
let bytecode = disasm("let x = 5\nlet y = x + 1\nprint(y)");
assert!(
bytecode.contains("Add"),
"Should have Add when variable is involved"
);
}
#[test]
fn test_constant_folding_nested() {
let val = run("(1 + 2) * (3 + 4)").unwrap();
assert_int(&val, 21);
let bytecode = disasm("let x = (1 + 2) * (3 + 4)\nprint(x)");
assert!(
bytecode.contains("21"),
"Should fold nested constants to 21"
);
}
#[test]
fn test_constant_folding_comparison() {
let val = run("10 > 5").unwrap();
assert_bool(&val, true);
let val = run("3 == 3").unwrap();
assert_bool(&val, true);
}
#[test]
fn test_dce_after_return() {
let bytecode = disasm(
r#"
fn foo() {
return 1
print("unreachable")
}
print(foo())
"#,
);
let foo_section = bytecode.split("=== foo ===").nth(1).unwrap_or("");
assert!(
!foo_section.contains("unreachable"),
"Should not compile code after return"
);
let val = run(r#"
fn foo() {
return 1
print("unreachable")
}
foo()
"#)
.unwrap();
assert_int(&val, 1);
}
#[test]
fn test_dce_after_break() {
let val = run(r#"
let mut result = 0
for i in range(10) {
if i == 5 {
break
}
result = result + 1
}
result
"#)
.unwrap();
assert_int(&val, 5);
}
#[test]
fn test_dce_if_both_branches_return() {
let val = run(r#"
fn classify(x: int) -> string {
if x > 0 {
return "positive"
} else {
return "non-positive"
}
}
classify(42)
"#)
.unwrap();
assert!(matches!(val, VmValue::String(ref s) if s.as_ref() == "positive"));
}
#[test]
fn test_complex_struct_with_constant_folding() {
let val = run(r#"
struct Point { x: int, y: int }
let offset = 10 + 20 + 30
let p = Point { x: 1 + 2, y: 3 + 4 }
p.x + p.y + offset
"#)
.unwrap();
assert_int(&val, 70);
}
#[test]
fn test_backward_compat_arithmetic() {
let val = run("1 + 2 * 3").unwrap();
assert_int(&val, 7);
}
#[test]
fn test_backward_compat_strings() {
let val = run(r#"let s = "hello"
len(s)"#)
.unwrap();
assert_int(&val, 5);
}
#[test]
fn test_backward_compat_lists() {
let val = run("let xs = [1, 2, 3]\nlen(xs)").unwrap();
assert_int(&val, 3);
}
#[test]
fn test_backward_compat_functions() {
let val = run("fn add(a, b) { a + b }\nadd(1, 2)").unwrap();
assert_int(&val, 3);
}
#[test]
fn test_backward_compat_if_else() {
let val = run("if true { 1 } else { 2 }").unwrap();
assert_int(&val, 1);
}
#[test]
fn test_backward_compat_for_loop() {
let val = run("let mut sum = 0\nfor i in range(5) { sum = sum + i }\nsum").unwrap();
assert_int(&val, 10);
}
#[test]
fn test_backward_compat_match() {
let val = run("match 42 { 42 => true, _ => false }").unwrap();
assert_bool(&val, true);
}
#[test]
fn test_backward_compat_closures() {
let val = run("let f = (x) => x * 2\nf(21)").unwrap();
assert_int(&val, 42);
}
#[test]
fn test_backward_compat_while_loop() {
let val = run(r#"
let mut i = 0
while i < 10 {
i = i + 1
}
i
"#)
.unwrap();
assert_int(&val, 10);
}
#[test]
fn test_backward_compat_try_catch() {
let val = run(r#"
let mut result = 0
try {
throw "error"
result = 999
} catch e {
result = 42
}
result
"#)
.unwrap();
assert_int(&val, 42);
}