rumus 0.2.0

A native-Rust deep learning framework with explicit memory safety and hardware acceleration
Documentation
// Row-wise Softmax with Log-Sum-Exp stability.
//
// One workgroup per row.  Three phases:
//   1. Reduce → max
//   2. Each thread: exp(x - max), reduce → sum_exp
//   3. Each thread: output = exp(x - max) / sum_exp
//
// Reuses unary_layout: input(read) + output(rw) + uniform.

struct SoftmaxParams {
    num_rows: u32,
    row_size: u32,
    _pad0: u32,
    _pad1: u32,
}
// 16 bytes ✓

@group(0) @binding(0) var<storage, read>       sm_input:  array<f32>;
@group(0) @binding(1) var<storage, read_write> sm_output: array<f32>;
@group(0) @binding(2) var<uniform>             sm_params: SoftmaxParams;

var<workgroup> shared_val: array<f32, 64>;

@compute @workgroup_size(64)
fn softmax_forward_kernel(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wgid: vec3<u32>,
) {
    let row = wgid.x;
    if (row >= sm_params.num_rows) { return; }
    let D = sm_params.row_size;
    let tid = lid.x;
    let base = row * D;

    // Phase 1: max
    var local_max: f32 = -3.402823e+38;
    var j = tid;
    while (j < D) {
        local_max = max(local_max, sm_input[base + j]);
        j += 64u;
    }
    shared_val[tid] = local_max;
    workgroupBarrier();
    var s: u32 = 32u;
    while (s > 0u) {
        if (tid < s) { shared_val[tid] = max(shared_val[tid], shared_val[tid + s]); }
        workgroupBarrier();
        s = s >> 1u;
    }
    let max_val = shared_val[0];
    workgroupBarrier();

    // Phase 2: sum of exp
    var local_sum: f32 = 0.0;
    j = tid;
    while (j < D) {
        local_sum += exp(sm_input[base + j] - max_val);
        j += 64u;
    }
    shared_val[tid] = local_sum;
    workgroupBarrier();
    s = 32u;
    while (s > 0u) {
        if (tid < s) { shared_val[tid] += shared_val[tid + s]; }
        workgroupBarrier();
        s = s >> 1u;
    }
    let sum_exp = shared_val[0];
    workgroupBarrier();

    // Phase 3: normalize
    j = tid;
    while (j < D) {
        sm_output[base + j] = exp(sm_input[base + j] - max_val) / sum_exp;
        j += 64u;
    }
}