opt_solver 0.1.1

Common optimization algorithms
Documentation
use super::*;

#[inline]
pub fn min_par<
    const I8: usize,
    const I16: usize,
    const I32: usize,
    const I64: usize,
    const I128: usize,
    const I: usize,
    const F32: usize,
    const F64: usize,
    const C64: usize,
    const C128: usize,
    const D: usize,
    O: PartialOrd + Clone + Sync + Send,
>(
    pso: &PSO,
    f: impl Fn(X<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>) -> O + Sync,
    x_min: X<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>,
    x_max: X<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>,
) -> (O, X<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>) {
    let mut rng: ThreadRng = thread_rng();

    let x_bounds = XBounds::new(x_min.clone(), x_max.clone());

    let mut v_max = x_max.clone() - x_min.clone();
    v_max.mul_assign_f64(pso.max_v);
    let mut v_min = -v_max.clone();

    struct ParState<
        const I8: usize,
        const I16: usize,
        const I32: usize,
        const I64: usize,
        const I128: usize,
        const I: usize,
        const F32: usize,
        const F64: usize,
        const C64: usize,
        const C128: usize,
        const D: usize,
        O: PartialOrd + Clone,
    >(
        pub X<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>,
        pub X<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>,
        pub X<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>,
        pub O,
    );
    unsafe impl<
            const I8: usize,
            const I16: usize,
            const I32: usize,
            const I64: usize,
            const I128: usize,
            const I: usize,
            const F32: usize,
            const F64: usize,
            const C64: usize,
            const C128: usize,
            const D: usize,
            O: PartialOrd + Clone,
        > Send for ParState<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D, O>
    {
    }

    unsafe impl<
            const I8: usize,
            const I16: usize,
            const I32: usize,
            const I64: usize,
            const I128: usize,
            const I: usize,
            const F32: usize,
            const F64: usize,
            const C64: usize,
            const C128: usize,
            const D: usize,
            O: PartialOrd + Clone,
        > Sync for ParState<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D, O>
    {
    }
    let mut status: Vec<ParState<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D, O>> = x_bounds
        .sample_iter(&mut rng)
        .take(pso.n_particles)
        .map(|x| {
            ParState(
                x.clone(),
                X::<I8, I16, I32, I64, I128, I, F32, F64, C64, C128, D>::zero(),
                x.clone(),
                f(x),
            )
        })
        .collect();

    let (mut x_best, mut out_best) = {
        let best = status
            .par_iter()
            .min_by(|&a, &b| a.3.partial_cmp(&b.3).expect("failed to compare"))
            .expect("n_particles is 0!");
        (best.2.clone(), best.3.clone())
    };

    for _ in 0..pso.n_iters {
        (x_best, out_best) = status
            .par_iter_mut()
            .fold(
                || (x_best.clone(), out_best.clone()),
                |(mut x_best, mut out_best), ParState(xt, vt, x_best_pso, out_best_pso)| {
                    *vt = (x_best.clone() - xt.clone()).mul_f64(pso.c1)
                        + (x_best_pso.clone() - xt.clone()).mul_f64(pso.c2);

                    *xt += vt.clone();

                    macro_rules! if_out_off_bounds {
                        ($t:tt) => {
                            xt.$t
                                .iter_mut()
                                .zip(vt.$t.iter_mut())
                                .zip(x_min.$t.into_iter())
                                .zip(x_max.$t.into_iter())
                                .zip(v_min.$t.into_iter())
                                .zip(v_max.$t.into_iter())
                                .for_each(
                                    |(((((xt_i, vt_i), x_min_i), x_max_i), v_min_i), v_max_i)| {
                                        if *xt_i < x_min_i {
                                            *xt_i = x_min_i;
                                            *vt_i = 0;
                                        } else if *xt_i > x_max_i {
                                            *xt_i = x_max_i;
                                            *vt_i = 0;
                                        } else if *vt_i < v_min_i {
                                            *vt_i = v_min_i;
                                        } else if *xt_i > v_max_i {
                                            *vt_i = v_max_i;
                                        }
                                    },
                                )
                        };
                        ($t:tt,$zero:expr) => {
                            xt.$t
                                .iter_mut()
                                .zip(vt.$t.iter_mut())
                                .zip(x_min.$t.into_iter())
                                .zip(x_max.$t.into_iter())
                                .zip(v_min.$t.into_iter())
                                .zip(v_max.$t.into_iter())
                                .for_each(
                                    |(((((xt_i, vt_i), x_min_i), x_max_i), v_min_i), v_max_i)| {
                                        if *xt_i < x_min_i {
                                            *xt_i = x_min_i;
                                            *vt_i = $zero;
                                        } else if *xt_i > x_max_i {
                                            *xt_i = x_max_i;
                                            *vt_i = $zero;
                                        } else if *vt_i < v_min_i {
                                            *vt_i = v_min_i;
                                        } else if *xt_i > v_max_i {
                                            *vt_i = v_max_i;
                                        }
                                    },
                                )
                        };
                    }
                    if_out_off_bounds!(i8);
                    if_out_off_bounds!(i16);
                    if_out_off_bounds!(i32);
                    if_out_off_bounds!(i64);
                    if_out_off_bounds!(i128);
                    xt.i.iter_mut()
                        .zip(vt.i.iter_mut())
                        .zip(x_min.i.iter())
                        .zip(x_max.i.iter())
                        .zip(v_min.i.iter())
                        .zip(v_max.i.iter())
                        .for_each(|(((((xt_i, vt_i), x_min_i), x_max_i), v_min_i), v_max_i)| {
                            if *xt_i < *x_min_i {
                                *xt_i = x_min_i.clone();
                                *vt_i = BigInt::from(0);
                            } else if *xt_i > *x_max_i {
                                *xt_i = x_max_i.clone();
                                *vt_i = BigInt::from(0);
                            } else if *vt_i < *v_min_i {
                                *vt_i = v_max_i.clone();
                            } else if *xt_i > *v_max_i {
                                *vt_i = v_max_i.clone();
                            }
                        });
                    if_out_off_bounds!(f32, 0.0);
                    if_out_off_bounds!(f64, 0.0);
                    if_out_off_bounds!(d, Decimal::from(0));
                    let out_t = f(xt.clone());
                    if out_t < *out_best_pso {
                        if out_t < out_best {
                            *out_best_pso = out_t.clone();
                            *x_best_pso = xt.clone();
                            out_best = out_t;
                            x_best = xt.clone();
                        } else {
                            *out_best_pso = out_t;
                            *x_best_pso = xt.clone();
                        }
                    }
                    (x_best, out_best)
                },
            )
            .min_by(|a, b| a.1.partial_cmp(&b.1).expect("failed to compare"))
            .expect("n_particles is 0!");
        status
            .par_iter_mut()
            .for_each(|state| state.1.mul_assign_f64(pso.inertia))
    }

    (out_best, x_best)
}