radiate_core/objectives/
front.rs

1use crate::objectives::{Objective, Scored, pareto};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::{cmp::Ordering, hash::Hash, ops::Range, sync::Arc};
5
6const DEFAULT_ENTROPY_BINS: usize = 20;
7
8pub struct FrontAddResult {
9    pub added_count: usize,
10    pub removed_count: usize,
11    pub comparisons: usize,
12}
13
14/// A `Front<T>` is a collection of `T`'s that are non-dominated with respect to each other.
15/// This is useful for multi-objective optimization problems where the goal is to find
16/// the best solutions that are not dominated by any other solution.
17/// This results in what is called the Pareto front.
18#[derive(Clone)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub struct Front<T>
21where
22    T: Scored,
23{
24    values: Vec<Arc<T>>,
25    range: Range<usize>,
26    objective: Objective,
27}
28
29impl<T> Front<T>
30where
31    T: Scored,
32{
33    pub fn new(range: Range<usize>, objective: Objective) -> Self {
34        Front {
35            values: Vec::new(),
36            range,
37            objective: objective.clone(),
38        }
39    }
40
41    pub fn range(&self) -> Range<usize> {
42        self.range.clone()
43    }
44
45    pub fn objective(&self) -> Objective {
46        self.objective.clone()
47    }
48
49    pub fn is_empty(&self) -> bool {
50        self.values.is_empty()
51    }
52
53    pub fn values(&self) -> &[Arc<T>] {
54        &self.values
55    }
56
57    pub fn crowding_distance(&self) -> Option<Vec<f32>> {
58        let scores = self
59            .values
60            .iter()
61            .filter_map(|s| s.score())
62            .collect::<Vec<_>>();
63
64        if scores.is_empty() {
65            return None;
66        }
67
68        Some(pareto::crowding_distance(&scores))
69    }
70
71    pub fn entropy(&self) -> Option<f32> {
72        let scores = self
73            .values
74            .iter()
75            .filter_map(|s| s.score())
76            .collect::<Vec<_>>();
77
78        if scores.is_empty() {
79            return None;
80        }
81
82        Some(pareto::entropy(&scores, DEFAULT_ENTROPY_BINS))
83    }
84
85    pub fn add_all(&mut self, items: &[T]) -> FrontAddResult
86    where
87        T: Eq + Hash + Clone + Send + Sync + 'static,
88    {
89        let mut updated = false;
90        let mut to_remove = Vec::new();
91        let mut added_count = 0;
92        let mut removed_count = 0;
93        let mut comparisons = 0;
94
95        for i in 0..items.len() {
96            let new_member = &items[i];
97            let mut is_dominated = true;
98
99            for existing_val in self.values.iter() {
100                let equals = new_member == existing_val.as_ref();
101                if self.dom_cmp(existing_val.as_ref(), new_member) == Ordering::Greater || equals {
102                    // If an existing value dominates the new value, return false
103                    is_dominated = false;
104                    comparisons += 1;
105                    break;
106                } else if self.dom_cmp(new_member, existing_val.as_ref()) == Ordering::Greater {
107                    // If the new value dominates an existing value, continue checking
108                    to_remove.push(Arc::clone(existing_val));
109                    comparisons += 1;
110                    continue;
111                }
112            }
113
114            if is_dominated {
115                updated = true;
116                self.values.push(Arc::new(new_member.clone()));
117                added_count += 1;
118                for rem in to_remove.drain(..) {
119                    self.values.retain(|x| x.as_ref() != rem.as_ref());
120                    removed_count += 1;
121                }
122            }
123
124            if updated && self.values.len() > self.range.end {
125                self.filter();
126            }
127
128            to_remove.clear();
129            updated = false;
130        }
131
132        FrontAddResult {
133            added_count,
134            removed_count,
135            comparisons,
136        }
137    }
138
139    fn dom_cmp(&self, one: &T, two: &T) -> Ordering {
140        let one_score = one.score();
141        let two_score = two.score();
142
143        if one_score.is_none() || two_score.is_none() {
144            return Ordering::Equal;
145        }
146
147        if let (Some(one), Some(two)) = (one_score, two_score) {
148            if pareto::dominance(one, two, &self.objective) {
149                return Ordering::Greater;
150            } else if pareto::dominance(two, one, &self.objective) {
151                return Ordering::Less;
152            }
153        }
154
155        Ordering::Equal
156    }
157
158    fn filter(&mut self) {
159        if let Some(crowding_distances) = self.crowding_distance() {
160            let mut enumerated = crowding_distances.iter().enumerate().collect::<Vec<_>>();
161
162            enumerated.sort_unstable_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(Ordering::Equal));
163
164            self.values = enumerated
165                .iter()
166                .take(self.range.start)
167                .map(|(i, _)| Arc::clone(&self.values[*i]))
168                .collect::<Vec<Arc<T>>>();
169        }
170    }
171}
172
173impl<T> Default for Front<T>
174where
175    T: Scored,
176{
177    fn default() -> Self {
178        Front::new(0..0, Objective::default())
179    }
180}