use crate::*;
use indexmap::{IndexMap, IndexSet};
use leo_ast::{NetworkName, Stub};
use leo_errors::{BufferEmitter, Handler, Result};
use leo_parser::{parse_library, parse_program};
use leo_span::{Symbol, create_session_if_not_set_then, source_map::FileName, with_session_globals};
use serial_test::serial;
use std::rc::Rc;
macro_rules! compiler_passes {
($macro:ident) => {
$macro! {
(common_subexpression_elimination_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(CommonSubexpressionEliminating, ())
]),
(const_prop_unroll_and_morphing_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(ConstPropUnrollAndMorphing, (TypeCheckingInput::new(NetworkName::TestnetV0)))
]),
(destructuring_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(Destructuring, ())
]),
(dead_code_elimination_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(DeadCodeEliminating, ())
]),
(flattening_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(SsaForming, (SsaFormingInput { rename_defs: true })),
(Flattening, ())
]),
(function_inlining_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(FunctionInlining, ())
]),
(option_lowering_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(OptionLowering, (TypeCheckingInput::new(NetworkName::TestnetV0)))
]),
(processing_async_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(ProcessingAsync, (TypeCheckingInput::new(NetworkName::TestnetV0)))
]),
(ssa_forming_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(SsaForming, (SsaFormingInput { rename_defs: true }))
]),
(storage_lowering_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(StorageLowering, (TypeCheckingInput::new(NetworkName::TestnetV0)))
]),
(write_transforming_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(WriteTransforming, ())
]),
(remove_unreachable_runner, [
(RemoveUnreachable, ())
]),
(ssa_const_propagation_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
(SsaForming, (SsaFormingInput { rename_defs: true })),
(SsaConstPropagation, ()),
]),
(disambiguate_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))),
(Disambiguate, ()),
]),
(check_interfaces_runner, [
(GlobalVarsCollection, ()),
(PathResolution, ()),
(GlobalItemsCollection, ()),
(CheckInterfaces, ()),
]),
}
};
}
macro_rules! make_runner {
($runner_name:ident, [$(($pass:ident, $input:expr)),* $(,)?]) => {
fn $runner_name(source: &str) -> String {
let buf = BufferEmitter::new();
let handler = Handler::new(buf.clone());
let node_builder = Rc::new(leo_ast::NodeBuilder::default());
create_session_if_not_set_then(|_| {
let mut state = CompilerState { handler: handler.clone(), node_builder: Rc::clone(&node_builder), ..Default::default() };
state.ast = match handler.extend_if_error(parse_passes_test_source(
source,
&handler,
&node_builder,
NetworkName::TestnetV0,
)) {
Ok(ast) => ast,
Err(()) => return format!("{}{}", buf.extract_errs(), buf.extract_warnings()),
};
$(
if handler.extend_if_error($pass::do_pass($input, &mut state)).is_err() {
return format!("{}{}", buf.extract_errs(), buf.extract_warnings());
}
)*
format!("{}{}", buf.extract_warnings(), state.ast)
})
}
};
}
macro_rules! make_all_runners {
($(($runner:ident, $passes:tt)),* $(,)?) => {
$(
make_runner!($runner, $passes);
)*
};
}
compiler_passes!(make_all_runners);
macro_rules! make_all_tests {
($(($runner:ident, [$(($pass:ident, $input:tt)),* $(,)?])),* $(,)?) => {
$(
paste::paste! {
#[test]
#[serial]
fn [<$runner _test>]() {
make_all_tests_inner!($runner, [$(($pass, $input)),*]);
}
}
)*
};
}
macro_rules! make_all_tests_inner {
($runner:ident, [($pass:ident, $input:tt)]) => {
paste::paste! {
leo_test_framework::run_tests(
concat!("passes/", stringify!([<$pass:snake>])),
$runner,
);
}
};
($runner:ident, [($pass:ident, $input:tt), $(($rest_pass:ident, $rest_input:tt)),+ $(,)?]) => {
make_all_tests_inner!($runner, [$(($rest_pass, $rest_input)),+]);
};
}
compiler_passes!(make_all_tests);
const PASSES_PROGRAM_DELIMITER: &str = "// --- Next Program --- //";
fn parse_passes_test_source(
source: &str,
handler: &Handler,
node_builder: &Rc<leo_ast::NodeBuilder>,
network: NetworkName,
) -> Result<leo_ast::Ast> {
if !source.contains(PASSES_PROGRAM_DELIMITER) {
let sf = with_session_globals(|s| s.source_map.new_source(source, FileName::Custom("test".into())));
return parse_program(handler.clone(), node_builder, &sf, &[], network).map(leo_ast::Ast::Program);
}
let sections: Vec<&str> = source.split(PASSES_PROGRAM_DELIMITER).collect();
let (main_section, dep_sections) = sections.split_last().expect("split always yields at least one element");
let main_source = main_section.trim();
let mut stubs: IndexMap<Symbol, Stub> = IndexMap::new();
for section in dep_sections {
let trimmed = section.trim();
if let Some((lib_name, lib_source)) = extract_passes_library_header(trimmed) {
let (main_lib_src, lib_modules) = split_passes_modules(lib_source);
let sf = with_session_globals(|s| {
s.source_map.new_source(&main_lib_src, FileName::Custom("compiler-test".into()))
});
let module_sfs: Vec<_> = lib_modules
.iter()
.map(|(src, name): &(String, String)| {
with_session_globals(|s| s.source_map.new_source(src, FileName::Custom(name.clone())))
})
.collect();
let library = parse_library(handler.clone(), node_builder, lib_name, &sf, &module_sfs, network)?;
stubs.insert(lib_name, Stub::FromLibrary { library, parents: IndexSet::new() });
}
}
let main_sf = with_session_globals(|s| s.source_map.new_source(main_source, FileName::Custom("test".into())));
let mut program = parse_program(handler.clone(), node_builder, &main_sf, &[], network)?;
let main_symbol = program
.program_scopes
.values()
.next()
.map(|scope| scope.program_id.as_symbol())
.expect("a program must have at least one scope");
let lib_symbols: Vec<Symbol> = stubs.keys().copied().collect();
for (i, sym) in lib_symbols.iter().enumerate() {
if let Some(Stub::FromLibrary { parents, .. }) = stubs.get_mut(sym) {
parents.extend(lib_symbols[i + 1..].iter().copied());
parents.insert(main_symbol);
}
}
program.stubs = stubs;
Ok(leo_ast::Ast::Program(program))
}
fn split_passes_modules(source: &str) -> (String, Vec<(String, String)>) {
const MOD_DELIM: &str = "// --- Next Module:";
if !source.contains(MOD_DELIM) {
return (source.to_string(), Vec::new());
}
let mut main_source = String::new();
let mut modules: Vec<(String, String)> = Vec::new();
let mut current_name: Option<String> = None;
let mut current_src = String::new();
for line in source.lines() {
if let Some(rest) = line.strip_prefix(MOD_DELIM) {
if let Some(name) = current_name.take() {
modules.push((current_src.clone(), name));
current_src.clear();
} else {
main_source = current_src.clone();
current_src.clear();
}
current_name = Some(rest.trim().trim_end_matches(" --- //").to_string());
} else {
current_src.push_str(line);
current_src.push('\n');
}
}
if let Some(name) = current_name {
modules.push((current_src, name));
} else {
main_source = current_src;
}
(main_source, modules)
}
fn extract_passes_library_header(source: &str) -> Option<(Symbol, &str)> {
let mut offset = 0;
for line in source.lines() {
let trimmed = line.trim();
if trimmed.starts_with("// --- library:") && trimmed.ends_with("--- //") {
let name = trimmed.trim_start_matches("// --- library:").trim_end_matches("--- //").trim();
let rest = source[offset + line.len()..].trim_start_matches('\n');
return Some((Symbol::intern(name), rest));
}
offset += line.len() + 1;
}
None
}