tract-metal 0.23.0-dev.5

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
#include <metal_stdlib>
#include <metal_integer>
#include <metal_math>
#include <metal_simdgroup_matrix>  // Available from Metal version 2.3 released with OS X 11.0+

using namespace metal;

namespace utils {
    
    METAL_FUNC uint indices_to_idx_1(uint index, constant const size_t strides[1]) {
        return index * strides[0];
    }
    
    METAL_FUNC uint indices_to_idx_2(uint2 indices, constant const size_t strides[2]) {
        return indices.x * strides[1] + indices.y * strides[0];
    }
    
    METAL_FUNC uint indices_to_idx_3(uint3 indices, constant const size_t strides[3]) {
        return indices.x * strides[2] + indices.y * strides[1] + indices.z * strides[0];
    }
    
    METAL_FUNC uint indices_to_idx_4(uint3 indices,
                                     constant const size_t shape[4], 
                                     constant const size_t strides[4]) {
        auto idx = indices.x * strides[3] + indices.y * strides[2];
        idx += (indices.z % shape[1]) * strides[1];
        indices.z /= shape[1];
        idx += indices.z * strides[0];
        return idx;
    }
    
    METAL_FUNC uint indices_to_idx_5(uint3 indices,
                                     constant const size_t shape[5], 
                                     constant const size_t strides[5]) {
        auto idx = indices.x * strides[4] + indices.y * strides[3];
        idx += (indices.z % shape[2]) * strides[2];
        indices.z /= shape[2];
        idx += (indices.z % shape[1]) * strides[1];
        indices.z /= shape[1];
        idx += indices.z * strides[0];
        return idx;
    }
}

/*
 * Based on code from:
 * https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/binary_ops.h
 */

struct Add {
    template <typename T>
    T operator()(T x, T y) {
        return x + y;
    }
};

struct Div {
    template <typename T>
    T operator()(T x, T y) {
        return x / y;
    }
};

struct Sub {
    template <typename T>
    T operator()(T x, T y) {
        return x - y;
    }
};

struct Mul {
    template <typename T>
    T operator()(T x, T y) {
        return x * y;
    }
};

struct Equals {
    template <typename T>
    bool operator()(T x, T y) {
        return x == y;
    }
};

struct NotEquals {
    template <typename T>
    bool operator()(T x, T y) {
        return x != y;
    }
};

struct Greater {
    template <typename T>
    bool operator()(T x, T y) {
        return x > y;
    }
};

struct GreaterEqual {
    template <typename T>
    bool operator()(T x, T y) {
        return x >= y;
    }
};

struct Less {
    template <typename T>
    bool operator()(T x, T y) {
        return x < y;
    }
};

struct LessEqual {
    template <typename T>
    bool operator()(T x, T y) {
        return x <= y;
    }
};

struct And {
    template <typename T>
    T operator()(T x, T y) {
        return x && y;
    };
};

struct Or {
    template <typename T>
    T operator()(T x, T y) {
        return x || y;
    };
};

struct Min {
    template <typename T>
    T operator()(T x, T y) {
        return x < y ? x : y;
    }
};

struct Max {
    template <typename T>
    T operator()(T x, T y) {
        return x > y ? x : y;
    }
};

struct BitAnd {
    template <typename T>
    T operator()(T x, T y) {
        return x & y;
    }
};

struct BitOr {
    template <typename T>
    T operator()(T x, T y) {
        return x | y;
    }
};

struct BitXor {
    template <typename T>
    T operator()(T x, T y) {
        return x ^ y;
    }
};

struct Pow {
    template <typename T>
    metal::enable_if_t<!metal::is_integral_v<T>, T>
    operator()(T base, T exp) {
        return metal::pow(base, exp);
    }
    
    template <typename T>
    metal::enable_if_t<metal::is_integral_v<T>, T>
    operator()(T base, T exp) {
        T res = 1;
        while (exp) {
            if (exp & 1) {
                res *= base;
            }
            exp >>= 1;
            base *= base;
        }
        return res;
    }
};

