diffusionx 0.12.0

A multi-threaded crate for random number generation and stochastic process simulation, with optional GPU acceleration.
/**
 * Metal Kernel for Random Number Generation
 *
 * This file implements various Metal compute kernels for generating
 * random numbers from different distributions.
 */

#include <metal_stdlib>
#include "utils.metal"
using namespace metal;

/**
 * @brief Generate exponential random numbers
 */
kernel void randexp(device float* out [[buffer(0)]],
                    constant uint& len [[buffer(1)]],
                    constant ulong& seed [[buffer(2)]],
                    uint idx [[thread_position_in_grid]]) {
    if (idx < len) {
        PhiloxState state = philox_init(seed, idx);
        float u = philox_uniform(state);
        out[idx] = -log(u);
    }
}

/**
 * @brief Sample from standard stable distribution when α = 1
 */
inline float sample_standard_alpha_one(float alpha, float beta, thread PhiloxState& state) {
    float v = philox_uniform(state) * M_PI_F - M_PI_2_F;
    float w = -log(philox_uniform(state));
    float half_pi_plus_beta_v = M_PI_2_F + beta * v;
    float c1 = half_pi_plus_beta_v * tan(v);
    float c2 = (M_PI_2_F * w * cos(v)) / log(half_pi_plus_beta_v) * beta;
    return (c1 - c2) * M_2_PI_F;
}

/**
 * @brief Sample from standard stable distribution when α ≠ 1
 */
inline float sample_standard_alpha_with_constants(
    float alpha, float inv_alpha, float one_minus_alpha_div_alpha,
    float b, float s, thread PhiloxState& state) {
    float v = philox_uniform(state) * M_PI_F - M_PI_2_F;
    float w = -log(philox_uniform(state));
    float v_plus_b = v + b;
    float cos_v = cos(v);
    float c1 = alpha * sin(v_plus_b) / pow(cos_v, inv_alpha);
    float c2 = pow(cos(v - alpha * v_plus_b) / w, one_minus_alpha_div_alpha);
    return s * c1 * c2;
}

/**
 * @brief Generate standard stable random numbers
 */
kernel void standard_stable_rand(device float* out [[buffer(0)]],
                                 constant float& alpha [[buffer(1)]],
                                 constant float& beta [[buffer(2)]],
                                 constant float& inv_alpha [[buffer(3)]],
                                 constant float& one_minus_alpha_div_alpha [[buffer(4)]],
                                 constant float& b [[buffer(5)]],
                                 constant float& s [[buffer(6)]],
                                 constant uint& len [[buffer(7)]],
                                 constant ulong& seed [[buffer(8)]],
                                 uint idx [[thread_position_in_grid]]) {
    if (idx < len) {
        PhiloxState state = philox_init(seed, idx);
        float r;
        if (abs(alpha - 1.0f) < 1e-3f) {
            r = sample_standard_alpha_one(alpha, beta, state);
        } else {
            r = sample_standard_alpha_with_constants(
                alpha, inv_alpha, one_minus_alpha_div_alpha, b, s, state);
        }
        out[idx] = r;
    }
}

/**
 * @brief Generate uniform random numbers in (0, 1]
 */
kernel void randuniform(device float* out [[buffer(0)]],
                        constant uint& len [[buffer(1)]],
                        constant ulong& seed [[buffer(2)]],
                        uint idx [[thread_position_in_grid]]) {
    if (idx < len) {
        PhiloxState state = philox_init(seed, idx);
        out[idx] = philox_uniform(state);
    }
}

/**
 * @brief Generate normal random numbers
 */
kernel void randnormal(device float* out [[buffer(0)]],
                       constant uint& len [[buffer(1)]],
                       constant float& mu [[buffer(2)]],
                       constant float& sigma [[buffer(3)]],
                       constant ulong& seed [[buffer(4)]],
                       uint idx [[thread_position_in_grid]]) {
    if (idx < len) {
        PhiloxState state = philox_init(seed, idx);
        out[idx] = mu + sigma * philox_normal(state);
    }
}