use lex_ast::canonicalize_program;
use lex_bytecode::vm::Vm;
use lex_bytecode::Value;
use lex_runtime::{check_program, DefaultHandler, Policy};
use lex_syntax::parse_source;
fn build(src: &str) -> lex_bytecode::Program {
let prog = parse_source(src).unwrap();
let stages = canonicalize_program(&prog);
let bc = lex_bytecode::compile_program(&stages);
check_program(&bc, &Policy::pure()).expect("program type-checks under pure policy");
bc
}
fn run(bc: &lex_bytecode::Program, entry: &str, args: Vec<Value>) -> Value {
let handler = DefaultHandler::new(Policy::pure());
let mut vm = Vm::with_handler(bc, Box::new(handler));
vm.call(entry, args).unwrap()
}
#[test]
fn par_map_returns_results_in_input_order() {
let src = r#"
import "std.list" as list
fn doubled(xs :: List[Int]) -> List[Int] {
list.par_map(xs, fn(x :: Int) -> Int { x + x })
}
"#;
let bc = build(src);
let xs: Vec<Value> = (0..8).map(Value::Int).collect();
let r = run(&bc, "doubled", vec![Value::List(xs)]);
let expected: Vec<Value> = (0..8).map(|i: i64| Value::Int(i * 2)).collect();
assert_eq!(r, Value::List(expected));
}
#[test]
fn par_map_on_empty_list_yields_empty_list() {
let src = r#"
import "std.list" as list
fn run_(xs :: List[Int]) -> List[Int] {
list.par_map(xs, fn(x :: Int) -> Int { x })
}
"#;
let bc = build(src);
let r = run(&bc, "run_", vec![Value::List(vec![])]);
assert_eq!(r, Value::List(vec![]));
}
const SPIN_SRC: &str = r#"
import "std.list" as list
fn spin(xs :: List[Int]) -> Int {
list.fold(xs, 0, fn(acc :: Int, x :: Int) -> Int { acc + 1 })
}
fn par_spins(buckets :: List[List[Int]]) -> List[Int] {
list.par_map(buckets, fn(b :: List[Int]) -> Int { spin(b) })
}
"#;
fn measure_par_spin(n_workers: usize, items_per_bucket: usize) -> std::time::Duration {
let bc = build(SPIN_SRC);
let bucket: Vec<Value> = (0..items_per_bucket as i64).map(Value::Int).collect();
let buckets: Vec<Value> = (0..n_workers).map(|_| Value::List(bucket.clone())).collect();
let t0 = std::time::Instant::now();
let _ = run(&bc, "par_spins", vec![Value::List(buckets)]);
t0.elapsed()
}
#[test]
#[ignore]
fn par_map_speedup_and_concurrency_cap() {
let cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
if cores < 2 {
eprintln!("skipping: single-core host can't demonstrate parallelism");
return;
}
const ITEMS_PER_BUCKET: usize = 8_000;
let n_tasks = cores.min(4);
std::env::remove_var("LEX_PAR_MAX_CONCURRENCY");
let one = measure_par_spin(1, ITEMS_PER_BUCKET);
let parallel = measure_par_spin(n_tasks, ITEMS_PER_BUCKET);
std::env::set_var("LEX_PAR_MAX_CONCURRENCY", "1");
let capped = measure_par_spin(n_tasks, ITEMS_PER_BUCKET);
std::env::remove_var("LEX_PAR_MAX_CONCURRENCY");
let serial_equiv = one * (n_tasks as u32);
let ceiling = serial_equiv.mul_f64(0.70);
assert!(
parallel < ceiling,
"par_map should beat 70% of serial wall-clock: one={one:?}, \
parallel({n_tasks} tasks)={parallel:?}, ceiling={ceiling:?}"
);
assert!(
capped > parallel.mul_f64(1.4),
"cap=1 must dominate parallel run: parallel={parallel:?}, capped={capped:?}"
);
}
#[test]
fn par_map_results_are_correct_under_concurrency_cap_of_one() {
std::env::set_var("LEX_PAR_MAX_CONCURRENCY", "1");
let src = r#"
import "std.list" as list
fn squared(xs :: List[Int]) -> List[Int] {
list.par_map(xs, fn(x :: Int) -> Int { x * x })
}
"#;
let bc = build(src);
let xs: Vec<Value> = (0..16).map(Value::Int).collect();
let r = run(&bc, "squared", vec![Value::List(xs)]);
std::env::remove_var("LEX_PAR_MAX_CONCURRENCY");
let expected: Vec<Value> = (0..16).map(|i: i64| Value::Int(i * i)).collect();
assert_eq!(r, Value::List(expected));
}
#[test]
fn par_map_distributes_when_n_exceeds_cap() {
std::env::set_var("LEX_PAR_MAX_CONCURRENCY", "4");
let src = r#"
import "std.list" as list
fn run_(xs :: List[Int]) -> List[Int] {
list.par_map(xs, fn(x :: Int) -> Int { x + 1000 })
}
"#;
let bc = build(src);
let xs: Vec<Value> = (0..32).map(Value::Int).collect();
let r = run(&bc, "run_", vec![Value::List(xs)]);
std::env::remove_var("LEX_PAR_MAX_CONCURRENCY");
let expected: Vec<Value> = (0..32).map(|i: i64| Value::Int(i + 1000)).collect();
assert_eq!(r, Value::List(expected));
}
#[test]
fn par_map_effectful_closure_works_with_default_handler() {
use std::sync::{Arc, Mutex};
struct SharedSink(Arc<Mutex<Vec<String>>>);
impl lex_runtime::IoSink for SharedSink {
fn print_line(&mut self, s: &str) {
self.0.lock().unwrap().push(s.into());
}
}
let src = r#"
import "std.list" as list
import "std.io" as io
fn echo_par(xs :: List[Str]) -> [io] List[Unit] {
list.par_map(xs, fn(s :: Str) -> [io] Unit { io.print(s) })
}
"#;
let prog = parse_source(src).unwrap();
let stages = canonicalize_program(&prog);
let bc = lex_bytecode::compile_program(&stages);
let mut policy = Policy::pure();
policy.allow_effects.insert("io".into());
check_program(&bc, &policy).expect("type-checks under io policy");
let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let handler = DefaultHandler::new(policy)
.with_sink(Box::new(SharedSink(Arc::clone(&captured))));
let mut vm = Vm::with_handler(&bc, Box::new(handler));
let r = vm
.call(
"echo_par",
vec![Value::List(vec![
Value::Str("a".into()),
Value::Str("b".into()),
Value::Str("c".into()),
])],
)
.expect("effectful par_map runs under DefaultHandler");
assert_eq!(
r,
Value::List(vec![Value::Unit, Value::Unit, Value::Unit]),
"result list shape: one Unit per input"
);
}
#[test]
fn par_map_workers_share_budget_pool() {
let src = r#"
import "std.list" as list
fn step() -> [budget(10)] Int { 1 }
fn par_steps(xs :: List[Int]) -> List[Int] {
list.par_map(xs, fn(x :: Int) -> Int { step() })
}
"#;
let prog = parse_source(src).unwrap();
let stages = canonicalize_program(&prog);
let bc = lex_bytecode::compile_program(&stages);
let mut policy = Policy::pure();
policy.allow_effects.insert("budget".into());
check_program(&bc, &policy).expect("pure-with-budget policy accepts the program");
policy.budget = Some(25);
let handler = DefaultHandler::new(policy);
let mut vm = Vm::with_handler(&bc, Box::new(handler));
let r = vm.call(
"par_steps",
vec![Value::List(vec![Value::Int(0); 4])],
);
assert!(
r.is_err(),
"shared budget pool must reject the over-ceiling par_map: {r:?}"
);
let msg = format!("{:?}", r.unwrap_err());
assert!(
msg.contains("budget"),
"expected a budget-exceeded error, got: {msg}"
);
}