use std::mem;
use log::{debug, warn};
use hv_fonseca_et_al_2006_sys::calculate_hv;
use crate::core::{Individual, Individuals, OError};
use crate::metrics::hypervolume::{check_args, check_ref_point_coordinate};
use crate::utils::fast_non_dominated_sort;
#[derive(Debug)]
pub struct HyperVolumeFonseca2006 {
individuals: Vec<Vec<f64>>,
reference_point: Vec<f64>,
}
impl HyperVolumeFonseca2006 {
pub fn new(individuals: &mut [Individual], reference_point: &[f64]) -> Result<Self, OError> {
let metric_name = "Hyper-volume Fonseca et al. (2006)".to_string();
check_args(individuals, reference_point)
.map_err(|e| OError::Metric(metric_name.clone(), e))?;
if reference_point.len() != 3 {
return Err(OError::Metric(
metric_name,
"This can only be used on a 3-objective problem.".to_string(),
));
}
let problem = individuals[0].problem();
for (obj_idx, (obj_name, obj)) in problem.objectives().iter().enumerate() {
check_ref_point_coordinate(
&individuals.objective_values(obj_name)?,
obj,
reference_point[obj_idx],
obj_idx + 1,
)
.map_err(|e| OError::Metric(metric_name.clone(), e))?;
}
let num_individuals = individuals.len();
let mut front_data = fast_non_dominated_sort(individuals, true)?;
let individuals = mem::take(&mut front_data.fronts[0]);
if num_individuals != individuals.len() {
warn!("{} individuals were removed from the given data because they are dominated by all the other points", num_individuals - individuals.len());
}
let objective_values = individuals
.iter()
.map(|ind| ind.get_objective_values())
.collect::<Result<Vec<Vec<f64>>, _>>()?;
let mut ref_point = reference_point.to_vec();
for (obj_idx, obj_name) in problem.objective_names().iter().enumerate() {
if !problem.is_objective_minimised(obj_name)? {
ref_point[obj_idx] *= -1.0;
}
}
debug!("Using non-dominated front {:?}", objective_values);
debug!("Reference point is {:?}", ref_point);
Ok(Self {
individuals: objective_values,
reference_point: ref_point,
})
}
pub fn compute(&self) -> f64 {
calculate_hv(&self.individuals, &self.reference_point)
}
}
#[cfg(test)]
mod test {
use float_cmp::approx_eq;
use crate::core::test_utils::individuals_from_obj_values_dummy;
use crate::core::ObjectiveDirection;
use crate::metrics::test_utils::parse_pagmo_test_data_file;
use crate::metrics::HyperVolumeFonseca2006;
#[test]
fn test_wrong_ref_point() {
let objective_values = vec![vec![1.0, 2.0, 1.0], vec![2.0, 1.0, 1.0]];
let objective_direction = [ObjectiveDirection::Minimise; 3];
let mut individuals =
individuals_from_obj_values_dummy(&objective_values, &objective_direction, None);
let ref_point = vec![2.0, 2.0, 2.0];
let hv = HyperVolumeFonseca2006::new(&mut individuals, &ref_point);
assert!(hv
.unwrap_err()
.to_string()
.contains("The coordinate (2) of the reference point #1 must be strictly larger"));
}
#[test]
fn test_simple_front() {
let objective_values = vec![vec![1.0, 1.0, 1.0], vec![2.0, 2.0, 2.0]];
let objective_direction = vec![ObjectiveDirection::Minimise; 3];
let ref_point = [3.0, 3.0, 3.0];
let mut individuals = individuals_from_obj_values_dummy(
&objective_values,
&objective_direction.clone(),
None,
);
let hv = HyperVolumeFonseca2006::new(&mut individuals, &ref_point).unwrap();
assert_eq!(hv.compute(), 8.0);
}
pub(crate) fn assert_test_file(file: &str) {
let all_test_data = parse_pagmo_test_data_file(file).unwrap();
let obj_count = all_test_data.first().unwrap().reference_point.len();
let objective_direction = vec![ObjectiveDirection::Minimise; obj_count];
for (ti, test_data) in all_test_data.iter().enumerate() {
let mut individuals = individuals_from_obj_values_dummy(
&test_data.objective_values,
&objective_direction,
None,
);
let hv =
HyperVolumeFonseca2006::new(&mut individuals, &test_data.reference_point).unwrap();
let calculated = hv.compute();
let expected = test_data.hyper_volume;
if !approx_eq!(f64, calculated, expected, epsilon = 0.001) {
panic!(
r#"assertion failed for test #{}: `(left approx_eq right)` left: `{:?}`, right: `{:?}`"#,
ti + 1,
calculated,
expected,
)
}
}
}
#[test]
fn test_c_max_t100_d3_n128() {
assert_test_file("c_max_t100_d3_n128");
}
#[test]
fn test_c_max_t1_d3_n2048() {
assert_test_file("c_max_t1_d3_n2048");
}
}