#![cfg(feature = "wasm")]
use metadol::parse_file;
use metadol::wasm::{WasmCompiler, WasmError, WasmRuntime};
use std::path::Path;
use std::process::Command;
pub type WasmTestResult<T> = Result<T, WasmError>;
pub fn compile_dol_to_wasm(source: &str) -> Option<Vec<u8>> {
let module = parse_file(source).ok()?;
let mut compiler = WasmCompiler::new();
compiler.compile(&module).ok()
}
pub fn compile_dol_to_wasm_verbose(source: &str) -> Result<Vec<u8>, String> {
let module = parse_file(source).map_err(|e| format!("Parse error: {:?}", e))?;
let mut compiler = WasmCompiler::new();
compiler
.compile(&module)
.map_err(|e| format!("Compile error: {}", e.message))
}
pub fn compile_dol_optimized(source: &str) -> Option<Vec<u8>> {
let module = parse_file(source).ok()?;
let mut compiler = WasmCompiler::new().with_optimization(true);
compiler.compile(&module).ok()
}
pub fn compile_dol_file(path: &Path) -> Option<Vec<u8>> {
let source = std::fs::read_to_string(path).ok()?;
compile_dol_to_wasm(&source)
}
pub fn validate_wasm_bytes(wasm_bytes: &[u8]) -> bool {
if wasm_bytes.len() < 8 {
return false;
}
let magic = &wasm_bytes[0..4];
if magic != [0x00, 0x61, 0x73, 0x6D] {
return false;
}
let version = &wasm_bytes[4..8];
if version != [0x01, 0x00, 0x00, 0x00] {
return false;
}
true
}
pub fn validate_wasm_with_tools(wasm_bytes: &[u8]) -> Result<(), String> {
if !validate_wasm_bytes(wasm_bytes) {
return Err("Basic WASM validation failed (magic number or version)".to_string());
}
let temp_file = std::env::temp_dir().join("test_module.wasm");
if std::fs::write(&temp_file, wasm_bytes).is_err() {
return Err("Failed to write temp file for validation".to_string());
}
let result = Command::new("wasm-tools")
.args(["validate", temp_file.to_str().unwrap()])
.output();
let _ = std::fs::remove_file(&temp_file);
match result {
Ok(output) if output.status.success() => Ok(()),
Ok(output) => {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(format!("wasm-tools validation failed: {}", stderr))
}
Err(_) => {
Ok(())
}
}
}
pub fn execute_wasm_function(
wasm_bytes: &[u8],
function_name: &str,
args: &[i64],
) -> WasmTestResult<i64> {
let runtime = WasmRuntime::new()?;
let mut module = runtime.load(wasm_bytes)?;
let wasm_args: Vec<wasmtime::Val> = args.iter().map(|&v| wasmtime::Val::I64(v)).collect();
let result = module.call(function_name, &wasm_args)?;
match result.first() {
Some(val) => match val.i64() {
Some(v) => Ok(v),
None => Err(WasmError::new("Function did not return i64")),
},
None => Err(WasmError::new("Function returned no values")),
}
}
pub fn load_wasm_module(wasm_bytes: &[u8]) -> WasmTestResult<metadol::wasm::WasmModule> {
let runtime = WasmRuntime::new()?;
runtime.load(wasm_bytes)
}
pub fn read_test_case(level: &str, filename: &str) -> Option<String> {
let path = format!(
"{}/test-cases/{}/{}",
env!("CARGO_MANIFEST_DIR"),
level,
filename
);
std::fs::read_to_string(&path).ok()
}
pub fn get_test_cases(level: &str) -> Vec<(String, String)> {
let dir_path = format!("{}/test-cases/{}", env!("CARGO_MANIFEST_DIR"), level);
let mut results = Vec::new();
if let Ok(entries) = std::fs::read_dir(&dir_path) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map_or(false, |ext| ext == "dol") {
if let Ok(contents) = std::fs::read_to_string(&path) {
let filename = path.file_name().unwrap().to_string_lossy().to_string();
results.push((filename, contents));
}
}
}
}
results
}
pub fn get_working_test_cases() -> Vec<(String, String)> {
get_test_cases("working")
}
#[allow(dead_code)]
pub fn assert_compiles(source: &str, context: &str) {
match compile_dol_to_wasm_verbose(source) {
Ok(wasm) => {
assert!(
validate_wasm_bytes(&wasm),
"{}: Compiled but produced invalid WASM",
context
);
}
Err(e) => {
panic!("{}: Compilation failed - {}", context, e);
}
}
}
#[allow(dead_code)]
pub fn assert_compile_fails(source: &str, context: &str) {
if compile_dol_to_wasm(source).is_some() {
panic!("{}: Expected compilation to fail but it succeeded", context);
}
}
#[allow(dead_code)]
pub fn assert_wasm_result(
wasm_bytes: &[u8],
function_name: &str,
args: &[i64],
expected: i64,
context: &str,
) {
match execute_wasm_function(wasm_bytes, function_name, args) {
Ok(result) => {
assert_eq!(
result, expected,
"{}: Expected {} but got {}",
context, expected, result
);
}
Err(e) => {
panic!("{}: Execution failed - {}", context, e.message);
}
}
}
#[derive(Debug, Clone)]
pub struct ExportedFunction {
pub name: String,
pub param_count: usize,
pub result_count: usize,
}
pub fn inspect_wasm_exports(wasm_bytes: &[u8]) -> Vec<String> {
let engine = wasmtime::Engine::default();
let module = match wasmtime::Module::from_binary(&engine, wasm_bytes) {
Ok(m) => m,
Err(_) => return vec![],
};
module
.exports()
.filter_map(|export| {
if export.ty().func().is_some() {
Some(export.name().to_string())
} else {
None
}
})
.collect()
}
pub fn gen_add_function() -> &'static str {
r#"fun add(a: i64, b: i64) -> i64 {
return a + b
}"#
}
pub fn gen_function_with_locals() -> &'static str {
r#"fun compute(x: i64) -> i64 {
let doubled: i64 = x + x
let result: i64 = doubled + 1
return result
}"#
}
pub fn gen_function_with_if() -> &'static str {
r#"fun max(a: i64, b: i64) -> i64 {
if a > b {
return a
} else {
return b
}
}"#
}
pub fn gen_function_with_while() -> &'static str {
r#"fun countdown(n: i64) -> i64 {
let count: i64 = n
while count > 0 {
count = count - 1
}
return count
}"#
}
pub fn gen_function_with_for() -> &'static str {
r#"fun sum_to(n: i64) -> i64 {
let total: i64 = 0
for i in 0..n {
total = total + i
}
return total
}"#
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_wasm_bytes_valid() {
let valid = vec![0x00, 0x61, 0x73, 0x6D, 0x01, 0x00, 0x00, 0x00];
assert!(validate_wasm_bytes(&valid));
}
#[test]
fn test_validate_wasm_bytes_invalid_magic() {
let invalid = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x00, 0x00, 0x00];
assert!(!validate_wasm_bytes(&invalid));
}
#[test]
fn test_validate_wasm_bytes_invalid_version() {
let invalid = vec![0x00, 0x61, 0x73, 0x6D, 0xFF, 0xFF, 0xFF, 0xFF];
assert!(!validate_wasm_bytes(&invalid));
}
#[test]
fn test_validate_wasm_bytes_too_short() {
let short = vec![0x00, 0x61, 0x73, 0x6D];
assert!(!validate_wasm_bytes(&short));
}
#[test]
fn test_validate_wasm_bytes_empty() {
let empty: Vec<u8> = vec![];
assert!(!validate_wasm_bytes(&empty));
}
#[test]
fn test_gen_add_function_is_valid() {
let source = gen_add_function();
assert!(source.contains("fun add"));
assert!(source.contains("return"));
}
}