use crate::{BatchFitnessFunction, FitnessFunction, Score};
use std::sync::Arc;
const MIN_SCORE: f32 = 1e-8;
pub struct CompositeFitnessFn<T, S> {
objectives: Vec<Arc<dyn for<'a> FitnessFunction<&'a T, S>>>,
weights: Vec<f32>,
}
impl<T, S> CompositeFitnessFn<T, S>
where
S: Into<Score> + Clone,
{
pub fn new() -> Self {
Self {
objectives: Vec::new(),
weights: Vec::new(),
}
}
pub fn add_weighted_fn(
mut self,
fitness_fn: impl for<'a> FitnessFunction<&'a T, S> + 'static,
weight: f32,
) -> Self
where
S: Into<Score>,
{
self.objectives.push(Arc::new(fitness_fn));
self.weights.push(weight);
self
}
pub fn add_fitness_fn(
mut self,
fitness_fn: impl for<'a> FitnessFunction<&'a T, S> + 'static,
) -> Self
where
S: Into<Score>,
{
self.objectives.push(Arc::new(fitness_fn));
self.weights.push(1.0);
self
}
}
impl<T> FitnessFunction<T> for CompositeFitnessFn<T, f32> {
fn evaluate(&self, individual: T) -> f32 {
let mut total_score = 0.0;
let mut total_weight = 0.0;
for (objective, weight) in self.objectives.iter().zip(&self.weights) {
let score = objective.evaluate(&individual);
total_score += score * weight;
total_weight += weight;
}
total_score / total_weight.max(MIN_SCORE)
}
}
impl<T> BatchFitnessFunction<T> for CompositeFitnessFn<T, f32> {
fn evaluate(&self, individuals: Vec<T>) -> Vec<f32> {
let mut results = Vec::with_capacity(individuals.len());
for individual in individuals {
let mut total_score = 0.0;
let mut total_weight = 0.0;
for (objective, weight) in self.objectives.iter().zip(&self.weights) {
let score = objective.evaluate(&individual);
total_score += score * weight;
total_weight += weight;
}
results.push(total_score / total_weight.max(MIN_SCORE));
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fitness::FitnessFunction;
fn mock_accuracy_fn(individual: &i32) -> f32 {
*individual as f32 * 0.1
}
fn mock_complexity_fn(individual: &i32) -> f32 {
-*individual as f32 * 0.05
}
#[test]
fn test_add_weighted_fn() {
let composite = CompositeFitnessFn::new()
.add_weighted_fn(mock_accuracy_fn, 0.7)
.add_weighted_fn(mock_complexity_fn, 0.3);
assert_eq!(composite.objectives.len(), 2);
assert_eq!(composite.weights, vec![0.7, 0.3]);
}
#[test]
fn test_add_fitness_fn() {
let composite = CompositeFitnessFn::new()
.add_fitness_fn(mock_accuracy_fn)
.add_fitness_fn(mock_complexity_fn);
assert_eq!(composite.objectives.len(), 2);
assert_eq!(composite.weights, vec![1.0, 1.0]);
}
#[test]
fn test_evaluate_single() {
let composite = CompositeFitnessFn::new()
.add_weighted_fn(mock_accuracy_fn, 0.7)
.add_weighted_fn(mock_complexity_fn, 0.3);
let individual = 10;
let fitness = FitnessFunction::evaluate(&composite, individual);
assert!((fitness - 0.55).abs() < 1e-6);
}
#[test]
fn test_evaluate_batch() {
let composite = CompositeFitnessFn::new()
.add_weighted_fn(mock_accuracy_fn, 0.7)
.add_weighted_fn(mock_complexity_fn, 0.3);
let individuals = vec![10, 20, 30];
let fitness_scores = BatchFitnessFunction::evaluate(&composite, individuals);
assert_eq!(fitness_scores.len(), 3);
assert!((fitness_scores[0] - 0.55).abs() < 1e-6);
}
#[test]
fn test_empty_composite() {
let composite = CompositeFitnessFn::new();
let individual = 10;
let fitness = FitnessFunction::evaluate(&composite, individual);
assert_eq!(fitness, 0.0);
}
}