diffusionx 0.12.0

A multi-threaded crate for random number generation and stochastic process simulation, with optional GPU acceleration.
use crate::{
    FloatExt, XResult,
    gpu::{CUDA_CTX, GPUMoment, OU_PTX},
    simulation::continuous::OrnsteinUhlenbeck as OU,
};
use cudarc::{
    driver::{CudaFunction, CudaModule},
    nvrtc::Ptx,
};
use std::sync::{Arc, LazyLock};

static MODULE: LazyLock<XResult<Arc<CudaModule>>> = LazyLock::new(|| {
    let ctx = CUDA_CTX.as_ref()?;
    let module = ctx.load_module(Ptx::from(OU_PTX))?;
    Ok(module)
});

static MEAN_KERNEL: LazyLock<XResult<CudaFunction>> = LazyLock::new(|| {
    let module = MODULE.as_ref()?;
    let kernel = module.load_function("mean")?;
    Ok(kernel)
});

static MSD_KERNEL: LazyLock<XResult<CudaFunction>> = LazyLock::new(|| {
    let module = MODULE.as_ref()?;
    let kernel = module.load_function("msd")?;
    Ok(kernel)
});

static RAW_MOMENT_KERNEL: LazyLock<XResult<CudaFunction>> = LazyLock::new(|| {
    let module = MODULE.as_ref()?;
    let kernel = module.load_function("raw_moment")?;
    Ok(kernel)
});

static FRAC_RAW_KERNEL: LazyLock<XResult<CudaFunction>> = LazyLock::new(|| {
    let module = MODULE.as_ref()?;
    let kernel = module.load_function("frac_raw_moment")?;
    Ok(kernel)
});

static CENTRAL_MOMENT_KERNEL: LazyLock<XResult<CudaFunction>> = LazyLock::new(|| {
    let module = MODULE.as_ref()?;
    let kernel = module.load_function("central_moment")?;
    Ok(kernel)
});

static FRAC_CENTRAL_KERNEL: LazyLock<XResult<CudaFunction>> = LazyLock::new(|| {
    let module = MODULE.as_ref()?;
    let kernel = module.load_function("frac_central_moment")?;
    Ok(kernel)
});

subscribe_gpu_function!(MODULE, mean, MEAN_KERNEL, (start_position: f32, theta: f32, sigma: f32, duration: f32, time_step: f32));

subscribe_gpu_function!(MODULE, msd, MSD_KERNEL, (start_position: f32, theta: f32, sigma: f32, duration: f32, time_step: f32));

subscribe_gpu_function!(MODULE, raw_moment, RAW_MOMENT_KERNEL, (start_position: f32, theta: f32, sigma: f32, order: i32, duration: f32, time_step: f32));

subscribe_gpu_function!(MODULE, frac_raw_moment, FRAC_RAW_KERNEL, (start_position: f32, theta: f32, sigma: f32, order: f32, duration: f32, time_step: f32));

subscribe_central_moment_gpu_function!(MODULE, central_moment, CENTRAL_MOMENT_KERNEL, (start_position: f32, theta: f32, sigma: f32, duration: f32, time_step: f32), i32);

subscribe_central_moment_gpu_function!(MODULE, frac_central_moment, FRAC_CENTRAL_KERNEL, (start_position: f32, theta: f32, sigma: f32, duration: f32, time_step: f32), f32);

impl<T: FloatExt> GPUMoment for OU<T> {
    fn central_moment_gpu(
        &self,
        duration: f32,
        order: i32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        central_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            order,
            particles,
        )
    }

    fn raw_moment_gpu(
        &self,
        duration: f32,
        order: i32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        raw_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            order,
            duration,
            time_step,
            particles,
        )
    }

    fn mean_gpu(&self, duration: f32, particles: usize, time_step: f32) -> XResult<f32> {
        mean(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            particles,
        )
    }

    fn msd_gpu(&self, duration: f32, particles: usize, time_step: f32) -> XResult<f32> {
        msd(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            particles,
        )
    }

    fn frac_central_moment_gpu(
        &self,
        duration: f32,
        order: f32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        frac_central_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            duration,
            time_step,
            order,
            particles,
        )
    }

    fn frac_raw_moment_gpu(
        &self,
        duration: f32,
        order: f32,
        particles: usize,
        time_step: f32,
    ) -> XResult<f32> {
        frac_raw_moment(
            self.get_start_position().to_f32().unwrap(),
            self.get_theta().to_f32().unwrap(),
            self.get_sigma().to_f32().unwrap(),
            order,
            duration,
            time_step,
            particles,
        )
    }
}

#[cfg(test)]
mod tests {
    use crate::gpu::GPUMoment;
    use crate::simulation::continuous::Bm;

    #[test]
    fn test_gpu_moment() {
        let bm = Bm::<f32>::default();
        bm.mean_gpu(1.0, 100, 0.1).unwrap();
        bm.msd_gpu(1.0, 100, 0.1).unwrap();
        bm.raw_moment_gpu(1.0, 2, 100, 0.1).unwrap();
        bm.frac_raw_moment_gpu(1.0, 1.4, 100, 0.1).unwrap();
        bm.central_moment_gpu(1.0, 2, 100, 0.1).unwrap();
        bm.frac_central_moment_gpu(1.0, 1.5, 100, 0.1).unwrap();
    }
}