use lex_ast::canonicalize_program;
use lex_bytecode::{compile_program, vm::Vm, Value};
use lex_runtime::{DefaultHandler, Policy};
use lex_syntax::parse_source;
fn run(src: &str, func: &str, args: Vec<Value>) -> Value {
let prog = parse_source(src).expect("parse");
let stages = canonicalize_program(&prog);
if let Err(errs) = lex_types::check_program(&stages) {
panic!("type errors: {errs:#?}");
}
let bc = compile_program(&stages);
let handler = DefaultHandler::new(Policy::permissive());
let mut vm = Vm::with_handler(&bc, Box::new(handler));
vm.call(func, args).expect("vm")
}
fn decimal(coef: i64, exp: i64) -> Value {
use indexmap::IndexMap;
use smol_str::SmolStr;
let mut fields: IndexMap<SmolStr, Value> = IndexMap::new();
fields.insert("coefficient".into(), Value::Int(coef));
fields.insert("exponent".into(), Value::Int(exp));
Value::record_interned(fields)
}
fn i(n: i64) -> Value { Value::Int(n) }
fn b(v: bool) -> Value { Value::Bool(v) }
fn s(v: &str) -> Value { Value::Str(v.into()) }
const PRELUDE: &str = r#"import "std.decimal" as d
"#;
fn src(body: &str) -> String { format!("{PRELUDE}{body}") }
#[test]
fn decimal_constructor() {
let code = src("fn t() -> { coefficient :: Int, exponent :: Int } { d.decimal(12345, -2) }");
assert_eq!(run(&code, "t", vec![]), decimal(12345, -2));
}
#[test]
fn zero_and_one() {
let code = src(r#"
fn z() -> { coefficient :: Int, exponent :: Int } { d.zero() }
fn o() -> { coefficient :: Int, exponent :: Int } { d.one() }
"#);
assert_eq!(run(&code, "z", vec![]), decimal(0, 0));
assert_eq!(run(&code, "o", vec![]), decimal(1, 0));
}
#[test]
fn from_int() {
let code = src("fn t(n :: Int) -> { coefficient :: Int, exponent :: Int } { d.from_int(n) }");
assert_eq!(run(&code, "t", vec![i(42)]), decimal(42, 0));
assert_eq!(run(&code, "t", vec![i(-7)]), decimal(-7, 0));
assert_eq!(run(&code, "t", vec![i(0)]), decimal(0, 0));
}
#[test]
fn pow10() {
let code = src("fn t(n :: Int) -> Int { d.pow10(n) }");
assert_eq!(run(&code, "t", vec![i(0)]), i(1));
assert_eq!(run(&code, "t", vec![i(1)]), i(10));
assert_eq!(run(&code, "t", vec![i(3)]), i(1000));
assert_eq!(run(&code, "t", vec![i(18)]), i(1_000_000_000_000_000_000i64));
}
#[test]
fn add_same_exponent() {
let code = src(r#"
fn t(a :: { coefficient :: Int, exponent :: Int },
b :: { coefficient :: Int, exponent :: Int })
-> { coefficient :: Int, exponent :: Int } { d.add(a, b) }
"#);
let a = decimal(125, -2);
let b_val = decimal(75, -2);
assert_eq!(run(&code, "t", vec![a, b_val]), decimal(200, -2));
}
#[test]
fn add_different_exponents() {
let code = src(r#"
fn t(a :: { coefficient :: Int, exponent :: Int },
b :: { coefficient :: Int, exponent :: Int })
-> { coefficient :: Int, exponent :: Int } { d.add(a, b) }
"#);
let a = decimal(150, -2);
let b_val = decimal(5, -3);
assert_eq!(run(&code, "t", vec![a, b_val]), decimal(1505, -3));
}
#[test]
fn sub_basic() {
let code = src(r#"
fn t(a :: { coefficient :: Int, exponent :: Int },
b :: { coefficient :: Int, exponent :: Int })
-> { coefficient :: Int, exponent :: Int } { d.sub(a, b) }
"#);
let a = decimal(200, -2);
let b_val = decimal(75, -2);
assert_eq!(run(&code, "t", vec![a, b_val]), decimal(125, -2));
}
#[test]
fn mul_basic() {
let code = src(r#"
fn t(a :: { coefficient :: Int, exponent :: Int },
b :: { coefficient :: Int, exponent :: Int })
-> { coefficient :: Int, exponent :: Int } { d.mul(a, b) }
"#);
let a = decimal(125, -2);
let b_val = decimal(5, -4);
assert_eq!(run(&code, "t", vec![a, b_val]), decimal(625, -6));
}
#[test]
fn fee_calculation() {
let code = src(r#"
fn t() -> { coefficient :: Int, exponent :: Int } {
let notional := d.decimal(125000, -2)
let rate := d.decimal(5, -4)
let fee := d.mul(notional, rate)
d.round_to(fee, -2, "HalfUp")
}
"#);
assert_eq!(run(&code, "t", vec![]), decimal(63, -2));
}
#[test]
fn compare_equal() {
let code = src(r#"
fn t(a :: { coefficient :: Int, exponent :: Int },
b :: { coefficient :: Int, exponent :: Int }) -> Int { d.compare(a, b) }
"#);
assert_eq!(run(&code, "t", vec![decimal(100, -2), decimal(1, 0)]), i(0));
}
#[test]
fn compare_lt_gt() {
let code = src(r#"
fn t(a :: { coefficient :: Int, exponent :: Int },
b :: { coefficient :: Int, exponent :: Int }) -> Int { d.compare(a, b) }
"#);
assert_eq!(run(&code, "t", vec![decimal(99, -2), decimal(1, 0)]), i(-1));
assert_eq!(run(&code, "t", vec![decimal(101, -2), decimal(1, 0)]), i(1));
}
#[test]
fn predicates() {
let code = src(r#"
fn pos(d1 :: { coefficient :: Int, exponent :: Int }) -> Bool { d.is_positive(d1) }
fn neg(d1 :: { coefficient :: Int, exponent :: Int }) -> Bool { d.is_negative(d1) }
fn zer(d1 :: { coefficient :: Int, exponent :: Int }) -> Bool { d.is_zero(d1) }
"#);
assert_eq!(run(&code, "pos", vec![decimal(5, -1)]), b(true));
assert_eq!(run(&code, "pos", vec![decimal(-5, -1)]), b(false));
assert_eq!(run(&code, "neg", vec![decimal(-5, -1)]), b(true));
assert_eq!(run(&code, "neg", vec![decimal(5, -1)]), b(false));
assert_eq!(run(&code, "zer", vec![decimal(0, 3)]), b(true));
assert_eq!(run(&code, "zer", vec![decimal(1, 0)]), b(false));
}
#[test]
fn negate_and_abs() {
let code = src(r#"
fn neg(d1 :: { coefficient :: Int, exponent :: Int })
-> { coefficient :: Int, exponent :: Int } { d.negate(d1) }
fn ab(d1 :: { coefficient :: Int, exponent :: Int })
-> { coefficient :: Int, exponent :: Int } { d.abs(d1) }
"#);
assert_eq!(run(&code, "neg", vec![decimal(125, -2)]), decimal(-125, -2));
assert_eq!(run(&code, "neg", vec![decimal(-63, -2)]), decimal(63, -2));
assert_eq!(run(&code, "ab", vec![decimal(-63, -2)]), decimal(63, -2));
assert_eq!(run(&code, "ab", vec![decimal(63, -2)]), decimal(63, -2));
}
#[test]
fn normalize_removes_trailing_zeros() {
let code = src(r#"
fn t(d1 :: { coefficient :: Int, exponent :: Int })
-> { coefficient :: Int, exponent :: Int } { d.normalize(d1) }
"#);
assert_eq!(run(&code, "t", vec![decimal(200, -2)]), decimal(2, 0));
assert_eq!(run(&code, "t", vec![decimal(1500, -3)]), decimal(15, -1));
assert_eq!(run(&code, "t", vec![decimal(0, -5)]), decimal(0, 0));
}
#[test]
fn round_half_up() {
let code = src(r#"
fn t(d1 :: { coefficient :: Int, exponent :: Int }) -> { coefficient :: Int, exponent :: Int } {
d.round_to(d1, -2, "HalfUp")
}
"#);
assert_eq!(run(&code, "t", vec![decimal(625, -3)]), decimal(63, -2));
}
#[test]
fn round_half_down() {
let code = src(r#"
fn t(d1 :: { coefficient :: Int, exponent :: Int }) -> { coefficient :: Int, exponent :: Int } {
d.round_to(d1, -2, "HalfDown")
}
"#);
assert_eq!(run(&code, "t", vec![decimal(625, -3)]), decimal(62, -2));
}
#[test]
fn round_half_even() {
let code = src(r#"
fn t(d1 :: { coefficient :: Int, exponent :: Int }) -> { coefficient :: Int, exponent :: Int } {
d.round_to(d1, -2, "HalfEven")
}
"#);
assert_eq!(run(&code, "t", vec![decimal(625, -3)]), decimal(62, -2));
assert_eq!(run(&code, "t", vec![decimal(635, -3)]), decimal(64, -2));
}
#[test]
fn round_floor_ceiling() {
let code = src(r#"
fn floor_fn(d1 :: { coefficient :: Int, exponent :: Int }) -> { coefficient :: Int, exponent :: Int } {
d.round_to(d1, -2, "Floor")
}
fn ceil_fn(d1 :: { coefficient :: Int, exponent :: Int }) -> { coefficient :: Int, exponent :: Int } {
d.round_to(d1, -2, "Ceiling")
}
"#);
assert_eq!(run(&code, "floor_fn", vec![decimal(627, -3)]), decimal(62, -2));
assert_eq!(run(&code, "ceil_fn", vec![decimal(627, -3)]), decimal(63, -2));
assert_eq!(run(&code, "floor_fn", vec![decimal(-627, -3)]), decimal(-63, -2));
assert_eq!(run(&code, "ceil_fn", vec![decimal(-627, -3)]), decimal(-62, -2));
}
#[test]
fn round_exact_no_rounding() {
let code = src(r#"
fn t(d1 :: { coefficient :: Int, exponent :: Int }) -> { coefficient :: Int, exponent :: Int } {
d.round_to(d1, -2, "HalfUp")
}
"#);
assert_eq!(run(&code, "t", vec![decimal(6, -1)]), decimal(60, -2));
}
#[test]
fn to_str_fractional() {
let code = src(r#"
fn t(d1 :: { coefficient :: Int, exponent :: Int }) -> Str { d.to_str(d1) }
"#);
assert_eq!(run(&code, "t", vec![decimal(12345, -2)]), s("123.45"));
assert_eq!(run(&code, "t", vec![decimal(63, -2)]), s("0.63"));
assert_eq!(run(&code, "t", vec![decimal(-63, -2)]), s("-0.63"));
assert_eq!(run(&code, "t", vec![decimal(0, -2)]), s("0.00"));
}
#[test]
fn to_str_integer() {
let code = src(r#"
fn t(d1 :: { coefficient :: Int, exponent :: Int }) -> Str { d.to_str(d1) }
"#);
assert_eq!(run(&code, "t", vec![decimal(42, 0)]), s("42"));
assert_eq!(run(&code, "t", vec![decimal(7, 2)]), s("700"));
assert_eq!(run(&code, "t", vec![decimal(-5, 0)]), s("-5"));
}