mlx-native 0.6.2

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
/// Compute g and beta from alpha_logit, beta_logit, dt_bias, ssm_a.
///
/// g[t, vh]    = softplus(alpha_logit[t, vh] + dt_bias[vh]) * (-ssm_a[vh])
/// beta[t, vh] = sigmoid(beta_logit[t, vh])
///
/// All buffers are f32. Shapes: alpha_logit/beta_logit [seq, nv], g/beta [seq, nv],
/// dt_bias/ssm_a [nv].
///
/// Grid: dispatch `seq * nv` threads. Each thread handles one (t, vh) element.

#include <metal_stdlib>
using namespace metal;

kernel void compute_g_beta_f32(
    device const float *alpha_logit [[buffer(0)]],  // [seq, nv]
    device const float *beta_logit  [[buffer(1)]],  // [seq, nv]
    device const float *dt_bias     [[buffer(2)]],  // [nv]
    device const float *ssm_a       [[buffer(3)]],  // [nv] — stores -exp(A_log)
    device       float *g_out       [[buffer(4)]],  // [seq, nv]
    device       float *beta_out    [[buffer(5)]],  // [seq, nv]
    device const uint  *params      [[buffer(6)]],  // [seq, nv] packed as two u32
    uint tid [[thread_position_in_grid]]
) {
    const uint nv  = params[0];
    const uint seq = params[1];
    const uint n   = seq * nv;
    if (tid >= n) return;

    const uint t  = tid / nv;
    const uint vh = tid % nv;

    // g = softplus(alpha_logit + dt_bias) * (-ssm_a)
    // softplus(x) = log(1 + exp(x))
    // ssm_a stores -exp(A_log), so -ssm_a = exp(A_log) > 0.
    // Use precise:: math to match CPU f32 results (avoids fast-math ULP drift).
    const float a_logit = alpha_logit[t * nv + vh] + dt_bias[vh];
    float sp;
    if (a_logit > 20.0f) {
        sp = a_logit;
    } else if (a_logit < -20.0f) {
        sp = 0.0f;
    } else {
        sp = precise::log(1.0f + precise::exp(a_logit));
    }
    g_out[tid] = sp * (-ssm_a[vh]);

    // beta = sigmoid(beta_logit)
    beta_out[tid] = 1.0f / (1.0f + precise::exp(-beta_logit[t * nv + vh]));
}