mistralrs-quant 0.8.1

Fast, flexible LLM inference.
Documentation
#include <metal_stdlib>

template <typename T>
[[kernel]] void bitwise_or(const device T *a [[buffer(0)]],
                           const device T *b [[buffer(1)]],
                           device T *output [[buffer(2)]],
                           uint tid [[thread_position_in_grid]]) {
  output[tid] = a[tid] | b[tid];
}

#define instantiate_bitwise_or(type)                                           \
  template [[host_name("bitwise_or_" #type)]] [[kernel]] void                  \
  bitwise_or<type>(                                                            \
      const device type *a [[buffer(0)]], const device type *b [[buffer(1)]],  \
      device type *out [[buffer(2)]], uint tid [[thread_position_in_grid]]);

instantiate_bitwise_or(uint8_t);
instantiate_bitwise_or(uint32_t);
instantiate_bitwise_or(int64_t);
instantiate_bitwise_or(int);

template <typename T>
[[kernel]] void bitwise_xor(const device T *a [[buffer(0)]],
                            const device T *b [[buffer(1)]],
                            device T *output [[buffer(2)]],
                            uint tid [[thread_position_in_grid]]) {
  output[tid] = a[tid] ^ b[tid];
}

#define instantiate_bitwise_xor(type)                                          \
  template [[host_name("bitwise_xor_" #type)]] [[kernel]] void                 \
  bitwise_xor<type>(                                                           \
      const device type *a [[buffer(0)]], const device type *b [[buffer(1)]],  \
      device type *out [[buffer(2)]], uint tid [[thread_position_in_grid]]);

instantiate_bitwise_xor(uint8_t);
instantiate_bitwise_xor(uint32_t);
instantiate_bitwise_xor(int64_t);
instantiate_bitwise_xor(int);

template <typename T>
[[kernel]] void bitwise_and(const device T *a [[buffer(0)]],
                            const device T *b [[buffer(1)]],
                            device T *output [[buffer(2)]],
                            uint tid [[thread_position_in_grid]]) {
  output[tid] = a[tid] & b[tid];
}

#define instantiate_bitwise_and(type)                                          \
  template [[host_name("bitwise_and_" #type)]] [[kernel]] void                 \
  bitwise_and<type>(                                                           \
      const device type *a [[buffer(0)]], const device type *b [[buffer(1)]],  \
      device type *out [[buffer(2)]], uint tid [[thread_position_in_grid]]);

instantiate_bitwise_and(uint8_t);
instantiate_bitwise_and(uint32_t);
instantiate_bitwise_and(int64_t);
instantiate_bitwise_and(int);

template <typename T>
[[kernel]] void bitwise_leftshift(const device T *a [[buffer(0)]],
                                  device T *output [[buffer(1)]],
                                  device const uint &k,
                                  uint tid [[thread_position_in_grid]]) {
  output[tid] = a[tid] << k;
}

#define instantiate_bitwise_leftshift(type)                                    \
  template [[host_name("bitwise_leftshift_" #type)]] [[kernel]] void           \
  bitwise_leftshift<type>(                                                     \
      const device type *a [[buffer(0)]], device type *out [[buffer(1)]],      \
      device const uint &k, uint tid [[thread_position_in_grid]]);

instantiate_bitwise_leftshift(uint8_t);
instantiate_bitwise_leftshift(uint32_t);
instantiate_bitwise_leftshift(int64_t);
instantiate_bitwise_leftshift(int);

template <typename T>
[[kernel]] void bitwise_not(const device T *a [[buffer(0)]],
                            device T *output [[buffer(1)]],
                            uint tid [[thread_position_in_grid]]) {
  output[tid] = ~a[tid];
}

#define instantiate_bitwise_not(type)                                          \
  template [[host_name("bitwise_not_" #type)]] [[kernel]] void                 \
  bitwise_not<type>(const device type *a [[buffer(0)]],                        \
                    device type *out [[buffer(1)]],                            \
                    uint tid [[thread_position_in_grid]]);

instantiate_bitwise_not(uint8_t);
instantiate_bitwise_not(uint32_t);
instantiate_bitwise_not(int64_t);
instantiate_bitwise_not(int);