radiate_core/objectives/
front.rs

1use crate::{
2    Optimize,
3    objectives::{Objective, pareto},
4};
5use std::{cmp::Ordering, hash::Hash, ops::Range, sync::Arc};
6
7/// A `Front<T>` is a collection of `T`'s that are non-dominated with respect to each other.
8/// This is useful for multi-objective optimization problems where the goal is to find
9/// the best solutions that are not dominated by any other solution.
10/// This results in what is called the Pareto front.
11#[derive(Clone)]
12pub struct Front<T>
13where
14    T: AsRef<[f32]>,
15{
16    values: Vec<Arc<T>>,
17    ord: Arc<dyn Fn(&T, &T) -> Ordering + Send + Sync>,
18    range: Range<usize>,
19    objective: Objective,
20}
21
22impl<T> Front<T>
23where
24    T: AsRef<[f32]>,
25{
26    pub fn new<F>(range: Range<usize>, objective: Objective, comp: F) -> Self
27    where
28        F: Fn(&T, &T) -> Ordering + Send + Sync + 'static,
29    {
30        Front {
31            values: Vec::new(),
32            range,
33            objective,
34            ord: Arc::new(comp),
35        }
36    }
37
38    pub fn range(&self) -> Range<usize> {
39        self.range.clone()
40    }
41
42    pub fn objective(&self) -> Objective {
43        self.objective.clone()
44    }
45
46    pub fn values(&self) -> &[Arc<T>] {
47        &self.values
48    }
49
50    pub fn add_all(&mut self, items: &[T]) -> usize
51    where
52        T: Eq + Hash + Clone + Send + Sync + 'static,
53    {
54        let mut updated = false;
55        let mut to_remove = Vec::new();
56        let mut added_count = 0;
57
58        for i in 0..items.len() {
59            let new_member = &items[i];
60            let mut is_dominated = true;
61
62            for existing_val in self.values.iter() {
63                let equals = new_member == existing_val.as_ref();
64                if (self.ord)(existing_val.as_ref(), new_member) == Ordering::Greater || equals {
65                    // If an existing value dominates the new value, return false
66                    is_dominated = false;
67                    break;
68                } else if (self.ord)(new_member, existing_val.as_ref()) == Ordering::Greater {
69                    // If the new value dominates an existing value, continue checking
70                    to_remove.push(Arc::clone(existing_val));
71                    continue;
72                }
73            }
74
75            if is_dominated {
76                updated = true;
77                self.values.push(Arc::new(new_member.clone()));
78                added_count += 1;
79                for rem in to_remove.drain(..) {
80                    self.values.retain(|x| x.as_ref() != rem.as_ref());
81                }
82            }
83
84            if updated && self.values.len() > self.range.end {
85                self.filter();
86            }
87
88            to_remove.clear();
89            updated = false;
90        }
91
92        added_count
93    }
94
95    pub fn filter(&mut self) {
96        let values = self.values.iter().map(|s| s.as_ref()).collect::<Vec<_>>();
97        let crowding_distances = pareto::crowding_distance(&values);
98
99        let mut enumerated = crowding_distances.iter().enumerate().collect::<Vec<_>>();
100
101        enumerated.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(Ordering::Equal));
102
103        self.values = enumerated
104            .iter()
105            .take(self.range.end)
106            .map(|(i, _)| Arc::clone(&self.values[*i]))
107            .collect::<Vec<Arc<T>>>();
108    }
109}
110
111impl<T> Default for Front<T>
112where
113    T: AsRef<[f32]>,
114{
115    fn default() -> Self {
116        Front::new(0..0, Objective::Single(Optimize::Minimize), |_, _| {
117            std::cmp::Ordering::Equal
118        })
119    }
120}