use crate::convert::RewriteResult;
use crate::frontend::DefaultParserDispatch;
use crate::frontend::Parser;
use crate::frontend::ParserDispatch;
use crate::init_subscriber;
use crate::ir::Op;
use crate::shared::Shared;
use crate::shared::SharedExt;
use crate::transform;
use crate::DefaultTransformDispatch;
use crate::Passes;
use crate::TransformDispatch;
use crate::TransformOptions;
use anyhow::Result;
use std::cmp::max;
use std::marker::PhantomData;
use std::panic::Location;
use tracing::info;
use wasmtime::Engine;
use wasmtime::Instance;
use wasmtime::Module;
use wasmtime::Store;
pub struct Tester<P: ParserDispatch, T: TransformDispatch> {
_marker: PhantomData<P>,
_marker2: PhantomData<T>,
}
pub type DefaultTester = Tester<DefaultParserDispatch, DefaultTransformDispatch>;
impl<P: ParserDispatch, T: TransformDispatch> Tester<P, T> {
pub fn init_tracing() {
let level = tracing::Level::INFO;
match init_subscriber(level) {
Ok(_) => (),
Err(_e) => (),
}
}
fn point_to_missing_line(expected: &str, index: usize) -> String {
let mut result = String::new();
result.push_str("A line is missing from the output:\n");
result.push_str("```");
for (i, line) in expected.lines().enumerate() {
if i == index {
let msg = format!("{line} <== missing");
result.push_str(&format!("\n{msg}"));
} else {
result.push_str(&format!("\n{line}"));
}
}
result.push_str("\n```");
result
}
pub fn check_lines_exact(actual: &str, expected: &str, caller: &Location<'_>) {
let actual = actual.trim();
let expected = expected.trim();
let l = max(actual.lines().count(), expected.lines().count());
for i in 0..l {
let actual_line = match actual.lines().nth(i) {
None => {
panic!(
"Expected line {i} not found in output: called from {}",
caller
);
}
Some(actual_line) => actual_line,
};
let expected_line = match expected.lines().nth(i) {
None => {
panic!(
"Expected line {i} not found in output: called from {}",
caller
);
}
Some(expected_line) => expected_line,
};
assert_eq!(actual_line, expected_line, "called from {}", caller);
}
}
pub fn check_lines_contain(actual: &str, expected: &str, caller: &Location<'_>) {
let actual = actual.trim();
let expected = expected.trim();
let mut actual_index = 0;
'outer: for i in 0..expected.lines().count() {
let expected_line = expected.lines().nth(i).unwrap().trim();
if expected_line.is_empty() {
continue;
}
let start = actual_index;
for j in start..actual.lines().count() {
let actual_line = actual.lines().nth(j).unwrap();
if actual_line.trim() == expected_line.trim() {
actual_index = j + 1;
continue 'outer;
}
}
let msg = Self::point_to_missing_line(expected, i);
panic!("{msg}\nwhen called from {caller}");
}
}
fn print_heading(msg: &str, src: &str) {
info!("{msg}:\n```\n{src}\n```\n");
}
pub fn preprocess(src: &str) -> String {
Self::print_heading("Before preprocessing", src.trim());
let actual = P::preprocess(src);
Self::print_heading("After preprocessing", &actual);
actual
}
pub fn parse(src: &str) -> (Shared<dyn Op>, String) {
Self::print_heading("Before parse", src.trim());
let module = Parser::<P>::parse(&src).unwrap();
let actual = format!("{}", module.rd());
Self::print_heading("After parse", &actual);
(module, actual)
}
pub fn transform(arguments: Vec<&str>, src: &str) -> (Shared<dyn Op>, String) {
let src = src.trim();
let module = Parser::<P>::parse(src).unwrap();
let should_print = !arguments.contains(&"--tester-no-print");
if should_print {
let msg = format!("Before (transform {arguments:?})");
Self::print_heading(&msg, &src);
}
for arg in arguments.clone() {
if arg.starts_with("convert-") {
panic!("conversion passes should be prefixed with `--convert-`");
}
}
let passes = Passes::from_convert_vec(arguments.clone());
let options = TransformOptions::from_passes(passes);
let result = transform::<T>(module.clone(), &options).unwrap();
let new_root_op = match result {
RewriteResult::Changed(changed_op) => changed_op.op,
RewriteResult::Unchanged => {
panic!("Expected changes");
}
};
let actual = format!("{}", new_root_op.rd());
if should_print {
let msg = format!("After (transform {arguments:?})");
Self::print_heading(&msg, &actual);
}
(new_root_op, actual)
}
fn verify_core(op: Shared<dyn Op>) {
let op = op.rd();
if !op.name().to_string().contains("module") {
let parent = op.operation().rd().parent();
assert!(
op.operation().rd().parent().is_some(),
"Parent was not set for:\n{}",
op
);
let parent = parent.unwrap();
let operation = op.operation();
let operation = operation.rd();
assert!(parent.rd().index_of(&operation).is_some(),
"Could not find the following op in parent. Is the parent field pointing to the wrong block?\n{}",
op);
for result in operation.results().vec().rd().iter() {
let value = result.rd();
assert!(value.typ().is_ok(), "type was not set for {value}");
}
}
}
pub fn verify(op: Shared<dyn Op>) {
Self::verify_core(op.clone());
let ops = op.rd().ops();
for op in ops {
Self::verify(op);
}
}
pub fn load_wat(wat: &str) -> Result<(Store<()>, Instance)> {
let engine = Engine::default();
let module = Module::new(&engine, wat).unwrap();
let mut store = Store::new(&engine, ());
let instance = Instance::new(&mut store, &module, &[]).unwrap();
Ok((store, instance))
}
}