radiate_core/objectives/
front.rs1use crate::objectives::{Objective, Scored, pareto};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::{cmp::Ordering, hash::Hash, ops::Range, sync::Arc};
5
6const DEFAULT_ENTROPY_BINS: usize = 20;
7
8pub struct FrontAddResult {
9 pub added_count: usize,
10 pub removed_count: usize,
11 pub comparisons: usize,
12}
13
14#[derive(Clone)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub struct Front<T>
21where
22 T: Scored,
23{
24 values: Vec<Arc<T>>,
25 range: Range<usize>,
26 objective: Objective,
27}
28
29impl<T> Front<T>
30where
31 T: Scored,
32{
33 pub fn new(range: Range<usize>, objective: Objective) -> Self {
34 Front {
35 values: Vec::new(),
36 range,
37 objective: objective.clone(),
38 }
39 }
40
41 pub fn range(&self) -> Range<usize> {
42 self.range.clone()
43 }
44
45 pub fn objective(&self) -> Objective {
46 self.objective.clone()
47 }
48
49 pub fn is_empty(&self) -> bool {
50 self.values.is_empty()
51 }
52
53 pub fn values(&self) -> &[Arc<T>] {
54 &self.values
55 }
56
57 pub fn crowding_distance(&self) -> Option<Vec<f32>> {
58 let scores = self
59 .values
60 .iter()
61 .filter_map(|s| s.score())
62 .collect::<Vec<_>>();
63
64 if scores.is_empty() {
65 return None;
66 }
67
68 Some(pareto::crowding_distance(&scores))
69 }
70
71 pub fn entropy(&self) -> Option<f32> {
72 let scores = self
73 .values
74 .iter()
75 .filter_map(|s| s.score())
76 .collect::<Vec<_>>();
77
78 if scores.is_empty() {
79 return None;
80 }
81
82 Some(pareto::entropy(&scores, DEFAULT_ENTROPY_BINS))
83 }
84
85 pub fn add_all(&mut self, items: &[T]) -> FrontAddResult
86 where
87 T: Eq + Hash + Clone + Send + Sync + 'static,
88 {
89 let mut updated = false;
90 let mut to_remove = Vec::new();
91 let mut added_count = 0;
92 let mut removed_count = 0;
93 let mut comparisons = 0;
94
95 for i in 0..items.len() {
96 let new_member = &items[i];
97 let mut is_dominated = true;
98
99 for existing_val in self.values.iter() {
100 let equals = new_member == existing_val.as_ref();
101 if self.dom_cmp(existing_val.as_ref(), new_member) == Ordering::Greater || equals {
102 is_dominated = false;
104 comparisons += 1;
105 break;
106 } else if self.dom_cmp(new_member, existing_val.as_ref()) == Ordering::Greater {
107 to_remove.push(Arc::clone(existing_val));
109 comparisons += 1;
110 continue;
111 }
112 }
113
114 if is_dominated {
115 updated = true;
116 self.values.push(Arc::new(new_member.clone()));
117 added_count += 1;
118 for rem in to_remove.drain(..) {
119 self.values.retain(|x| x.as_ref() != rem.as_ref());
120 removed_count += 1;
121 }
122 }
123
124 if updated && self.values.len() > self.range.end {
125 self.filter();
126 }
127
128 to_remove.clear();
129 updated = false;
130 }
131
132 FrontAddResult {
133 added_count,
134 removed_count,
135 comparisons,
136 }
137 }
138
139 fn dom_cmp(&self, one: &T, two: &T) -> Ordering {
140 let one_score = one.score();
141 let two_score = two.score();
142
143 if one_score.is_none() || two_score.is_none() {
144 return Ordering::Equal;
145 }
146
147 if let (Some(one), Some(two)) = (one_score, two_score) {
148 if pareto::dominance(one, two, &self.objective) {
149 return Ordering::Greater;
150 } else if pareto::dominance(two, one, &self.objective) {
151 return Ordering::Less;
152 }
153 }
154
155 Ordering::Equal
156 }
157
158 fn filter(&mut self) {
159 if let Some(crowding_distances) = self.crowding_distance() {
160 let mut enumerated = crowding_distances.iter().enumerate().collect::<Vec<_>>();
161
162 enumerated.sort_unstable_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(Ordering::Equal));
163
164 self.values = enumerated
165 .iter()
166 .take(self.range.start)
167 .map(|(i, _)| Arc::clone(&self.values[*i]))
168 .collect::<Vec<Arc<T>>>();
169 }
170 }
171}
172
173impl<T> Default for Front<T>
174where
175 T: Scored,
176{
177 fn default() -> Self {
178 Front::new(0..0, Objective::default())
179 }
180}