use anyhow::Result;
use float_cmp::approx_eq;
use itertools::Itertools;
use std::env;
use std::fs::{File, read_dir};
use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf};
use tempfile::{TempDir, tempdir};
mod common;
use common::*;
define_regression_test_with_debug_files!(simple);
define_regression_test!(missing_commodity);
define_regression_test!(muse1_default);
define_regression_test!(two_outputs);
define_regression_test!(circularity);
#[cfg(target_arch = "x86_64")]
define_regression_test!(two_regions);
define_regression_test_with_patches!(simple_divisible);
define_regression_test_with_patches!(simple_npv);
define_regression_test_with_patches!(simple_marginal);
define_regression_test_with_patches!(simple_marginal_average);
define_regression_test_with_patches!(simple_full);
define_regression_test_with_patches!(simple_full_average);
define_regression_test_with_patches!(simple_ironing_out);
const FLOAT_CMP_TOLERANCE: f64 = 1e-10;
fn run_regression_test(example: &str, extra_args: &[&str]) {
let tmp: TempDir;
let output_dir = if let Ok(dir) = env::var("MUSE2_TEST_OUTPUT_DIR") {
[&dir, example].iter().collect()
} else {
tmp = tempdir().unwrap();
tmp.path().to_path_buf()
};
let output_dir_str = output_dir.to_string_lossy();
let mut args = vec!["example", "run", example, "--output-dir", &output_dir_str];
args.extend(extra_args);
assert_muse2_runs(&args);
let test_data_dir = PathBuf::from(format!("tests/data/{example}"));
compare_output_dirs(
&output_dir,
&test_data_dir,
extra_args.contains(&"--debug-model"),
);
}
fn compare_output_dirs(cur_output_dir1: &Path, test_data_dir: &Path, debug_model: bool) {
let mut file_names1 = get_csv_file_names(cur_output_dir1);
if !debug_model {
file_names1.retain(|name| !name.starts_with("debug_"));
}
let file_names2 = get_csv_file_names(test_data_dir);
assert!(file_names1 == file_names2);
let mut errors = Vec::new();
for file_name in file_names1 {
compare_lines(cur_output_dir1, test_data_dir, &file_name, &mut errors);
}
assert!(
errors.is_empty(),
"The following errors occurred:\n * {}",
errors.join("\n * ")
);
}
fn compare_lines(
output_dir1: &Path,
output_dir2: &Path,
file_name: &str,
errors: &mut Vec<String>,
) {
let lines1 = read_lines(&output_dir1.join(file_name));
let lines2 = read_lines(&output_dir2.join(file_name));
if lines1.len() != lines2.len() {
errors.push(format!(
"{file_name}: Different number of lines: {} vs {}",
lines1.len(),
lines2.len()
));
}
for (idx, (line1, line2)) in lines1.into_iter().zip(lines2).enumerate() {
let line_num = idx + 1; if !compare_line(line_num, &line1, &line2, file_name, errors) {
errors.push(format!(
"{file_name}: line {line_num}:\n + \"{line1}\"\n - \"{line2}\""
));
}
}
}
fn compare_line(
num: usize,
line1: &str,
line2: &str,
file_name: &str,
errors: &mut Vec<String>,
) -> bool {
let fields1 = line1.split(',').collect_vec();
let fields2 = line2.split(',').collect_vec();
if fields1.len() != fields2.len() {
errors.push(format!(
"{}: line {}: Different number of fields: {} vs {}",
file_name,
num,
fields1.len(),
fields2.len()
));
}
fields1.into_iter().zip(fields2).all(|(f1, f2)| {
try_compare_floats(f1, f2).unwrap_or_else(|| f1 == f2)
})
}
fn parse_finite(s: &str) -> Option<f64> {
s.parse().ok().filter(|f: &f64| f.is_finite())
}
fn try_compare_floats(s1: &str, s2: &str) -> Option<bool> {
let float1 = parse_finite(s1)?;
let float2 = parse_finite(s2)?;
Some(approx_eq!(
f64,
float1,
float2,
epsilon = FLOAT_CMP_TOLERANCE
))
}
fn get_csv_file_names(dir_path: &Path) -> Vec<String> {
let entries = read_dir(dir_path).unwrap();
let mut file_names = Vec::new();
for entry in entries {
let file_name = entry.unwrap().file_name();
let file_name = file_name.to_str().unwrap();
if Path::new(file_name)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("csv"))
{
file_names.push(file_name.to_string());
}
}
file_names.sort();
file_names
}
fn read_lines(path: &Path) -> Vec<String> {
let file1 = File::open(path).unwrap();
BufReader::new(file1)
.lines()
.map_while(Result::ok)
.collect()
}