use crate::types::{FilterParams, LocalSolution};
use ndarray::Array1;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum FiltersErrors {
#[error("Distance factor must be positive or equal to zero, got {value}")]
NegativeDistanceFactor { value: f64 },
#[error(
"Distance threshold calculation failed: threshold={threshold}, distance_factor={distance_factor}"
)]
InvalidDistanceThreshold { threshold: f64, distance_factor: f64 },
}
#[derive(Debug)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct MeritFilter {
pub threshold: f64,
}
impl Default for MeritFilter {
fn default() -> Self {
Self::new()
}
}
impl MeritFilter {
pub fn new() -> Self {
Self { threshold: f64::INFINITY }
}
pub fn update_threshold(&mut self, threshold: f64) {
self.threshold = threshold;
}
pub fn check(&self, value: f64) -> bool {
value <= self.threshold
}
}
#[derive(Debug)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct DistanceFilter {
solutions: Vec<LocalSolution>, params: FilterParams,
}
impl DistanceFilter {
pub fn new(params: FilterParams) -> Result<Self, FiltersErrors> {
if params.distance_factor < 0.0 {
return Err(FiltersErrors::NegativeDistanceFactor { value: params.distance_factor });
}
Ok(Self {
solutions: Vec::new(), params,
})
}
pub fn add_solution(&mut self, solution: LocalSolution) {
self.solutions.push(solution);
}
pub fn check(&self, point: &Array1<f64>) -> bool {
self.solutions.iter().all(|s| {
euclidean_distance_squared(point, &s.point)
> self.params.distance_factor * self.params.distance_factor
})
}
#[cfg(feature = "checkpointing")]
pub fn get_solutions(&self) -> &Vec<LocalSolution> {
&self.solutions
}
#[cfg(feature = "checkpointing")]
pub fn set_solutions(&mut self, solutions: Vec<LocalSolution>) {
self.solutions = solutions;
}
}
#[inline]
fn euclidean_distance_squared(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let diff = a - b;
diff.dot(&diff)
}
#[cfg(test)]
mod test_filters {
use super::*;
use ndarray::array;
#[test]
fn test_filter_params_invalid_distance_factor() {
let params: FilterParams = FilterParams {
distance_factor: -0.5, wait_cycle: 10,
threshold_factor: 0.1,
};
let df: Result<DistanceFilter, FiltersErrors> = DistanceFilter::new(params);
assert!(
matches!(df, Err(FiltersErrors::NegativeDistanceFactor { value } ) if value == -0.5)
);
}
#[test]
fn test_merit_filter_update_threshold() {
let mut filter = MeritFilter::new();
filter.update_threshold(10.0);
assert_eq!(filter.threshold, 10.0);
}
#[test]
fn test_distance_filter_valid() {
let params = FilterParams { distance_factor: 1.0, wait_cycle: 5, threshold_factor: 0.2 };
let filter = DistanceFilter::new(params).unwrap();
assert_eq!(filter.params.distance_factor, 1.0);
assert_eq!(filter.solutions.len(), 0);
}
#[test]
fn test_distance_filter_add_solution() {
let params = FilterParams { distance_factor: 1.0, wait_cycle: 5, threshold_factor: 0.2 };
let mut filter = DistanceFilter::new(params).unwrap();
let solution = LocalSolution { point: array![1.0, 2.0, 3.0], objective: 5.0 };
filter.add_solution(solution);
assert_eq!(filter.solutions.len(), 1);
assert_eq!(filter.solutions[0].objective, 5.0);
}
#[test]
fn test_distance_filter_check() {
let params = FilterParams { distance_factor: 2.0, wait_cycle: 5, threshold_factor: 0.2 };
let mut filter = DistanceFilter::new(params).unwrap();
filter.add_solution(LocalSolution { point: array![0.0, 0.0, 0.0], objective: 5.0 });
assert!(!filter.check(&array![1.0, 1.0, 1.0]));
assert!(filter.check(&array![3.0, 4.0, 3.0]));
}
}