mod combination;
mod input;
mod spec_builder;
pub use crate::error::DiagramError;
pub use combination::Combination;
pub use input::InputType;
pub use spec_builder::DiagramSpecBuilder;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DiagramSpec {
pub(crate) exclusive_areas: HashMap<Combination, f64>,
pub(crate) input_type: InputType,
pub(crate) set_names: Vec<String>,
}
impl DiagramSpec {
pub fn input_type(&self) -> InputType {
self.input_type
}
pub fn set_names(&self) -> &[String] {
&self.set_names
}
pub fn exclusive_areas(&self) -> &HashMap<Combination, f64> {
&self.exclusive_areas
}
pub fn inclusive_areas(&self) -> HashMap<Combination, f64> {
let mut inclusive: HashMap<Combination, f64> = HashMap::new();
for (combo_super, &area) in self.exclusive_areas.iter() {
if area.abs() < crate::constants::EPSILON {
continue;
}
let sets = combo_super.sets();
let k = sets.len();
for mask in 1u64..(1u64 << k) {
let mut subset_sets: Vec<&str> = Vec::with_capacity(k);
for (i, name) in sets.iter().enumerate() {
if (mask >> i) & 1 == 1 {
subset_sets.push(name.as_str());
}
}
let subset_combo = Combination::new(&subset_sets);
*inclusive.entry(subset_combo).or_insert(0.0) += area;
}
}
inclusive.retain(|_, area| *area > crate::constants::EPSILON);
inclusive
}
pub fn get_exclusive(&self, combination: &Combination) -> Option<f64> {
self.exclusive_areas.get(combination).copied()
}
pub fn get_inclusive(&self, combination: &Combination) -> Option<f64> {
let mut sum = 0.0;
for (combo, &area) in self.exclusive_areas.iter() {
if combo.contains_all(combination) {
sum += area;
}
}
if sum > crate::constants::EPSILON {
Some(sum)
} else {
None
}
}
pub(crate) fn preprocess(&self) -> Result<PreprocessedSpec, DiagramError> {
const EPSILON: f64 = 1e-10;
let mut singleton_inclusive: HashMap<&str, f64> = HashMap::new();
for (combo, &area) in self.exclusive_areas.iter() {
if area.abs() < EPSILON {
continue;
}
for set_name in combo.sets() {
*singleton_inclusive.entry(set_name.as_str()).or_insert(0.0) += area;
}
}
let mut non_empty_sets: Vec<String> = Vec::new();
let mut set_to_idx: HashMap<String, usize> = HashMap::new();
for set_name in self.set_names.iter() {
let inclusive = singleton_inclusive
.get(set_name.as_str())
.copied()
.unwrap_or(0.0);
if inclusive >= EPSILON {
let idx = non_empty_sets.len();
non_empty_sets.push(set_name.clone());
set_to_idx.insert(set_name.clone(), idx);
}
}
let n_sets = non_empty_sets.len();
if n_sets <= 1 {
return Err(DiagramError::InvalidCombination(
"Need at least 2 non-empty sets".to_string(),
));
}
use crate::geometry::diagram;
let mut exclusive_areas_mask = HashMap::new();
for (combo, &area) in self.exclusive_areas.iter() {
if combo.sets().iter().all(|s| set_to_idx.contains_key(s)) {
let mask = diagram::combination_to_mask(combo, &non_empty_sets);
exclusive_areas_mask.insert(mask, area);
}
}
let set_areas: Vec<f64> = non_empty_sets
.iter()
.map(|s| singleton_inclusive.get(s.as_str()).copied().unwrap_or(0.0))
.collect();
let relationships = Self::compute_pairwise_relations(
&non_empty_sets,
&set_to_idx,
&set_areas,
&self.exclusive_areas,
);
Ok(PreprocessedSpec {
set_names: non_empty_sets,
set_to_idx,
exclusive_areas: exclusive_areas_mask,
n_sets,
set_areas,
relationships,
})
}
fn compute_pairwise_relations(
set_names: &[String],
set_to_idx: &HashMap<String, usize>,
set_areas: &[f64],
exclusive_areas: &HashMap<Combination, f64>,
) -> PairwiseRelations {
const EPSILON: f64 = 1e-10;
let n = set_names.len();
let mut subset = vec![vec![false; n]; n];
let mut disjoint = vec![vec![false; n]; n];
let mut overlap_areas = vec![vec![0.0; n]; n];
for (combo, &area) in exclusive_areas.iter() {
if area.abs() < EPSILON {
continue;
}
let mut indices: Vec<usize> = Vec::with_capacity(combo.sets().len());
let mut all_non_empty = true;
for s in combo.sets() {
match set_to_idx.get(s) {
Some(&idx) => indices.push(idx),
None => {
all_non_empty = false;
break;
}
}
}
if !all_non_empty || indices.len() < 2 {
continue;
}
for a in 0..indices.len() {
for b in (a + 1)..indices.len() {
let i = indices[a];
let j = indices[b];
overlap_areas[i][j] += area;
overlap_areas[j][i] += area;
}
}
}
for i in 0..n {
for j in (i + 1)..n {
let overlap = overlap_areas[i][j];
if overlap < EPSILON {
disjoint[i][j] = true;
disjoint[j][i] = true;
}
if (overlap - set_areas[j]).abs() < EPSILON {
subset[i][j] = true;
}
if (overlap - set_areas[i]).abs() < EPSILON {
subset[j][i] = true;
}
}
}
PairwiseRelations {
n_sets: n,
subset,
disjoint,
overlap_areas,
}
}
fn inclusive_to_exclusive_static(
inclusive: &HashMap<Combination, f64>,
) -> Result<HashMap<Combination, f64>, DiagramError> {
let mut exclusive: HashMap<Combination, f64> = HashMap::new();
let mut sorted_combos: Vec<_> = inclusive.keys().collect();
sorted_combos.sort_by_key(|c| std::cmp::Reverse(c.len()));
for combo in sorted_combos {
let inclusive_area = inclusive[combo];
let mut exclusive_area = inclusive_area;
for (other_combo, &other_excl) in exclusive.iter() {
if other_combo != combo && other_combo.contains_all(combo) {
exclusive_area -= other_excl;
}
}
if exclusive_area < -1e-10 {
return Err(DiagramError::InvalidValue {
combination: combo.to_string(),
value: exclusive_area,
});
}
exclusive.insert(combo.clone(), exclusive_area.max(0.0));
}
Ok(exclusive)
}
}
#[derive(Clone)]
pub(crate) struct PreprocessedSpec {
#[allow(dead_code)] pub(crate) set_names: Vec<String>,
pub(crate) set_to_idx: HashMap<String, usize>,
pub(crate) exclusive_areas: HashMap<crate::geometry::diagram::RegionMask, f64>,
pub(crate) n_sets: usize,
pub(crate) set_areas: Vec<f64>,
pub(crate) relationships: PairwiseRelations,
}
#[derive(Clone)]
pub(crate) struct PairwiseRelations {
#[allow(dead_code)]
pub(crate) n_sets: usize,
pub(crate) subset: Vec<Vec<bool>>,
pub(crate) disjoint: Vec<Vec<bool>>,
pub(crate) overlap_areas: Vec<Vec<f64>>,
}
impl PairwiseRelations {
#[allow(dead_code)]
pub(crate) fn is_subset(&self, i: usize, j: usize) -> bool {
self.subset[i][j]
}
#[allow(dead_code)]
pub(crate) fn is_disjoint(&self, i: usize, j: usize) -> bool {
self.disjoint[i][j]
}
#[allow(dead_code)]
pub(crate) fn overlap_area(&self, i: usize, j: usize) -> f64 {
self.overlap_areas[i][j]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_both_representations_available() {
let spec = DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 8.0)
.intersection(&["A", "B"], 2.0)
.input_type(InputType::Inclusive)
.build()
.unwrap();
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(10.0));
assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(8.0));
assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "B"])),
Some(2.0)
);
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(8.0)); assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(6.0)); assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(2.0)
);
}
#[test]
fn test_inclusive_three_set_decomposition() {
let spec = DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 8.0)
.set("C", 6.0)
.intersection(&["A", "B"], 3.0)
.intersection(&["A", "C"], 2.0)
.intersection(&["B", "C"], 1.0)
.intersection(&["A", "B", "C"], 0.0)
.input_type(InputType::Inclusive)
.build()
.unwrap();
let g = |names: &[&str]| {
spec.get_exclusive(&Combination::new(names))
.expect("exclusive area should be defined")
};
assert!((g(&["A"]) - 5.0).abs() < 1e-10);
assert!((g(&["B"]) - 4.0).abs() < 1e-10);
assert!((g(&["C"]) - 3.0).abs() < 1e-10);
assert!((g(&["A", "B"]) - 3.0).abs() < 1e-10);
assert!((g(&["A", "C"]) - 2.0).abs() < 1e-10);
assert!((g(&["B", "C"]) - 1.0).abs() < 1e-10);
}
#[test]
fn test_inclusive_rejects_negative_disjoint_area() {
let result = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 5.0)
.intersection(&["A", "B"], 10.0)
.input_type(InputType::Inclusive)
.build();
assert!(
matches!(result, Err(DiagramError::InvalidValue { .. })),
"expected InvalidValue for negative-disjoint inclusive input, got {:?}",
result
);
}
#[test]
fn test_exclusive_input() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0) .set("B", 2.0) .intersection(&["A", "B"], 1.0) .input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(5.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(2.0));
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(6.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(3.0)); assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
}
#[test]
fn test_get_inclusive_recovers_implicit_subsets() {
let spec = DiagramSpecBuilder::new()
.set("A", 0.0)
.set("B", 5.0)
.intersection(&["A", "B"], 1.0)
.intersection(&["A", "B", "C"], 0.1)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "C"])),
Some(0.1)
);
assert_eq!(
spec.get_inclusive(&Combination::new(&["B", "C"])),
Some(0.1)
);
assert!((spec.get_inclusive(&Combination::new(&["A"])).unwrap() - 1.1).abs() < 1e-10);
assert!((spec.get_inclusive(&Combination::new(&["B"])).unwrap() - 6.1).abs() < 1e-10);
assert!((spec.get_inclusive(&Combination::new(&["C"])).unwrap() - 0.1).abs() < 1e-10);
assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "B", "C", "D"])),
None
);
}
#[test]
fn test_inclusive_areas_matches_get_inclusive() {
let spec = DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 8.0)
.set("C", 6.0)
.intersection(&["A", "B"], 2.0)
.intersection(&["A", "C"], 3.0)
.intersection(&["B", "C"], 1.0)
.intersection(&["A", "B", "C"], 0.5)
.input_type(InputType::Exclusive)
.build()
.unwrap();
let bulk = spec.inclusive_areas();
for (combo, &area) in bulk.iter() {
let got = spec
.get_inclusive(combo)
.expect("get_inclusive should agree with inclusive_areas keys");
assert!(
(got - area).abs() < 1e-10,
"mismatch for {combo}: bulk={area}, get_inclusive={got}"
);
}
}
}