radiate_core/objectives/
front.rs1use crate::{
2 Chromosome, Epoch, Executor, Phenotype,
3 objectives::{Objective, pareto},
4};
5use std::{
6 cmp::Ordering,
7 collections::HashSet,
8 hash::Hash,
9 ops::Range,
10 sync::{Arc, RwLock},
11};
12
13#[derive(Clone)]
18pub struct Front<T>
19where
20 T: AsRef<[f32]>,
21{
22 values: Vec<Arc<T>>,
23 ord: Arc<dyn Fn(&T, &T) -> Ordering + Send + Sync>,
24 range: Range<usize>,
25 objective: Objective,
26 thread_pool: Arc<Executor>,
27}
28
29impl<T> Front<T>
30where
31 T: AsRef<[f32]>,
32{
33 pub fn new<F>(
34 range: Range<usize>,
35 objective: Objective,
36 thread_pool: Arc<Executor>,
37 comp: F,
38 ) -> Self
39 where
40 F: Fn(&T, &T) -> Ordering + Send + Sync + 'static,
41 {
42 Front {
43 values: Vec::new(),
44 range,
45 objective,
46 ord: Arc::new(comp),
47 thread_pool,
48 }
49 }
50
51 pub fn range(&self) -> Range<usize> {
52 self.range.clone()
53 }
54
55 pub fn objective(&self) -> Objective {
56 self.objective.clone()
57 }
58
59 pub fn values(&self) -> &[Arc<T>] {
60 &self.values
61 }
62
63 pub fn add_all(&mut self, items: &[T]) -> usize
64 where
65 T: Eq + Hash + Clone + Send + Sync + 'static,
66 {
67 let ord = Arc::clone(&self.ord);
68 let values = Arc::new(RwLock::new(self.values.clone()));
69 let dominating_values = Arc::new(RwLock::new(vec![false; items.len()]));
70 let remove_values = Arc::new(RwLock::new(HashSet::new()));
71 let values_to_add = Arc::new(RwLock::new(Vec::new()));
72
73 let mut jobs = Vec::new();
74 for (idx, member) in items.iter().enumerate() {
75 let ord_clone = Arc::clone(&ord);
76 let values_clone = Arc::clone(&values);
77 let doms_vector = Arc::clone(&dominating_values);
78 let remove_vector = Arc::clone(&remove_values);
79 let new_member = member.clone();
80 let values_to_add = Arc::clone(&values_to_add);
81
82 jobs.push(move || {
84 let mut is_dominated = true;
85
86 for existing_val in values_clone.read().unwrap().iter() {
87 if (ord_clone)(existing_val, &new_member) == Ordering::Greater {
88 is_dominated = false;
90 break;
91 } else if (ord_clone)(&new_member, existing_val) == Ordering::Greater {
92 remove_vector.write().unwrap().insert(existing_val.clone());
95 continue;
96 } else if &new_member == existing_val.as_ref() {
97 is_dominated = false;
99 break;
100 }
101 }
102
103 if is_dominated {
104 doms_vector.write().unwrap().get_mut(idx).map(|v| *v = true);
105 let mut writer = values_to_add.write().unwrap();
106 writer.push(new_member);
107 }
108 });
109 }
110
111 let count = jobs.len();
112
113 self.thread_pool.submit_batch(jobs);
114
115 self.values
116 .retain(|x| !remove_values.read().unwrap().contains(x));
117 self.values
118 .extend(values_to_add.write().unwrap().drain(..).map(Arc::new));
119
120 if self.values.len() > self.range.end {
121 self.filter();
122 }
123
124 count
125 }
126
127 pub fn filter(&mut self) {
128 let values = self.values.iter().map(|s| s.as_ref()).collect::<Vec<_>>();
129 let crowding_distances = pareto::crowding_distance(&values, &self.objective);
130
131 let mut enumerated = crowding_distances.iter().enumerate().collect::<Vec<_>>();
132
133 enumerated.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(Ordering::Equal));
134
135 self.values = enumerated
136 .iter()
137 .take(self.range.end)
138 .map(|(i, _)| Arc::clone(&self.values[*i]))
139 .collect::<Vec<Arc<T>>>();
140 }
141}
142
143#[derive(Clone, Default)]
144pub struct ParetoFront<T> {
145 front: Vec<T>,
146}
147
148impl<T> ParetoFront<T> {
149 pub fn new() -> Self {
150 ParetoFront { front: Vec::new() }
151 }
152
153 pub fn add(&mut self, item: T) {
154 self.front.push(item);
155 }
156
157 pub fn values(&self) -> &[T] {
158 &self.front
159 }
160}
161
162impl<C, E> FromIterator<E> for ParetoFront<Phenotype<C>>
163where
164 C: Chromosome + Clone,
165 E: Epoch<Chromosome = C, Value = Front<Phenotype<C>>>,
166{
167 fn from_iter<I: IntoIterator<Item = E>>(iter: I) -> Self {
168 let mut result = ParetoFront::new();
169 let final_epoch = iter.into_iter().last();
170 if let Some(epoch) = final_epoch {
171 let front = epoch.value();
172 for value in front.values() {
173 result.add((*(*value)).clone());
174 }
175 }
176
177 result
178 }
179}