mod combination;
mod input;
mod spec_builder;
pub use crate::error::DiagramError;
pub use combination::Combination;
pub use input::InputType;
pub use spec_builder::DiagramSpecBuilder;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DiagramSpec {
pub(crate) exclusive_areas: HashMap<Combination, f64>,
pub(crate) inclusive_areas: HashMap<Combination, f64>,
pub(crate) input_type: InputType,
pub(crate) set_names: Vec<String>,
}
impl DiagramSpec {
pub fn input_type(&self) -> InputType {
self.input_type
}
pub fn set_names(&self) -> &[String] {
&self.set_names
}
pub fn exclusive_areas(&self) -> &HashMap<Combination, f64> {
&self.exclusive_areas
}
pub fn inclusive_areas(&self) -> &HashMap<Combination, f64> {
&self.inclusive_areas
}
pub fn get_exclusive(&self, combination: &Combination) -> Option<f64> {
self.exclusive_areas.get(combination).copied()
}
pub fn get_inclusive(&self, combination: &Combination) -> Option<f64> {
self.inclusive_areas.get(combination).copied()
}
pub(crate) fn preprocess(&self) -> Result<PreprocessedSpec, DiagramError> {
const EPSILON: f64 = 1e-10;
let mut non_empty_sets = Vec::new();
let mut set_to_idx = HashMap::new();
for set_name in self.set_names.iter() {
let combo = Combination::new(&[set_name]);
if let Some(&area) = self.inclusive_areas.get(&combo) {
if area >= EPSILON {
let idx = non_empty_sets.len();
non_empty_sets.push(set_name.clone());
set_to_idx.insert(set_name.clone(), idx);
}
}
}
let n_sets = non_empty_sets.len();
if n_sets <= 1 {
return Err(DiagramError::InvalidCombination(
"Need at least 2 non-empty sets".to_string(),
));
}
let mut filtered_exclusive = HashMap::new();
let mut filtered_inclusive = HashMap::new();
for (combo, &area) in self.exclusive_areas.iter() {
let all_non_empty = combo.sets().iter().all(|s| set_to_idx.contains_key(s));
if all_non_empty {
filtered_exclusive.insert(combo.clone(), area);
if let Some(&inclusive_area) = self.inclusive_areas.get(combo) {
filtered_inclusive.insert(combo.clone(), inclusive_area);
}
}
}
for (combo, &inclusive_area) in self.inclusive_areas.iter() {
let all_non_empty = combo.sets().iter().all(|s| set_to_idx.contains_key(s));
if all_non_empty && inclusive_area > 1e-10 && !filtered_inclusive.contains_key(combo) {
filtered_inclusive.insert(combo.clone(), inclusive_area);
if !filtered_exclusive.contains_key(combo) {
filtered_exclusive.insert(combo.clone(), 0.0);
}
}
}
let mut set_areas = vec![0.0; n_sets];
for (i, set_name) in non_empty_sets.iter().enumerate() {
let combo = Combination::new(&[set_name]);
if let Some(&area) = filtered_inclusive.get(&combo) {
set_areas[i] = area;
}
}
let relationships = Self::compute_pairwise_relations(&non_empty_sets, &filtered_inclusive)?;
use crate::geometry::diagram;
let exclusive_areas_mask = filtered_exclusive
.iter()
.map(|(combo, &area)| {
let mask = diagram::combination_to_mask(combo, &non_empty_sets);
(mask, area)
})
.collect();
let inclusive_areas_mask = filtered_inclusive
.iter()
.map(|(combo, &area)| {
let mask = diagram::combination_to_mask(combo, &non_empty_sets);
(mask, area)
})
.collect();
Ok(PreprocessedSpec {
set_names: non_empty_sets,
set_to_idx,
exclusive_areas: exclusive_areas_mask,
inclusive_areas: inclusive_areas_mask,
n_sets,
set_areas,
relationships,
})
}
fn compute_pairwise_relations(
set_names: &[String],
inclusive_areas: &HashMap<Combination, f64>,
) -> Result<PairwiseRelations, DiagramError> {
let n = set_names.len();
let mut subset = vec![vec![false; n]; n];
let mut disjoint = vec![vec![false; n]; n];
let mut overlap_areas = vec![vec![0.0; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let set_i = &set_names[i];
let set_j = &set_names[j];
let combo_i = Combination::new(&[set_i]);
let combo_j = Combination::new(&[set_j]);
let combo_ij = Combination::new(&[set_i, set_j]);
let area_i = inclusive_areas.get(&combo_i).copied().unwrap_or(0.0);
let area_j = inclusive_areas.get(&combo_j).copied().unwrap_or(0.0);
let area_ij_inclusive = inclusive_areas.get(&combo_ij).copied().unwrap_or(0.0);
overlap_areas[i][j] = area_ij_inclusive;
overlap_areas[j][i] = area_ij_inclusive;
if area_ij_inclusive < 1e-10 {
disjoint[i][j] = true;
disjoint[j][i] = true;
}
if (area_ij_inclusive - area_j).abs() < 1e-10 {
subset[i][j] = true; }
if (area_ij_inclusive - area_i).abs() < 1e-10 {
subset[j][i] = true; }
}
}
Ok(PairwiseRelations {
n_sets: n,
subset,
disjoint,
overlap_areas,
})
}
fn exclusive_to_inclusive_static(
exclusive: &HashMap<Combination, f64>,
) -> Result<HashMap<Combination, f64>, DiagramError> {
let mut inclusive: HashMap<Combination, f64> = HashMap::new();
let mut all_sets = std::collections::HashSet::new();
for combo in exclusive.keys() {
for set_name in combo.sets() {
all_sets.insert(set_name.clone());
}
}
let all_sets: Vec<String> = all_sets.into_iter().collect();
let n_sets = all_sets.len();
for mask in 1..(1 << n_sets) {
let mut combo_sets = Vec::new();
for (i, set_name) in all_sets.iter().enumerate() {
if (mask & (1 << i)) != 0 {
combo_sets.push(set_name.as_str());
}
}
let combo = Combination::new(&combo_sets);
let mut inclusive_area = 0.0;
for (other_combo, &other_excl) in exclusive.iter() {
if other_combo.contains_all(&combo) {
inclusive_area += other_excl;
}
}
if inclusive_area > 1e-10 {
inclusive.insert(combo, inclusive_area);
}
}
Ok(inclusive)
}
fn inclusive_to_exclusive_static(
inclusive: &HashMap<Combination, f64>,
) -> Result<HashMap<Combination, f64>, DiagramError> {
let mut exclusive: HashMap<Combination, f64> = HashMap::new();
let mut sorted_combos: Vec<_> = inclusive.keys().collect();
sorted_combos.sort_by_key(|c| std::cmp::Reverse(c.len()));
for combo in sorted_combos {
let inclusive_area = inclusive[combo];
let mut exclusive_area = inclusive_area;
for (other_combo, &other_excl) in exclusive.iter() {
if other_combo != combo && other_combo.contains_all(combo) {
exclusive_area -= other_excl;
}
}
if exclusive_area < -1e-10 {
return Err(DiagramError::InvalidValue {
combination: combo.to_string(),
value: exclusive_area,
});
}
exclusive.insert(combo.clone(), exclusive_area.max(0.0));
}
Ok(exclusive)
}
}
#[allow(dead_code)] #[derive(Clone)]
pub(crate) struct PreprocessedSpec {
pub(crate) set_names: Vec<String>,
pub(crate) set_to_idx: HashMap<String, usize>,
pub(crate) exclusive_areas: HashMap<crate::geometry::diagram::RegionMask, f64>,
pub(crate) inclusive_areas: HashMap<crate::geometry::diagram::RegionMask, f64>,
pub(crate) n_sets: usize,
pub(crate) set_areas: Vec<f64>,
pub(crate) relationships: PairwiseRelations,
}
#[derive(Clone)]
pub(crate) struct PairwiseRelations {
#[allow(dead_code)]
pub(crate) n_sets: usize,
pub(crate) subset: Vec<Vec<bool>>,
pub(crate) disjoint: Vec<Vec<bool>>,
pub(crate) overlap_areas: Vec<Vec<f64>>,
}
impl PairwiseRelations {
#[allow(dead_code)]
pub(crate) fn is_subset(&self, i: usize, j: usize) -> bool {
self.subset[i][j]
}
#[allow(dead_code)]
pub(crate) fn is_disjoint(&self, i: usize, j: usize) -> bool {
self.disjoint[i][j]
}
#[allow(dead_code)]
pub(crate) fn overlap_area(&self, i: usize, j: usize) -> f64 {
self.overlap_areas[i][j]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_both_representations_available() {
let spec = DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 8.0)
.intersection(&["A", "B"], 2.0)
.input_type(InputType::Inclusive)
.build()
.unwrap();
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(10.0));
assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(8.0));
assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "B"])),
Some(2.0)
);
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(8.0)); assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(6.0)); assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(2.0)
);
}
#[test]
fn test_inclusive_three_set_decomposition() {
let spec = DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 8.0)
.set("C", 6.0)
.intersection(&["A", "B"], 3.0)
.intersection(&["A", "C"], 2.0)
.intersection(&["B", "C"], 1.0)
.intersection(&["A", "B", "C"], 0.0)
.input_type(InputType::Inclusive)
.build()
.unwrap();
let g = |names: &[&str]| {
spec.get_exclusive(&Combination::new(names))
.expect("exclusive area should be defined")
};
assert!((g(&["A"]) - 5.0).abs() < 1e-10);
assert!((g(&["B"]) - 4.0).abs() < 1e-10);
assert!((g(&["C"]) - 3.0).abs() < 1e-10);
assert!((g(&["A", "B"]) - 3.0).abs() < 1e-10);
assert!((g(&["A", "C"]) - 2.0).abs() < 1e-10);
assert!((g(&["B", "C"]) - 1.0).abs() < 1e-10);
}
#[test]
fn test_inclusive_rejects_negative_disjoint_area() {
let result = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 5.0)
.intersection(&["A", "B"], 10.0)
.input_type(InputType::Inclusive)
.build();
assert!(
matches!(result, Err(DiagramError::InvalidValue { .. })),
"expected InvalidValue for negative-disjoint inclusive input, got {:?}",
result
);
}
#[test]
fn test_exclusive_input() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0) .set("B", 2.0) .intersection(&["A", "B"], 1.0) .input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(5.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(2.0));
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(6.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(3.0)); assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
}
}