use super::{Combination, DiagramSpec, InputType};
use crate::constants::MAX_SETS;
use crate::error::DiagramError;
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct DiagramSpecBuilder {
combinations: HashMap<Combination, f64>,
input_type: Option<InputType>,
set_order: Vec<String>,
}
impl Default for DiagramSpecBuilder {
fn default() -> Self {
Self::new()
}
}
impl DiagramSpecBuilder {
pub fn new() -> Self {
DiagramSpecBuilder {
combinations: HashMap::new(),
input_type: None,
set_order: Vec::new(),
}
}
pub fn set(mut self, name: impl Into<String>, value: f64) -> Self {
let name_string = name.into();
let combination = Combination::new(&[&name_string]);
if !self.set_order.contains(&name_string) {
self.set_order.push(name_string.clone());
}
self.combinations.insert(combination, value);
self
}
pub fn intersection(mut self, sets: &[&str], value: f64) -> Self {
let combination = Combination::new(sets);
self.combinations.insert(combination, value);
self
}
pub fn input_type(mut self, input_type: InputType) -> Self {
self.input_type = Some(input_type);
self
}
pub fn build(self) -> Result<DiagramSpec, DiagramError> {
if self.combinations.is_empty() {
return Err(DiagramError::EmptySets);
}
let mut all_set_names = HashSet::new();
let mut single_sets = HashSet::new();
for combination in self.combinations.keys() {
for set_name in combination.sets() {
all_set_names.insert(set_name.clone());
}
if combination.len() == 1 {
single_sets.insert(combination.sets()[0].clone());
}
}
let mut combinations = self.combinations;
for set_name in &all_set_names {
if !single_sets.contains(set_name) {
let combination = Combination::new(&[set_name.as_str()]);
combinations.insert(combination, 0.0);
single_sets.insert(set_name.clone());
}
}
let mut ordered_set_names: Vec<String> = self
.set_order
.iter()
.filter(|name| all_set_names.contains(*name))
.cloned()
.collect();
for set_name in &all_set_names {
if !ordered_set_names.contains(set_name) {
ordered_set_names.push(set_name.clone());
}
}
if ordered_set_names.len() > MAX_SETS {
return Err(DiagramError::TooManySets {
requested: ordered_set_names.len(),
max: MAX_SETS,
});
}
for (combination, &value) in &combinations {
if value < 0.0 {
return Err(DiagramError::InvalidValue {
combination: combination.to_string(),
value,
});
}
}
let input_type = self.input_type.unwrap_or_default();
let exclusive_areas = match input_type {
InputType::Exclusive => combinations,
InputType::Inclusive => DiagramSpec::inclusive_to_exclusive_static(&combinations)?,
};
Ok(DiagramSpec {
exclusive_areas,
input_type,
set_names: ordered_set_names,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_simple() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 2.0)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 2);
assert!(spec.set_names().contains(&"A".to_string()));
assert!(spec.set_names().contains(&"B".to_string()));
assert!(spec.get_inclusive(&Combination::new(&["A"])).is_some());
assert!(spec.get_exclusive(&Combination::new(&["A"])).is_some());
}
#[test]
fn test_builder_with_intersection() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 2.0)
.intersection(&["A", "B"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.input_type(), InputType::Exclusive);
assert_eq!(spec.exclusive_areas().len(), 3);
assert_eq!(spec.inclusive_areas().len(), 3);
}
#[test]
fn test_builder_implicit_zero_set() {
let spec = DiagramSpecBuilder::new()
.set("B", 5.0)
.intersection(&["A", "B"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
let combo_a = Combination::new(&["A"]);
assert_eq!(spec.get_exclusive(&combo_a), Some(0.0));
assert_eq!(spec.get_inclusive(&combo_a), Some(1.0));
}
#[test]
fn test_contained_set_exclusive() {
let spec = DiagramSpecBuilder::new()
.set("A", 0.0)
.set("B", 5.0)
.intersection(&["A", "B"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(0.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(5.0));
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(1.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(6.0)); assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
}
#[test]
fn test_implicit_set_from_three_way() {
let spec = DiagramSpecBuilder::new()
.set("A", 0.0)
.set("B", 5.0)
.intersection(&["A", "B"], 1.0)
.intersection(&["A", "B", "C"], 0.1)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 3);
assert!(spec.set_names().contains(&"A".to_string()));
assert!(spec.set_names().contains(&"B".to_string()));
assert!(spec.set_names().contains(&"C".to_string()));
assert_eq!(spec.get_exclusive(&Combination::new(&["C"])), Some(0.0));
assert_eq!(spec.get_inclusive(&Combination::new(&["C"])), Some(0.1));
}
#[test]
fn test_nested_containment() {
let spec = DiagramSpecBuilder::new()
.set("B", 5.0)
.intersection(&["A", "B"], 2.0)
.intersection(&["A", "B", "C"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 3);
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(0.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(5.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["C"])), Some(0.0));
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(2.0)
);
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B", "C"])),
Some(1.0)
);
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(3.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(8.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["C"])), Some(1.0)); }
#[test]
fn test_builder_negative_value_error() {
let result = DiagramSpecBuilder::new().set("A", -5.0).build();
assert!(matches!(result, Err(DiagramError::InvalidValue { .. })));
}
#[test]
fn test_builder_empty_error() {
let result = DiagramSpecBuilder::new().build();
assert!(matches!(result, Err(DiagramError::EmptySets)));
}
#[test]
fn test_three_way_intersection() {
let spec = DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 8.0)
.set("C", 12.0)
.intersection(&["A", "B"], 2.0)
.intersection(&["A", "C"], 3.0)
.intersection(&["B", "C"], 1.0)
.intersection(&["A", "B", "C"], 0.5)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 3);
assert_eq!(spec.exclusive_areas().len(), 7);
assert_eq!(spec.inclusive_areas().len(), 7);
}
#[test]
fn test_too_many_sets_rejected() {
let mut builder = DiagramSpecBuilder::new();
for i in 0..(MAX_SETS + 1) {
builder = builder.set(format!("S{i}"), 1.0);
}
let result = builder.build();
assert!(
matches!(
result,
Err(DiagramError::TooManySets { requested, max })
if requested == MAX_SETS + 1 && max == MAX_SETS
),
"expected TooManySets for n = MAX_SETS + 1, got {:?}",
result
);
}
#[test]
fn test_max_sets_accepted() {
let mut builder = DiagramSpecBuilder::new();
for i in 0..MAX_SETS {
builder = builder.set(format!("S{i}"), 1.0);
}
let spec = builder.build().expect("MAX_SETS singletons should build");
assert_eq!(spec.set_names().len(), MAX_SETS);
assert_eq!(spec.inclusive_areas().len(), MAX_SETS);
}
#[test]
fn test_sparse_exclusive_to_inclusive_no_power_set_blowup() {
let mut builder = DiagramSpecBuilder::new();
for i in 0..25 {
builder = builder.set(format!("S{i}"), 1.0);
}
let spec = builder
.intersection(&["S0", "S1"], 0.5)
.input_type(InputType::Exclusive)
.build()
.expect("sparse 25-set spec should build");
assert!(
(spec
.get_inclusive(&Combination::new(&["S0"]))
.expect("S0 inclusive")
- 1.5)
.abs()
< 1e-10
);
assert!(
(spec
.get_inclusive(&Combination::new(&["S5"]))
.expect("S5 inclusive")
- 1.0)
.abs()
< 1e-10
);
assert!(
(spec
.get_inclusive(&Combination::new(&["S0", "S1"]))
.expect("S0&S1 inclusive")
- 0.5)
.abs()
< 1e-10
);
assert!(spec
.get_inclusive(&Combination::new(&["S0", "S1", "S2"]))
.is_none());
assert_eq!(spec.inclusive_areas().len(), 26);
}
#[test]
fn test_get_combination() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 2.0)
.intersection(&["A", "B"], 1.0)
.build()
.unwrap();
let combo_ab = Combination::new(&["A", "B"]);
assert_eq!(spec.get_inclusive(&combo_ab), Some(1.0));
let combo_ac = Combination::new(&["A", "C"]);
assert_eq!(spec.get_inclusive(&combo_ac), None);
}
#[test]
fn test_build_scales_with_large_kway_intersection() {
let n = 30;
let names: Vec<String> = (0..n).map(|i| format!("S{i}")).collect();
let mut builder = DiagramSpecBuilder::new();
for name in &names {
builder = builder.set(name.as_str(), 1.0);
}
let intersection_refs: Vec<&str> = names.iter().map(String::as_str).collect();
let start = std::time::Instant::now();
let spec = builder
.intersection(&intersection_refs, 0.5)
.input_type(InputType::Exclusive)
.build()
.expect("30-way spec should build");
let elapsed = start.elapsed();
assert_eq!(spec.exclusive_areas().len(), n + 1);
assert!(
elapsed < std::time::Duration::from_secs(1),
"30-way build took {:?}; expected sub-second",
elapsed
);
}
}