use crate::searcher_impl::{
SearcherCore, evaluate_pbests, move_particles_with_offset, reduce_best,
};
use crate::{Best, Contextful, FitCalc, Group, ParticleMover, Searcher, Unit};
use rand::Rng;
use std::ops::Range;
#[derive(Clone, Debug)]
pub struct NichedSearcher<TUnit, TFit, TMover>
where
TUnit: Unit,
TFit: FitCalc,
TMover: ParticleMover,
{
core: SearcherCore<TUnit, TFit, TMover>,
niche_bests: Vec<Best<TUnit>>,
partition: Vec<Range<usize>>,
}
impl<TUnit, TFit, TMover> NichedSearcher<TUnit, TFit, TMover>
where
TUnit: Unit,
TFit: FitCalc<T = TUnit>,
TMover: ParticleMover<TUnit = TUnit>,
{
#[must_use]
#[track_caller]
pub fn new(fit_calc: TFit, mover: TMover, partition: Vec<Range<usize>>) -> Self {
assert!(
!partition.is_empty(),
"NichedSearcher requires a non-empty partition",
);
let mut expected = 0;
for (i, range) in partition.iter().enumerate() {
assert!(
range.start == expected,
"partition[{i}] = {range:?} is non-contiguous; \
expected start = {expected}",
);
assert!(
range.end >= range.start,
"partition[{i}] = {range:?} has end < start",
);
expected = range.end;
}
let n_niches = partition.len();
Self {
core: SearcherCore::new(fit_calc, mover),
niche_bests: vec![Best::new(); n_niches],
partition,
}
}
pub fn niche_bests(&self) -> &[Best<TUnit>] {
&self.niche_bests
}
pub fn partition(&self) -> &[Range<usize>] {
&self.partition
}
fn update_swarm_best(&mut self) {
for nb in &self.niche_bests {
if nb.best_fit > self.core.swarm_best.best_fit {
self.core.swarm_best = *nb;
}
}
}
}
pub trait IntoNichedSearcher: ParticleMover + Sized {
#[must_use]
fn into_niched_searcher<TFit>(
self,
fit_calc: TFit,
partition: Vec<Range<usize>>,
) -> NichedSearcher<Self::TUnit, TFit, Self>
where
TFit: FitCalc<T = Self::TUnit>,
{
NichedSearcher::new(fit_calc, self, partition)
}
}
impl<TUnit, M> IntoNichedSearcher for M
where
TUnit: Unit,
M: ParticleMover<TUnit = TUnit, TCommon = Best<TUnit>>,
{
}
impl<TUnit, TFit, TMover, TContext> Contextful for NichedSearcher<TUnit, TFit, TMover>
where
TFit: FitCalc<T = TUnit, TContext = TContext>,
TMover: ParticleMover<TUnit = TUnit, TCommon = Best<TUnit>, TContext = TContext>,
TUnit: Unit,
TContext: Copy,
{
type TContext = TContext;
fn set_context(&mut self, context: Self::TContext) {
self.core.set_context(context);
}
fn set_iteration(&mut self, iteration: usize, max_iteration: usize) {
self.core.set_iteration(iteration, max_iteration);
}
}
impl<TUnit, TFit, TMover, TContext> Searcher for NichedSearcher<TUnit, TFit, TMover>
where
TFit: FitCalc<T = TUnit, TContext = TContext>,
TMover: ParticleMover<TUnit = TUnit, TCommon = Best<TUnit>, TContext = TContext>,
TUnit: Unit,
TContext: Copy,
{
type TUnit = TUnit;
fn init(&mut self, particles: &mut Group<Self::TUnit>) {
let expected = self.partition.last().map_or(0, |r| r.end);
assert!(
particles.len() == expected,
"particle count {} does not match partition coverage {expected}",
particles.len(),
);
for nb in &mut self.niche_bests {
*nb = Best::new();
}
self.core.swarm_best = Best::new();
self.core.sync_rngs(particles.len());
evaluate_pbests(&self.core.fit_calc, particles);
for (i, range) in self.partition.iter().enumerate() {
reduce_best(&particles[range.clone()], &mut self.niche_bests[i]);
}
self.update_swarm_best();
}
fn next(&mut self, particles: &mut Group<Self::TUnit>) -> &Best<TUnit> {
self.core.sync_rngs(particles.len());
for (i, range) in self.partition.iter().enumerate() {
let niche_best = &self.niche_bests[i];
let particles_slice = &mut particles[range.clone()];
let rngs_slice = &mut self.core.particle_rngs[range.clone()];
move_particles_with_offset(
&self.core.mover,
niche_best,
particles_slice,
rngs_slice,
range.start,
);
}
evaluate_pbests(&self.core.fit_calc, particles);
for (i, range) in self.partition.iter().enumerate() {
reduce_best(&particles[range.clone()], &mut self.niche_bests[i]);
}
self.update_swarm_best();
&self.core.swarm_best
}
fn reseed<R: Rng>(&mut self, rng: &mut R) {
self.core.reseed(rng);
}
}