mistralrs-quant 0.8.1

Fast, flexible LLM inference.
Documentation
#include <metal_stdlib>
using namespace metal;

// Fused softmax with sinks kernel for GPT-OSS attention.
//
// Each threadgroup processes one row (batch, head, query) of length k_len.
// The sink value for the head is included in the softmax denominator
// but NOT written to the output.
//
// Input:
//   logits: [batch * heads * q_len, k_len] (contiguous)
//   sinks:  [heads] - per-head sink values
// Output:
//   output: [batch * heads * q_len, k_len] - softmax probabilities (sink
//   dropped)

template <typename T>
[[kernel]] void softmax_with_sinks(
    const device T *logits [[buffer(0)]], const device T *sinks [[buffer(1)]],
    device T *output [[buffer(2)]], constant uint &num_heads [[buffer(3)]],
    constant uint &q_len [[buffer(4)]], constant uint &k_len [[buffer(5)]],
    threadgroup float *shared_mem [[threadgroup(0)]],
    uint tgpig [[threadgroup_position_in_grid]],
    uint tpitg [[thread_position_in_threadgroup]],
    uint sgitg [[simdgroup_index_in_threadgroup]],
    uint tiisg [[thread_index_in_simdgroup]],
    uint ntg [[threads_per_threadgroup]]) {

  // Each threadgroup handles one row: (batch, head, query)
  const uint row_idx = tgpig;
  const uint h = (row_idx / q_len) % num_heads;

  // Pointers to this row
  const device T *row_logits = logits + row_idx * k_len;
  device T *row_output = output + row_idx * k_len;

  // Get sink value for this head
  const float sink_val = float(sinks[h]);

  // Shared memory layout: [0] = max, [1] = sum
  threadgroup float *s_max = shared_mem;
  threadgroup float *s_sum = shared_mem + 1;

  const uint tid = tpitg;
  constexpr uint SIMD_SIZE = 32;

  // Step 1: Find max (including sink)
  float local_max = -INFINITY;
  for (uint k = tid; k < k_len; k += ntg) {
    float val = float(row_logits[k]);
    local_max = max(local_max, val);
  }
  // Thread 0 includes sink in max
  if (tid == 0) {
    local_max = max(local_max, sink_val);
  }

  // Simdgroup reduction for max
  for (uint offset = SIMD_SIZE / 2; offset > 0; offset /= 2) {
    local_max = max(local_max, simd_shuffle_xor(local_max, offset));
  }

  // Cross-simdgroup reduction for max via shared memory
  const uint num_simdgroups = (ntg + SIMD_SIZE - 1) / SIMD_SIZE;
  threadgroup float *warp_scratch = shared_mem + 2; // after s_max, s_sum

  if (tiisg == 0) {
    warp_scratch[sgitg] = local_max;
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);

  if (tid < SIMD_SIZE) {
    local_max = (tid < num_simdgroups) ? warp_scratch[tid] : -INFINITY;
    for (uint offset = SIMD_SIZE / 2; offset > 0; offset /= 2) {
      local_max = max(local_max, simd_shuffle_xor(local_max, offset));
    }
    if (tid == 0) {
      *s_max = local_max;
    }
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);
  const float row_max = *s_max;

  // Step 2: Compute exp(x - max) and sum (including sink)
  float local_sum = 0.0f;
  for (uint k = tid; k < k_len; k += ntg) {
    float val = float(row_logits[k]);
    local_sum += exp(val - row_max);
  }
  // Thread 0 includes sink in sum
  if (tid == 0) {
    local_sum += exp(sink_val - row_max);
  }

  // Simdgroup reduction for sum
  for (uint offset = SIMD_SIZE / 2; offset > 0; offset /= 2) {
    local_sum += simd_shuffle_xor(local_sum, offset);
  }

  // Cross-simdgroup reduction for sum
  if (tiisg == 0) {
    warp_scratch[sgitg] = local_sum;
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);

  if (tid < SIMD_SIZE) {
    local_sum = (tid < num_simdgroups) ? warp_scratch[tid] : 0.0f;
    for (uint offset = SIMD_SIZE / 2; offset > 0; offset /= 2) {
      local_sum += simd_shuffle_xor(local_sum, offset);
    }
    if (tid == 0) {
      *s_sum = local_sum;
    }
  }
  threadgroup_barrier(mem_flags::mem_threadgroup);
  const float row_sum = *s_sum;

  // Step 3: Write normalized outputs (sink is NOT written - it's dropped)
  const float inv_sum = 1.0f / row_sum;
  for (uint k = tid; k < k_len; k += ntg) {
    float val = float(row_logits[k]);
    row_output[k] = T(exp(val - row_max) * inv_sum);
  }
}

#define instantiate_softmax_with_sinks(type)                                   \
  template [[host_name("softmax_with_sinks_" #type)]] [[kernel]] void          \
  softmax_with_sinks<type>(const device type *logits [[buffer(0)]],            \
                           const device type *sinks [[buffer(1)]],             \
                           device type *output [[buffer(2)]],                  \
                           constant uint &num_heads [[buffer(3)]],             \
                           constant uint &q_len [[buffer(4)]],                 \
                           constant uint &k_len [[buffer(5)]],                 \
                           threadgroup float *shared_mem [[threadgroup(0)]],   \
                           uint tgpig [[threadgroup_position_in_grid]],        \
                           uint tpitg [[thread_position_in_threadgroup]],      \
                           uint sgitg [[simdgroup_index_in_threadgroup]],      \
                           uint tiisg [[thread_index_in_simdgroup]],           \
                           uint ntg [[threads_per_threadgroup]]);

instantiate_softmax_with_sinks(float);
instantiate_softmax_with_sinks(half);
#if __METAL_VERSION__ >= 310
instantiate_softmax_with_sinks(bfloat);
#endif