1use crate::objectives::{Objective, Scored, pareto};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::{cmp::Ordering, ops::Range, sync::Arc};
6
7const DEFAULT_ENTROPY_BINS: usize = 20;
8const EPSILON: f32 = 1e-10;
9
10#[derive(Clone, Default)]
11struct FrontScratch {
12 remove: Vec<usize>,
13 keep_idx: Vec<usize>,
14 scores: Vec<f32>,
15 dist: Vec<f32>,
16 order: Vec<usize>,
17}
18
19#[derive(Debug)]
20pub struct FrontAddResult {
21 pub added_count: usize,
22 pub removed_count: usize,
23 pub comparisons: usize,
24 pub filter_count: usize,
25 pub size: usize,
26}
27
28#[derive(Clone)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30pub struct Front<T>
31where
32 T: Scored,
33{
34 values: Vec<Arc<T>>,
35 range: Range<usize>,
36 objective: Objective,
37
38 #[cfg_attr(feature = "serde", serde(skip))]
39 scratch: FrontScratch,
40}
41
42impl<T> Front<T>
43where
44 T: Scored,
45{
46 pub fn new(range: Range<usize>, objective: Objective) -> Self {
47 Front {
48 values: Vec::new(),
49 range,
50 objective,
51 scratch: FrontScratch::default(),
52 }
53 }
54
55 pub fn len(&self) -> usize {
56 self.values.len()
57 }
58
59 pub fn range(&self) -> Range<usize> {
60 self.range.clone()
61 }
62
63 pub fn objective(&self) -> Objective {
64 self.objective.clone()
65 }
66
67 pub fn is_empty(&self) -> bool {
68 self.values.is_empty()
69 }
70
71 pub fn values(&self) -> &[Arc<T>] {
72 &self.values
73 }
74
75 pub fn crowding_distance(&mut self) -> Option<&[f32]> {
76 self.ensure_score_matrix()?;
77 let (n, _) = self.score_dims()?;
78 self.crowding_distance_in_place(n);
79
80 Some(&self.scratch.dist[..n])
81 }
82
83 pub fn entropy(&mut self) -> Option<f32> {
84 self.ensure_score_matrix()?;
85 let (n, m) = self.score_dims()?;
86
87 Some(entropy_flat(
88 &self.scratch.scores,
89 n,
90 m,
91 DEFAULT_ENTROPY_BINS,
92 ))
93 }
94
95 pub fn add_all(&mut self, items: Vec<T>) -> FrontAddResult
96 where
97 T: Eq + Clone + Send + Sync + 'static,
98 {
99 let mut added_count = 0;
100 let mut removed_count = 0;
101 let mut comparisons = 0;
102 let mut filter_count = 0;
103
104 for new_member in items.into_iter() {
105 self.scratch.remove.clear();
106
107 let mut accept = true;
109
110 for (idx, existing) in self.values.iter().enumerate() {
111 if existing.as_ref() == &new_member {
112 accept = false;
113 break;
114 }
115
116 match self.dom_cmp(existing.as_ref(), &new_member) {
118 Ordering::Greater => {
119 accept = false;
121 comparisons += 1;
122 break;
123 }
124 Ordering::Less => {
125 self.scratch.remove.push(idx);
127 comparisons += 1;
128 }
129 Ordering::Equal => comparisons += 1,
130 }
131 }
132
133 if !accept {
134 continue;
135 }
136
137 if !self.scratch.remove.is_empty() {
140 self.scratch.remove.sort_unstable();
141 self.scratch.remove.dedup();
142
143 for &idx in self.scratch.remove.iter().rev() {
144 self.values.swap_remove(idx);
145 removed_count += 1;
146 }
147 }
148
149 self.values.push(Arc::new(new_member));
150 added_count += 1;
151
152 if self.values.len() > self.range.end {
154 self.fast_filter();
155 filter_count += 1;
156 }
157
158 self.scratch.scores.clear();
160 }
161
162 FrontAddResult {
163 added_count,
164 removed_count,
165 comparisons,
166 filter_count,
167 size: self.values.len(),
168 }
169 }
170
171 #[inline]
174 pub fn remove_outliers(&mut self, trim: f32) -> Option<usize> {
175 if self.values.len() < 4 {
176 return None;
177 }
178
179 let trim = trim.clamp(0.0, 0.5);
180 if trim == 0.0 {
181 return None;
182 }
183
184 self.ensure_score_matrix()?;
185 let (n, _m) = self.score_dims()?;
186
187 self.crowding_distance_in_place(n);
188
189 let drop = ((n as f32) * trim).floor() as usize;
190 if drop == 0 {
191 return None;
192 }
193
194 self.scratch.order.clear();
196 self.scratch.order.extend(0..n);
197
198 let dist = &self.scratch.dist;
199 self.scratch.order.sort_unstable_by(|&i, &j| {
200 let a = dist[i];
201 let b = dist[j];
202
203 match (a.is_infinite(), b.is_infinite()) {
204 (true, true) => Ordering::Equal,
205 (true, false) => Ordering::Less,
206 (false, true) => Ordering::Greater,
207 _ => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
208 }
209 });
210
211 self.scratch.remove.clear();
212 self.scratch
213 .remove
214 .extend(self.scratch.order.iter().take(drop).copied());
215
216 self.scratch.remove.sort_unstable();
217 self.scratch.remove.dedup();
218 let removed = self.scratch.remove.len();
219 for &idx in self.scratch.remove.iter().rev() {
220 self.values.swap_remove(idx);
221 }
222
223 self.scratch.scores.clear();
224 Some(removed)
225 }
226
227 #[inline]
228 fn dom_cmp(&self, one: &T, two: &T) -> Ordering {
229 let one_score = one.score();
230 let two_score = two.score();
231
232 if one_score.is_none() || two_score.is_none() {
233 return Ordering::Equal;
234 }
235
236 let (a, b) = (one_score.unwrap(), two_score.unwrap());
237
238 if pareto::dominance(a, b, &self.objective) {
239 Ordering::Greater
240 } else if pareto::dominance(b, a, &self.objective) {
241 Ordering::Less
242 } else {
243 Ordering::Equal
244 }
245 }
246
247 pub fn fronts(&mut self) -> Vec<Front<T>>
248 where
249 T: Clone + Eq + Send + Sync + 'static,
250 {
251 let mut fronts: Vec<Front<T>> = Vec::new();
252 for member in self.values.iter() {
253 let mut updated = false;
254
255 for front in fronts.iter_mut() {
256 let to_insert = (*(*member)).clone();
257 let result = front.add_all(vec![to_insert]);
258
259 if result.added_count > 0 {
260 updated = true;
261 break;
262 }
263 }
264
265 if !updated {
266 let mut new_front = Front::new(self.range.clone(), self.objective.clone());
267 let to_insert = (*(*member)).clone();
268 new_front.add_all(vec![to_insert]);
269 fronts.push(new_front);
270 }
271 }
272
273 fronts
274 }
275
276 fn fast_filter(&mut self) {
277 let keep = self.range.start.min(self.values.len());
278 if keep == 0 || self.values.len() <= keep {
279 return;
280 }
281
282 if self.ensure_score_matrix().is_none() {
284 return;
285 }
286
287 let (n, _m) = match self.score_dims() {
288 Some(x) => x,
289 None => return,
290 };
291
292 self.crowding_distance_in_place(n);
293
294 self.scratch.keep_idx.clear();
296 self.scratch.keep_idx.extend(0..n);
297
298 let dist = &self.scratch.dist;
299
300 self.scratch
302 .keep_idx
303 .select_nth_unstable_by(keep, |&i, &j| {
304 dist[j].partial_cmp(&dist[i]).unwrap_or(Ordering::Equal)
305 });
306
307 self.scratch.keep_idx.truncate(keep);
308
309 let mut new_values = Vec::with_capacity(keep);
310 for &i in self.scratch.keep_idx.iter() {
311 new_values.push(Arc::clone(&self.values[i]));
312 }
313
314 self.values = new_values;
315 self.scratch.scores.clear();
316 }
317
318 #[inline]
319 fn score_dims(&self) -> Option<(usize, usize)> {
320 let n = self.values.len();
321
322 if n == 0 {
323 return None;
324 }
325
326 let first = self.values.iter().find_map(|v| v.score())?;
327 Some((n, first.len()))
328 }
329
330 fn ensure_score_matrix(&mut self) -> Option<()> {
331 let (n, m) = self.score_dims()?;
332
333 if m == 0 {
334 return None;
335 }
336
337 if self.scratch.scores.len() == n * m {
339 return Some(());
340 }
341
342 self.scratch.scores.resize(n * m, 0.0);
343 for (i, v) in self.values.iter().enumerate() {
344 let s = v.score()?;
345 if s.len() != m {
346 return None;
347 }
348
349 let row = &mut self.scratch.scores[i * m..i * m + m];
350 row.copy_from_slice(s.as_slice());
351 }
352
353 Some(())
354 }
355
356 fn crowding_distance_in_place(&mut self, n: usize) {
357 let (_, m) = match self.score_dims() {
358 Some(x) => x,
359 None => return,
360 };
361
362 if n == 0 || m == 0 {
363 return;
364 }
365
366 self.scratch.dist.clear();
367 self.scratch.dist.resize(n, 0.0);
368
369 self.scratch.order.clear();
370 self.scratch.order.extend(0..n);
371
372 for dim in 0..m {
373 let scores = &self.scratch.scores;
374 self.scratch.order.sort_unstable_by(|&i, &j| {
375 let a = scores[i * m + dim];
376 let b = scores[j * m + dim];
377 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
378 });
379
380 let first_idx = self.scratch.order[0];
381 let last_idx = self.scratch.order[n - 1];
382 let min = self.scratch.scores[first_idx * m..first_idx * m + m][dim];
383 let max = self.scratch.scores[last_idx * m..last_idx * m + m][dim];
384 let range = max - min;
385
386 if !range.is_finite() || range == 0.0 {
387 continue;
388 }
389
390 self.scratch.dist[self.scratch.order[0]] = f32::INFINITY;
391 self.scratch.dist[self.scratch.order[n - 1]] = f32::INFINITY;
392
393 for k in 1..(n - 1) {
394 let prev_idx = self.scratch.order[k - 1];
395 let next_idx = self.scratch.order[k + 1];
396 let prev = self.scratch.scores[prev_idx * m..prev_idx * m + m][dim];
397 let next = self.scratch.scores[next_idx * m..next_idx * m + m][dim];
398
399 let contrib = (next - prev).abs() / range;
400 self.scratch.dist[self.scratch.order[k]] += contrib;
401 }
402 }
403 }
404}
405
406impl<T> Default for Front<T>
407where
408 T: Scored,
409{
410 fn default() -> Self {
411 Front::new(0..0, Objective::default())
412 }
413}
414
415fn entropy_flat(scores: &[f32], n: usize, m: usize, bins_per_dim: usize) -> f32 {
431 if n == 0 || m == 0 || bins_per_dim == 0 {
432 return 0.0;
433 }
434
435 let mut mins = vec![f32::INFINITY; m];
437 let mut maxs = vec![f32::NEG_INFINITY; m];
438
439 for i in 0..n {
440 let row = &scores[i * m..i * m + m];
441 for d in 0..m {
442 let x = row[d];
443 if x < mins[d] {
444 mins[d] = x;
445 }
446 if x > maxs[d] {
447 maxs[d] = x;
448 }
449 }
450 }
451
452 for d in 0..m {
453 if (maxs[d] - mins[d]).abs() < EPSILON {
454 maxs[d] = mins[d] + 1.0;
455 }
456 }
457
458 let mut cell_counts: HashMap<Vec<u8>, usize> = HashMap::new();
459
460 for i in 0..n {
461 let row = &scores[i * m..i * m + m];
462 let mut cell = Vec::with_capacity(m);
463
464 for d in 0..m {
465 let norm = (row[d] - mins[d]) / (maxs[d] - mins[d]); let mut idx = (norm * bins_per_dim as f32).floor() as i32;
467 if idx < 0 {
468 idx = 0;
469 }
470 if idx >= bins_per_dim as i32 {
471 idx = bins_per_dim as i32 - 1;
472 }
473 cell.push(idx as u8);
474 }
475
476 *cell_counts.entry(cell).or_insert(0) += 1;
477 }
478
479 let n_f = n as f32;
480 let mut h = 0.0_f32;
481 for &count in cell_counts.values() {
482 let p = count as f32 / n_f;
483 if p > 0.0 {
484 h -= p * p.ln();
485 }
486 }
487
488 let k = cell_counts.len().min(n);
489 if k > 1 { h / (k as f32).ln() } else { 0.0 }
490}