use std::{collections::HashMap, ffi::OsStr, fs, path::Path};
use csv::Reader;
use assembly_theory::{
assembly::{index_search, ParallelMode},
bounds::Bound,
canonize::CanonizeMode,
kernels::KernelMode,
loader::parse_molfile_str,
memoize::MemoizeMode,
};
fn load_ma_index(dataset: &str) -> HashMap<String, u32> {
let ma_index_path = Path::new("data").join(dataset).join("ma-index.csv");
let mut reader =
Reader::from_path(ma_index_path).expect(&format!("{dataset}/ma-index.csv does not exist."));
let mut ma_index: HashMap<String, u32> = HashMap::new();
for result in reader.records() {
let record = result.expect("ma-index.csv is malformed.");
let record = record.iter().collect::<Vec<_>>();
ma_index.insert(
record[0].to_string(),
record[1]
.to_string()
.parse::<u32>()
.expect("non-integer index"),
);
}
ma_index
}
fn test_reference_dataset(
dataset: &str,
canonize_mode: CanonizeMode,
parallel_mode: ParallelMode,
memoize_mode: MemoizeMode,
kernel_mode: KernelMode,
bounds: &[Bound],
) {
let ma_index = load_ma_index(dataset);
let mut incorrect_mols: Vec<(String, u32, u32)> = Vec::new();
let mut paths: Vec<_> = fs::read_dir(Path::new("data").join(dataset))
.unwrap()
.filter_map(|r| r.ok())
.collect();
paths.sort_by_key(|p| p.path());
for path in paths {
let name = path.path();
if name.extension().and_then(OsStr::to_str) != Some("mol") {
continue;
}
let mol = parse_molfile_str(
&fs::read_to_string(name.clone()).expect(&format!("Could not read file {name:?}")),
)
.expect(&format!("Failed to parse {name:?}"));
let (index, _, _) = index_search(
&mol,
None,
canonize_mode,
parallel_mode,
memoize_mode,
kernel_mode,
bounds,
);
let molname = name.file_name().unwrap().to_str().unwrap().to_string();
let true_index = ma_index[&molname];
if index != true_index {
incorrect_mols.push((molname, index, true_index));
}
}
let mut error_details = String::new();
for (molname, index, true_index) in &incorrect_mols {
error_details.push_str(&format!(
"{molname}: assembly index {index} (assembly-theory) != {true_index} (ground truth)\n"
));
}
assert!(incorrect_mols.is_empty(), "{}", error_details);
}
#[test]
fn canonization() {
for canonize_mode in [CanonizeMode::Nauty, CanonizeMode::TreeNauty] {
test_reference_dataset(
"gdb13_1201",
canonize_mode,
ParallelMode::None,
MemoizeMode::None,
KernelMode::None,
&[],
);
}
}
#[test]
fn memoization() {
for (memoize_mode, canonize_mode) in [
(MemoizeMode::None, CanonizeMode::TreeNauty),
(MemoizeMode::CanonIndex, CanonizeMode::Nauty),
(MemoizeMode::CanonIndex, CanonizeMode::TreeNauty),
] {
test_reference_dataset(
"gdb13_1201",
canonize_mode,
ParallelMode::DepthOne,
memoize_mode,
KernelMode::None,
&[],
);
}
}
#[test]
fn individual_bounds() {
for bound in [
Bound::Log,
Bound::Int,
Bound::VecSimple,
Bound::VecSmallFrags,
Bound::MatchableEdges,
] {
test_reference_dataset(
"checks",
CanonizeMode::TreeNauty,
ParallelMode::DepthOne,
MemoizeMode::None,
KernelMode::None,
&[bound],
);
}
}
#[test]
fn all_bounds() {
test_reference_dataset(
"coconut_55",
CanonizeMode::TreeNauty,
ParallelMode::DepthOne,
MemoizeMode::None,
KernelMode::None,
&[
Bound::Log,
Bound::Int,
Bound::VecSimple,
Bound::VecSmallFrags,
Bound::MatchableEdges,
],
);
}
#[test]
fn memoize_bounds() {
test_reference_dataset(
"coconut_55",
CanonizeMode::TreeNauty,
ParallelMode::DepthOne,
MemoizeMode::CanonIndex,
KernelMode::None,
&[
Bound::Log,
Bound::Int,
Bound::VecSimple,
Bound::VecSmallFrags,
Bound::MatchableEdges,
],
);
}