#define INSTANTIATE_1ROW_BIN_OP()                             \
template [[host_name("bin_ops::add_1row_f32")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<float4, Add>;                         \
template [[host_name("bin_ops::sub_1row_f32")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<float4, Sub>;                         \
template [[host_name("bin_ops::div_1row_f32")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<float4, Div>;                         \
template [[host_name("bin_ops::mul_1row_f32")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<float4, Mul>;                         \
template [[host_name("bin_ops::add_1row_f16")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<half4, Add>;                         \
template [[host_name("bin_ops::sub_1row_f16")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<half4, Sub>;                         \
template [[host_name("bin_ops::dib_1row_f16")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<half4, Div>;                         \
template [[host_name("bin_ops::mul_1row_f16")]] [[kernel]]     \
bin_op_1row_t bin_op_1row<half4, Mul>;                         \

#define INSTANTIATE_BIN_OP(name, op, itname, itype, otype)                    \
template [[host_name("bin_ops::" #name "_" #itname)]] [[kernel]]      \
bin_op_t bin_op<itype, otype, op>;                            \

#define INSTANTIATE_FLOAT(name, op)                     \
INSTANTIATE_BIN_OP(name, op, f32, float, float)         \
INSTANTIATE_BIN_OP(name, op, f16, half, half)          

#define INSTANTIATE_FLOAT_BOOL(name, op)                \
INSTANTIATE_BIN_OP(name, op, f32, float, bool)          \
INSTANTIATE_BIN_OP(name, op, f16, half, bool)          

#define INSTANTIATE_INTEGER(name, op)                    \
INSTANTIATE_BIN_OP(name, op, u8,  uint8_t, uint8_t)      \
INSTANTIATE_BIN_OP(name, op, u16, uint16_t, uint16_t)    \
INSTANTIATE_BIN_OP(name, op, u32, uint32_t, uint32_t)    \
INSTANTIATE_BIN_OP(name, op, u64, uint64_t, uint64_t)    \
INSTANTIATE_BIN_OP(name, op, i8,  int8_t, int8_t)        \
INSTANTIATE_BIN_OP(name, op, i16, int16_t, int16_t)      \
INSTANTIATE_BIN_OP(name, op, i32, int32_t, int32_t)      \
INSTANTIATE_BIN_OP(name, op, i64, int64_t, int64_t)       

#define INSTANTIATE_INTEGER_BOOL(name, op)               \
INSTANTIATE_BIN_OP(name, op, u8,  uint8_t, bool)         \
INSTANTIATE_BIN_OP(name, op, u16, uint16_t, bool)        \
INSTANTIATE_BIN_OP(name, op, u32, uint32_t, bool)        \
INSTANTIATE_BIN_OP(name, op, u64, uint64_t, bool)        \
INSTANTIATE_BIN_OP(name, op, i8,  int8_t, bool)          \
INSTANTIATE_BIN_OP(name, op, i16, int16_t, bool)         \
INSTANTIATE_BIN_OP(name, op, i32, int32_t, bool)         \
INSTANTIATE_BIN_OP(name, op, i64, int64_t, bool)        

#define INSTANTIATE_ALL_TYPES(name, op)                  \
INSTANTIATE_FLOAT(name, op)                              \
INSTANTIATE_INTEGER(name, op)  

#define INSTANTIATE_ALL_TYPES_BOOL(name, op)             \
INSTANTIATE_FLOAT_BOOL(name, op)                         \
INSTANTIATE_INTEGER_BOOL(name, op)                

template<typename In, typename Out, typename Op>
[[kernel]] void bin_op(device const void *lhs_b [[buffer(0)]],
                    constant const size_t * lhs_shape [[buffer(1)]],
                    constant const size_t * lhs_strides [[buffer(2)]],
                    device const void *rhs_b [[buffer(3)]],
                    constant const size_t * rhs_shape [[buffer(4)]],
                    constant const size_t * rhs_strides [[buffer(5)]],
                    device void *output_b [[buffer(6)]],
                    constant const size_t * out_shape [[buffer(7)]],
                    constant const size_t * out_strides [[buffer(8)]],
                    uint3   tgpig[[threadgroup_position_in_grid]],
                    ushort3 tpitg[[thread_position_in_threadgroup]],
                    ushort3   ntg[[threads_per_threadgroup]]) {
        device const In * lhs = (device const In *)lhs_b;
        device const In * rhs = (device const In *)rhs_b;
        device  Out * output = (device Out *)output_b;

        auto lhs_idx = tgpig.z * lhs_strides[0] + tgpig.y * lhs_strides[1] + tgpig.x * lhs_strides[2];
        auto rhs_idx = tgpig.z * rhs_strides[0] + tgpig.y * rhs_strides[1] + tgpig.x * rhs_strides[2];
        auto out_idx = tgpig.z * out_strides[0] + tgpig.y * out_strides[1] + tgpig.x * out_strides[2];

        for (size_t i = tpitg.x; i < out_shape[3]; i += ntg.x) {
            output[out_idx + i] = Op()(lhs[lhs_idx + i * lhs_strides[3]], rhs[rhs_idx + i * rhs_strides[3]]);
        }
}

typedef decltype(bin_op<float, float, Mul>) bin_op_t;


template<typename T4, typename Op>
[[kernel]] void bin_op_1row(device const void *lhs_b [[buffer(0)]],
                           device const void *rhs_b [[buffer(1)]],
                           device void *output_b [[buffer(2)]],
                           device const size_t & n [[buffer(3)]],
                           uint tpig[[thread_position_in_grid]]) {
    device const T4 * lhs = (device const T4 *)lhs_b;
    device const T4 * rhs = (device const T4 *)rhs_b;
    device  T4 * output = (device  T4 *)output_b;

    const uint nb = n/4;
    output[tpig] = Op()(lhs[tpig], rhs[tpig % nb]);
}

typedef decltype(bin_op_1row<float4, Mul>) bin_op_1row_t;

INSTANTIATE_ALL_TYPES(mul, Mul)
INSTANTIATE_ALL_TYPES(div, Div)
INSTANTIATE_ALL_TYPES(add, Add)
INSTANTIATE_ALL_TYPES(sub, Sub)
INSTANTIATE_ALL_TYPES(pow, Pow)
INSTANTIATE_ALL_TYPES_BOOL(lt, Less)
INSTANTIATE_ALL_TYPES_BOOL(gt, Greater)
INSTANTIATE_ALL_TYPES_BOOL(lte, LessEqual)
INSTANTIATE_ALL_TYPES_BOOL(gte, GreaterEqual)
INSTANTIATE_ALL_TYPES_BOOL(eq, Equals)
INSTANTIATE_ALL_TYPES_BOOL(ne, NotEquals)
INSTANTIATE_ALL_TYPES(min, Min)
INSTANTIATE_ALL_TYPES(max, Max)
INSTANTIATE_INTEGER(bitand, BitAnd)
INSTANTIATE_INTEGER(bitor, BitOr)
INSTANTIATE_INTEGER(bitxor, BitXor)
INSTANTIATE_BIN_OP(and, And, bool, bool, bool)
INSTANTIATE_BIN_OP(or, Or, bool, bool, bool)

INSTANTIATE_1ROW_BIN_OP()

// --- Iff (select) kernel ---

template <typename T>
[[kernel]] void iff_generic(
    device const bool *cond [[buffer(0)]],
    device const T *then_values [[buffer(1)]],
    device const T *else_values [[buffer(2)]],
    device T *out [[buffer(3)]],
    constant const size_t *out_shape [[buffer(4)]],
    constant const size_t *cond_strides [[buffer(5)]],
    constant const size_t *then_strides [[buffer(6)]],
    constant const size_t *else_strides [[buffer(7)]],
    constant const size_t *out_strides [[buffer(8)]],
    uint tpig [[thread_position_in_grid]])
{
    size_t total = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] * out_shape[4];
    if (tpig >= total) return;

    size_t tmp = tpig;
    size_t i4 = tmp % out_shape[4]; tmp /= out_shape[4];
    size_t i3 = tmp % out_shape[3]; tmp /= out_shape[3];
    size_t i2 = tmp % out_shape[2]; tmp /= out_shape[2];
    size_t i1 = tmp % out_shape[1]; tmp /= out_shape[1];
    size_t i0 = tmp;

    size_t icond = i0 * cond_strides[0] + i1 * cond_strides[1] + i2 * cond_strides[2]
                 + i3 * cond_strides[3] + i4 * cond_strides[4];
    bool pick = cond[icond];

    size_t offset = i0 * (pick ? then_strides[0] : else_strides[0])
                  + i1 * (pick ? then_strides[1] : else_strides[1])
                  + i2 * (pick ? then_strides[2] : else_strides[2])
                  + i3 * (pick ? then_strides[3] : else_strides[3])
                  + i4 * (pick ? then_strides[4] : else_strides[4]);

    size_t io = i0 * out_strides[0] + i1 * out_strides[1] + i2 * out_strides[2]
              + i3 * out_strides[3] + i4 * out_strides[4];

    out[io] = (pick ? then_values : else_values)[offset];
}

#define INSTANTIATE_IFF(tname, type) \
    template [[host_name("bin_ops::iff_generic_" #tname)]] [[kernel]] \
    void iff_generic<type>( \
        device const bool*, device const type*, device const type*, device type*, \
        constant const size_t*, constant const size_t*, constant const size_t*, \
        constant const size_t*, constant const size_t*, uint);

INSTANTIATE_IFF(u8, uint8_t)
INSTANTIATE_IFF(u16, uint16_t)
INSTANTIATE_IFF(u32, uint32_t)
INSTANTIATE_IFF(u64, uint64_t)