use super::{Combination, DiagramSpec, InputType};
use crate::error::DiagramError;
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct DiagramSpecBuilder {
combinations: HashMap<Combination, f64>,
input_type: Option<InputType>,
set_order: Vec<String>,
}
impl Default for DiagramSpecBuilder {
fn default() -> Self {
Self::new()
}
}
impl DiagramSpecBuilder {
pub fn new() -> Self {
DiagramSpecBuilder {
combinations: HashMap::new(),
input_type: None,
set_order: Vec::new(),
}
}
pub fn set(mut self, name: impl Into<String>, value: f64) -> Self {
let name_string = name.into();
let combination = Combination::new(&[&name_string]);
if !self.set_order.contains(&name_string) {
self.set_order.push(name_string.clone());
}
self.combinations.insert(combination, value);
self
}
pub fn intersection(mut self, sets: &[&str], value: f64) -> Self {
let combination = Combination::new(sets);
self.combinations.insert(combination, value);
self
}
pub fn input_type(mut self, input_type: InputType) -> Self {
self.input_type = Some(input_type);
self
}
pub fn build(self) -> Result<DiagramSpec, DiagramError> {
if self.combinations.is_empty() {
return Err(DiagramError::EmptySets);
}
let mut all_set_names = HashSet::new();
let mut single_sets = HashSet::new();
for combination in self.combinations.keys() {
for set_name in combination.sets() {
all_set_names.insert(set_name.clone());
}
if combination.len() == 1 {
single_sets.insert(combination.sets()[0].clone());
}
}
let mut combinations = self.combinations;
for set_name in &all_set_names {
if !single_sets.contains(set_name) {
let combination = Combination::new(&[set_name.as_str()]);
combinations.insert(combination, 0.0);
single_sets.insert(set_name.clone());
}
}
let mut ordered_set_names: Vec<String> = self
.set_order
.iter()
.filter(|name| all_set_names.contains(*name))
.cloned()
.collect();
for set_name in &all_set_names {
if !ordered_set_names.contains(set_name) {
ordered_set_names.push(set_name.clone());
}
}
for (combination, &value) in &combinations {
if value < 0.0 {
return Err(DiagramError::InvalidValue {
combination: combination.to_string(),
value,
});
}
}
let input_type = self.input_type.unwrap_or_default();
let (exclusive_areas, inclusive_areas) = match input_type {
InputType::Exclusive => {
let exclusive = combinations;
let inclusive = DiagramSpec::exclusive_to_inclusive_static(&exclusive)?;
(exclusive, inclusive)
}
InputType::Inclusive => {
let inclusive = combinations;
let exclusive = DiagramSpec::inclusive_to_exclusive_static(&inclusive)?;
(exclusive, inclusive)
}
};
Ok(DiagramSpec {
exclusive_areas,
inclusive_areas,
input_type,
set_names: ordered_set_names,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_simple() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 2.0)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 2);
assert!(spec.set_names().contains(&"A".to_string()));
assert!(spec.set_names().contains(&"B".to_string()));
assert!(spec.get_inclusive(&Combination::new(&["A"])).is_some());
assert!(spec.get_exclusive(&Combination::new(&["A"])).is_some());
}
#[test]
fn test_builder_with_intersection() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 2.0)
.intersection(&["A", "B"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.input_type(), InputType::Exclusive);
assert_eq!(spec.exclusive_areas().len(), 3);
assert_eq!(spec.inclusive_areas().len(), 3);
}
#[test]
fn test_builder_implicit_zero_set() {
let spec = DiagramSpecBuilder::new()
.set("B", 5.0)
.intersection(&["A", "B"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
let combo_a = Combination::new(&["A"]);
assert_eq!(spec.get_exclusive(&combo_a), Some(0.0));
assert_eq!(spec.get_inclusive(&combo_a), Some(1.0));
}
#[test]
fn test_contained_set_exclusive() {
let spec = DiagramSpecBuilder::new()
.set("A", 0.0)
.set("B", 5.0)
.intersection(&["A", "B"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(0.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(5.0));
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(1.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(6.0)); assert_eq!(
spec.get_inclusive(&Combination::new(&["A", "B"])),
Some(1.0)
);
}
#[test]
fn test_implicit_set_from_three_way() {
let spec = DiagramSpecBuilder::new()
.set("A", 0.0)
.set("B", 5.0)
.intersection(&["A", "B"], 1.0)
.intersection(&["A", "B", "C"], 0.1)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 3);
assert!(spec.set_names().contains(&"A".to_string()));
assert!(spec.set_names().contains(&"B".to_string()));
assert!(spec.set_names().contains(&"C".to_string()));
assert_eq!(spec.get_exclusive(&Combination::new(&["C"])), Some(0.0));
assert_eq!(spec.get_inclusive(&Combination::new(&["C"])), Some(0.1));
}
#[test]
fn test_nested_containment() {
let spec = DiagramSpecBuilder::new()
.set("B", 5.0)
.intersection(&["A", "B"], 2.0)
.intersection(&["A", "B", "C"], 1.0)
.input_type(InputType::Exclusive)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 3);
assert_eq!(spec.get_exclusive(&Combination::new(&["A"])), Some(0.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["B"])), Some(5.0));
assert_eq!(spec.get_exclusive(&Combination::new(&["C"])), Some(0.0));
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B"])),
Some(2.0)
);
assert_eq!(
spec.get_exclusive(&Combination::new(&["A", "B", "C"])),
Some(1.0)
);
assert_eq!(spec.get_inclusive(&Combination::new(&["A"])), Some(3.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["B"])), Some(8.0)); assert_eq!(spec.get_inclusive(&Combination::new(&["C"])), Some(1.0)); }
#[test]
fn test_builder_negative_value_error() {
let result = DiagramSpecBuilder::new().set("A", -5.0).build();
assert!(matches!(result, Err(DiagramError::InvalidValue { .. })));
}
#[test]
fn test_builder_empty_error() {
let result = DiagramSpecBuilder::new().build();
assert!(matches!(result, Err(DiagramError::EmptySets)));
}
#[test]
fn test_three_way_intersection() {
let spec = DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 8.0)
.set("C", 12.0)
.intersection(&["A", "B"], 2.0)
.intersection(&["A", "C"], 3.0)
.intersection(&["B", "C"], 1.0)
.intersection(&["A", "B", "C"], 0.5)
.build()
.unwrap();
assert_eq!(spec.set_names().len(), 3);
assert_eq!(spec.exclusive_areas().len(), 7);
assert_eq!(spec.inclusive_areas().len(), 7);
}
#[test]
fn test_get_combination() {
let spec = DiagramSpecBuilder::new()
.set("A", 5.0)
.set("B", 2.0)
.intersection(&["A", "B"], 1.0)
.build()
.unwrap();
let combo_ab = Combination::new(&["A", "B"]);
assert_eq!(spec.get_inclusive(&combo_ab), Some(1.0));
let combo_ac = Combination::new(&["A", "C"]);
assert_eq!(spec.get_inclusive(&combo_ac), None);
}
}