use super::{Chromosome, Codec, Genotype, Score};
use crate::{Objective, error::RadiateResult};
use radiate_error::{RadiateError, radiate_err};
use std::sync::Arc;
pub trait Problem<C: Chromosome, T>: Send + Sync {
fn encode(&self) -> Genotype<C>;
fn decode(&self, genotype: &Genotype<C>) -> T;
fn eval(&self, individual: &Genotype<C>) -> Result<Score, RadiateError>;
fn eval_batch(&self, individuals: &[Genotype<C>]) -> Result<Vec<Score>, RadiateError> {
individuals.iter().map(|ind| self.eval(ind)).collect()
}
}
type FitnessFn<T> = dyn Fn(T) -> Score + Send + Sync;
type RawFitnessFn<C> = dyn Fn(&Genotype<C>) -> Score + Send + Sync;
pub struct EngineProblem<C, T>
where
C: Chromosome,
{
pub objective: Objective,
pub codec: Arc<dyn Codec<C, T>>,
pub fitness_fn: Option<Arc<FitnessFn<T>>>,
pub raw_fitness_fn: Option<Arc<RawFitnessFn<C>>>,
}
impl<C: Chromosome, T> Problem<C, T> for EngineProblem<C, T> {
fn encode(&self) -> Genotype<C> {
self.codec.encode()
}
fn decode(&self, genotype: &Genotype<C>) -> T {
self.codec.decode(genotype)
}
fn eval(&self, individual: &Genotype<C>) -> RadiateResult<Score> {
let score = if let Some(raw_fn) = &self.raw_fitness_fn {
raw_fn(individual)
} else if let Some(fitness_fn) = &self.fitness_fn {
let phenotype = self.decode(individual);
fitness_fn(phenotype)
} else {
return Err(radiate_err!(
Evaluation: "No fitness function defined for EngineProblem"
));
};
if self.objective.validate(&score) {
return Ok(score);
}
Err(radiate_err!(
Evaluation: "Invalid fitness score {:?} for objective {:?}",
score,
self.objective
))
}
}
unsafe impl<C: Chromosome, T> Send for EngineProblem<C, T> {}
unsafe impl<C: Chromosome, T> Sync for EngineProblem<C, T> {}
type BatchFitnessFn<T> = dyn Fn(Vec<T>) -> Vec<Score> + Send + Sync;
type RawBatchFitnessFn<C> = dyn Fn(Vec<&Genotype<C>>) -> Vec<Score> + Send + Sync;
pub struct BatchEngineProblem<C, T>
where
C: Chromosome,
{
pub objective: Objective,
pub codec: Arc<dyn Codec<C, T>>,
pub batch_fitness_fn: Option<Arc<BatchFitnessFn<T>>>,
pub raw_batch_fitness_fn: Option<Arc<RawBatchFitnessFn<C>>>,
}
impl<C: Chromosome, T> Problem<C, T> for BatchEngineProblem<C, T> {
fn encode(&self) -> Genotype<C> {
self.codec.encode()
}
fn decode(&self, genotype: &Genotype<C>) -> T {
self.codec.decode(genotype)
}
fn eval(&self, individual: &Genotype<C>) -> RadiateResult<Score> {
let scores = if let Some(raw_batch_fn) = &self.raw_batch_fitness_fn {
raw_batch_fn(vec![individual])
} else if let Some(batch_fn) = &self.batch_fitness_fn {
let phenotypes = vec![self.decode(individual)];
batch_fn(phenotypes)
} else {
return Err(radiate_err!(
Evaluation: "No batch fitness function defined for BatchEngineProblem"
));
};
Ok(scores[0].clone())
}
fn eval_batch(&self, individuals: &[Genotype<C>]) -> RadiateResult<Vec<Score>> {
let scores = if let Some(raw_batch_fn) = &self.raw_batch_fitness_fn {
raw_batch_fn(individuals.iter().collect())
} else if let Some(batch_fn) = &self.batch_fitness_fn {
let phenotypes = individuals
.iter()
.map(|genotype| self.decode(genotype))
.collect::<Vec<T>>();
batch_fn(phenotypes)
} else {
return Err(radiate_err!(
Evaluation: "No batch fitness function defined for BatchEngineProblem"
));
};
for i in 0..scores.len() {
if !self.objective.validate(&scores[i]) {
return Err(radiate_err!(
Evaluation: "Invalid fitness score {:?} for objective {:?}",
scores[i],
self.objective
));
}
}
Ok(scores)
}
}
unsafe impl<C: Chromosome, T> Send for BatchEngineProblem<C, T> {}
unsafe impl<C: Chromosome, T> Sync for BatchEngineProblem<C, T> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Chromosome, Codec, FloatChromosome, FloatGene, Gene, Genotype, Score};
#[derive(Debug, Clone)]
struct MockPhenotype {
x: f32,
y: f32,
}
struct MockCodec;
impl Codec<FloatChromosome<f32>, MockPhenotype> for MockCodec {
fn encode(&self) -> Genotype<FloatChromosome<f32>> {
Genotype::new(vec![
FloatChromosome::from(FloatGene::from(1.0)),
FloatChromosome::from(FloatGene::from(2.0)),
])
}
fn decode(&self, genotype: &Genotype<FloatChromosome<f32>>) -> MockPhenotype {
MockPhenotype {
x: *genotype[0].get(0).allele(),
y: *genotype[1].get(0).allele(),
}
}
}
#[test]
fn test_engine_problem_basic_functionality() {
let fitness_fn =
Arc::new(|phenotype: MockPhenotype| Score::from(phenotype.x + phenotype.y));
let problem = EngineProblem {
objective: Objective::default(),
codec: Arc::new(MockCodec),
fitness_fn: Some(fitness_fn),
raw_fitness_fn: None,
};
let genotype = problem.encode();
assert_eq!(genotype.len(), 2);
let phenotype = problem.decode(&genotype);
assert_eq!(phenotype.x, 1.0);
assert_eq!(phenotype.y, 2.0);
let fitness = problem.eval(&genotype).unwrap();
assert_eq!(fitness.as_f32(), 3.0);
}
#[test]
fn test_engine_problem_batch_evaluation() {
let fitness_fn =
Arc::new(|phenotype: MockPhenotype| Score::from(phenotype.x + phenotype.y));
let problem = EngineProblem {
objective: Objective::default(),
codec: Arc::new(MockCodec),
fitness_fn: Some(fitness_fn),
raw_fitness_fn: None,
};
let genotypes = vec![problem.encode(), problem.encode()];
let scores = problem.eval_batch(&genotypes).unwrap();
assert_eq!(scores.len(), 2);
assert_eq!(scores[0].as_f32(), 3.0);
assert_eq!(scores[1].as_f32(), 3.0);
}
#[test]
fn test_batch_engine_problem_basic_functionality() {
let batch_fitness_fn = Arc::new(|phenotypes: Vec<MockPhenotype>| {
phenotypes.iter().map(|p| Score::from(p.x * p.y)).collect()
});
let problem = BatchEngineProblem {
objective: Objective::default(),
codec: Arc::new(MockCodec),
batch_fitness_fn: Some(batch_fitness_fn),
raw_batch_fitness_fn: None,
};
let genotype = problem.encode();
assert_eq!(genotype.len(), 2);
let phenotype = problem.decode(&genotype);
assert_eq!(phenotype.x, 1.0);
assert_eq!(phenotype.y, 2.0);
let fitness = problem.eval(&genotype).unwrap();
assert_eq!(fitness.as_f32(), 2.0); }
#[test]
fn test_batch_engine_problem_batch_evaluation() {
let batch_fitness_fn = Arc::new(|phenotypes: Vec<MockPhenotype>| {
phenotypes.iter().map(|p| Score::from(p.x * p.y)).collect()
});
let problem = BatchEngineProblem {
objective: Objective::default(),
codec: Arc::new(MockCodec),
batch_fitness_fn: Some(batch_fitness_fn),
raw_batch_fitness_fn: None,
};
let genotypes = vec![problem.encode(), problem.encode()];
let scores = problem.eval_batch(&genotypes).unwrap();
assert_eq!(scores.len(), 2);
assert_eq!(scores[0].as_f32(), 2.0); assert_eq!(scores[1].as_f32(), 2.0); }
#[test]
fn test_consistency_between_eval_and_eval_batch() {
let batch_fitness_fn = Arc::new(|phenotypes: Vec<MockPhenotype>| {
phenotypes.iter().map(|p| Score::from(p.x * p.y)).collect()
});
let problem = BatchEngineProblem {
objective: Objective::default(),
codec: Arc::new(MockCodec),
batch_fitness_fn: Some(batch_fitness_fn),
raw_batch_fitness_fn: None,
};
let genotype = problem.encode();
let individual_fitness = problem.eval(&genotype).unwrap();
let batch_scores = problem.eval_batch(&[genotype.clone()]).unwrap();
let batch_fitness = &batch_scores[0];
assert_eq!(individual_fitness.as_f32(), batch_fitness.as_f32());
}
}