mistralrs-core 0.8.1

Fast, flexible LLM inference.
Documentation
// Selective SSM (Mamba-2 style) Metal kernel for mistral.rs
// Ported from CUDA kernel in cuda/ssm.cu
//
// This kernel implements the selective scan operation:
//   state[t] = state[t-1] * exp(dt[t] * A) + dt[t] * B[t] * x[t]
//   y[t] = sum_d(state[t] * C[t]) + D * x[t]

#include <metal_stdlib>
using namespace metal;

#define SIMD_SIZE 32

static inline float simd_sum(float x) {
  // Use simd_shuffle to perform a warp-level reduction
  for (ushort offset = SIMD_SIZE / 2; offset > 0; offset >>= 1) {
    x += simd_shuffle_xor(x, offset);
  }
  return x;
}

// Mamba-2 SSM scan kernel (group-aware)
//
// Each simdgroup (32 threads) processes one (head, head_offset) pair across
// all timesteps. Lanes within the simdgroup process c_factor items each.
//
// Template parameter c_factor = items per lane = ceil(d_state / SIMD_SIZE)

template <int c_factor>
[[kernel]] void ssm_scan_kernel(
    const device float *x [[buffer(0)]], // (batch, seq_len, n_heads * head_dim)
    const device float *dt [[buffer(1)]], // (batch, seq_len, n_heads)
    const device float *A [[buffer(2)]],  // (n_heads,)
    const device float *B [[buffer(3)]],  // (batch, seq_len, n_heads * d_state)
    const device float *C [[buffer(4)]],  // (batch, seq_len, n_heads * d_state)
    const device float *D [[buffer(5)]],  // (n_heads,)
    const device float *dt_bias [[buffer(6)]], // (n_heads,)
    device float *state [[buffer(7)]], // (batch, n_heads, head_dim, d_state)
    device float *y [[buffer(8)]],     // (batch, seq_len, n_heads * head_dim)
    constant int &n_heads [[buffer(9)]], constant int &head_dim [[buffer(10)]],
    constant int &d_state [[buffer(11)]], constant int &seq_len [[buffer(12)]],
    constant float &dt_min [[buffer(13)]],
    constant float &dt_max [[buffer(14)]],
    uint2 tgpig [[threadgroup_position_in_grid]], // (warp_idx, batch_idx)
    uint lane [[thread_index_in_simdgroup]]) {
  const int batch_idx = tgpig.y;
  const int warp_idx = tgpig.x;

  const int head_idx = warp_idx / head_dim;
  const int head_off = warp_idx % head_dim;

  if (head_idx >= n_heads)
    return;

  const int d_inner = n_heads * head_dim;

  // Load A value for this head
  const float a_val = A[head_idx];

  // Load D value for this head
  const float d_val = D[head_idx];

  // Load dt_bias for this head
  const float dt_bias_val = dt_bias[head_idx];

  // Initialize state from input state tensor
  // state layout: (batch, n_heads, head_dim, d_state)
  const int state_base = batch_idx * n_heads * head_dim * d_state +
                         head_idx * head_dim * d_state + head_off * d_state;
  float reg_state[c_factor];
  for (int j = 0; j < c_factor; j++) {
    reg_state[j] = state[state_base + (int)lane * c_factor + j];
  }

  // Process each timestep
  for (int t = 0; t < seq_len; t++) {
    // Load dt and apply bias + softplus + clamp
    float dt_val = dt[batch_idx * seq_len * n_heads + t * n_heads + head_idx];
    dt_val += dt_bias_val;
    // softplus: log(1 + exp(x))
    dt_val = (dt_val > 20.0f) ? dt_val : log(1.0f + exp(dt_val));
    // clamp
    dt_val = min(max(dt_val, dt_min), dt_max);

    // dA = exp(dt * A)
    const float dA = exp(dt_val * a_val);

    // x value for this (head, head_offset) at timestep t
    const float x_val = x[batch_idx * seq_len * d_inner + t * d_inner +
                          head_idx * head_dim + head_off];
    const float x_dt = x_val * dt_val;

    // Compute state update and output dot product
    float state_sum = 0.0f;
    const int bc_base = batch_idx * seq_len * n_heads * d_state +
                        t * n_heads * d_state + head_idx * d_state;
    for (int j = 0; j < c_factor; j++) {
      const int state_off = (int)lane * c_factor + j;
      const float b_val = B[bc_base + state_off];
      const float c_val = C[bc_base + state_off];

      // state = state * dA + B * x_dt
      reg_state[j] = reg_state[j] * dA + b_val * x_dt;
      state_sum += reg_state[j] * c_val;
    }

    // Simd-reduce the output across lanes
    state_sum = simd_sum(state_sum);

    // Lane 0 writes the output (+ D skip connection)
    if (lane == 0) {
      y[batch_idx * seq_len * d_inner + t * d_inner + head_idx * head_dim +
        head_off] = state_sum + d_val * x_val;
    }
  }

  // Write final state back
  for (int j = 0; j < c_factor; j++) {
    state[state_base + (int)lane * c_factor + j] = reg_state[j];
  }
}

// Explicit instantiations for common d_state / SIMD_SIZE ratios
template [[host_name("ssm_scan_c1")]] [[kernel]]
void ssm_scan_kernel<1>(const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, device float *, device float *,
                        constant int &, constant int &, constant int &,
                        constant int &, constant float &, constant float &,
                        uint2, uint);

template [[host_name("ssm_scan_c2")]] [[kernel]]
void ssm_scan_kernel<2>(const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, device float *, device float *,
                        constant int &, constant int &, constant int &,
                        constant int &, constant float &, constant float &,
                        uint2, uint);

template [[host_name("ssm_scan_c4")]] [[kernel]]
void ssm_scan_kernel<4>(const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, device float *, device float *,
                        constant int &, constant int &, constant int &,
                        constant int &, constant float &, constant float &,
                        uint2, uint);

template [[host_name("ssm_scan_c8")]] [[kernel]]
void ssm_scan_kernel<8>(const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, const device float *,
                        const device float *, device float *, device float *,
                        constant int &, constant int &, constant int &,
                        constant int &, constant float &, constant float &,
                        uint2, uint);