#![allow(dead_code)]
mod puc_51;
mod puc_52;
mod puc_53;
mod puc_54;
mod puc_55;
use crate::runtime::function::Proto;
use crate::runtime::heap::{Gc, Heap};
use crate::vm::isa::{Inst, Op};
pub(super) fn undump_puc(bytes: &[u8], heap: &mut Heap) -> Result<Gc<Proto>, String> {
if bytes.len() < 5 {
return Err("truncated PUC binary chunk".to_string());
}
match bytes[4] {
0x51 => puc_51::undump(bytes, heap),
0x52 => puc_52::undump(bytes, heap),
0x53 => puc_53::undump_puc_53(bytes, heap),
0x54 => puc_54::undump(bytes, heap),
0x55 => puc_55::undump_puc_55(bytes, heap),
v => Err(format!(
"unsupported PUC Lua version byte 0x{v:02x} (expected 0x51..0x55)"
)),
}
}
pub(super) fn lower_k_via_tmp(
op: Op,
a: u32,
k_idx: u32,
c: u32,
c_is_k: bool,
tmp: u32,
max_temp_bump: &mut u8,
) -> Result<[Inst; 2], String> {
if tmp > 0xFF {
return Err(format!(
"lower_k_via_tmp: temp register {tmp} exceeds 255 (op={op:?}, a={a}, k_idx={k_idx})"
));
}
if k_idx > 0x1FFFF {
return Err(format!(
"lower_k_via_tmp: K-pool index {k_idx} exceeds 17-bit Bx field"
));
}
*max_temp_bump = (*max_temp_bump).max(tmp as u8 + 1);
Ok([
Inst::iabx(Op::LoadK, tmp, k_idx),
Inst::iabc(op, a, tmp, c, c_is_k),
])
}
pub(super) fn lower_i_imm(
op: Op,
a: u32,
b: u32,
sc: i32,
tmp: u32,
max_temp_bump: &mut u8,
) -> Result<[Inst; 2], String> {
if tmp > 0xFF {
return Err(format!(
"lower_i_imm: temp register {tmp} exceeds 255 (op={op:?}, a={a}, b={b}, sc={sc})"
));
}
*max_temp_bump = (*max_temp_bump).max(tmp as u8 + 1);
Ok([
Inst::iasbx(Op::LoadI, tmp, sc),
Inst::iabc(op, a, b, tmp, false),
])
}
pub(super) fn scan_tforprep_sites(
words: &[u32],
tforcall_op: u8,
tforloop_op: u8,
jmp_op: u8,
decode_op: impl Fn(u32) -> u8,
decode_a: impl Fn(u32) -> u32,
decode_sbx: impl Fn(u32) -> i32,
) -> std::collections::HashMap<usize, u32> {
let mut out = std::collections::HashMap::new();
for (pc, &w) in words.iter().enumerate() {
if decode_op(w) != jmp_op {
continue;
}
let next_pc = pc as i64 + 1;
let target = next_pc + decode_sbx(w) as i64;
if target < 0 {
continue;
}
let target = target as usize;
if target >= words.len() {
continue;
}
if decode_op(words[target]) != tforcall_op {
continue;
}
if target + 1 >= words.len() || decode_op(words[target + 1]) != tforloop_op {
continue;
}
out.insert(pc, decode_a(words[target]));
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lower_k_via_tmp_emits_pair() {
let mut bump = 0u8;
let pair = lower_k_via_tmp(Op::Add, 5, 7, 3, false, 6, &mut bump).unwrap();
assert_eq!(pair[0].op(), Op::LoadK);
assert_eq!(pair[0].a(), 6);
assert_eq!(pair[0].bx(), 7);
assert_eq!(pair[1].op(), Op::Add);
assert_eq!(pair[1].a(), 5);
assert_eq!(pair[1].b(), 6);
assert_eq!(pair[1].c(), 3);
assert!(!pair[1].k());
assert_eq!(bump, 7);
}
#[test]
fn lower_k_via_tmp_propagates_c_k_flag() {
let mut bump = 0u8;
let pair = lower_k_via_tmp(Op::Sub, 0, 2, 4, true, 1, &mut bump).unwrap();
assert_eq!(pair[1].op(), Op::Sub);
assert!(pair[1].k(), "c_is_k=true must propagate to k bit");
}
#[test]
fn lower_k_via_tmp_rejects_oversized_tmp() {
let mut bump = 0u8;
let err = lower_k_via_tmp(Op::Add, 0, 1, 2, false, 256, &mut bump).unwrap_err();
assert!(err.contains("exceeds 255"), "got: {err}");
assert_eq!(bump, 0, "bump must not change on error");
}
#[test]
fn lower_k_via_tmp_keeps_max_bump_monotonic() {
let mut bump = 10u8;
lower_k_via_tmp(Op::Add, 0, 1, 2, false, 3, &mut bump).unwrap();
assert_eq!(bump, 10, "bump must stay at running max");
}
#[test]
fn lower_i_imm_emits_pair() {
let mut bump = 0u8;
let pair = lower_i_imm(Op::Add, 5, 3, 42, 6, &mut bump).unwrap();
assert_eq!(pair[0].op(), Op::LoadI);
assert_eq!(pair[0].a(), 6);
assert_eq!(pair[0].sbx(), 42);
assert_eq!(pair[1].op(), Op::Add);
assert_eq!(pair[1].a(), 5);
assert_eq!(pair[1].b(), 3);
assert_eq!(pair[1].c(), 6);
assert!(!pair[1].k(), "I-imm arith never sets the k bit");
assert_eq!(bump, 7);
}
#[test]
fn lower_i_imm_handles_negative_imm() {
let mut bump = 0u8;
let pair = lower_i_imm(Op::Shr, 0, 1, -5, 2, &mut bump).unwrap();
assert_eq!(pair[0].sbx(), -5);
assert_eq!(pair[1].op(), Op::Shr);
}
#[test]
fn lower_i_imm_rejects_oversized_tmp() {
let mut bump = 0u8;
let err = lower_i_imm(Op::Add, 0, 1, 0, 256, &mut bump).unwrap_err();
assert!(err.contains("exceeds 255"), "got: {err}");
assert_eq!(bump, 0, "bump must not change on error");
}
#[test]
fn lower_i_imm_keeps_max_bump_monotonic() {
let mut bump = 8u8;
lower_i_imm(Op::Add, 0, 1, 0, 3, &mut bump).unwrap();
assert_eq!(bump, 8, "bump must stay at running max");
}
const TEST_BIAS_SBX: i32 = (1 << 17) - 1;
const TEST_JMP: u8 = 30;
const TEST_TFORCALL: u8 = 41;
const TEST_TFORLOOP: u8 = 42;
fn enc_iabc(op: u8, a: u32) -> u32 {
(op as u32 & 0x3F) | ((a & 0xFF) << 6)
}
fn enc_iasbx(op: u8, a: u32, sbx: i32) -> u32 {
let bx = (sbx + TEST_BIAS_SBX) as u32;
(op as u32 & 0x3F) | ((a & 0xFF) << 6) | ((bx & 0x3FFFF) << 14)
}
fn dec_op(w: u32) -> u8 {
(w & 0x3F) as u8
}
fn dec_a(w: u32) -> u32 {
(w >> 6) & 0xFF
}
fn dec_sbx(w: u32) -> i32 {
((w >> 14) & 0x3FFFF) as i32 - TEST_BIAS_SBX
}
#[test]
fn scan_finds_jmp_then_tforcall_pair() {
let code = vec![
enc_iasbx(TEST_JMP, 0, 2),
enc_iabc(0, 0),
enc_iabc(0, 0),
enc_iabc(TEST_TFORCALL, 5),
enc_iabc(TEST_TFORLOOP, 7),
];
let sites = scan_tforprep_sites(
&code,
TEST_TFORCALL,
TEST_TFORLOOP,
TEST_JMP,
dec_op,
dec_a,
dec_sbx,
);
assert_eq!(sites.len(), 1);
assert_eq!(sites.get(&0), Some(&5));
}
#[test]
fn scan_ignores_jmp_not_targeting_tforcall() {
let code = vec![enc_iasbx(TEST_JMP, 0, 1), enc_iabc(0, 0), enc_iabc(0, 0)];
let sites = scan_tforprep_sites(
&code,
TEST_TFORCALL,
TEST_TFORLOOP,
TEST_JMP,
dec_op,
dec_a,
dec_sbx,
);
assert!(sites.is_empty());
}
}