diffusionx 0.12.0

A multi-threaded crate for random number generation and stochastic process simulation, with optional GPU acceleration.
use crate::{
    FloatExt, XResult,
    gpu::{
        GPUMoment,
        metal::{MetalLibrary, OU_METALLIB, load_library},
    },
    simulation::continuous::OrnsteinUhlenbeck as OU,
};
use objc2::rc::Retained;
use std::sync::LazyLock;

static LIBRARY: LazyLock<XResult<Retained<MetalLibrary>>> =
    LazyLock::new(|| load_library(OU_METALLIB));

subscribe_metal_gpu_function!(LIBRARY, mean, "mean", (start_position: f32, theta: f32, sigma: f32, duration: f32, time_step: f32));

subscribe_metal_gpu_function!(LIBRARY, msd, "msd", (start_position: f32, theta: f32, sigma: f32, duration: f32, time_step: f32));

subscribe_metal_gpu_function!(LIBRARY, raw_moment, "raw_moment", (start_position: f32, theta: f32, sigma: f32, order: i32, duration: f32, time_step: f32));

subscribe_metal_gpu_function!(LIBRARY, frac_raw_moment, "frac_raw_moment", (start_position: f32, theta: f32, sigma: f32, order: f32, duration: f32, time_step: f32));

subscribe_metal_central_moment_gpu_function!(LIBRARY, central_moment, "central_moment", (start_position: f32, theta: f32, sigma: f32, duration: f32, time_step: f32), i32);

subscribe_metal_frac_central_moment_gpu_function!(
    LIBRARY,
    frac_central_moment,
    "frac_central_moment",
    (
        start_position: f32,
        theta: f32,
        sigma: f32,
        duration: f32,
        time_step: f32
    ),
    before_order = (),
    after_order = (start_position, theta, sigma, duration, time_step)
);

impl<T: FloatExt> GPUMoment for OU<T> {
    fn central_moment_gpu(
        &self,
        duration: f32,
        order: i32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        central_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            order,
            particles,
        )
    }

    fn raw_moment_gpu(
        &self,
        duration: f32,
        order: i32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        raw_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            order,
            duration,
            time_step,
            particles,
        )
    }

    fn mean_gpu(&self, duration: f32, particles: usize, time_step: f32) -> XResult<f32> {
        mean(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            particles,
        )
    }

    fn msd_gpu(&self, duration: f32, particles: usize, time_step: f32) -> XResult<f32> {
        msd(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            particles,
        )
    }

    fn frac_central_moment_gpu(
        &self,
        duration: f32,
        order: f32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        frac_central_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            order,
            particles,
        )
    }

    fn frac_raw_moment_gpu(
        &self,
        duration: f32,
        order: f32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        frac_raw_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            order,
            duration,
            time_step,
            particles,
        )
    }
}

#[cfg(test)]
mod tests {
    use crate::gpu::GPUMoment;
    use crate::simulation::continuous::OrnsteinUhlenbeck;

    #[test]
    fn test_gpu_moment() {
        let ou = OrnsteinUhlenbeck::<f32>::default();
        ou.mean_gpu(1.0, 100, 0.1).unwrap();
        ou.msd_gpu(1.0, 100, 0.1).unwrap();
        ou.raw_moment_gpu(1.0, 2, 100, 0.1).unwrap();
        ou.frac_raw_moment_gpu(1.0, 1.4, 100, 0.1).unwrap();
        ou.central_moment_gpu(1.0, 2, 100, 0.1).unwrap();
        ou.frac_central_moment_gpu(1.0, 1.5, 100, 0.1).unwrap();
    }
}