1use crate::objectives::{Objective, Scored, pareto};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::{cmp::Ordering, ops::Range, sync::Arc};
6
7const DEFAULT_ENTROPY_BINS: usize = 20;
8const EPSILON: f32 = 1e-10;
9
10#[derive(Clone, Default)]
11struct FrontScratch {
12 remove: Vec<usize>,
13 keep_idx: Vec<usize>,
14 scores: Vec<f32>,
15 dist: Vec<f32>,
16 order: Vec<usize>,
17}
18
19#[derive(Debug)]
20pub struct FrontAddResult {
21 pub added_count: usize,
22 pub removed_count: usize,
23 pub comparisons: usize,
24 pub filter_count: usize,
25 pub size: usize,
26}
27
28#[derive(Clone)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30pub struct Front<T>
31where
32 T: Scored,
33{
34 values: Vec<Arc<T>>,
35 range: Range<usize>,
36 objective: Objective,
37
38 #[cfg_attr(feature = "serde", serde(skip))]
39 scratch: FrontScratch,
40}
41
42impl<T> Front<T>
43where
44 T: Scored,
45{
46 pub fn new(range: Range<usize>, objective: Objective) -> Self {
47 Front {
48 values: Vec::new(),
49 range,
50 objective,
51 scratch: FrontScratch::default(),
52 }
53 }
54
55 pub fn len(&self) -> usize {
56 self.values.len()
57 }
58
59 pub fn range(&self) -> Range<usize> {
60 self.range.clone()
61 }
62
63 pub fn objective(&self) -> Objective {
64 self.objective.clone()
65 }
66
67 pub fn is_empty(&self) -> bool {
68 self.values.is_empty()
69 }
70
71 pub fn values(&self) -> &[Arc<T>] {
72 &self.values
73 }
74
75 pub fn crowding_distance(&mut self) -> Option<&[f32]> {
76 self.ensure_score_matrix()?;
77 let (n, _) = self.score_dims()?;
78 self.crowding_distance_in_place(n);
79
80 Some(&self.scratch.dist[..n])
81 }
82
83 pub fn entropy(&mut self) -> Option<f32> {
84 self.ensure_score_matrix()?;
85 let (n, m) = self.score_dims()?;
86
87 Some(entropy_flat(
88 &self.scratch.scores,
89 n,
90 m,
91 DEFAULT_ENTROPY_BINS,
92 ))
93 }
94
95 pub fn add_all(&mut self, items: Vec<T>) -> FrontAddResult
96 where
97 T: Eq + Clone + Send + Sync + 'static,
98 {
99 let mut added_count = 0;
100 let mut removed_count = 0;
101 let mut comparisons = 0;
102 let mut filter_count = 0;
103
104 for new_member in items.into_iter() {
105 self.scratch.remove.clear();
106
107 let mut accept = true;
109
110 for (idx, existing) in self.values.iter().enumerate() {
111 if existing.as_ref() == &new_member {
112 accept = false;
113 break;
114 }
115
116 match self.dom_cmp(existing.as_ref(), &new_member) {
118 Ordering::Greater => {
119 accept = false;
121 comparisons += 1;
122 break;
123 }
124 Ordering::Less => {
125 self.scratch.remove.push(idx);
127 comparisons += 1;
128 }
129 Ordering::Equal => comparisons += 1,
130 }
131 }
132
133 if !accept {
134 continue;
135 }
136
137 if !self.scratch.remove.is_empty() {
140 self.scratch.remove.sort_unstable();
141 self.scratch.remove.dedup();
142
143 for &idx in self.scratch.remove.iter().rev() {
144 self.values.swap_remove(idx);
145 removed_count += 1;
146 }
147 }
148
149 self.values.push(Arc::new(new_member));
150 added_count += 1;
151
152 if self.values.len() > self.range.end {
154 self.fast_filter();
155 filter_count += 1;
156 }
157
158 self.scratch.scores.clear();
160 }
161
162 FrontAddResult {
163 added_count,
164 removed_count,
165 comparisons,
166 filter_count,
167 size: self.values.len(),
168 }
169 }
170
171 #[inline]
174 pub fn remove_outliers(&mut self, trim: f32) -> Option<usize> {
175 if self.values.len() < 4 {
176 return None;
177 }
178
179 let trim = trim.clamp(0.0, 0.5);
180 if trim == 0.0 {
181 return None;
182 }
183
184 if self.ensure_score_matrix().is_none() {
185 return None;
186 }
187
188 let (n, _m) = match self.score_dims() {
189 Some(x) => x,
190 None => return None,
191 };
192
193 self.crowding_distance_in_place(n);
194
195 let drop = ((n as f32) * trim).floor() as usize;
196 if drop == 0 {
197 return None;
198 }
199
200 self.scratch.order.clear();
202 self.scratch.order.extend(0..n);
203
204 let dist = &self.scratch.dist;
205 self.scratch.order.sort_unstable_by(|&i, &j| {
206 let a = dist[i];
207 let b = dist[j];
208
209 match (a.is_infinite(), b.is_infinite()) {
210 (true, true) => Ordering::Equal,
211 (true, false) => Ordering::Less,
212 (false, true) => Ordering::Greater,
213 _ => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
214 }
215 });
216
217 self.scratch.remove.clear();
218 self.scratch
219 .remove
220 .extend(self.scratch.order.iter().take(drop).copied());
221
222 self.scratch.remove.sort_unstable();
223 self.scratch.remove.dedup();
224 let removed = self.scratch.remove.len();
225 for &idx in self.scratch.remove.iter().rev() {
226 self.values.swap_remove(idx);
227 }
228
229 self.scratch.scores.clear();
230 Some(removed)
231 }
232
233 #[inline]
234 fn dom_cmp(&self, one: &T, two: &T) -> Ordering {
235 let one_score = one.score();
236 let two_score = two.score();
237
238 if one_score.is_none() || two_score.is_none() {
239 return Ordering::Equal;
240 }
241
242 let (a, b) = (one_score.unwrap(), two_score.unwrap());
243
244 if pareto::dominance(a, b, &self.objective) {
245 Ordering::Greater
246 } else if pareto::dominance(b, a, &self.objective) {
247 Ordering::Less
248 } else {
249 Ordering::Equal
250 }
251 }
252
253 pub fn fronts(&mut self) -> Vec<Front<T>>
254 where
255 T: Clone + Eq + Send + Sync + 'static,
256 {
257 let mut fronts: Vec<Front<T>> = Vec::new();
258 for member in self.values.iter() {
259 let mut updated = false;
260
261 for front in fronts.iter_mut() {
262 let to_insert = (*(*member)).clone();
263 let result = front.add_all(vec![to_insert]);
264
265 if result.added_count > 0 {
266 updated = true;
267 break;
268 }
269 }
270
271 if !updated {
272 let mut new_front = Front::new(self.range.clone(), self.objective.clone());
273 let to_insert = (*(*member)).clone();
274 new_front.add_all(vec![to_insert]);
275 fronts.push(new_front);
276 }
277 }
278
279 fronts
280 }
281
282 fn fast_filter(&mut self) {
283 let keep = self.range.start.min(self.values.len());
284 if keep == 0 || self.values.len() <= keep {
285 return;
286 }
287
288 if self.ensure_score_matrix().is_none() {
290 return;
291 }
292
293 let (n, _m) = match self.score_dims() {
294 Some(x) => x,
295 None => return,
296 };
297
298 self.crowding_distance_in_place(n);
299
300 self.scratch.keep_idx.clear();
302 self.scratch.keep_idx.extend(0..n);
303
304 let dist = &self.scratch.dist;
305
306 self.scratch
308 .keep_idx
309 .select_nth_unstable_by(keep, |&i, &j| {
310 dist[j].partial_cmp(&dist[i]).unwrap_or(Ordering::Equal)
311 });
312
313 self.scratch.keep_idx.truncate(keep);
314
315 let mut new_values = Vec::with_capacity(keep);
316 for &i in self.scratch.keep_idx.iter() {
317 new_values.push(Arc::clone(&self.values[i]));
318 }
319
320 self.values = new_values;
321 self.scratch.scores.clear();
322 }
323
324 #[inline]
325 fn score_dims(&self) -> Option<(usize, usize)> {
326 let n = self.values.len();
327
328 if n == 0 {
329 return None;
330 }
331
332 let first = self.values.iter().find_map(|v| v.score())?;
333 Some((n, first.len()))
334 }
335
336 fn ensure_score_matrix(&mut self) -> Option<()> {
337 let (n, m) = self.score_dims()?;
338
339 if m == 0 {
340 return None;
341 }
342
343 if self.scratch.scores.len() == n * m {
345 return Some(());
346 }
347
348 self.scratch.scores.resize(n * m, 0.0);
349 for (i, v) in self.values.iter().enumerate() {
350 let s = v.score()?;
351 if s.len() != m {
352 return None;
353 }
354
355 let row = &mut self.scratch.scores[i * m..i * m + m];
356 row.copy_from_slice(s.as_slice());
357 }
358
359 Some(())
360 }
361
362 fn crowding_distance_in_place(&mut self, n: usize) {
363 let (_, m) = match self.score_dims() {
364 Some(x) => x,
365 None => return,
366 };
367
368 if n == 0 || m == 0 {
369 return;
370 }
371
372 self.scratch.dist.clear();
373 self.scratch.dist.resize(n, 0.0);
374
375 self.scratch.order.clear();
376 self.scratch.order.extend(0..n);
377
378 for dim in 0..m {
379 let scores = &self.scratch.scores;
380 self.scratch.order.sort_unstable_by(|&i, &j| {
381 let a = scores[i * m + dim];
382 let b = scores[j * m + dim];
383 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
384 });
385
386 let first_idx = self.scratch.order[0];
387 let last_idx = self.scratch.order[n - 1];
388 let min = self.scratch.scores[first_idx * m..first_idx * m + m][dim];
389 let max = self.scratch.scores[last_idx * m..last_idx * m + m][dim];
390 let range = max - min;
391
392 if !range.is_finite() || range == 0.0 {
393 continue;
394 }
395
396 self.scratch.dist[self.scratch.order[0]] = f32::INFINITY;
397 self.scratch.dist[self.scratch.order[n - 1]] = f32::INFINITY;
398
399 for k in 1..(n - 1) {
400 let prev_idx = self.scratch.order[k - 1];
401 let next_idx = self.scratch.order[k + 1];
402 let prev = self.scratch.scores[prev_idx * m..prev_idx * m + m][dim];
403 let next = self.scratch.scores[next_idx * m..next_idx * m + m][dim];
404
405 let contrib = (next - prev).abs() / range;
406 self.scratch.dist[self.scratch.order[k]] += contrib;
407 }
408 }
409 }
410}
411
412impl<T> Default for Front<T>
413where
414 T: Scored,
415{
416 fn default() -> Self {
417 Front::new(0..0, Objective::default())
418 }
419}
420
421fn entropy_flat(scores: &[f32], n: usize, m: usize, bins_per_dim: usize) -> f32 {
437 if n == 0 || m == 0 || bins_per_dim == 0 {
438 return 0.0;
439 }
440
441 let mut mins = vec![f32::INFINITY; m];
443 let mut maxs = vec![f32::NEG_INFINITY; m];
444
445 for i in 0..n {
446 let row = &scores[i * m..i * m + m];
447 for d in 0..m {
448 let x = row[d];
449 if x < mins[d] {
450 mins[d] = x;
451 }
452 if x > maxs[d] {
453 maxs[d] = x;
454 }
455 }
456 }
457
458 for d in 0..m {
459 if (maxs[d] - mins[d]).abs() < EPSILON {
460 maxs[d] = mins[d] + 1.0;
461 }
462 }
463
464 let mut cell_counts: HashMap<Vec<u8>, usize> = HashMap::new();
465
466 for i in 0..n {
467 let row = &scores[i * m..i * m + m];
468 let mut cell = Vec::with_capacity(m);
469
470 for d in 0..m {
471 let norm = (row[d] - mins[d]) / (maxs[d] - mins[d]); let mut idx = (norm * bins_per_dim as f32).floor() as i32;
473 if idx < 0 {
474 idx = 0;
475 }
476 if idx >= bins_per_dim as i32 {
477 idx = bins_per_dim as i32 - 1;
478 }
479 cell.push(idx as u8);
480 }
481
482 *cell_counts.entry(cell).or_insert(0) += 1;
483 }
484
485 let n_f = n as f32;
486 let mut h = 0.0_f32;
487 for &count in cell_counts.values() {
488 let p = count as f32 / n_f;
489 if p > 0.0 {
490 h -= p * p.ln();
491 }
492 }
493
494 let k = cell_counts.len().min(n);
495 if k > 1 { h / (k as f32).ln() } else { 0.0 }
496}