use std::collections::HashMap;
use std::path::PathBuf;
use difi::cifi::analyze_observations;
use difi::difi::analyze_linkages;
use difi::io::{read_linkage_members, read_observations};
use difi::metrics::singleton::SingletonMetric;
fn test_data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("python")
.join("difi")
.join("tests")
.join("testdata")
}
fn run_pipeline() -> (
difi::types::AllObjects,
difi::types::AllLinkages,
Vec<difi::partitions::PartitionSummary>,
difi::types::StringInterner,
) {
let obs_path = test_data_dir().join("observations.parquet");
let lm_path = test_data_dir().join("linkage_members.parquet");
let (obs, id_interner, _) = read_observations(&obs_path).unwrap();
let mut id_interner2 = id_interner.clone();
let lm = read_linkage_members(&lm_path, &mut id_interner2).unwrap();
let metric = SingletonMetric {
min_obs: 6,
min_nights: 3,
min_nightly_obs_in_min_nights: 1,
};
let (mut all_objects, _findable, mut summaries) =
analyze_observations(&obs, None, &metric).unwrap();
let (all_linkages, _ignored) =
analyze_linkages(&obs, &lm, &mut all_objects, &mut summaries[0], 6, 20.0).unwrap();
(all_objects, all_linkages, summaries, id_interner2)
}
#[test]
fn test_cifi_counts() {
let (all_objects, _, summaries, _) = run_pipeline();
assert_eq!(all_objects.len(), 5, "5 objects");
assert_eq!(summaries[0].findable, Some(5), "5 findable");
for i in 0..all_objects.len() {
assert_eq!(all_objects.num_obs[i], 30);
assert_eq!(all_objects.findable[i], Some(true));
}
}
#[test]
fn test_difi_summary_counts() {
let (_, all_linkages, summaries, _) = run_pipeline();
assert_eq!(all_linkages.len(), 20);
let n_pure: usize = all_linkages.pure.iter().filter(|&&p| p).count();
let n_pure_complete: usize = all_linkages.pure_complete.iter().filter(|&&p| p).count();
let n_contaminated: usize = all_linkages.contaminated.iter().filter(|&&c| c).count();
let n_mixed: usize = all_linkages.mixed.iter().filter(|&&m| m).count();
let n_found_pure: usize = all_linkages.found_pure.iter().filter(|&&f| f).count();
let n_found_contaminated: usize = all_linkages
.found_contaminated
.iter()
.filter(|&&f| f)
.count();
assert_eq!(n_pure, 10, "10 pure linkages");
assert_eq!(n_pure_complete, 5, "5 pure complete linkages");
assert_eq!(n_contaminated, 2, "2 contaminated linkages");
assert_eq!(n_mixed, 8, "8 mixed linkages");
assert_eq!(n_found_pure, 10, "10 found pure");
assert_eq!(n_found_contaminated, 2, "2 found contaminated");
assert_eq!(summaries[0].found, Some(5));
assert!((summaries[0].completeness.unwrap() - 100.0).abs() < 0.01);
assert_eq!(summaries[0].pure_known, Some(10));
assert_eq!(summaries[0].pure_unknown, Some(0));
assert_eq!(summaries[0].contaminated, Some(2));
assert_eq!(summaries[0].mixed, Some(8));
}
#[test]
fn test_difi_per_linkage() {
let (_, all_linkages, _, interner) = run_pipeline();
#[allow(clippy::type_complexity)]
let expected: Vec<(&str, Option<&str>, i64, f64, &str, bool)> = vec![
("linkage_pure_00000", Some("00000"), 30, 0.0, "pure", true),
("linkage_pure_00001", Some("00001"), 30, 0.0, "pure", true),
("linkage_pure_00002", Some("00002"), 30, 0.0, "pure", true),
("linkage_pure_00003", Some("00003"), 30, 0.0, "pure", true),
("linkage_pure_00004", Some("00004"), 30, 0.0, "pure", true),
(
"linkage_pure_incomplete_00000",
Some("00000"),
6,
0.0,
"pure",
false,
),
(
"linkage_pure_incomplete_00001",
Some("00001"),
10,
0.0,
"pure",
false,
),
(
"linkage_pure_incomplete_00002",
Some("00002"),
7,
0.0,
"pure",
false,
),
(
"linkage_pure_incomplete_00003",
Some("00003"),
8,
0.0,
"pure",
false,
),
(
"linkage_pure_incomplete_00004",
Some("00004"),
10,
0.0,
"pure",
false,
),
(
"linkage_partial_00000",
Some("00000"),
12,
8.3,
"contaminated",
false,
),
(
"linkage_partial_00001",
Some("00001"),
12,
16.7,
"contaminated",
false,
),
("linkage_partial_00002", None, 12, 25.0, "mixed", false),
("linkage_partial_00003", None, 12, 41.7, "mixed", false),
("linkage_partial_00004", None, 12, 50.0, "mixed", false),
("linkage_mixed_00000", None, 9, 66.7, "mixed", false),
("linkage_mixed_00001", None, 7, 57.1, "mixed", false),
("linkage_mixed_00002", None, 9, 66.7, "mixed", false),
("linkage_mixed_00003", None, 9, 66.7, "mixed", false),
("linkage_mixed_00004", None, 8, 62.5, "mixed", false),
];
let mut rust_linkages: HashMap<String, usize> = HashMap::new();
for i in 0..all_linkages.len() {
let name = interner
.resolve(all_linkages.linkage_id[i])
.unwrap()
.to_string();
rust_linkages.insert(name, i);
}
for (name, exp_linked, exp_nobs, exp_contam, exp_type, exp_pure_complete) in &expected {
let i = *rust_linkages
.get(*name)
.unwrap_or_else(|| panic!("Missing linkage: {name}"));
assert_eq!(
all_linkages.num_obs[i], *exp_nobs,
"{name}: num_obs mismatch"
);
let actual_type = if all_linkages.pure[i] {
"pure"
} else if all_linkages.contaminated[i] {
"contaminated"
} else {
"mixed"
};
assert_eq!(actual_type, *exp_type, "{name}: type mismatch");
assert_eq!(
all_linkages.pure_complete[i], *exp_pure_complete,
"{name}: pure_complete mismatch"
);
assert!(
(all_linkages.contamination[i] - exp_contam).abs() < 0.15,
"{name}: contamination {:.1} != {:.1}",
all_linkages.contamination[i],
exp_contam
);
let actual_linked = interner
.resolve(all_linkages.linked_object_id[i])
.map(|s| s.to_string());
let exp_linked_str = exp_linked.map(|s| s.to_string());
assert_eq!(
actual_linked, exp_linked_str,
"{name}: linked_object_id mismatch"
);
}
}
#[test]
fn test_difi_per_object() {
let (all_objects, _, _, interner) = run_pipeline();
#[allow(clippy::type_complexity)]
let expected: Vec<(&str, i64, i64, i64, i64, i64, i64, i64)> = vec![
("00000", 2, 1, 2, 1, 1, 0, 6),
("00001", 2, 1, 2, 1, 1, 0, 5),
("00002", 2, 0, 2, 1, 0, 0, 4),
("00003", 2, 0, 2, 1, 0, 2, 7),
("00004", 2, 0, 2, 1, 0, 1, 3),
];
let mut rust_objects: HashMap<String, usize> = HashMap::new();
for i in 0..all_objects.len() {
let name = interner
.resolve(all_objects.object_id[i])
.unwrap()
.to_string();
rust_objects.insert(name, i);
}
for (oid, fp, fc, p, pc, c, ct, m) in &expected {
let i = *rust_objects
.get(*oid)
.unwrap_or_else(|| panic!("Missing object: {oid}"));
assert_eq!(all_objects.found_pure[i], *fp, "{oid}: found_pure");
assert_eq!(
all_objects.found_contaminated[i], *fc,
"{oid}: found_contaminated"
);
assert_eq!(all_objects.pure[i], *p, "{oid}: pure");
assert_eq!(all_objects.pure_complete[i], *pc, "{oid}: pure_complete");
assert_eq!(all_objects.contaminated[i], *c, "{oid}: contaminated");
assert_eq!(all_objects.contaminant[i], *ct, "{oid}: contaminant");
assert_eq!(all_objects.mixed[i], *m, "{oid}: mixed");
}
}