radiate_core/fitness/
composite.rs1use crate::{BatchFitnessFunction, FitnessFunction, Score};
2use std::sync::Arc;
3
4const MIN_SCORE: f32 = 1e-8;
5
6pub struct CompositeFitnessFn<T, S> {
38 objectives: Vec<Arc<dyn for<'a> FitnessFunction<&'a T, S>>>,
39 weights: Vec<f32>,
40}
41
42impl<T, S> CompositeFitnessFn<T, S>
43where
44 S: Into<Score> + Clone,
45{
46 pub fn new() -> Self {
47 Self {
48 objectives: Vec::new(),
49 weights: Vec::new(),
50 }
51 }
52
53 pub fn add_weighted_fn(
54 mut self,
55 fitness_fn: impl for<'a> FitnessFunction<&'a T, S> + 'static,
56 weight: f32,
57 ) -> Self
58 where
59 S: Into<Score>,
60 {
61 self.objectives.push(Arc::new(fitness_fn));
62 self.weights.push(weight);
63 self
64 }
65
66 pub fn add_fitness_fn(
67 mut self,
68 fitness_fn: impl for<'a> FitnessFunction<&'a T, S> + 'static,
69 ) -> Self
70 where
71 S: Into<Score>,
72 {
73 self.objectives.push(Arc::new(fitness_fn));
74 self.weights.push(1.0);
75 self
76 }
77}
78
79impl<T> FitnessFunction<T> for CompositeFitnessFn<T, f32> {
90 fn evaluate(&self, individual: T) -> f32 {
91 let mut total_score = 0.0;
92 let mut total_weight = 0.0;
93 for (objective, weight) in self.objectives.iter().zip(&self.weights) {
94 let score = objective.evaluate(&individual);
95 total_score += score * weight;
96 total_weight += weight;
97 }
98
99 total_score / total_weight.max(MIN_SCORE)
100 }
101}
102
103impl<T> BatchFitnessFunction<T> for CompositeFitnessFn<T, f32> {
108 fn evaluate(&self, individuals: &[T]) -> Vec<f32> {
109 let mut results = Vec::with_capacity(individuals.len());
110
111 for individual in individuals {
112 let mut total_score = 0.0;
113 let mut total_weight = 0.0;
114
115 for (objective, weight) in self.objectives.iter().zip(&self.weights) {
116 let score = objective.evaluate(individual);
117 total_score += score * weight;
118 total_weight += weight;
119 }
120
121 results.push(total_score / total_weight.max(MIN_SCORE));
122 }
123
124 results
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::fitness::FitnessFunction;
132
133 fn mock_accuracy_fn(individual: &i32) -> f32 {
135 *individual as f32 * 0.1
136 }
137
138 fn mock_complexity_fn(individual: &i32) -> f32 {
139 -*individual as f32 * 0.05
140 }
141
142 #[test]
143 fn test_add_weighted_fn() {
144 let composite = CompositeFitnessFn::new()
145 .add_weighted_fn(mock_accuracy_fn, 0.7)
146 .add_weighted_fn(mock_complexity_fn, 0.3);
147
148 assert_eq!(composite.objectives.len(), 2);
149 assert_eq!(composite.weights, vec![0.7, 0.3]);
150 }
151
152 #[test]
153 fn test_add_fitness_fn() {
154 let composite = CompositeFitnessFn::new()
155 .add_fitness_fn(mock_accuracy_fn)
156 .add_fitness_fn(mock_complexity_fn);
157
158 assert_eq!(composite.objectives.len(), 2);
159 assert_eq!(composite.weights, vec![1.0, 1.0]);
160 }
161
162 #[test]
163 fn test_evaluate_single() {
164 let composite = CompositeFitnessFn::new()
165 .add_weighted_fn(mock_accuracy_fn, 0.7)
166 .add_weighted_fn(mock_complexity_fn, 0.3);
167
168 let individual = 10;
169 let fitness = FitnessFunction::evaluate(&composite, individual);
170
171 assert!((fitness - 0.55).abs() < 1e-6);
174 }
175
176 #[test]
177 fn test_evaluate_batch() {
178 let composite = CompositeFitnessFn::new()
179 .add_weighted_fn(mock_accuracy_fn, 0.7)
180 .add_weighted_fn(mock_complexity_fn, 0.3);
181
182 let individuals = vec![10, 20, 30];
183 let fitness_scores = BatchFitnessFunction::evaluate(&composite, &individuals);
184
185 assert_eq!(fitness_scores.len(), 3);
186
187 assert!((fitness_scores[0] - 0.55).abs() < 1e-6);
189 }
190
191 #[test]
192 fn test_empty_composite() {
193 let composite = CompositeFitnessFn::new();
194 let individual = 10;
195 let fitness = FitnessFunction::evaluate(&composite, individual);
196
197 assert_eq!(fitness, 0.0);
199 }
200}