use radiate_core::{
AlterContext, AlterResult, Chromosome, Crossover, Rate, Valid, random_provider,
};
pub struct MultiPointCrossover {
num_points: usize,
rate: Rate,
}
impl MultiPointCrossover {
pub fn new(rate: impl Into<Rate>, num_points: usize) -> Self {
let rate = rate.into();
if !rate.is_valid() {
panic!("Rate {rate:?} is not valid. Must be between 0.0 and 1.0",);
}
Self { num_points, rate }
}
}
impl<C: Chromosome> Crossover<C> for MultiPointCrossover {
fn rate(&self) -> Rate {
self.rate.clone()
}
#[inline]
fn cross_chromosomes(
&self,
chrom_one: &mut C,
chrom_two: &mut C,
_: &mut AlterContext,
) -> AlterResult {
let one = chrom_one.as_mut_slice();
let two = chrom_two.as_mut_slice();
let num_crosses = if self.num_points == 1 {
crossover_single_point(one, two)
} else {
crossover_multi_point(one, two, self.num_points)
};
AlterResult::from(num_crosses)
}
}
#[inline]
pub fn crossover_multi_point<G>(
chrom_one: &mut [G],
chrom_two: &mut [G],
num_points: usize,
) -> usize {
let length = std::cmp::min(chrom_one.len(), chrom_two.len());
if length < 2 {
return 0;
}
let num_points = num_points.clamp(1, length - 1);
let mut selected_points = random_provider::sample_indices(0..length, num_points);
selected_points.sort();
let mut current_parent = 1;
let mut last_point = 0;
for i in selected_points {
if current_parent == 1 {
chrom_one[last_point..i].swap_with_slice(&mut chrom_two[last_point..i]);
}
current_parent = 3 - current_parent;
last_point = i;
}
if current_parent == 1 {
chrom_one[last_point..].swap_with_slice(&mut chrom_two[last_point..]);
}
num_points
}
#[inline]
pub fn crossover_single_point<G>(chrom_one: &mut [G], chrom_two: &mut [G]) -> usize {
let length = std::cmp::min(chrom_one.len(), chrom_two.len());
if length < 2 {
return 0;
}
let crossover_point = random_provider::range(1..length);
chrom_one[crossover_point..].swap_with_slice(&mut chrom_two[crossover_point..]);
1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_crossover_multi_point() {
let mut chrom_one = vec![0; 10];
let mut chrom_two = vec![1; 10];
let points = crossover_multi_point(&mut chrom_one, &mut chrom_two, 2);
assert_eq!(chrom_one.len(), 10);
assert_eq!(chrom_two.len(), 10);
assert_eq!(points, 2);
}
}