// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
// clang-format off
#include "utils.metal"
#include "sort_impl.metal"
#define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \
instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
block_sort, itype, otype, arg_sort, bn, tn) \
instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
block_sort_nc, itype, otype, arg_sort, bn, tn)
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
_block_sort, itname, itype, itname, itype, false, bn, tn)
#define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 4) \
instantiate_arg_block_sort_base(itname, itype, bn, 4)
#define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 32) \
instantiate_block_sort_tn(itname, itype, 64) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_bn(uint8, uint8_t);
instantiate_block_sort_bn(uint16, uint16_t);
instantiate_block_sort_bn(uint32, uint32_t);
instantiate_block_sort_bn(int8, int8_t);
instantiate_block_sort_bn(int16, int16_t);
instantiate_block_sort_bn(int32, int32_t);
instantiate_block_sort_bn(float16, half);
instantiate_block_sort_bn(float32, float);
instantiate_block_sort_bn(bfloat16, bfloat16_t);
#define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 32) \
instantiate_block_sort_tn(itname, itype, 64) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t);
instantiate_block_sort_long(int64, int64_t);
#define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \
instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_sort, vtype, itype, arg_sort, bn, tn) \
instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_partition, vtype, itype, arg_sort, bn, tn) \
instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_merge, vtype, itype, arg_sort, bn, tn)
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 4)
instantiate_multi_block_sort_base(uint8, uint8_t);
instantiate_multi_block_sort_base(uint16, uint16_t);
instantiate_multi_block_sort_base(uint32, uint32_t);
instantiate_multi_block_sort_base(int8, int8_t);
instantiate_multi_block_sort_base(int16, int16_t);
instantiate_multi_block_sort_base(int32, int32_t);
instantiate_multi_block_sort_base(float16, half);
instantiate_multi_block_sort_base(float32, float);
instantiate_multi_block_sort_base(bfloat16, bfloat16_t);
#define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 4)
instantiate_multi_block_sort_long(uint64, uint64_t);
instantiate_multi_block_sort_long(int64, int64_t); // clang-format on