Skip to main content

radiate_core/objectives/
front.rs

1use crate::objectives::{Objective, Scored, pareto};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::{cmp::Ordering, ops::Range, sync::Arc};
6
7const DEFAULT_ENTROPY_BINS: usize = 20;
8const EPSILON: f32 = 1e-10;
9
10#[derive(Clone, Default)]
11struct FrontScratch {
12    remove: Vec<usize>,
13    keep_idx: Vec<usize>,
14    scores: Vec<f32>,
15    dist: Vec<f32>,
16    order: Vec<usize>,
17}
18
19#[derive(Debug)]
20pub struct FrontAddResult {
21    pub added_count: usize,
22    pub removed_count: usize,
23    pub comparisons: usize,
24    pub filter_count: usize,
25    pub size: usize,
26}
27
28#[derive(Clone)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30pub struct Front<T>
31where
32    T: Scored,
33{
34    values: Vec<Arc<T>>,
35    range: Range<usize>,
36    objective: Objective,
37
38    #[cfg_attr(feature = "serde", serde(skip))]
39    scratch: FrontScratch,
40}
41
42impl<T> Front<T>
43where
44    T: Scored,
45{
46    pub fn new(range: Range<usize>, objective: Objective) -> Self {
47        Front {
48            values: Vec::new(),
49            range,
50            objective,
51            scratch: FrontScratch::default(),
52        }
53    }
54
55    pub fn len(&self) -> usize {
56        self.values.len()
57    }
58
59    pub fn range(&self) -> Range<usize> {
60        self.range.clone()
61    }
62
63    pub fn objective(&self) -> Objective {
64        self.objective.clone()
65    }
66
67    pub fn is_empty(&self) -> bool {
68        self.values.is_empty()
69    }
70
71    pub fn values(&self) -> &[Arc<T>] {
72        &self.values
73    }
74
75    pub fn crowding_distance(&mut self) -> Option<&[f32]> {
76        self.ensure_score_matrix()?;
77        let (n, _) = self.score_dims()?;
78        self.crowding_distance_in_place(n);
79
80        Some(&self.scratch.dist[..n])
81    }
82
83    pub fn entropy(&mut self) -> Option<f32> {
84        self.ensure_score_matrix()?;
85        let (n, m) = self.score_dims()?;
86
87        Some(entropy_flat(
88            &self.scratch.scores,
89            n,
90            m,
91            DEFAULT_ENTROPY_BINS,
92        ))
93    }
94
95    pub fn add_all(&mut self, items: Vec<T>) -> FrontAddResult
96    where
97        T: Eq + Clone + Send + Sync + 'static,
98    {
99        let mut added_count = 0;
100        let mut removed_count = 0;
101        let mut comparisons = 0;
102        let mut filter_count = 0;
103
104        for new_member in items.into_iter() {
105            self.scratch.remove.clear();
106
107            // Decide accept/reject without mutating self.values
108            let mut accept = true;
109
110            for (idx, existing) in self.values.iter().enumerate() {
111                if existing.as_ref() == &new_member {
112                    accept = false;
113                    break;
114                }
115
116                // dominance checks
117                match self.dom_cmp(existing.as_ref(), &new_member) {
118                    Ordering::Greater => {
119                        // existing dominates new -> reject
120                        accept = false;
121                        comparisons += 1;
122                        break;
123                    }
124                    Ordering::Less => {
125                        // new dominates existing -> mark for removal
126                        self.scratch.remove.push(idx);
127                        comparisons += 1;
128                    }
129                    Ordering::Equal => comparisons += 1,
130                }
131            }
132
133            if !accept {
134                continue;
135            }
136
137            // Remove dominated existing values efficiently (swap_remove).
138            // Need stable removal: remove in descending index order.
139            if !self.scratch.remove.is_empty() {
140                self.scratch.remove.sort_unstable();
141                self.scratch.remove.dedup();
142
143                for &idx in self.scratch.remove.iter().rev() {
144                    self.values.swap_remove(idx);
145                    removed_count += 1;
146                }
147            }
148
149            self.values.push(Arc::new(new_member));
150            added_count += 1;
151
152            // Filter if we exceed max
153            if self.values.len() > self.range.end {
154                self.fast_filter();
155                filter_count += 1;
156            }
157
158            // Invalidate the cached score matrix
159            self.scratch.scores.clear();
160        }
161
162        FrontAddResult {
163            added_count,
164            removed_count,
165            comparisons,
166            filter_count,
167            size: self.values.len(),
168        }
169    }
170
171    /// Remove points with crowding distance in the top `trim` fraction.
172    /// Example: trim=0.02 removes the top 2% most isolated points.
173    #[inline]
174    pub fn remove_outliers(&mut self, trim: f32) -> Option<usize> {
175        if self.values.len() < 4 {
176            return None;
177        }
178
179        let trim = trim.clamp(0.0, 0.5);
180        if trim == 0.0 {
181            return None;
182        }
183
184        self.ensure_score_matrix()?;
185        let (n, _m) = self.score_dims()?;
186
187        self.crowding_distance_in_place(n);
188
189        let drop = ((n as f32) * trim).floor() as usize;
190        if drop == 0 {
191            return None;
192        }
193
194        // We want to drop the *largest* distances (most isolated).
195        self.scratch.order.clear();
196        self.scratch.order.extend(0..n);
197
198        let dist = &self.scratch.dist;
199        self.scratch.order.sort_unstable_by(|&i, &j| {
200            let a = dist[i];
201            let b = dist[j];
202
203            match (a.is_infinite(), b.is_infinite()) {
204                (true, true) => Ordering::Equal,
205                (true, false) => Ordering::Less,
206                (false, true) => Ordering::Greater,
207                _ => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
208            }
209        });
210
211        self.scratch.remove.clear();
212        self.scratch
213            .remove
214            .extend(self.scratch.order.iter().take(drop).copied());
215
216        self.scratch.remove.sort_unstable();
217        self.scratch.remove.dedup();
218        let removed = self.scratch.remove.len();
219        for &idx in self.scratch.remove.iter().rev() {
220            self.values.swap_remove(idx);
221        }
222
223        self.scratch.scores.clear();
224        Some(removed)
225    }
226
227    #[inline]
228    fn dom_cmp(&self, one: &T, two: &T) -> Ordering {
229        let one_score = one.score();
230        let two_score = two.score();
231
232        if one_score.is_none() || two_score.is_none() {
233            return Ordering::Equal;
234        }
235
236        let (a, b) = (one_score.unwrap(), two_score.unwrap());
237
238        if pareto::dominance(a, b, &self.objective) {
239            Ordering::Greater
240        } else if pareto::dominance(b, a, &self.objective) {
241            Ordering::Less
242        } else {
243            Ordering::Equal
244        }
245    }
246
247    pub fn fronts(&mut self) -> Vec<Front<T>>
248    where
249        T: Clone + Eq + Send + Sync + 'static,
250    {
251        let mut fronts: Vec<Front<T>> = Vec::new();
252        for member in self.values.iter() {
253            let mut updated = false;
254
255            for front in fronts.iter_mut() {
256                let to_insert = (*(*member)).clone();
257                let result = front.add_all(vec![to_insert]);
258
259                if result.added_count > 0 {
260                    updated = true;
261                    break;
262                }
263            }
264
265            if !updated {
266                let mut new_front = Front::new(self.range.clone(), self.objective.clone());
267                let to_insert = (*(*member)).clone();
268                new_front.add_all(vec![to_insert]);
269                fronts.push(new_front);
270            }
271        }
272
273        fronts
274    }
275
276    fn fast_filter(&mut self) {
277        let keep = self.range.start.min(self.values.len());
278        if keep == 0 || self.values.len() <= keep {
279            return;
280        }
281
282        // Build score matrix + crowding distances into scratch
283        if self.ensure_score_matrix().is_none() {
284            return;
285        }
286
287        let (n, _m) = match self.score_dims() {
288            Some(x) => x,
289            None => return,
290        };
291
292        self.crowding_distance_in_place(n);
293
294        // Pick top `keep` by crowding distance without sorting all n.
295        self.scratch.keep_idx.clear();
296        self.scratch.keep_idx.extend(0..n);
297
298        let dist = &self.scratch.dist;
299
300        // Partition so that [0..keep] are the best (in any order)
301        self.scratch
302            .keep_idx
303            .select_nth_unstable_by(keep, |&i, &j| {
304                dist[j].partial_cmp(&dist[i]).unwrap_or(Ordering::Equal)
305            });
306
307        self.scratch.keep_idx.truncate(keep);
308
309        let mut new_values = Vec::with_capacity(keep);
310        for &i in self.scratch.keep_idx.iter() {
311            new_values.push(Arc::clone(&self.values[i]));
312        }
313
314        self.values = new_values;
315        self.scratch.scores.clear();
316    }
317
318    #[inline]
319    fn score_dims(&self) -> Option<(usize, usize)> {
320        let n = self.values.len();
321
322        if n == 0 {
323            return None;
324        }
325
326        let first = self.values.iter().find_map(|v| v.score())?;
327        Some((n, first.len()))
328    }
329
330    fn ensure_score_matrix(&mut self) -> Option<()> {
331        let (n, m) = self.score_dims()?;
332
333        if m == 0 {
334            return None;
335        }
336
337        // If already built and size matches, keep it.
338        if self.scratch.scores.len() == n * m {
339            return Some(());
340        }
341
342        self.scratch.scores.resize(n * m, 0.0);
343        for (i, v) in self.values.iter().enumerate() {
344            let s = v.score()?;
345            if s.len() != m {
346                return None;
347            }
348
349            let row = &mut self.scratch.scores[i * m..i * m + m];
350            row.copy_from_slice(s.as_slice());
351        }
352
353        Some(())
354    }
355
356    fn crowding_distance_in_place(&mut self, n: usize) {
357        let (_, m) = match self.score_dims() {
358            Some(x) => x,
359            None => return,
360        };
361
362        if n == 0 || m == 0 {
363            return;
364        }
365
366        self.scratch.dist.clear();
367        self.scratch.dist.resize(n, 0.0);
368
369        self.scratch.order.clear();
370        self.scratch.order.extend(0..n);
371
372        for dim in 0..m {
373            let scores = &self.scratch.scores;
374            self.scratch.order.sort_unstable_by(|&i, &j| {
375                let a = scores[i * m + dim];
376                let b = scores[j * m + dim];
377                a.partial_cmp(&b).unwrap_or(Ordering::Equal)
378            });
379
380            let first_idx = self.scratch.order[0];
381            let last_idx = self.scratch.order[n - 1];
382            let min = self.scratch.scores[first_idx * m..first_idx * m + m][dim];
383            let max = self.scratch.scores[last_idx * m..last_idx * m + m][dim];
384            let range = max - min;
385
386            if !range.is_finite() || range == 0.0 {
387                continue;
388            }
389
390            self.scratch.dist[self.scratch.order[0]] = f32::INFINITY;
391            self.scratch.dist[self.scratch.order[n - 1]] = f32::INFINITY;
392
393            for k in 1..(n - 1) {
394                let prev_idx = self.scratch.order[k - 1];
395                let next_idx = self.scratch.order[k + 1];
396                let prev = self.scratch.scores[prev_idx * m..prev_idx * m + m][dim];
397                let next = self.scratch.scores[next_idx * m..next_idx * m + m][dim];
398
399                let contrib = (next - prev).abs() / range;
400                self.scratch.dist[self.scratch.order[k]] += contrib;
401            }
402        }
403    }
404}
405
406impl<T> Default for Front<T>
407where
408    T: Scored,
409{
410    fn default() -> Self {
411        Front::new(0..0, Objective::default())
412    }
413}
414
415/// Calculate the Shannon entropy of a set of scores in multidimensional space.
416/// The scores are discretized into a grid of bins, and the entropy is computed
417/// based on the distribution of scores across these bins. Higher entropy indicates
418/// a more diverse set of scores. This can be interpreted as a measure of how well
419/// the solutions are spread out in the objective space.
420///
421/// It works by:
422/// 1. Determining the min and max values for each objective dimension.
423/// 2. Mapping each score to a discrete bin index based on its normalized position
424///    within the min-max range for each dimension.
425/// 3. Counting the number of scores in each bin (cell).
426/// 4. Calculating the probabilities of each occupied bin and computing the
427///    Shannon entropy using these probabilities.
428/// 5. Optionally normalizing the entropy by the maximum possible entropy given
429///    the number of occupied bins and total scores.
430fn entropy_flat(scores: &[f32], n: usize, m: usize, bins_per_dim: usize) -> f32 {
431    if n == 0 || m == 0 || bins_per_dim == 0 {
432        return 0.0;
433    }
434
435    // mins/maxs per dim
436    let mut mins = vec![f32::INFINITY; m];
437    let mut maxs = vec![f32::NEG_INFINITY; m];
438
439    for i in 0..n {
440        let row = &scores[i * m..i * m + m];
441        for d in 0..m {
442            let x = row[d];
443            if x < mins[d] {
444                mins[d] = x;
445            }
446            if x > maxs[d] {
447                maxs[d] = x;
448            }
449        }
450    }
451
452    for d in 0..m {
453        if (maxs[d] - mins[d]).abs() < EPSILON {
454            maxs[d] = mins[d] + 1.0;
455        }
456    }
457
458    let mut cell_counts: HashMap<Vec<u8>, usize> = HashMap::new();
459
460    for i in 0..n {
461        let row = &scores[i * m..i * m + m];
462        let mut cell = Vec::with_capacity(m);
463
464        for d in 0..m {
465            let norm = (row[d] - mins[d]) / (maxs[d] - mins[d]); // [0,1]
466            let mut idx = (norm * bins_per_dim as f32).floor() as i32;
467            if idx < 0 {
468                idx = 0;
469            }
470            if idx >= bins_per_dim as i32 {
471                idx = bins_per_dim as i32 - 1;
472            }
473            cell.push(idx as u8);
474        }
475
476        *cell_counts.entry(cell).or_insert(0) += 1;
477    }
478
479    let n_f = n as f32;
480    let mut h = 0.0_f32;
481    for &count in cell_counts.values() {
482        let p = count as f32 / n_f;
483        if p > 0.0 {
484            h -= p * p.ln();
485        }
486    }
487
488    let k = cell_counts.len().min(n);
489    if k > 1 { h / (k as f32).ln() } else { 0.0 }
490}