hanzo-quant 0.6.1

Hanzo Engine - fast, flexible LLM inference engine written in Rust.
Documentation
#include <metal_stdlib>

using namespace metal;

// Fused: out = (rms_norm(x, weight, eps) + residual) [* scale if provided]
// One threadgroup per row; threads cooperate on the sum-of-squares reduction.
template <typename T>
void rmsnorm_residual_impl(device const T *x, device const T *residual,
                           device const T *weight, device const T *scale,
                           device T *out, constant uint &n_cols,
                           constant float &eps, constant uint &has_scale,
                           uint tid [[thread_index_in_threadgroup]],
                           uint row [[threadgroup_position_in_grid]],
                           uint tg_size [[threads_per_threadgroup]],
                           threadgroup float *shared) {
  const uint row_off = row * n_cols;

  float local_sum = 0.0f;
  for (uint i = tid; i < n_cols; i += tg_size) {
    float v = float(x[row_off + i]);
    local_sum += v * v;
  }
  shared[tid] = local_sum;
  threadgroup_barrier(mem_flags::mem_threadgroup);

  for (uint s = tg_size / 2; s > 0; s >>= 1) {
    if (tid < s) {
      shared[tid] += shared[tid + s];
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
  }

  const float inv_rms = rsqrt(shared[0] / float(n_cols) + eps);
  const float scale_v = has_scale ? float(scale[0]) : 1.0f;

  for (uint i = tid; i < n_cols; i += tg_size) {
    float normed = float(x[row_off + i]) * inv_rms * float(weight[i]);
    float v = (normed + float(residual[row_off + i])) * scale_v;
    out[row_off + i] = T(v);
  }
}

#define INSTANTIATE_RMSNORM_RESIDUAL(NAME, T)                                  \
  kernel void NAME(                                                            \
      device const T *x [[buffer(0)]], device const T *residual [[buffer(1)]], \
      device const T *weight [[buffer(2)]],                                    \
      device const T *scale [[buffer(3)]], device T *out [[buffer(4)]],        \
      constant uint &n_cols [[buffer(5)]], constant float &eps [[buffer(6)]],  \
      constant uint &has_scale [[buffer(7)]],                                  \
      uint tid [[thread_index_in_threadgroup]],                                \
      uint row [[threadgroup_position_in_grid]],                               \
      uint tg_size [[threads_per_threadgroup]],                                \
      threadgroup float *shared [[threadgroup(0)]]) {                          \
    rmsnorm_residual_impl<T>(x, residual, weight, scale, out, n_cols, eps,     \
                             has_scale, tid, row, tg_size, shared);            \
  }

INSTANTIATE_RMSNORM_RESIDUAL(rmsnorm_residual_f32, float)
INSTANTIATE_RMSNORM_RESIDUAL(rmsnorm_residual_f16, half)
INSTANTIATE_RMSNORM_RESIDUAL(rmsnorm_residual_bf16, bfloat)