mistralrs-quant 0.8.1

Fast, flexible LLM inference.
Documentation
#include "utils.metal"
#include <metal_stdlib>

using namespace metal;

float dequantize_fp4_tree(unsigned char val, float absmax) {
  float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
  if ((val & 0b0100) == 4) {                // 0
    if ((val & 0b0010) == 2) {              // 01
      if ((val & 0b0001) == 1) {            // 111
        return 0.25000000f * absmax * sign; // 1111
      } else {
        return 0.16666667f * absmax * sign; // 1110
      }
    } else {
      if ((val & 0b0001) == 1) {            // 110
        return 0.50000000f * absmax * sign; // 1101
      } else {
        return 0.33333333f * absmax * sign; // 1100
      }
    }
  } else {
    if ((val & 0b0010) == 2) {              // 10
      if ((val & 0b0001) == 1) {            // 101
        return 1.00000000f * absmax * sign; // 1011
      } else {
        return 0.66666667f * absmax * sign; // 1010
      }
    } else {
      if ((val & 0b0001) == 1) {                 // 100
        return 5.208333333e-03f * absmax * sign; // 1001
      } else {
        return 0.00000000f * absmax * sign; // 1000
      }
    }
  }
}

float dequantize_nf4(unsigned char val) {
  if ((val & 0b1000) == 8) {
    if ((val & 0b0100) == 4) {     // 1
      if ((val & 0b0010) == 2) {   // 11
        if ((val & 0b0001) == 1) { // 111
          return 1.0f;
        } else {
          return 0.7229568362236023f;
        }
      } else {
        if ((val & 0b0001) == 1) { // 110
          return 0.5626170039176941f;
        } else {
          return 0.44070982933044434f;
        }
      }
    } else {
      if ((val & 0b0010) == 2) {   // 10
        if ((val & 0b0001) == 1) { // 101
          return 0.33791524171829224f;
        } else {
          return 0.24611230194568634f;
        }
      } else {
        if ((val & 0b0001) == 1) { // 100
          return 0.16093020141124725f;
        } else {
          return 0.07958029955625534f;
        }
      }
    }
  } else {
    if ((val & 0b0100) == 4) {     // 0
      if ((val & 0b0010) == 2) {   // 01
        if ((val & 0b0001) == 1) { // 011
          return 0.0f;
        } else {
          return -0.09105003625154495f;
        }
      } else {
        if ((val & 0b0001) == 1) { // 010
          return -0.18477343022823334f;
        } else {
          return -0.28444138169288635f;
        }
      }
    } else {
      if ((val & 0b0010) == 2) {   // 00
        if ((val & 0b0001) == 1) { // 001
          return -0.39491748809814453f;
        } else {
          return -0.5250730514526367f;
        }
      } else {
        if ((val & 0b0001) == 1) { // 000
          return -0.6961928009986877f;
        } else {
          return -1.0f;
        }
      }
    }
  }
}

template <typename T>
[[kernel]] void kernel_dequantize_nf4(const device float *code [[buffer(0)]],
                                      const device uchar *input [[buffer(1)]],
                                      const device float *absmax [[buffer(2)]],
                                      device T *out [[buffer(3)]],
                                      device const int &blocksize,
                                      device const int &n,
                                      uint id [[thread_position_in_grid]]) {

  int block_idx = id * blocksize;
  int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx);
  int block_end = block_idx + valid_items;

  for (int i = block_idx; i < block_end; ++i) {
    float local_abs_max = absmax[block_idx / (blocksize / 2)];

    uint8_t input_value = static_cast<uint8_t>(input[i]);
    float high_nibble = dequantize_nf4(input_value >> 4);
    float low_nibble = dequantize_nf4(input_value & 0x0F);

    out[i * 2] = static_cast<T>(high_nibble * local_abs_max);
    out[i * 2 + 1] = static_cast<T>(low_nibble * local_abs_max);
  }
}

template <typename T>
[[kernel]] void kernel_dequantize_fp4(const device float *code [[buffer(0)]],
                                      const device uchar *input [[buffer(1)]],
                                      const device float *absmax [[buffer(2)]],
                                      device T *out [[buffer(3)]],
                                      device const int &blocksize,
                                      device const int &n,
                                      uint id [[thread_position_in_grid]]) {

  int block_idx = id * blocksize;
  int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx);
  int block_end = block_idx + valid_items;

  for (int i = block_idx; i < block_end; ++i) {
    float local_abs_max = absmax[block_idx / (blocksize / 2)];

    // Extract the high and low nibbles from the input value
    uint8_t input_value = static_cast<uint8_t>(input[i]);
    float high_nibble = dequantize_fp4_tree(input_value >> 4, local_abs_max);
    float low_nibble = dequantize_fp4_tree(input_value & 0x0F, local_abs_max);

    out[i * 2] = static_cast<T>(high_nibble);
    out[i * 2 + 1] = static_cast<T>(low_nibble);
  }
}

template <typename T>
[[kernel]] void kernel_dequantize_int8(const device float *code [[buffer(0)]],
                                       const device uchar *input [[buffer(1)]],
                                       const device float *absmax [[buffer(2)]],
                                       device T *out [[buffer(3)]],
                                       device const int &blocksize,
                                       device const int &n,
                                       uint id [[thread_position_in_grid]]) {

  int block_idx = id * blocksize;
  int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx);
  int block_end = block_idx + valid_items;

  for (int i = block_idx; i < block_end; ++i) {
    float local_abs_max = absmax[block_idx / blocksize];

    out[i] = static_cast<T>(code[input[i]] * local_abs_max);
  }
}

#define instantiate_dequantize_nf4(type)                                       \
  template [[host_name("kernel_dequantize_nf4_" #type)]] [[kernel]] void       \
  kernel_dequantize_nf4<type>(                                                 \
      const device float *code [[buffer(0)]],                                  \
      const device uchar *input [[buffer(1)]],                                 \
      const device float *absmax [[buffer(2)]],                                \
      device type *out [[buffer(3)]], device const int &blocksize,             \
      device const int &n, uint id [[thread_position_in_grid]]);

instantiate_dequantize_nf4(float) instantiate_dequantize_nf4(bfloat16_t)
    instantiate_dequantize_nf4(half)

#define instantiate_dequantize_fp4(type)                                       \
  template [[host_name("kernel_dequantize_fp4_" #type)]] [[kernel]] void       \
  kernel_dequantize_fp4<type>(                                                 \
      const device float *code [[buffer(0)]],                                  \
      const device uchar *input [[buffer(1)]],                                 \
      const device float *absmax [[buffer(2)]],                                \
      device type *out [[buffer(3)]], device const int &blocksize,             \
      device const int &n, uint id [[thread_position_in_grid]]);

        instantiate_dequantize_fp4(float) instantiate_dequantize_fp4(bfloat16_t)
            instantiate_dequantize_fp4(half)

#define instantiate_dequantize_int8(type)                                      \
  template [[host_name("kernel_dequantize_int8_" #type)]] [[kernel]] void      \
  kernel_dequantize_int8<type>(                                                \
      const device float *code [[buffer(0)]],                                  \
      const device uchar *input [[buffer(1)]],                                 \
      const device float *absmax [[buffer(2)]],                                \
      device type *out [[buffer(3)]], device const int &blocksize,             \
      device const int &n, uint id [[thread_position_in_grid]]);

                instantiate_dequantize_int8(float)
                    instantiate_dequantize_int8(bfloat16_t)
                        instantiate_dequantize_int8(half)