Skip to main content

radiate_core/objectives/
front.rs

1use crate::objectives::{Objective, Scored, pareto};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::{cmp::Ordering, ops::Range, sync::Arc};
6
7const DEFAULT_ENTROPY_BINS: usize = 20;
8const EPSILON: f32 = 1e-10;
9
10#[derive(Clone, Default)]
11struct FrontScratch {
12    remove: Vec<usize>,
13    keep_idx: Vec<usize>,
14    scores: Vec<f32>,
15    dist: Vec<f32>,
16    order: Vec<usize>,
17}
18
19#[derive(Debug)]
20pub struct FrontAddResult {
21    pub added_count: usize,
22    pub removed_count: usize,
23    pub comparisons: usize,
24    pub filter_count: usize,
25    pub size: usize,
26}
27
28#[derive(Clone)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30pub struct Front<T>
31where
32    T: Scored,
33{
34    values: Vec<Arc<T>>,
35    range: Range<usize>,
36    objective: Objective,
37
38    #[cfg_attr(feature = "serde", serde(skip))]
39    scratch: FrontScratch,
40}
41
42impl<T> Front<T>
43where
44    T: Scored,
45{
46    pub fn new(range: Range<usize>, objective: Objective) -> Self {
47        Front {
48            values: Vec::new(),
49            range,
50            objective,
51            scratch: FrontScratch::default(),
52        }
53    }
54
55    pub fn len(&self) -> usize {
56        self.values.len()
57    }
58
59    pub fn range(&self) -> Range<usize> {
60        self.range.clone()
61    }
62
63    pub fn objective(&self) -> Objective {
64        self.objective.clone()
65    }
66
67    pub fn is_empty(&self) -> bool {
68        self.values.is_empty()
69    }
70
71    pub fn values(&self) -> &[Arc<T>] {
72        &self.values
73    }
74
75    pub fn crowding_distance(&mut self) -> Option<&[f32]> {
76        self.ensure_score_matrix()?;
77        let (n, _) = self.score_dims()?;
78        self.crowding_distance_in_place(n);
79
80        Some(&self.scratch.dist[..n])
81    }
82
83    pub fn entropy(&mut self) -> Option<f32> {
84        self.ensure_score_matrix()?;
85        let (n, m) = self.score_dims()?;
86
87        Some(entropy_flat(
88            &self.scratch.scores,
89            n,
90            m,
91            DEFAULT_ENTROPY_BINS,
92        ))
93    }
94
95    pub fn add_all(&mut self, items: Vec<T>) -> FrontAddResult
96    where
97        T: Eq + Clone + Send + Sync + 'static,
98    {
99        let mut added_count = 0;
100        let mut removed_count = 0;
101        let mut comparisons = 0;
102        let mut filter_count = 0;
103
104        for new_member in items.into_iter() {
105            self.scratch.remove.clear();
106
107            // Decide accept/reject without mutating self.values
108            let mut accept = true;
109
110            for (idx, existing) in self.values.iter().enumerate() {
111                if existing.as_ref() == &new_member {
112                    accept = false;
113                    break;
114                }
115
116                // dominance checks
117                match self.dom_cmp(existing.as_ref(), &new_member) {
118                    Ordering::Greater => {
119                        // existing dominates new -> reject
120                        accept = false;
121                        comparisons += 1;
122                        break;
123                    }
124                    Ordering::Less => {
125                        // new dominates existing -> mark for removal
126                        self.scratch.remove.push(idx);
127                        comparisons += 1;
128                    }
129                    Ordering::Equal => comparisons += 1,
130                }
131            }
132
133            if !accept {
134                continue;
135            }
136
137            // Remove dominated existing values efficiently (swap_remove).
138            // Need stable removal: remove in descending index order.
139            if !self.scratch.remove.is_empty() {
140                self.scratch.remove.sort_unstable();
141                self.scratch.remove.dedup();
142
143                for &idx in self.scratch.remove.iter().rev() {
144                    self.values.swap_remove(idx);
145                    removed_count += 1;
146                }
147            }
148
149            self.values.push(Arc::new(new_member));
150            added_count += 1;
151
152            // Filter if we exceed max
153            if self.values.len() > self.range.end {
154                self.fast_filter();
155                filter_count += 1;
156            }
157
158            // Invalidate the cached score matrix
159            self.scratch.scores.clear();
160        }
161
162        FrontAddResult {
163            added_count,
164            removed_count,
165            comparisons,
166            filter_count,
167            size: self.values.len(),
168        }
169    }
170
171    /// Remove points with crowding distance in the top `trim` fraction.
172    /// Example: trim=0.02 removes the top 2% most isolated points.
173    #[inline]
174    pub fn remove_outliers(&mut self, trim: f32) -> Option<usize> {
175        if self.values.len() < 4 {
176            return None;
177        }
178
179        let trim = trim.clamp(0.0, 0.5);
180        if trim == 0.0 {
181            return None;
182        }
183
184        if self.ensure_score_matrix().is_none() {
185            return None;
186        }
187
188        let (n, _m) = match self.score_dims() {
189            Some(x) => x,
190            None => return None,
191        };
192
193        self.crowding_distance_in_place(n);
194
195        let drop = ((n as f32) * trim).floor() as usize;
196        if drop == 0 {
197            return None;
198        }
199
200        // We want to drop the *largest* distances (most isolated).
201        self.scratch.order.clear();
202        self.scratch.order.extend(0..n);
203
204        let dist = &self.scratch.dist;
205        self.scratch.order.sort_unstable_by(|&i, &j| {
206            let a = dist[i];
207            let b = dist[j];
208
209            match (a.is_infinite(), b.is_infinite()) {
210                (true, true) => Ordering::Equal,
211                (true, false) => Ordering::Less,
212                (false, true) => Ordering::Greater,
213                _ => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
214            }
215        });
216
217        self.scratch.remove.clear();
218        self.scratch
219            .remove
220            .extend(self.scratch.order.iter().take(drop).copied());
221
222        self.scratch.remove.sort_unstable();
223        self.scratch.remove.dedup();
224        let removed = self.scratch.remove.len();
225        for &idx in self.scratch.remove.iter().rev() {
226            self.values.swap_remove(idx);
227        }
228
229        self.scratch.scores.clear();
230        Some(removed)
231    }
232
233    #[inline]
234    fn dom_cmp(&self, one: &T, two: &T) -> Ordering {
235        let one_score = one.score();
236        let two_score = two.score();
237
238        if one_score.is_none() || two_score.is_none() {
239            return Ordering::Equal;
240        }
241
242        let (a, b) = (one_score.unwrap(), two_score.unwrap());
243
244        if pareto::dominance(a, b, &self.objective) {
245            Ordering::Greater
246        } else if pareto::dominance(b, a, &self.objective) {
247            Ordering::Less
248        } else {
249            Ordering::Equal
250        }
251    }
252
253    pub fn fronts(&mut self) -> Vec<Front<T>>
254    where
255        T: Clone + Eq + Send + Sync + 'static,
256    {
257        let mut fronts: Vec<Front<T>> = Vec::new();
258        for member in self.values.iter() {
259            let mut updated = false;
260
261            for front in fronts.iter_mut() {
262                let to_insert = (*(*member)).clone();
263                let result = front.add_all(vec![to_insert]);
264
265                if result.added_count > 0 {
266                    updated = true;
267                    break;
268                }
269            }
270
271            if !updated {
272                let mut new_front = Front::new(self.range.clone(), self.objective.clone());
273                let to_insert = (*(*member)).clone();
274                new_front.add_all(vec![to_insert]);
275                fronts.push(new_front);
276            }
277        }
278
279        fronts
280    }
281
282    fn fast_filter(&mut self) {
283        let keep = self.range.start.min(self.values.len());
284        if keep == 0 || self.values.len() <= keep {
285            return;
286        }
287
288        // Build score matrix + crowding distances into scratch
289        if self.ensure_score_matrix().is_none() {
290            return;
291        }
292
293        let (n, _m) = match self.score_dims() {
294            Some(x) => x,
295            None => return,
296        };
297
298        self.crowding_distance_in_place(n);
299
300        // Pick top `keep` by crowding distance without sorting all n.
301        self.scratch.keep_idx.clear();
302        self.scratch.keep_idx.extend(0..n);
303
304        let dist = &self.scratch.dist;
305
306        // Partition so that [0..keep] are the best (in any order)
307        self.scratch
308            .keep_idx
309            .select_nth_unstable_by(keep, |&i, &j| {
310                dist[j].partial_cmp(&dist[i]).unwrap_or(Ordering::Equal)
311            });
312
313        self.scratch.keep_idx.truncate(keep);
314
315        let mut new_values = Vec::with_capacity(keep);
316        for &i in self.scratch.keep_idx.iter() {
317            new_values.push(Arc::clone(&self.values[i]));
318        }
319
320        self.values = new_values;
321        self.scratch.scores.clear();
322    }
323
324    #[inline]
325    fn score_dims(&self) -> Option<(usize, usize)> {
326        let n = self.values.len();
327
328        if n == 0 {
329            return None;
330        }
331
332        let first = self.values.iter().find_map(|v| v.score())?;
333        Some((n, first.len()))
334    }
335
336    fn ensure_score_matrix(&mut self) -> Option<()> {
337        let (n, m) = self.score_dims()?;
338
339        if m == 0 {
340            return None;
341        }
342
343        // If already built and size matches, keep it.
344        if self.scratch.scores.len() == n * m {
345            return Some(());
346        }
347
348        self.scratch.scores.resize(n * m, 0.0);
349        for (i, v) in self.values.iter().enumerate() {
350            let s = v.score()?;
351            if s.len() != m {
352                return None;
353            }
354
355            let row = &mut self.scratch.scores[i * m..i * m + m];
356            row.copy_from_slice(s.as_slice());
357        }
358
359        Some(())
360    }
361
362    fn crowding_distance_in_place(&mut self, n: usize) {
363        let (_, m) = match self.score_dims() {
364            Some(x) => x,
365            None => return,
366        };
367
368        if n == 0 || m == 0 {
369            return;
370        }
371
372        self.scratch.dist.clear();
373        self.scratch.dist.resize(n, 0.0);
374
375        self.scratch.order.clear();
376        self.scratch.order.extend(0..n);
377
378        for dim in 0..m {
379            let scores = &self.scratch.scores;
380            self.scratch.order.sort_unstable_by(|&i, &j| {
381                let a = scores[i * m + dim];
382                let b = scores[j * m + dim];
383                a.partial_cmp(&b).unwrap_or(Ordering::Equal)
384            });
385
386            let first_idx = self.scratch.order[0];
387            let last_idx = self.scratch.order[n - 1];
388            let min = self.scratch.scores[first_idx * m..first_idx * m + m][dim];
389            let max = self.scratch.scores[last_idx * m..last_idx * m + m][dim];
390            let range = max - min;
391
392            if !range.is_finite() || range == 0.0 {
393                continue;
394            }
395
396            self.scratch.dist[self.scratch.order[0]] = f32::INFINITY;
397            self.scratch.dist[self.scratch.order[n - 1]] = f32::INFINITY;
398
399            for k in 1..(n - 1) {
400                let prev_idx = self.scratch.order[k - 1];
401                let next_idx = self.scratch.order[k + 1];
402                let prev = self.scratch.scores[prev_idx * m..prev_idx * m + m][dim];
403                let next = self.scratch.scores[next_idx * m..next_idx * m + m][dim];
404
405                let contrib = (next - prev).abs() / range;
406                self.scratch.dist[self.scratch.order[k]] += contrib;
407            }
408        }
409    }
410}
411
412impl<T> Default for Front<T>
413where
414    T: Scored,
415{
416    fn default() -> Self {
417        Front::new(0..0, Objective::default())
418    }
419}
420
421/// Calculate the Shannon entropy of a set of scores in multidimensional space.
422/// The scores are discretized into a grid of bins, and the entropy is computed
423/// based on the distribution of scores across these bins. Higher entropy indicates
424/// a more diverse set of scores. This can be interpreted as a measure of how well
425/// the solutions are spread out in the objective space.
426///
427/// It works by:
428/// 1. Determining the min and max values for each objective dimension.
429/// 2. Mapping each score to a discrete bin index based on its normalized position
430///    within the min-max range for each dimension.
431/// 3. Counting the number of scores in each bin (cell).
432/// 4. Calculating the probabilities of each occupied bin and computing the
433///    Shannon entropy using these probabilities.
434/// 5. Optionally normalizing the entropy by the maximum possible entropy given
435///    the number of occupied bins and total scores.
436fn entropy_flat(scores: &[f32], n: usize, m: usize, bins_per_dim: usize) -> f32 {
437    if n == 0 || m == 0 || bins_per_dim == 0 {
438        return 0.0;
439    }
440
441    // mins/maxs per dim
442    let mut mins = vec![f32::INFINITY; m];
443    let mut maxs = vec![f32::NEG_INFINITY; m];
444
445    for i in 0..n {
446        let row = &scores[i * m..i * m + m];
447        for d in 0..m {
448            let x = row[d];
449            if x < mins[d] {
450                mins[d] = x;
451            }
452            if x > maxs[d] {
453                maxs[d] = x;
454            }
455        }
456    }
457
458    for d in 0..m {
459        if (maxs[d] - mins[d]).abs() < EPSILON {
460            maxs[d] = mins[d] + 1.0;
461        }
462    }
463
464    let mut cell_counts: HashMap<Vec<u8>, usize> = HashMap::new();
465
466    for i in 0..n {
467        let row = &scores[i * m..i * m + m];
468        let mut cell = Vec::with_capacity(m);
469
470        for d in 0..m {
471            let norm = (row[d] - mins[d]) / (maxs[d] - mins[d]); // [0,1]
472            let mut idx = (norm * bins_per_dim as f32).floor() as i32;
473            if idx < 0 {
474                idx = 0;
475            }
476            if idx >= bins_per_dim as i32 {
477                idx = bins_per_dim as i32 - 1;
478            }
479            cell.push(idx as u8);
480        }
481
482        *cell_counts.entry(cell).or_insert(0) += 1;
483    }
484
485    let n_f = n as f32;
486    let mut h = 0.0_f32;
487    for &count in cell_counts.values() {
488        let p = count as f32 / n_f;
489        if p > 0.0 {
490            h -= p * p.ln();
491        }
492    }
493
494    let k = cell_counts.len().min(n);
495    if k > 1 { h / (k as f32).ln() } else { 0.0 }
496}