radiate_core/objectives/
front.rs

1use crate::{
2    Chromosome, Epoch, Executor, Phenotype,
3    objectives::{Objective, pareto},
4};
5use std::{
6    cmp::Ordering,
7    collections::HashSet,
8    hash::Hash,
9    ops::Range,
10    sync::{Arc, RwLock},
11};
12
13/// A front is a collection of scores that are non-dominated with respect to each other.
14/// This is useful for multi-objective optimization problems where the goal is to find
15/// the best solutions that are not dominated by any other solution.
16/// This results in what is called the Pareto front.
17#[derive(Clone)]
18pub struct Front<T>
19where
20    T: AsRef<[f32]>,
21{
22    values: Vec<Arc<T>>,
23    ord: Arc<dyn Fn(&T, &T) -> Ordering + Send + Sync>,
24    range: Range<usize>,
25    objective: Objective,
26    thread_pool: Arc<Executor>,
27}
28
29impl<T> Front<T>
30where
31    T: AsRef<[f32]>,
32{
33    pub fn new<F>(
34        range: Range<usize>,
35        objective: Objective,
36        thread_pool: Arc<Executor>,
37        comp: F,
38    ) -> Self
39    where
40        F: Fn(&T, &T) -> Ordering + Send + Sync + 'static,
41    {
42        Front {
43            values: Vec::new(),
44            range,
45            objective,
46            ord: Arc::new(comp),
47            thread_pool,
48        }
49    }
50
51    pub fn range(&self) -> Range<usize> {
52        self.range.clone()
53    }
54
55    pub fn objective(&self) -> Objective {
56        self.objective.clone()
57    }
58
59    pub fn values(&self) -> &[Arc<T>] {
60        &self.values
61    }
62
63    pub fn add_all(&mut self, items: &[T]) -> usize
64    where
65        T: Eq + Hash + Clone + Send + Sync + 'static,
66    {
67        let ord = Arc::clone(&self.ord);
68        let values = Arc::new(RwLock::new(self.values.clone()));
69        let dominating_values = Arc::new(RwLock::new(vec![false; items.len()]));
70        let remove_values = Arc::new(RwLock::new(HashSet::new()));
71        let values_to_add = Arc::new(RwLock::new(Vec::new()));
72
73        let mut jobs = Vec::new();
74        for (idx, member) in items.iter().enumerate() {
75            let ord_clone = Arc::clone(&ord);
76            let values_clone = Arc::clone(&values);
77            let doms_vector = Arc::clone(&dominating_values);
78            let remove_vector = Arc::clone(&remove_values);
79            let new_member = member.clone();
80            let values_to_add = Arc::clone(&values_to_add);
81
82            // self.thread_pool.group_submit(&wg, move || {
83            jobs.push(move || {
84                let mut is_dominated = true;
85
86                for existing_val in values_clone.read().unwrap().iter() {
87                    if (ord_clone)(existing_val, &new_member) == Ordering::Greater {
88                        // If an existing value dominates the new value, return false
89                        is_dominated = false;
90                        break;
91                    } else if (ord_clone)(&new_member, existing_val) == Ordering::Greater {
92                        // If the new value dominates an existing value, continue checking
93                        // to_remove.push(Arc::clone(existing_val));
94                        remove_vector.write().unwrap().insert(existing_val.clone());
95                        continue;
96                    } else if &new_member == existing_val.as_ref() {
97                        // If they are equal, we consider it dominated
98                        is_dominated = false;
99                        break;
100                    }
101                }
102
103                if is_dominated {
104                    doms_vector.write().unwrap().get_mut(idx).map(|v| *v = true);
105                    let mut writer = values_to_add.write().unwrap();
106                    writer.push(new_member);
107                }
108            });
109        }
110
111        let count = jobs.len();
112
113        self.thread_pool.submit_batch(jobs);
114
115        self.values
116            .retain(|x| !remove_values.read().unwrap().contains(x));
117        self.values
118            .extend(values_to_add.write().unwrap().drain(..).map(Arc::new));
119
120        if self.values.len() > self.range.end {
121            self.filter();
122        }
123
124        count
125    }
126
127    pub fn filter(&mut self) {
128        let values = self.values.iter().map(|s| s.as_ref()).collect::<Vec<_>>();
129        let crowding_distances = pareto::crowding_distance(&values, &self.objective);
130
131        let mut enumerated = crowding_distances.iter().enumerate().collect::<Vec<_>>();
132
133        enumerated.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(Ordering::Equal));
134
135        self.values = enumerated
136            .iter()
137            .take(self.range.end)
138            .map(|(i, _)| Arc::clone(&self.values[*i]))
139            .collect::<Vec<Arc<T>>>();
140    }
141}
142
143#[derive(Clone, Default)]
144pub struct ParetoFront<T> {
145    front: Vec<T>,
146}
147
148impl<T> ParetoFront<T> {
149    pub fn new() -> Self {
150        ParetoFront { front: Vec::new() }
151    }
152
153    pub fn add(&mut self, item: T) {
154        self.front.push(item);
155    }
156
157    pub fn values(&self) -> &[T] {
158        &self.front
159    }
160}
161
162impl<C, E> FromIterator<E> for ParetoFront<Phenotype<C>>
163where
164    C: Chromosome + Clone,
165    E: Epoch<Chromosome = C, Value = Front<Phenotype<C>>>,
166{
167    fn from_iter<I: IntoIterator<Item = E>>(iter: I) -> Self {
168        let mut result = ParetoFront::new();
169        let final_epoch = iter.into_iter().last();
170        if let Some(epoch) = final_epoch {
171            let front = epoch.value();
172            for value in front.values() {
173                result.add((*(*value)).clone());
174            }
175        }
176
177        result
178    }
179}