use crate::{Best, Contextful, FitCalc, Group, Particle, ParticleMover, ParticleRefMut, Unit};
use rand::rngs::SmallRng;
use rand::{Rng, RngExt as _, SeedableRng as _};
pub(crate) fn reseed_small_rng<R: Rng>(source: &mut R) -> SmallRng {
SmallRng::seed_from_u64(source.random::<u64>())
}
pub(crate) fn par_for_each_mut<T, F>(slice: &mut [Particle<T>], leaf_size: usize, f: &F)
where
T: Unit,
F: Fn(ParticleRefMut<'_, T>) + Sync,
{
if slice.len() <= leaf_size {
for p in slice.iter_mut() {
f(p.as_ref_mut());
}
} else {
let mid = slice.len() / 2;
let (left, right) = slice.split_at_mut(mid);
rayon::join(
|| par_for_each_mut(left, leaf_size, f),
|| par_for_each_mut(right, leaf_size, f),
);
}
}
pub fn par_for_each_mut_rng<T, F>(
slice: &mut [Particle<T>],
rngs: &mut [SmallRng],
leaf_size: usize,
f: &F,
) where
T: Unit,
F: Fn(usize, ParticleRefMut<'_, T>, &mut SmallRng) + Sync,
{
par_for_each_mut_rng_offset(slice, rngs, 0, leaf_size, f);
}
fn par_for_each_mut_rng_offset<T, F>(
slice: &mut [Particle<T>],
rngs: &mut [SmallRng],
offset: usize,
leaf_size: usize,
f: &F,
) where
T: Unit,
F: Fn(usize, ParticleRefMut<'_, T>, &mut SmallRng) + Sync,
{
debug_assert_eq!(
slice.len(),
rngs.len(),
"particle slice and rng slice must be split in lock-step",
);
if slice.len() <= leaf_size {
for (i, (p, rng)) in slice.iter_mut().zip(rngs.iter_mut()).enumerate() {
f(offset + i, p.as_ref_mut(), rng);
}
} else {
let mid = slice.len() / 2;
let (left, right) = slice.split_at_mut(mid);
let (rng_left, rng_right) = rngs.split_at_mut(mid);
rayon::join(
|| par_for_each_mut_rng_offset(left, rng_left, offset, leaf_size, f),
|| par_for_each_mut_rng_offset(right, rng_right, offset + mid, leaf_size, f),
);
}
}
pub fn move_particles<TUnit, TMover>(
mover: &TMover,
common: &TMover::TCommon,
particles: &mut Group<TUnit>,
particle_rngs: &mut [SmallRng],
) where
TUnit: Unit,
TMover: ParticleMover<TUnit = TUnit>,
TMover::TCommon: Sync,
{
move_particles_with_offset(mover, common, particles, particle_rngs, 0);
}
pub fn move_particles_with_offset<TUnit, TMover>(
mover: &TMover,
common: &TMover::TCommon,
particles: &mut [Particle<TUnit>],
particle_rngs: &mut [SmallRng],
base_idx: usize,
) where
TUnit: Unit,
TMover: ParticleMover<TUnit = TUnit>,
TMover::TCommon: Sync,
{
par_for_each_mut_rng_offset(
particles,
particle_rngs,
base_idx,
TMover::PAR_LEAF_SIZE,
&|idx, mut p, rng| {
mover.update(common, rng, idx, &mut p);
},
);
}
pub fn evaluate_pbests<TUnit, TFit>(fit_calc: &TFit, particles: &mut Group<TUnit>)
where
TUnit: Unit,
TFit: FitCalc<T = TUnit>,
{
par_for_each_mut(particles, TFit::PAR_LEAF_SIZE, &|p| {
let fit = fit_calc.calculate_fit(*p.pos);
*p.fit = fit;
if fit > *p.best_fit {
*p.best_pos = *p.pos;
*p.best_fit = fit;
}
});
}
pub fn reduce_best<TUnit: Unit>(particles: &[Particle<TUnit>], best: &mut Best<TUnit>) {
for p in particles {
if p.best_fit > best.best_fit {
best.best_pos = p.best_pos;
best.best_fit = p.best_fit;
}
}
}
pub(crate) fn sync_particle_rngs(
master: &mut SmallRng,
particle_rngs: &mut Vec<SmallRng>,
n: usize,
) {
if particle_rngs.len() != n {
particle_rngs.clear();
particle_rngs.extend((0..n).map(|_| reseed_small_rng(master)));
}
}
#[derive(Clone, Debug)]
pub struct SearcherCore<TUnit, TFit, TMover>
where
TUnit: Unit,
TFit: FitCalc,
TMover: ParticleMover,
{
pub fit_calc: TFit,
pub mover: TMover,
pub swarm_best: Best<TUnit>,
pub master_rng: SmallRng,
pub particle_rngs: Vec<SmallRng>,
}
impl<TUnit, TFit, TMover> SearcherCore<TUnit, TFit, TMover>
where
TUnit: Unit,
TFit: FitCalc<T = TUnit>,
TMover: ParticleMover<TUnit = TUnit>,
{
pub fn new(fit_calc: TFit, mover: TMover) -> Self {
Self {
fit_calc,
mover,
swarm_best: Best::new(),
master_rng: rand::make_rng(),
particle_rngs: Vec::new(),
}
}
pub fn sync_rngs(&mut self, n: usize) {
sync_particle_rngs(&mut self.master_rng, &mut self.particle_rngs, n);
}
pub fn reseed<R: Rng>(&mut self, rng: &mut R) {
self.master_rng = reseed_small_rng(rng);
self.particle_rngs.clear();
}
}
impl<TUnit, TFit, TMover, TContext> Contextful for SearcherCore<TUnit, TFit, TMover>
where
TFit: FitCalc<T = TUnit, TContext = TContext>,
TMover: ParticleMover<TUnit = TUnit, TContext = TContext>,
TUnit: Unit,
TContext: Copy,
{
type TContext = TContext;
fn set_context(&mut self, context: Self::TContext) {
self.mover.set_context(context);
self.fit_calc.set_context(context);
}
fn set_iteration(&mut self, iteration: usize, max_iteration: usize) {
self.mover.set_iteration(iteration, max_iteration);
self.fit_calc.set_iteration(iteration, max_iteration);
}
}