// Portions of this file are adapted from Apple's MLX framework
// (https://github.com/ml-explore/mlx)
// Licensed under the Apache License 2.0
// Copyright © 2023 Apple Inc.
#include "float4.metal"
#include "utils.metal"
#include <metal_common>
#include <metal_math>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
using namespace metal;
typedef half float16_t;
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT> struct TransformNone {
static METAL_FUNC OutT apply(InT x) { return static_cast<OutT>(x); }
static METAL_FUNC OutT apply(InT x, OutT) { return static_cast<OutT>(x); }
};
template <typename OutT, typename InT> struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x) { return static_cast<OutT>(x); }
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT> struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(InT x) { return static_cast<OutT>(x); }
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T> struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2 swizzle(uint3 tid [[threadgroup_position_in_grid]],
const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape,
constant const int64_t *a_strides,
constant const int64_t *b_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
int pos_in_dim = (elem % shape[i]);
elem /= shape[i];
loc_a += pos_in_dim * a_strides[i];
loc_b += pos_in_dim * b_strides[i];
}
return ulong2(loc_a, loc_b);
}
METAL_FUNC ulong3 elem_to_loc_broadcast(uint elem, constant const int *shape,
constant const int64_t *a_strides,
constant const int64_t *b_strides,
constant const int64_t *c_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};
ulong loc_c{0};
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
int pos_in_dim = (elem % shape[i]);
elem /= shape[i];
loc_a += pos_in_dim * a_strides[i];
loc_b += pos_in_dim * b_strides[i];
loc_c += pos_in_dim * c_strides[i];
}
return ulong3(loc_a, loc_b, loc_c);
}
template <int val> using Int = integral_constant<int, val>;
#pragma METAL internals : enable
namespace metal {
template <typename T> struct is_empty : metal::bool_constant<__is_empty(T)> {};
#ifdef __cpp_variable_templates
template <typename T> constexpr constant bool is_empty_v = is_empty<T>::value;
#endif
template <typename... Ts> struct make_void {
typedef void type;
};
template <typename... Ts> using void_t = typename make_void<Ts...>::type;
template <class T>
struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};
template <typename T> struct pointer_element {};
template <typename T> struct pointer_element<thread T *> {
using type = remove_cv_t<T>;
};
template <typename T> struct pointer_element<device T *> {
using type = remove_cv_t<T>;
};
template <typename T> struct pointer_element<constant T *> {
using type = remove_cv_t<T>;
};
template <typename T> struct pointer_element<threadgroup T *> {
using type = remove_cv_t<T>;
};
template <typename T>
using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
} // namespace metal
#pragma METAL internals : disable
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename RInt, typename CInt> struct Shape2D {
RInt r;
CInt c;
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
};
template <typename Shape, typename Layout> struct Layout2D {
Shape shape;
Layout layout;
};
template <typename T, int kFragRows_, int kFragCols_> struct BaseMMAFrag {
static_assert(kFragRows_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
static_assert(kFragCols_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
};
template <typename T> struct BaseMMAFrag<T, 8, 8> {
STEEL_CONST int kFragRows = 8;
STEEL_CONST int kFragCols = 8;
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST int kElemRows = 1;
STEEL_CONST int kElemCols = 2;
static_assert(kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type;
typedef metal::vec<T, kElemRows> row_frag_type;
typedef metal::vec<T, kElemCols> col_frag_type;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4;
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
return short2{fn, fm};
}
template <typename SrcPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void load(thread frag_type &dst, SrcPtrType src,
StrX str_x, StrY str_y) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] =
static_cast<T>(src[i * str_x.value + j * str_y.value]);
}
}
}
template <typename SrcPtrType, typename StrX, typename StrY, typename LimX,
typename LimY, typename OffX, typename OffY>
METAL_FUNC static constexpr void
load_safe(thread frag_type &dst, SrcPtrType src, StrX str_x, StrY str_y,
LimX lim_x, LimY lim_y, OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <typename DstPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
store(const thread frag_type &src, DstPtrType dst, StrX str_x, StrY str_y) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * str_x + j * str_y.value] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
template <typename DstPtrType, typename StrX, typename StrY, typename LimX,
typename LimY, typename OffX, typename OffY>
METAL_FUNC static constexpr void
store_safe(const thread frag_type &src, DstPtrType dst, StrX str_x,
StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y.value] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(thread frag_type &D, thread frag_type &A,
thread frag_type &B,
thread frag_type &C) {
mat_type D_mat;
mat_type A_mat;
mat_type B_mat;
mat_type C_mat;
reinterpret_cast<thread frag_type &>(A_mat.thread_elements()) = A;
reinterpret_cast<thread frag_type &>(B_mat.thread_elements()) = B;
reinterpret_cast<thread frag_type &>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type &>(D_mat.thread_elements());
}
METAL_FUNC static constexpr void mma(thread mat_type &D, thread mat_type &A,
thread mat_type &B, thread mat_type &C) {
simdgroup_multiply_accumulate(D, A, B, C);
}
template <typename Op>
METAL_FUNC static constexpr void row_reduce(thread const frag_type &inp_vals,
thread T *reduced_vals) {
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
}
template <typename Op>
METAL_FUNC static constexpr void row_bin_op(thread frag_type &inp_vals,
thread T *row_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
inp_vals[i * kElemCols + j] =
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
}
}
}
};
template <typename T, int kTileRows_, int kTileCols_,
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
struct MMATile {
using MMAFrag_t = MMAFrag_;
using elem_type = T;
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
STEEL_CONST int kTileRows = kTileRows_;
STEEL_CONST int kTileCols = kTileCols_;
STEEL_CONST int kRows = kTileRows * kFragRows;
STEEL_CONST int kCols = kTileCols * kFragCols;
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags] = {frag_type(0)};
METAL_FUNC MMATile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type &frag_at(const short i, const short j) {
return val_frags[i * kTileCols + j];
}
METAL_FUNC constexpr const thread frag_type &frag_at(const short i,
const short j) const {
return val_frags[i * kTileCols + j];
}
METAL_FUNC mat_type mat_at(const short i, const short j) {
mat_type val_mat;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
}
return val_mat;
}
METAL_FUNC thread elem_type *elems() {
return reinterpret_cast<thread elem_type *>(val_frags);
}
METAL_FUNC const thread elem_type *elems() const {
return reinterpret_cast<const thread elem_type *>(val_frags);
}
template <typename Op>
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_reduce<Op>(frag_at(i, j),
&vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::template row_bin_op<Op>(frag_at(i, j),
&vals[i * MMAFrag_t::kElemRows]);
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U *src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(frag_at(i, j),
&(src[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{}, Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void store(threadgroup U *dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(frag_at(i, j),
&(dst[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{}, Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load(const device U *src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store(device U *dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load_safe(const device U *src, const int ld,
const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::load_safe(frag_at(i, j), src, ld, Int<1>{}, src_tile_dims.y,
src_tile_dims.x, (i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store_safe(device U *dst, const int ld,
const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_safe(frag_at(i, j), dst, ld, Int<1>{}, dst_tile_dims.y,
dst_tile_dims.x, (i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
METAL_FUNC void
tile_matmad(thread MMATile<T, M, N> &D, thread MMATile<U, M, K> &A,
thread MMATile<U, K, N> &B, thread MMATile<T, M, N> &C) {
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
MMATile<T, M, N>::MMAFrag_t::mma(D.frag_at(m, n_serp), A.frag_at(m, k),
B.frag_at(k, n_serp),
C.frag_at(m, n_serp));
}
}
}
}
template <typename T, typename U, int BM, int BN, int BK, int WM, int WN,
bool transpose_a, bool transpose_b, short lda_tgp, short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// Simdgroup matrices
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
// Offsets within threadgroup
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
sm += tm;
sn += tn;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
simdgroup_barrier(mem_flags::mem_none);
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Ctile, Atile, Btile, Ctile);
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U *D, const int ldd) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void store_result_safe(device U *D, const int ldd,
short2 dst_tile_dims) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(const device U *C, const int ldc,
const int fdc,
thread const BinaryEpilogue &epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto &accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void
apply_epilogue_safe(const device U *C, const int ldc, const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue &epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto &accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Read C
U c_elems[kelems] = {0};
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
c_elems[k] = C[offset_c + k * fdc];
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C,
const int ldc, const int fdc,
thread const Epilogue &epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto &accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(device U *D, const int ldd,
const device U *C, const int ldc,
const int fdc, short2 dst_tile_dims,
thread const Epilogue &epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto &accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[offset_d + k] =
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename T, short BROWS, short BCOLS, short dst_ld,
short reduction_dim, short tgp_size, short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T *dst;
const device T *src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC
BlockLoader(const device T *src_, const int src_ld_, threadgroup T *dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_), tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector *)(&dst[i * dst_ld])) =
*((const device ReadVector *)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() { src += tile_stride; }
};
template <int R, int C> struct CShape {
STEEL_CONST int kRows = R;
STEEL_CONST int kCols = C;
};
template <typename T, short BROWS, short BCOLS, short kDstStrRow,
short kDstStrCol, short reduction_dim, short tgp_size,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS>
struct BlockLoaderT {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T *dst;
const device T *src;
/* Constructor */
METAL_FUNC
BlockLoaderT(const device T *src_, const int src_ld_, threadgroup T *dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_), tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp &op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] =
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() { src += tile_stride; }
};
} // namespace steel
} // namespace mlx
MLX_MTL_CONST int SIMD_SIZE = 32;
MLX_MTL_CONST int QUAD_SIZE = 4;
// Helper to load scale based on bit width
template <typename T, typename S, int bits>
inline T load_scale(const device S *scale_ptr) {
if (bits == 40) {
// For mxfp4, scale is stored as uint8_t UM8E0 format
const device uint8_t *uint_scale = (const device uint8_t *)scale_ptr;
return static_cast<T>(scale_to_float(*uint_scale));
} else {
return static_cast<T>(*scale_ptr);
}
}
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T *x, thread U *x_thread) {
static_assert(bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8 ||
bits == 40,
"Template undefined for bits not in {2, 3, 4, 6, 8, 40}");
U sum = 0;
if (bits == 2) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 4.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 64.0f;
}
}
else if (bits == 3) {
for (int i = 0; i < values_per_thread; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
}
}
else if (bits == 4) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 16.0f;
x_thread[i + 2] = x[i + 2] / 256.0f;
x_thread[i + 3] = x[i + 3] / 4096.0f;
}
}
else if (bits == 6) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
sum += x[i];
x_thread[i] = x[i];
}
}
else if (bits == 40) {
// mxfp4: block size of 32, no special scaling needed for load_vector
for (int i = 0; i < values_per_thread; i++) {
sum += x[i];
x_thread[i] = x[i];
}
}
return sum;
}
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
static_assert(bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8 ||
bits == 40,
"Template undefined for bits not in {2, 3, 4, 6, 8, 40}");
U sum = 0;
if (bits == 2) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 4.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 64.0f;
}
}
else if (bits == 3) {
for (int i = 0; i < N; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 8.0f;
x_thread[i + 2] = x[i + 2] / 64.0f;
x_thread[i + 3] = x[i + 3] / 2.0f;
x_thread[i + 4] = x[i + 4] / 16.0f;
x_thread[i + 5] = x[i + 5] / 128.0f;
x_thread[i + 6] = x[i + 6] / 4.0f;
x_thread[i + 7] = x[i + 7] / 32.0f;
}
}
else if (bits == 4) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 16.0f;
x_thread[i + 2] = x[i + 2] / 256.0f;
x_thread[i + 3] = x[i + 3] / 4096.0f;
}
}
else if (bits == 6) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 64.0f;
x_thread[i + 2] = x[i + 2] / 16.0f;
x_thread[i + 3] = x[i + 3] / 4.0f;
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
sum += x[i];
x_thread[i] = x[i];
}
}
else if (bits == 40) {
// mxfp4: block size of 32, no special scaling needed for load_vector_safe
for (int i = 0; i < N; i++) {
sum += x[i];
x_thread[i] = x[i];
}
}
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
return sum;
}
template <typename U, int values_per_thread, int bits>
inline U qdot(const device uint8_t *w, const thread U *x_thread, U scale,
U bias, U sum) {
static_assert(bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8 ||
bits == 40,
"Template undefined for bits not in {2, 3, 4, 6, 8, 40}");
U accum = 0;
if (bits == 2) {
for (int i = 0; i < (values_per_thread / 4); i++) {
accum += (x_thread[4 * i] * (w[i] & 0x03) +
x_thread[4 * i + 1] * (w[i] & 0x0c) +
x_thread[4 * i + 2] * (w[i] & 0x30) +
x_thread[4 * i + 3] * (w[i] & 0xc0));
}
}
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) {
const device uint16_t *ws = (const device uint16_t *)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum += (x_thread[4 * i] * (ws[i] & 0x000f) +
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
x_thread[4 * i + 3] * (ws[i] & 0xf000));
}
}
else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
accum += x_thread[i] * w[i];
}
}
else if (bits == 40) {
// mxfp4: 4-bit FP4 weights, block size 32
// Each byte contains 2 FP4 values
for (int i = 0; i < values_per_thread; i += 2) {
uint8_t packed = w[i / 2];
U w0 = static_cast<U>(fp4_to_float(packed & 0x0f));
U w1 = static_cast<U>(fp4_to_float((packed >> 4) & 0x0f));
accum += x_thread[i] * w0 + x_thread[i + 1] * w1;
}
}
return scale * accum + sum * bias;
}
template <typename U, int values_per_thread, int bits>
inline U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale,
U bias, U sum, int N) {
static_assert(bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8 ||
bits == 40,
"Template undefined for bits not in {2, 3, 4, 6, 8, 40}");
U accum = 0;
if (bits == 2) {
for (int i = 0; i < (N / 4); i++) {
accum += (x_thread[4 * i] * (w[i] & 0x03) +
x_thread[4 * i + 1] * (w[i] & 0x0c) +
x_thread[4 * i + 2] * (w[i] & 0x30) +
x_thread[4 * i + 3] * (w[i] & 0xc0));
}
}
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
x_thread += 8 * i;
w += 3 * i;
accum += (w[0] & 0x07) * x_thread[0];
accum += (w[0] & 0x38) * x_thread[1];
accum += (w[0] & 0xc0) * x_thread[2];
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
accum += (w[1] & 0x0e) * x_thread[3];
accum += (w[1] & 0x70) * x_thread[4];
accum += (w[1] & 0x80) * x_thread[5];
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
accum += (w[2] & 0x1c) * x_thread[6];
accum += (w[2] & 0xe0) * x_thread[7];
}
}
else if (bits == 4) {
const device uint16_t *ws = (const device uint16_t *)w;
for (int i = 0; i < (N / 4); i++) {
accum += (x_thread[4 * i] * (ws[i] & 0x000f) +
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
x_thread[4 * i + 3] * (ws[i] & 0xf000));
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
x_thread += 4 * i;
w += 3 * i;
accum += (w[0] & 0x3f) * x_thread[0];
accum += (w[0] & 0xc0) * x_thread[1];
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
accum += (w[1] & 0xf0) * x_thread[2];
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
accum += (w[2] & 0xfc) * x_thread[3];
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
accum += x_thread[i] * w[i];
}
}
else if (bits == 40) {
// mxfp4: 4-bit FP4 weights, block size 32
// Each byte contains 2 FP4 values
for (int i = 0; i < N; i += 2) {
uint8_t packed = w[i / 2];
U w0 = static_cast<U>(fp4_to_float(packed & 0x0f));
U w1 = static_cast<U>(fp4_to_float((packed >> 4) & 0x0f));
accum += x_thread[i] * w0;
if (i + 1 < N) {
accum += x_thread[i + 1] * w1;
}
}
}
return scale * accum + sum * bias;
}
template <typename U, int values_per_thread, int bits>
inline void qouter(const thread uint8_t *w, U x, U scale, U bias,
thread U *result) {
static_assert(bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8 ||
bits == 40,
"Template undefined for bits not in {2, 3, 4, 6, 8, 40}");
if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
for (int i = 0; i < (values_per_thread / 4); i++) {
result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
}
}
else if (bits == 3) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
result[8 * i + 2] +=
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
result[8 * i + 5] +=
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
}
}
else if (bits == 4) {
U s[2] = {scale, scale / 16.0f};
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
}
} else if (bits == 6) {
for (int i = 0; i < (values_per_thread / 4); i++) {
uint8_t w0 = w[3 * i];
uint8_t w1 = w[3 * i + 1];
uint8_t w2 = w[3 * i + 2];
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
result[4 * i + 1] +=
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
result[4 * i + 2] +=
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
result[i] += x * (scale * w[i] + bias);
}
}
else if (bits == 40) {
// mxfp4: 4-bit FP4 weights, block size 32
// Each byte contains 2 FP4 values
for (int i = 0; i < values_per_thread; i += 2) {
uint8_t packed = w[i / 2];
U w0 = static_cast<U>(fp4_to_float(packed & 0x0f));
U w1 = static_cast<U>(fp4_to_float((packed >> 4) & 0x0f));
result[i] += x * (scale * w0); // No bias for mxfp4
result[i + 1] += x * (scale * w1); // No bias for mxfp4
}
}
}
template <typename U, int N, int bits>
inline void dequantize(const device uint8_t *w, U scale, U bias,
threadgroup U *w_local) {
static_assert(bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8 ||
bits == 40,
"Template undefined for bits not in {2, 3, 4, 6, 8, 40}");
if (bits == 2) {
U s[4] = {scale, scale / static_cast<U>(4.0f),
scale / static_cast<U>(16.0f), scale / static_cast<U>(64.0f)};
for (int i = 0; i < (N / 4); i++) {
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
}
}
else if (bits == 3) {
for (int i = 0; i < (N / 8); i++) {
w_local += 8 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x7) * scale + bias;
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
}
}
else if (bits == 4) {
U s[2] = {scale, scale / static_cast<U>(16.0f)};
for (int i = 0; i < (N / 2); i++) {
w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
}
}
else if (bits == 6) {
for (int i = 0; i < (N / 4); i++) {
w_local += 4 * i;
w += 3 * i;
w_local[0] = (w[0] & 0x3f) * scale + bias;
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
w_local[i] = scale * w[i] + bias;
}
}
else if (bits == 40) {
// mxfp4: 4-bit FP4 weights, block size 32
// Each byte contains 2 FP4 values
for (int i = 0; i < N; i += 2) {
uint8_t packed = w[i / 2];
w_local[i] =
scale *
static_cast<U>(fp4_to_float(packed & 0x0f)); // No bias for mxfp4
if (i + 1 < N) {
w_local[i + 1] =
scale * static_cast<U>(fp4_to_float((packed >> 4) &
0x0f)); // No bias for mxfp4
}
}
}
}
template <typename T, short BROWS, short BCOLS, short dst_ld,
short reduction_dim, short tgp_size, short group_size, short bits>
struct QuantizedBlockLoader {
static_assert(BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8 ||
bits == 40,
"Template undefined for bits not in {2, 3, 4, 6, 8, 40}");
MLX_MTL_CONST short pack_factor = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 8 / bits;
MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
threadgroup T *dst;
const device uint8_t *src;
const device T *scales;
const device T *biases;
QuantizedBlockLoader(const device uint8_t *src_, const device T *scales_,
const device T *biases_, const int src_ld_,
threadgroup T *dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim
? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_step_cnt(0), group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size),
biases(biases_ + bi * src_ld / group_size) {}
void load_unsafe() const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
T scale = load_scale<T, T, bits>(scales);
T bias = bits == 40 ? T(0) : *biases; // No bias for mxfp4
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(src + i * bytes_per_pack, scale, bias,
dst + i * pack_factor);
}
}
void load_safe(short2 src_tile_dim) const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bi >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bi >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
T scale = load_scale<T, T, bits>(scales);
T bias = bits == 40 ? T(0) : *biases; // No bias for mxfp4
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(device uint8_t *)(src + i * bytes_per_pack), scale, bias,
dst + i * pack_factor);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales++;
biases++;
}
} else {
scales++;
biases++;
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <typename T, int group_size, int bits, int D>
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales,
const device T *biases, const device T *x,
device T *y, constant int &in_vec_size,
const constant int &out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = bits == 40 ? 2 : 32 / bits;
constexpr int values_per_thread = D / QUAD_SIZE;
constexpr int packs_per_thread = values_per_thread / pack_factor;
constexpr int scale_step_per_thread = group_size / values_per_thread;
constexpr int results_per_quadgroup = 8;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) {
auto wl =
(const device uint8_t *)(w + row * in_vec_size_w * quads_per_simd);
const device T *sl = scales + row * in_vec_size_g * quads_per_simd;
const device T *bl = biases + row * in_vec_size_g * quads_per_simd;
U s = load_scale<U, T, bits>(sl);
U b = bits == 40 ? U(0) : bl[0];
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}
for (int row = 0; row < results_per_quadgroup; row++) {
result[row] = quad_sum(result[row]);
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
y[row * quads_per_simd] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales,
const device T *biases, const device T *x,
device T *y, const constant int &in_vec_size,
const constant int &out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_simdgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t *ws = (const device uint8_t *)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
/* --- Optimisation: pre‑compute stride constants & per‑row base pointers ---
*/
const int ws_block_step =
block_size * bytes_per_pack / pack_factor; // bytes to jump per K‑block
const int sb_block_step =
block_size / group_size; // elements to jump per K‑block
// Cache per‑row pointers so we avoid recomputing `row * in_vec_size_*`
thread const device uint8_t *wl_ptrs[results_per_simdgroup];
thread const device T *sl_ptrs[results_per_simdgroup];
thread const device T *bl_ptrs[results_per_simdgroup];
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; ++row) {
wl_ptrs[row] = ws + row * in_vec_size_w;
sl_ptrs[row] = scales + row * in_vec_size_g;
bl_ptrs[row] = biases + row * in_vec_size_g;
}
// Stream over the input vector in blocks of `block_size`, re‑using the
// cached row‑relative pointers to minimise pointer arithmetic.
#pragma clang loop unroll(enable)
for (int k = 0; k < in_vec_size; k += block_size) {
// Load a block of `x` into registers and compute its running sum.
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; ++row) {
U s = sl_ptrs[row][0];
U b = bl_ptrs[row][0];
result[row] +=
qdot<U, values_per_thread, bits>(wl_ptrs[row], x_thread, s, b, sum);
// Advance all cached pointers to the next K‑block.
wl_ptrs[row] += ws_block_step;
sl_ptrs[row] += sb_block_step;
bl_ptrs[row] += sb_block_step;
}
// Move `x` to the next K‑block (only once per SIMD‑lane).
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales,
const device T *biases, const device T *x, device T *y,
const constant int &in_vec_size,
const constant int &out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
constexpr int pack_factor = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t *ws = (const device uint8_t *)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
if (out_row >= out_vec_size) {
return;
}
// In this case we need to properly guard all our reads because there isn't
// even 1 tile in the matrix
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
ws +=
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
int k = 0;
for (; k < in_vec_size - block_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t *)(ws + row * in_vec_size_w);
const device T *sl = scales + row * in_vec_size_g;
const device T *bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
}
const int remaining =
clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0, values_per_thread);
if (remaining > 0) {
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread,
remaining);
for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t *)(ws + row * in_vec_size_w);
const device T *sl = scales + row * in_vec_size_g;
const device T *bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}
for (int row = 0; out_row + row < out_vec_size; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
// In this case the last tile is moved back to redo some output values
else {
ws += used_out_row * in_vec_size_w +
simd_lid * packs_per_thread * bytes_per_pack;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + used_out_row;
int k = 0;
for (; k < in_vec_size - block_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t *)(ws + row * in_vec_size_w);
const device T *sl = scales + row * in_vec_size_g;
const device T *bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
}
const int remaining =
clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0, values_per_thread);
if (remaining > 0) {
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread,
remaining);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t *)(ws + row * in_vec_size_w);
const device T *sl = scales + row * in_vec_size_g;
const device T *bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b,
sum, remaining);
}
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
}
template <typename T, const int group_size, const int bits>
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales,
const device T *biases, const device T *x, device T *y,
const int in_vec_size, const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE;
using W_T =
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
const device W_T *ws = (const device W_T *)w;
typedef float U;
typedef struct {
W_T wi[tn * bytes_per_pack];
} vec_w;
thread vec_w w_local;
thread U result[tn * pack_factor] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid;
y += tid.y * out_vec_size + out_col;
if (out_col >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of block_size
int remaining = in_vec_size % block_size;
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += block_size) {
x_local = *x;
scale = load_scale<U, T, bits>(scales);
bias = bits == 40 ? U(0) : *biases;
w_local = *((device vec_w *)ws);
qouter<U, tn * pack_factor, bits>((thread uint8_t *)&w_local, x_local,
scale, bias, result);
x += block_size;
scales += block_size * out_vec_size_g;
biases += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
}
} else {
for (int i = block_size; i < in_vec_size; i += block_size) {
x_local = *x;
scale = load_scale<U, T, bits>(scales);
bias = bits == 40 ? U(0) : *biases;
w_local = *((device vec_w *)ws);
qouter<U, tn * pack_factor, bits>((thread uint8_t *)&w_local, x_local,
scale, bias, result);
x += block_size;
scales += block_size * out_vec_size_g;
biases += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = load_scale<U, T, bits>(scales);
bias = bits == 40 ? U(0) : *biases;
w_local = *((device vec_w *)ws);
} else {
x_local = 0;
scale = 0;
bias = 0;
}
qouter<U, tn * pack_factor, bits>((thread uint8_t *)&w_local, x_local,
scale, bias, result);
}
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k = 0; k < tn * pack_factor; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k = 0; k < tn * pack_factor; k++) {
y[k] = static_cast<T>(result[k]);
}
}
}
template <typename T, const int group_size, const int bits,
const bool aligned_N, const int BM = 32, const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales,
const device T *biases, const device T *x,
device T *y, threadgroup T *Xs, threadgroup T *Ws,
const constant int &K, const constant int &N,
const constant int &M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
(void)lid;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true,
BK_padded, BK_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t =
QuantizedBlockLoader<T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE,
group_size, bits>;
// Set the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
auto wl = (const device uint8_t *)w;
x += y_row * K;
wl += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
} else {
mma_op.store_result(y, N);
}
}
template <typename T, const int group_size, const int bits, const int BM = 32,
const int BK = 32, const int BN = 32>
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales,
const device T *biases, const device T *x,
device T *y, threadgroup T *Xs, threadgroup T *Ws,
const constant int &K, const constant int &N,
const constant int &M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
(void)lid;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 8 / bits;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false,
BK_padded, BN_padded>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1,
WM * WN * SIMD_SIZE, 1, 4>;
using loader_w_t =
QuantizedBlockLoader<T, BK, BN, BN_padded, 0, WM * WN * SIMD_SIZE,
group_size, bits>;
auto wl = (const device uint8_t *)w;
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
if ((K % BK) != 0) {
const int k_blocks = K / BK;
for (int k = 0; k < k_blocks; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
const short num_k = K - k_blocks * BK;
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(num_k, num_els));
loader_w.load_safe(short2(BN, num_k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if ((K % BK) != 0) {
const int k_blocks = K / BK;
for (int k = 0; k < k_blocks; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
const short num_k = K - k_blocks * BK;
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(num_k, BM));
loader_w.load_safe(short2(BN, num_k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
} else {
mma_op.store_result(y, N);
}
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T *&x, const device uint32_t *&w, const device T *&scales,
const device T *&biases, device T *&y, int output_stride,
const constant int &x_batch_ndims, const constant int *x_shape,
const constant int64_t *x_strides, const constant int &w_batch_ndims,
const constant int *w_shape, const constant int64_t *w_strides,
const constant int64_t *s_strides, const constant int64_t *b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
biases += w_idx * b_strides[0];
} else {
ulong3 idx = elem_to_loc_broadcast(w_idx, w_shape, w_strides, s_strides,
b_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
biases += idx.z;
}
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T *&x, const device uint32_t *&w, const device T *&scales,
const device T *&biases, const device uint32_t *lhs_indices,
const device uint32_t *rhs_indices, device T *&y, int output_stride,
const constant int &batch_ndims, const constant int *batch_shape,
const constant int64_t *lhs_strides, const constant int64_t *rhs_strides,
const constant int &x_batch_ndims, const constant int *x_shape,
const constant int64_t *x_strides, const constant int &w_batch_ndims,
const constant int *w_shape, const constant int64_t *w_strides,
const constant int64_t *s_strides, const constant int64_t *b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx;
uint32_t w_idx;
if (batch_ndims == 1) {
x_idx = lhs_indices[tid.z * lhs_strides[0]];
w_idx = rhs_indices[tid.z * rhs_strides[0]];
} else {
ulong2 idx = elem_to_loc_broadcast(tid.z, batch_shape, lhs_strides,
rhs_strides, batch_ndims);
x_idx = lhs_indices[idx.x];
w_idx = rhs_indices[idx.y];
}
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
biases += w_idx * b_strides[0];
} else {
ulong3 idx = elem_to_loc_broadcast(w_idx, w_shape, w_strides, s_strides,
b_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
biases += idx.z;
}
y += tid.z * output_stride;
}
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]],
const device T *x [[buffer(3)]],
device T *y [[buffer(4)]],
const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, y, out_vec_size * M,
x_batch_ndims, x_shape, x_strides, w_batch_ndims,
w_shape, w_strides, s_strides, b_strides, tid);
}
qmv_quad_impl<T, group_size, bits, D>(w, scales, biases, x, y, in_vec_size,
out_vec_size, tid, quad_gid, quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]],
const device T *x [[buffer(3)]],
device T *y [[buffer(4)]],
const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, y, out_vec_size * M,
x_batch_ndims, x_shape, x_strides, w_batch_ndims,
w_shape, w_strides, s_strides, b_strides, tid);
}
qmv_fast_impl<T, group_size, bits>(w, scales, biases, x, y, in_vec_size,
out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]],
const device T *x [[buffer(3)]], device T *y [[buffer(4)]],
const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, y, out_vec_size * M,
x_batch_ndims, x_shape, x_strides, w_batch_ndims,
w_shape, w_strides, s_strides, b_strides, tid);
}
qmv_impl<T, group_size, bits>(w, scales, biases, x, y, in_vec_size,
out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]],
const device T *x [[buffer(3)]], device T *y [[buffer(4)]],
const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, y, out_vec_size * M,
x_batch_ndims, x_shape, x_strides, w_batch_ndims,
w_shape, w_strides, s_strides, b_strides, tid);
}
qvm_impl<T, group_size, bits>(w, scales, biases, x, y, in_vec_size,
out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]],
const device T *x [[buffer(3)]],
device T *y [[buffer(4)]],
const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
const constant int &final_block_size [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, y, out_vec_size * M,
x_batch_ndims, x_shape, x_strides, w_batch_ndims,
w_shape, w_strides, s_strides, b_strides, tid);
// When (in_vec_size % split_k != 0) the final block needs to be smaller
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
qvm_impl<T, group_size, bits>(w, scales, biases, x, y, in_vec_size_adj,
out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits,
const bool aligned_N, const bool batched, const int BM = 32,
const int BK = 32, const int BN = 32>
[[kernel]] void
qmm_t(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]], const device T *x [[buffer(3)]],
device T *y [[buffer(4)]], const constant int &K [[buffer(5)]],
const constant int &N [[buffer(6)]], const constant int &M [[buffer(7)]],
const constant int &x_batch_ndims [[buffer(8)]],
const constant int *x_shape [[buffer(9)]],
const constant int64_t *x_strides [[buffer(10)]],
const constant int &w_batch_ndims [[buffer(11)]],
const constant int *w_shape [[buffer(12)]],
const constant int64_t *w_strides [[buffer(13)]],
const constant int64_t *s_strides [[buffer(14)]],
const constant int64_t *b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
if (batched) {
adjust_matrix_offsets<T>(x, w, scales, biases, y, M * N, x_batch_ndims,
x_shape, x_strides, w_batch_ndims, w_shape,
w_strides, s_strides, b_strides, tid);
}
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits, const bool batched,
const int BM = 32, const int BK = 32, const int BN = 32>
[[kernel]] void
qmm_n(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]], const device T *x [[buffer(3)]],
device T *y [[buffer(4)]], const constant int &K [[buffer(5)]],
const constant int &N [[buffer(6)]], const constant int &M [[buffer(7)]],
const constant int &x_batch_ndims [[buffer(8)]],
const constant int *x_shape [[buffer(9)]],
const constant int64_t *x_strides [[buffer(10)]],
const constant int &w_batch_ndims [[buffer(11)]],
const constant int *w_shape [[buffer(12)]],
const constant int64_t *w_strides [[buffer(13)]],
const constant int64_t *s_strides [[buffer(14)]],
const constant int64_t *b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
if (batched) {
adjust_matrix_offsets<T>(x, w, scales, biases, y, M * N, x_batch_ndims,
x_shape, x_strides, w_batch_ndims, w_shape,
w_strides, s_strides, b_strides, tid);
}
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv_fast(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]],
const device T *x [[buffer(3)]],
device T *y [[buffer(4)]],
const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
const constant int &batch_ndims [[buffer(15)]],
const constant int *batch_shape [[buffer(16)]],
const device uint32_t *lhs_indices [[buffer(17)]],
const device uint32_t *rhs_indices [[buffer(18)]],
const constant int64_t *lhs_strides [[buffer(19)]],
const constant int64_t *rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, lhs_indices, rhs_indices, y,
out_vec_size * M, batch_ndims, batch_shape,
lhs_strides, rhs_strides, x_batch_ndims, x_shape,
x_strides, w_batch_ndims, w_shape, w_strides,
s_strides, b_strides, tid);
qmv_fast_impl<T, group_size, bits>(w, scales, biases, x, y, in_vec_size,
out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
[[kernel]] void
bs_qmv(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]], const device T *x [[buffer(3)]],
device T *y [[buffer(4)]], const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
const constant int &batch_ndims [[buffer(15)]],
const constant int *batch_shape [[buffer(16)]],
const device uint32_t *lhs_indices [[buffer(17)]],
const device uint32_t *rhs_indices [[buffer(18)]],
const constant int64_t *lhs_strides [[buffer(19)]],
const constant int64_t *rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, lhs_indices, rhs_indices, y,
out_vec_size * M, batch_ndims, batch_shape,
lhs_strides, rhs_strides, x_batch_ndims, x_shape,
x_strides, w_batch_ndims, w_shape, w_strides,
s_strides, b_strides, tid);
qmv_impl<T, group_size, bits>(w, scales, biases, x, y, in_vec_size,
out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
[[kernel]] void
bs_qvm(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]], const device T *x [[buffer(3)]],
device T *y [[buffer(4)]], const constant int &in_vec_size [[buffer(5)]],
const constant int &out_vec_size [[buffer(6)]],
const constant int &x_batch_ndims [[buffer(7)]],
const constant int *x_shape [[buffer(8)]],
const constant int64_t *x_strides [[buffer(9)]],
const constant int &w_batch_ndims [[buffer(10)]],
const constant int *w_shape [[buffer(11)]],
const constant int64_t *w_strides [[buffer(12)]],
const constant int64_t *s_strides [[buffer(13)]],
const constant int64_t *b_strides [[buffer(14)]],
const constant int &batch_ndims [[buffer(15)]],
const constant int *batch_shape [[buffer(16)]],
const device uint32_t *lhs_indices [[buffer(17)]],
const device uint32_t *rhs_indices [[buffer(18)]],
const constant int64_t *lhs_strides [[buffer(19)]],
const constant int64_t *rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets<T>(x, w, scales, biases, lhs_indices, rhs_indices, y,
out_vec_size * M, batch_ndims, batch_shape,
lhs_strides, rhs_strides, x_batch_ndims, x_shape,
x_strides, w_batch_ndims, w_shape, w_strides,
s_strides, b_strides, tid);
qvm_impl<T, group_size, bits>(w, scales, biases, x, y, in_vec_size,
out_vec_size, tid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits,
const bool aligned_N, const int BM = 32, const int BK = 32,
const int BN = 32>
[[kernel]] void
bs_qmm_t(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]], const device T *x [[buffer(3)]],
device T *y [[buffer(4)]], const constant int &K [[buffer(5)]],
const constant int &N [[buffer(6)]],
const constant int &M [[buffer(7)]],
const constant int &x_batch_ndims [[buffer(8)]],
const constant int *x_shape [[buffer(9)]],
const constant int64_t *x_strides [[buffer(10)]],
const constant int &w_batch_ndims [[buffer(11)]],
const constant int *w_shape [[buffer(12)]],
const constant int64_t *w_strides [[buffer(13)]],
const constant int64_t *s_strides [[buffer(14)]],
const constant int64_t *b_strides [[buffer(15)]],
const constant int &batch_ndims [[buffer(16)]],
const constant int *batch_shape [[buffer(17)]],
const device uint32_t *lhs_indices [[buffer(18)]],
const device uint32_t *rhs_indices [[buffer(19)]],
const constant int64_t *lhs_strides [[buffer(20)]],
const constant int64_t *rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
adjust_matrix_offsets<T>(
x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims,
batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides,
w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid);
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits, const int BM = 32,
const int BK = 32, const int BN = 32>
[[kernel]] void
bs_qmm_n(const device uint32_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]], const device T *x [[buffer(3)]],
device T *y [[buffer(4)]], const constant int &K [[buffer(5)]],
const constant int &N [[buffer(6)]],
const constant int &M [[buffer(7)]],
const constant int &x_batch_ndims [[buffer(8)]],
const constant int *x_shape [[buffer(9)]],
const constant int64_t *x_strides [[buffer(10)]],
const constant int &w_batch_ndims [[buffer(11)]],
const constant int *w_shape [[buffer(12)]],
const constant int64_t *w_strides [[buffer(13)]],
const constant int64_t *s_strides [[buffer(14)]],
const constant int64_t *b_strides [[buffer(15)]],
const constant int &batch_ndims [[buffer(16)]],
const constant int *batch_shape [[buffer(17)]],
const device uint32_t *lhs_indices [[buffer(18)]],
const device uint32_t *rhs_indices [[buffer(19)]],
const constant int64_t *lhs_strides [[buffer(20)]],
const constant int64_t *rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
adjust_matrix_offsets<T>(
x, w, scales, biases, lhs_indices, rhs_indices, y, M * N, batch_ndims,
batch_shape, lhs_strides, rhs_strides, x_batch_ndims, x_shape, x_strides,
w_batch_ndims, w_shape, w_strides, s_strides, b_strides, tid);
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(const device T *w [[buffer(0)]],
device uint8_t *out [[buffer(1)]],
device T *scales [[buffer(2)]],
device T *biases [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr T n_bins =
bits == 40 ? 15 : (1 << bits) - 1; // mxfp4 has 16 values (0-15)
constexpr int packs_per_int = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 8 / bits;
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
static_assert(group_size % simd_size == 0,
"Group size must be divisible by simd size.");
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t in_index = offset * values_per_reduce;
size_t out_index = power_of_2_bits
? offset * writes_per_pack
: offset * bytes_per_pack / writes_per_reduce;
T w_thread[values_per_reduce];
T w_min = Limits<T>::max;
T w_max = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
T val = w[in_index + i];
w_thread[i] = val;
w_min = min(w_min, val);
w_max = max(w_max, val);
}
w_min = simd_min(w_min);
w_max = simd_max(w_max);
T scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
T edge = side ? w_min : w_max;
T q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
T bias = at_zero ? T(0) : edge;
// Write out the scales and biases
size_t gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = scale;
biases[gindex] = bias;
}
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
uint32_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val;
if (bits == 40) {
// TODO: mxfp4 quantization would need to convert float to FP4 format
// For now, this is primarily for inference with pre-quantized weights
val = 0;
} else {
val = min(round((w_thread[i] - bias) / scale), n_bins);
}
if (bits == 8) {
output = val;
} else {
output += val << (bits * (i % packs_per_int));
}
if (packs_per_int < values_per_reduce &&
i % packs_per_int == packs_per_int - 1) {
out[out_index + i / packs_per_int] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 1; j < writes_per_reduce; j++) {
uint8_t sval = simd_shuffle_down(val, j);
output += sval << (bits * (j * values_per_reduce + i));
}
}
}
if (bits == 3 || bits == 6) {
if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
}
} else {
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_dequantize(const device uint8_t *w [[buffer(0)]],
const device T *scales [[buffer(1)]],
const device T *biases [[buffer(2)]],
device T *out [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
constexpr int packs_per_int = bits == 3 ? 8
: bits == 6 ? 4
: bits == 40 ? 2
: 8 / bits;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
size_t offset = index.x + grid_dim.x * size_t(index.y);
size_t oindex = offset * packs_per_int;
size_t gindex = oindex / group_size;
T scale = load_scale<T, T, bits>(scales + gindex);
T bias = bits == 40 ? T(0) : biases[gindex];
out += oindex;
if (bits == 3) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x7) * scale + bias;
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
} else if (bits == 6) {
w += offset * bytes_per_pack;
out[0] = (w[0] & 0x3f) * scale + bias;
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
} else if (bits == 40) {
// mxfp4: Handle 2 FP4 values per byte
uint8_t byte = w[offset / 2];
if (i == 0) {
d = byte & 0x0f;
out[i] = static_cast<T>(scale * fp4_to_float(d));
} else {
d = (byte >> 4) & 0x0f;
out[i] = static_cast<T>(scale * fp4_to_float(d));
}
continue;
}
out[i] = scale * d + bias;
}
}
}
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel(#name "_" #type "_gs_" #group_size "_b_" #bits, name, \
type, group_size, bits)
#define instantiate_quantized_batched(name, type, group_size, bits, batched) \
instantiate_kernel(#name "_" #type "_gs_" #group_size "_b_" #bits \
"_batch_" #batched, \
name, type, group_size, bits, batched)
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel(#name "_" #type "_gs_" #group_size "_b_" #bits \
"_alN_" #aligned, \
name, type, group_size, bits, aligned)
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, \
aligned, batched) \
instantiate_kernel(#name "_" #type "_gs_" #group_size "_b_" #bits \
"_alN_" #aligned "_batch_" #batched, \
name, type, group_size, bits, aligned, batched)
#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
instantiate_kernel(#name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D \
"_batch_" #batched, \
name, type, group_size, bits, D, batched)
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
instantiate_kernel(#name "_" #type "_gs_" #group_size "_b_" #bits \
"_spk_" #split_k, \
name, type, group_size, bits, split_k)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, \
bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, \
true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, \
bits, true, 0) \
instantiate_quantized_aligned_batched( \
qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched( \
qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, \
128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) instantiate_quantized_groups(6) \
instantiate_quantized_groups(8) instantiate_quantized_types( \
32, 40) /* mxfp4 with block size 32 */
instantiate_quantized_all() // clang-format on