use std::collections::HashMap;
use log::debug;
use crate::core::OError;
use crate::utils::{DasDarren1998, NumberOfPartitions};
pub(crate) struct AdaptiveReferencePoints<'a> {
reference_points: &'a mut Vec<Vec<f64>>,
rho_j: &'a mut HashMap<usize, usize>,
number_of_or_reference_points: usize,
new_points_set: Vec<Vec<f64>>,
}
impl<'a> AdaptiveReferencePoints<'a> {
pub fn new(
reference_points: &'a mut Vec<Vec<f64>>,
rho_j: &'a mut HashMap<usize, usize>,
number_of_or_reference_points: usize,
gap: f64,
) -> Result<Self, OError> {
let number_of_objectives = reference_points.first().unwrap().len();
let ds = DasDarren1998::new(number_of_objectives, &NumberOfPartitions::OneLayer(1))?;
let mut new_points_set = ds.get_weights();
new_points_set = new_points_set
.iter()
.map(|new_p| new_p.iter().map(|new_c| new_c * gap).collect())
.collect();
let centroid: Vec<f64> = (0..number_of_objectives)
.map(|c_idx| {
new_points_set.iter().map(|v| v[c_idx]).sum::<f64>() / new_points_set.len() as f64
})
.collect();
new_points_set = new_points_set
.iter()
.map(|new_p| {
new_p
.iter()
.zip(¢roid)
.map(|(coord, ci)| coord - ci)
.collect()
})
.collect();
Ok(AdaptiveReferencePoints {
reference_points,
rho_j,
number_of_or_reference_points,
new_points_set,
})
}
pub fn calculate(&mut self) -> Result<(), OError> {
self.add()?;
self.delete();
Ok(())
}
fn add(&mut self) -> Result<(), OError> {
let mut all_new_points = vec![];
for (ref_point_index, counter) in self.rho_j.iter() {
if self.is_original_ref_point(*ref_point_index) && *counter >= 2 {
let new_points: Vec<Vec<f64>> = self
.new_points_set
.iter()
.map(|new_p| {
new_p
.iter()
.zip(&self.reference_points[*ref_point_index])
.map(|(new_c, old_c)| new_c + old_c)
.collect()
})
.collect();
debug!(
"Selected new ref points {:?} for #{ref_point_index} = {:?}",
new_points, &self.reference_points[*ref_point_index]
);
all_new_points.push(new_points);
}
}
for new_points in all_new_points {
for point in new_points {
if self.has_point(&point, None) {
debug!("Disregarded already existing ref point {:?}", point);
continue;
}
if point.iter().any(|c| *c < 0.0 || *c > 1.0) {
debug!("Disregarded outside-range ref point {:?}", point);
continue;
}
debug!("Added new ref points {:?}", point);
self.reference_points.push(point);
let new_index = self.reference_points.len();
self.rho_j.insert(new_index, 0);
}
}
Ok(())
}
fn delete(&mut self) {
let perfect_assoc_counter: usize = self
.rho_j
.iter()
.map(|(_, counter)| if *counter == 1 { 1 } else { 0 })
.sum();
if perfect_assoc_counter == self.number_of_or_reference_points {
let mut points_to_delete = vec![];
for (ref_point_idx, point) in self.reference_points.iter().enumerate() {
if self.is_original_ref_point(ref_point_idx) {
continue;
}
if self.rho_j[&ref_point_idx] == 0 {
points_to_delete.push(point.clone());
self.rho_j.remove(&ref_point_idx);
}
}
println!("points_to_delete={:?}", points_to_delete);
self.reference_points
.retain(|v| !points_to_delete.contains(&v));
}
}
fn is_original_ref_point(&self, ref_point_index: usize) -> bool {
ref_point_index < self.number_of_or_reference_points
}
pub fn has_point(&self, point: &[f64], tolerance: Option<f64>) -> bool {
let tolerance = tolerance.unwrap_or(0.000001);
for ref_point in self.reference_points.iter() {
if ref_point
.iter()
.zip(point)
.all(|(r, p)| (r - p).abs() < tolerance)
{
return true;
}
}
return false;
}
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::panic;
use crate::algorithms::nsga3::adaptive_ref_points::AdaptiveReferencePoints;
use crate::core::test_utils::assert_approx_array_eq;
use crate::utils::{DasDarren1998, NumberOfPartitions};
#[test]
fn test_ref_point_addition() {
let ds = DasDarren1998::new(3, &NumberOfPartitions::OneLayer(5)).unwrap();
let test_data = HashMap::from([
(
3_usize,
(1_usize, vec![vec![0.13333333, 0.53333333, 0.33333333]]),
),
(
12,
(
3,
vec![
vec![0.33333333, 0.13333333, 0.53333333],
vec![0.33333333, 0.33333333, 0.33333333],
vec![0.53333333, 0.13333333, 0.33333333],
],
),
),
]);
for (sc_id, (ref_point_index, results)) in test_data.iter().enumerate() {
let mut ref_points = ds.get_weights();
let mut rho_j = HashMap::new();
for point_idx in 0..ref_points.len() {
rho_j.insert(point_idx, 0);
}
*rho_j.get_mut(&ref_point_index).unwrap() = 2_usize;
let counter = ref_points.len();
let mut a =
AdaptiveReferencePoints::new(&mut ref_points, &mut rho_j, counter, ds.gap())
.unwrap();
a.calculate().unwrap();
assert_eq!(ref_points.len(), counter + results.0);
assert_eq!(rho_j.len(), ref_points.len());
for (pos, expected) in results.1.iter().enumerate() {
let result = panic::catch_unwind(|| {
assert_approx_array_eq(
&ref_points[ref_points.len() - (results.1.len()) + pos],
&expected,
None,
);
});
match result {
Ok(_) => {}
Err(e) => match e.downcast::<String>() {
Ok(v) => panic!(
"Scenario {} failed because:\n {:?}.\n Reference points were: {:?}",
sc_id + 1,
v,
ref_points
),
_ => {}
},
}
}
}
}
#[test]
fn test_ref_point_addition_overlap() {
let ds = DasDarren1998::new(3, &NumberOfPartitions::OneLayer(5)).unwrap();
let mut ref_points = ds.get_weights();
let mut rho_j = HashMap::new();
for point_idx in 0..ref_points.len() {
rho_j.insert(point_idx, 0);
}
assert_approx_array_eq(&ref_points[7], &[0.2, 0.2, 0.6], Some(0.0));
*rho_j.get_mut(&7).unwrap() = 2_usize;
assert_approx_array_eq(&ref_points[12], &[0.4, 0.2, 0.4], Some(0.0));
*rho_j.get_mut(&12).unwrap() = 2_usize;
let counter = ref_points.len();
let mut a =
AdaptiveReferencePoints::new(&mut ref_points, &mut rho_j, counter, ds.gap()).unwrap();
a.calculate().unwrap();
assert_eq!(ref_points.len(), counter + 5);
assert_eq!(rho_j.len(), ref_points.len());
}
#[test]
fn test_ref_point_deletion() {
let ds = DasDarren1998::new(3, &NumberOfPartitions::OneLayer(5)).unwrap();
let mut ref_points = ds.get_weights();
let or_counter = ref_points.len();
let mut rho_j = HashMap::new();
for point_idx in 0..ref_points.len() {
rho_j.insert(point_idx, 1);
}
rho_j.insert(11, 0);
ref_points.push(vec![0.33, 0.13, 0.53]); ref_points.push(vec![0.33, 0.33, 0.33]); rho_j.insert(21, 0);
rho_j.insert(22, 1);
println!("rho_j={:?}", rho_j);
let counter = ref_points.len();
println!("{:?}", counter);
let mut a = AdaptiveReferencePoints::new(&mut ref_points, &mut rho_j, or_counter, ds.gap())
.unwrap();
a.calculate().unwrap();
assert_eq!(a.reference_points.len(), or_counter + 1);
assert_eq!(ref_points.contains(&vec![0.33, 0.13, 0.53]), false);
assert_eq!(rho_j.contains_key(&21), false);
assert_eq!(ref_points.contains(&vec![0.33, 0.33, 0.33]), true);
assert_eq!(rho_j.contains_key(&22), true);
assert_eq!(rho_j.contains_key(&11), true);
}
}