pso 0.2.0

Particle Swarm Optimizer
Documentation
use super::minima::OwnedRecord;
use std::sync::Arc;

#[derive(Debug, Clone)]
pub struct Particle {
    pos: Vec<f64>,
    vel: Vec<f64>,
    min: Vec<f64>,
    num_vars: usize,
    pub min_cost: f64,
}

impl Particle {
    pub fn new(pos: &[f64], vel: &[f64]) -> Self {
        let mut pos_vec = Vec::new();
        let mut vel_vec = Vec::new();

        pos_vec.extend_from_slice(pos);
        vel_vec.extend_from_slice(vel);

        let num_vars = pos_vec.len();

        Self {
            min: pos_vec.clone(),
            pos: pos_vec,
            vel: vel_vec,
            num_vars,
            min_cost: std::f64::MAX,
        }
    }

    pub fn cost_fn_mut(&mut self, objective: &mut impl FnMut(&[f64]) -> f64) -> Option<f64> {
        let cost_here = objective(&self.pos);

        if cost_here < self.min_cost {
            self.min = self.pos.clone();
            self.min_cost = cost_here;
            return Some(cost_here);
        }

        None
    }

    pub fn cost_fn<T>(&mut self, objective: &Arc<T>) -> Option<f64>
    where
        T: Fn(&[f64]) -> f64,
    {
        let cost_here = objective(&self.pos);

        if cost_here < self.min_cost {
            self.min = self.pos.clone();
            self.min_cost = cost_here;
            return Some(cost_here);
        }

        None
    }

    pub fn update_syn(
        &mut self,
        momentum: f64,
        rand_vecs: Vec<Vec<f64>>,
        motion_coeffs: &[f64],
        tribal_best: &[f64],
        global_best: &[f64],
    ) {
        // for (i, (p, v, m)) in izip!(
        //     self.pos.iter_mut(),
        //     self.vel.iter_mut(),
        //     self.min.iter_mut()
        // )
        // .enumerate()
        // {
        //     // move
        //     *p += *v;

        //     // update velocity
        //     *v = momentum * *v
        //         + rand_vecs[0][i] * motion_coeffs[0] * (*m - *p)
        //         + rand_vecs[1][i] * motion_coeffs[1] * (tribal_best[i] - *p)
        //         + rand_vecs[2][i] * motion_coeffs[2] * (global_best[i] - *p);
        // }

        for i in 0..self.num_vars {
            // move
            self.pos[i] += self.vel[i];

            // set velocity
            self.vel[i] = momentum * self.vel[i]
                + rand_vecs[0][i] * motion_coeffs[0] * (self.min[i] - self.pos[i])
                + rand_vecs[1][i] * motion_coeffs[1] * (tribal_best[i] - self.pos[i])
                + rand_vecs[2][i] * motion_coeffs[2] * (global_best[i] - self.pos[i]);
        }
    }

    pub fn update_sol(
        &mut self,
        momentum: f64,
        rand_vecs: Vec<Vec<f64>>,
        motion_coeffs: &[f64],
        tribal_best: &[f64],
    ) {
        // for (i, (p, v, m)) in izip!(
        //     self.pos.iter_mut(),
        //     self.vel.iter_mut(),
        //     self.min.iter_mut()
        // )
        // .enumerate()
        // {
        //     // move
        //     *p += *v;

        //     // update velocity
        //     *v = momentum * *v
        //         + rand_vecs[0][i] * motion_coeffs[0] * (*m - *p)
        //         + rand_vecs[1][i] * motion_coeffs[1] * (tribal_best[i] - *p);
        // }

        for i in 0..self.num_vars {
            // move
            self.pos[i] += self.vel[i];

            // set velocity
            self.vel[i] = momentum * self.vel[i]
                + rand_vecs[0][i] * motion_coeffs[0] * (self.min[i] - self.pos[i])
                + rand_vecs[1][i] * motion_coeffs[1] * (tribal_best[i] - self.pos[i]);
        }
    }

    pub fn enforce_bounds(&mut self, bound_state: (&[[f64; 2]], &[[f64; 2]], f64)) {
        for (i, (p, pb)) in self.pos.iter_mut().zip(bound_state.0.iter()).enumerate() {
            if *p < pb[0] {
                *p = pb[0];
                self.vel[i] *= bound_state.2;
            } else if *p > pb[1] {
                *p = pb[1];
                self.vel[i] *= bound_state.2;
            }
        }

        for (v, vb) in self.vel.iter_mut().zip(bound_state.1.iter()) {
            if *v < vb[0] {
                *v = vb[0];
            } else if *v > vb[1] {
                *v = vb[1];
            }
        }
    }

    pub fn get_record(&self) -> OwnedRecord {
        OwnedRecord::build(self.min.clone(), self.min_cost)
    }
}