use crate::searcher_impl::{SearcherCore, evaluate_pbests, par_for_each_mut_rng, reduce_best};
use crate::{Best, Contextful, FitCalc, Group, ParticleMover, Searcher, Unit};
use rand::Rng;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum LBestKind {
Ring {
k: u8,
},
VonNeumann,
}
pub(crate) fn build_neighbours(n: usize, kind: LBestKind) -> Vec<Vec<usize>> {
match kind {
LBestKind::Ring { k } => build_ring(n, k as usize),
LBestKind::VonNeumann => build_von_neumann(n),
}
}
fn build_ring(n: usize, k: usize) -> Vec<Vec<usize>> {
match n {
0 => return Vec::new(),
1 => return vec![Vec::new()],
2 => return vec![vec![1], vec![0]],
_ => {}
}
let k = k.min((n - 1) / 2);
let mut out: Vec<Vec<usize>> = Vec::with_capacity(n);
for i in 0..n {
let mut v = Vec::with_capacity(2 * k);
for offset in 1..=k {
v.push((i + n - offset) % n);
v.push((i + offset) % n);
}
out.push(v);
}
out
}
fn build_von_neumann(n: usize) -> Vec<Vec<usize>> {
#[expect(
clippy::cast_sign_loss,
reason = "n is usize, sqrt is non-negative, the as usize cannot drop a sign"
)]
let sqrt_n = (n as f64).sqrt() as usize;
let mut r = sqrt_n;
while r >= 3 && n % r != 0 {
r -= 1;
}
if r < 3 {
return build_ring(n, 1);
}
let c = n / r;
debug_assert!(r >= 3 && c >= 3, "torus dims must be >= 3");
debug_assert_eq!(r * c, n, "factorization must cover n exactly");
let mut out: Vec<Vec<usize>> = Vec::with_capacity(n);
for i in 0..n {
let row = i / c;
let col = i % c;
let up = ((row + r - 1) % r) * c + col;
let down = ((row + 1) % r) * c + col;
let left = row * c + (col + c - 1) % c;
let right = row * c + (col + 1) % c;
out.push(vec![up, down, left, right]);
}
out
}
#[derive(Clone, Debug)]
pub struct LBestSearcher<TUnit, TFit, TMover>
where
TUnit: Unit,
TFit: FitCalc,
TMover: ParticleMover,
{
core: SearcherCore<TUnit, TFit, TMover>,
kind: LBestKind,
neighbours: Vec<Vec<usize>>,
pbests_snapshot: Vec<Best<TUnit>>,
}
impl<TUnit, TFit, TMover> LBestSearcher<TUnit, TFit, TMover>
where
TUnit: Unit,
TFit: FitCalc<T = TUnit>,
TMover: ParticleMover<TUnit = TUnit>,
{
#[must_use]
#[track_caller]
pub fn new(fit_calc: TFit, mover: TMover, kind: LBestKind) -> Self {
assert_ne!(
kind,
LBestKind::Ring { k: 0 },
"LBestKind::Ring requires k >= 1",
);
Self {
core: SearcherCore::new(fit_calc, mover),
kind,
neighbours: Vec::new(),
pbests_snapshot: Vec::new(),
}
}
pub fn neighbours(&self) -> &[Vec<usize>] {
&self.neighbours
}
pub fn kind(&self) -> LBestKind {
self.kind
}
fn sync_neighbours(&mut self, n: usize) {
if self.neighbours.len() != n {
self.neighbours = build_neighbours(n, self.kind);
}
}
}
pub trait IntoLBestSearcher: ParticleMover + Sized {
#[must_use]
fn into_lbest_searcher<TFit>(
self,
fit_calc: TFit,
kind: LBestKind,
) -> LBestSearcher<Self::TUnit, TFit, Self>
where
TFit: FitCalc<T = Self::TUnit>,
{
LBestSearcher::new(fit_calc, self, kind)
}
}
impl<TUnit, M> IntoLBestSearcher for M
where
TUnit: Unit,
M: ParticleMover<TUnit = TUnit, TCommon = Best<TUnit>>,
{
}
impl<TUnit, TFit, TMover, TContext> Contextful for LBestSearcher<TUnit, TFit, TMover>
where
TFit: FitCalc<T = TUnit, TContext = TContext>,
TMover: ParticleMover<TUnit = TUnit, TCommon = Best<TUnit>, TContext = TContext>,
TUnit: Unit,
TContext: Copy,
{
type TContext = TContext;
fn set_context(&mut self, context: Self::TContext) {
self.core.set_context(context);
}
fn set_iteration(&mut self, iteration: usize, max_iteration: usize) {
self.core.set_iteration(iteration, max_iteration);
}
}
impl<TUnit, TFit, TMover, TContext> Searcher for LBestSearcher<TUnit, TFit, TMover>
where
TFit: FitCalc<T = TUnit, TContext = TContext>,
TMover: ParticleMover<TUnit = TUnit, TCommon = Best<TUnit>, TContext = TContext>,
TUnit: Unit,
TContext: Copy,
{
type TUnit = TUnit;
fn init(&mut self, particles: &mut Group<Self::TUnit>) {
let n = particles.len();
self.sync_neighbours(n);
self.core.sync_rngs(n);
self.core.swarm_best = Best::new();
evaluate_pbests(&self.core.fit_calc, particles);
reduce_best(particles, &mut self.core.swarm_best);
}
fn next(&mut self, particles: &mut Group<Self::TUnit>) -> &Best<TUnit> {
let n = particles.len();
self.sync_neighbours(n);
self.core.sync_rngs(n);
self.pbests_snapshot.clear();
self.pbests_snapshot.reserve(n);
for p in particles.iter() {
self.pbests_snapshot.push(Best {
best_pos: p.best_pos,
best_fit: p.best_fit,
});
}
let mover = &self.core.mover;
let snapshot = self.pbests_snapshot.as_slice();
let neighbours = self.neighbours.as_slice();
par_for_each_mut_rng(
particles,
&mut self.core.particle_rngs,
TMover::PAR_LEAF_SIZE,
&|idx, mut p, rng| {
let mut lbest = snapshot[idx];
for &j in &neighbours[idx] {
if snapshot[j].best_fit > lbest.best_fit {
lbest = snapshot[j];
}
}
mover.update(&lbest, rng, idx, &mut p);
},
);
evaluate_pbests(&self.core.fit_calc, particles);
reduce_best(particles, &mut self.core.swarm_best);
&self.core.swarm_best
}
fn reseed<R: Rng>(&mut self, rng: &mut R) {
self.core.reseed(rng);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ring_k1_has_two_neighbours_per_particle() {
let nb = build_neighbours(8, LBestKind::Ring { k: 1 });
assert_eq!(nb.len(), 8);
for v in &nb {
assert_eq!(v.len(), 2);
}
assert_eq!(nb[0], vec![7, 1]);
assert_eq!(nb[4], vec![3, 5]);
}
#[test]
fn ring_k2_has_four_neighbours_per_particle() {
let nb = build_neighbours(10, LBestKind::Ring { k: 2 });
assert_eq!(nb.len(), 10);
for v in &nb {
assert_eq!(v.len(), 4);
}
assert_eq!(nb[0], vec![9, 1, 8, 2]);
}
#[test]
fn von_neumann_lays_out_on_a_grid() {
let nb = build_neighbours(16, LBestKind::VonNeumann);
assert_eq!(nb.len(), 16);
for v in &nb {
assert_eq!(v.len(), 4);
}
assert_eq!(nb[0], vec![12, 4, 3, 1]);
}
#[test]
fn von_neumann_falls_back_to_ring_for_primes() {
let nb = build_neighbours(7, LBestKind::VonNeumann);
assert_eq!(nb.len(), 7);
for v in &nb {
assert_eq!(v.len(), 2);
}
}
#[test]
fn von_neumann_falls_back_to_ring_when_a_dim_would_be_2() {
for n in [4, 6, 10, 14] {
let nb = build_neighbours(n, LBestKind::VonNeumann);
assert_eq!(nb.len(), n, "n={n}");
for v in &nb {
assert_eq!(v.len(), 2, "n={n}: expected ring fallback (2 neighbours)");
}
}
}
#[test]
fn von_neumann_has_four_distinct_neighbours_when_built() {
for n in [9, 12, 15, 16, 20, 25] {
let nb = build_neighbours(n, LBestKind::VonNeumann);
assert_eq!(nb.len(), n, "n={n}");
for (i, v) in nb.iter().enumerate() {
assert_eq!(v.len(), 4, "n={n}, i={i}");
let mut sorted = v.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(
sorted.len(),
4,
"n={n}, i={i}: neighbours not distinct: {v:?}"
);
assert!(!v.contains(&i), "n={n}, i={i}: self-loop in {v:?}");
}
}
}
#[test]
fn empty_swarm_yields_empty_graph() {
assert!(build_neighbours(0, LBestKind::Ring { k: 1 }).is_empty());
assert!(build_neighbours(0, LBestKind::VonNeumann).is_empty());
}
#[test]
fn single_particle_has_no_neighbours_and_no_self_loop() {
let nb = build_neighbours(1, LBestKind::Ring { k: 1 });
assert_eq!(nb, vec![Vec::<usize>::new()]);
let nb = build_neighbours(1, LBestKind::VonNeumann);
assert_eq!(nb, vec![Vec::<usize>::new()]);
}
#[test]
fn two_particles_pair_with_each_other_without_duplicates() {
let nb = build_neighbours(2, LBestKind::Ring { k: 1 });
assert_eq!(nb, vec![vec![1], vec![0]]);
let nb = build_neighbours(2, LBestKind::VonNeumann);
assert_eq!(nb, vec![vec![1], vec![0]]);
}
}