hpt-cudakernels 0.1.3

A library implements cuda kernels for hpt
Documentation
#pragma once
#include "../utils/loop_progress.cuh"
#include "reduce_template.cuh"
#include "reduce_helper.cuh"
#include <stdio.h>

#define CONTIGUOUS_ALL_REDUCE(BLOCK_SIZE, OUT_TYPE, IN_TYPE, REDUCE_OP)                                                                                             \
    case BLOCK_SIZE:                                                                                                                                                \
        all_reduce<                                                                                                                                                 \
            ContiguousIndexCalculator<IN_TYPE>, IN_TYPE, OUT_TYPE, REDUCE_OP, BLOCK_SIZE, 32>(out, buffer, finished, size, ContiguousIndexCalculator<IN_TYPE>(in)); \
        break;

#define UNCONTIGUOUS_ALL_REDUCE(BLOCK_SIZE, OUT_TYPE, IN_TYPE, REDUCE_OP)                                                                                                                     \
    case BLOCK_SIZE:                                                                                                                                                                          \
        all_reduce<                                                                                                                                                                           \
            UncontiguousIndexCalculator<IN_TYPE>, IN_TYPE, OUT_TYPE, REDUCE_OP, BLOCK_SIZE, 32>(out, buffer, finished, size, UncontiguousIndexCalculator<IN_TYPE>(in, shape, strides, ndim)); \
        break;

#define CONTIGUOUS_FAST_DIM_INCLUDE_ARMS(BLOCK_SIZE, OUT_TYPE, IN_TYPE, REDUCE_OP)                                                                                   \
    case BLOCK_SIZE:                                                                                                                                                 \
        reduce_fast_dim_include<IN_TYPE, OUT_TYPE, REDUCE_OP, BLOCK_SIZE, 32>(out, in, buffer, finished, fd, strides, ndim, fast_dim_size, reduce_size_no_fast_dim); \
        break;

#define CONTIGUOUS_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, OUT_TYPE, IN_TYPE, REDUCE_OP)                                 \
    case BLOCK_SIZE:                                                                                            \
        reduce_fast_dim_only<IN_TYPE, OUT_TYPE, ContiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>( \
            out, in, buffer, finished, fast_dim_size, output_size, ContiguousIndexCalculator<IN_TYPE>(in), 1);  \
        break;

#define UNCONTIGUOUS_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, OUT_TYPE, IN_TYPE, REDUCE_OP)                                                                \
    case BLOCK_SIZE:                                                                                                                             \
        reduce_fast_dim_only<IN_TYPE, OUT_TYPE, UncontiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>(                                \
            out, in, buffer, finished, fast_dim_size, output_size, UncontiguousIndexCalculator<IN_TYPE>(in, shape, strides, ndim), last_stride); \
        break;

#define CONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, OUT_TYPE, IN_TYPE, REDUCE_OP)                                 \
    case BLOCK_SIZE:                                                                                                  \
        reduce_small_fast_dim_only<IN_TYPE, OUT_TYPE, ContiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>( \
            out, in, fast_dim_size, output_size, ContiguousIndexCalculator<IN_TYPE>(in), 1);                          \
        break;

#define UNCONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, OUT_TYPE, IN_TYPE, REDUCE_OP)                                        \
    case BLOCK_SIZE:                                                                                                           \
        reduce_small_fast_dim_only<IN_TYPE, OUT_TYPE, UncontiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>(        \
            out, in, fast_dim_size, output_size, UncontiguousIndexCalculator<IN_TYPE>(in, shape, strides, ndim), last_stride); \
        break;

#define DECLARE_KERNEL(IN_TYPE, RUST_TYPE, OP, REDUCE_OP)                                                                                                                                                                                                      \
    using Output##OP##_##RUST_TYPE = REDUCE_OP<IN_TYPE>::R;                                                                                                                                                                                                    \
    extern "C" __global__ void contiguous_##OP##_##RUST_TYPE(Output##OP##_##RUST_TYPE *out, Output##OP##_##RUST_TYPE *buffer, IN_TYPE *in, int32_t *finished, size_t size)                                                                                     \
    {                                                                                                                                                                                                                                                          \
        switch (blockDim.x)                                                                                                                                                                                                                                    \
        {                                                                                                                                                                                                                                                      \
            CONTIGUOUS_ALL_REDUCE(32, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                           \
            CONTIGUOUS_ALL_REDUCE(64, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                           \
            CONTIGUOUS_ALL_REDUCE(128, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                          \
            CONTIGUOUS_ALL_REDUCE(256, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                          \
            CONTIGUOUS_ALL_REDUCE(512, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                          \
        default:                                                                                                                                                                                                                                               \
            break;                                                                                                                                                                                                                                             \
        }                                                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                                                          \
    extern "C" __global__ void uncontiguous_##OP##_##RUST_TYPE(Output##OP##_##RUST_TYPE *out, Output##OP##_##RUST_TYPE *buffer, IN_TYPE *in, int32_t *finished, FastDivmod *shape, int *strides, size_t ndim, size_t size)                                     \
    {                                                                                                                                                                                                                                                          \
        switch (blockDim.x)                                                                                                                                                                                                                                    \
        {                                                                                                                                                                                                                                                      \
            UNCONTIGUOUS_ALL_REDUCE(32, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                         \
            UNCONTIGUOUS_ALL_REDUCE(64, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                         \
            UNCONTIGUOUS_ALL_REDUCE(128, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                        \
            UNCONTIGUOUS_ALL_REDUCE(256, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                        \
            UNCONTIGUOUS_ALL_REDUCE(512, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                        \
        default:                                                                                                                                                                                                                                               \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                    \
            {                                                                                                                                                                                                                                                  \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "uncontiguous_" #OP "_" #RUST_TYPE);                                                                                                                                               \
            }                                                                                                                                                                                                                                                  \
            break;                                                                                                                                                                                                                                             \
        }                                                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                                                          \
    extern "C" __global__ void OP##_fast_dim_include_##RUST_TYPE(Output##OP##_##RUST_TYPE *out, IN_TYPE *in, Output##OP##_##RUST_TYPE *buffer, int *finished, FastDivmod *fd, int *strides, size_t ndim, size_t fast_dim_size, size_t reduce_size_no_fast_dim) \
    {                                                                                                                                                                                                                                                          \
        switch (blockDim.x * blockDim.y)                                                                                                                                                                                                                       \
        {                                                                                                                                                                                                                                                      \
            CONTIGUOUS_FAST_DIM_INCLUDE_ARMS(32, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                \
            CONTIGUOUS_FAST_DIM_INCLUDE_ARMS(64, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                \
            CONTIGUOUS_FAST_DIM_INCLUDE_ARMS(128, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                               \
            CONTIGUOUS_FAST_DIM_INCLUDE_ARMS(256, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                               \
            CONTIGUOUS_FAST_DIM_INCLUDE_ARMS(512, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                               \
        default:                                                                                                                                                                                                                                               \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                    \
            {                                                                                                                                                                                                                                                  \
                printf("Error: Invalid block size %d for %s\n", blockDim.x *blockDim.y, #OP "_fast_dim_include_" #RUST_TYPE);                                                                                                                                  \
            }                                                                                                                                                                                                                                                  \
            break;                                                                                                                                                                                                                                             \
        }                                                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                                                          \
    extern "C" __global__ void contiguous_##OP##_fast_dim_only_##RUST_TYPE(                                                                                                                                                                                    \
        Output##OP##_##RUST_TYPE *out,                                                                                                                                                                                                                         \
        IN_TYPE *in,                                                                                                                                                                                                                                           \
        Output##OP##_##RUST_TYPE *buffer,                                                                                                                                                                                                                      \
        int32_t *finished,                                                                                                                                                                                                                                     \
        size_t fast_dim_size,                                                                                                                                                                                                                                  \
        size_t output_size)                                                                                                                                                                                                                                    \
    {                                                                                                                                                                                                                                                          \
        switch (blockDim.x)                                                                                                                                                                                                                                    \
        {                                                                                                                                                                                                                                                      \
            CONTIGUOUS_FAST_DIM_ONLY_ARMS(32, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                   \
            CONTIGUOUS_FAST_DIM_ONLY_ARMS(64, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                   \
            CONTIGUOUS_FAST_DIM_ONLY_ARMS(128, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                  \
            CONTIGUOUS_FAST_DIM_ONLY_ARMS(256, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                  \
            CONTIGUOUS_FAST_DIM_ONLY_ARMS(512, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                  \
        default:                                                                                                                                                                                                                                               \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                    \
            {                                                                                                                                                                                                                                                  \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "contiguous_" #OP "_fast_dim_only_" #RUST_TYPE);                                                                                                                                   \
            }                                                                                                                                                                                                                                                  \
            break;                                                                                                                                                                                                                                             \
        }                                                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                                                          \
    extern "C" __global__ void uncontiguous_##OP##_fast_dim_only_##RUST_TYPE(                                                                                                                                                                                  \
        Output##OP##_##RUST_TYPE *out,                                                                                                                                                                                                                         \
        IN_TYPE *in,                                                                                                                                                                                                                                           \
        Output##OP##_##RUST_TYPE *buffer,                                                                                                                                                                                                                      \
        int32_t *finished,                                                                                                                                                                                                                                     \
        FastDivmod *shape,                                                                                                                                                                                                                                     \
        int *strides,                                                                                                                                                                                                                                          \
        size_t ndim,                                                                                                                                                                                                                                           \
        size_t fast_dim_size,                                                                                                                                                                                                                                  \
        size_t output_size,                                                                                                                                                                                                                                    \
        int64_t last_stride)                                                                                                                                                                                                                                   \
    {                                                                                                                                                                                                                                                          \
        switch (blockDim.x)                                                                                                                                                                                                                                    \
        {                                                                                                                                                                                                                                                      \
            UNCONTIGUOUS_FAST_DIM_ONLY_ARMS(32, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                 \
            UNCONTIGUOUS_FAST_DIM_ONLY_ARMS(64, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                 \
            UNCONTIGUOUS_FAST_DIM_ONLY_ARMS(128, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                \
            UNCONTIGUOUS_FAST_DIM_ONLY_ARMS(256, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                \
            UNCONTIGUOUS_FAST_DIM_ONLY_ARMS(512, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                                \
        default:                                                                                                                                                                                                                                               \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                    \
            {                                                                                                                                                                                                                                                  \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "uncontiguous_" #OP "_fast_dim_only_" #RUST_TYPE);                                                                                                                                 \
            }                                                                                                                                                                                                                                                  \
            break;                                                                                                                                                                                                                                             \
        }                                                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                                                          \
    extern "C" __global__ void contiguous_##OP##_small_fast_dim_only_##RUST_TYPE(Output##OP##_##RUST_TYPE *out, IN_TYPE *in, size_t fast_dim_size, size_t output_size)                                                                                         \
    {                                                                                                                                                                                                                                                          \
        switch (blockDim.x)                                                                                                                                                                                                                                    \
        {                                                                                                                                                                                                                                                      \
            CONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(32, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                             \
            CONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(64, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                             \
            CONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(128, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                            \
            CONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(256, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                            \
            CONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(512, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                            \
        default:                                                                                                                                                                                                                                               \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                    \
            {                                                                                                                                                                                                                                                  \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "contiguous_" #OP "_small_fast_dim_only_" #RUST_TYPE);                                                                                                                             \
            }                                                                                                                                                                                                                                                  \
            break;                                                                                                                                                                                                                                             \
        }                                                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                                                          \
    extern "C" __global__ void uncontiguous_##OP##_small_fast_dim_only_##RUST_TYPE(Output##OP##_##RUST_TYPE *out, IN_TYPE *in, FastDivmod *shape, int *strides, size_t ndim, size_t fast_dim_size, size_t output_size, int64_t last_stride)                    \
    {                                                                                                                                                                                                                                                          \
        switch (blockDim.x)                                                                                                                                                                                                                                    \
        {                                                                                                                                                                                                                                                      \
            UNCONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(32, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                           \
            UNCONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(64, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                           \
            UNCONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(128, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                          \
            UNCONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(256, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                          \
            UNCONTIGUOUS_SMALL_FAST_DIM_ONLY_ARMS(512, Output##OP##_##RUST_TYPE, IN_TYPE, REDUCE_OP);                                                                                                                                                          \
        default:                                                                                                                                                                                                                                               \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                    \
            {                                                                                                                                                                                                                                                  \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "uncontiguous_" #OP "_small_fast_dim_only_" #RUST_TYPE);                                                                                                                           \
            }                                                                                                                                                                                                                                                  \
            break;                                                                                                                                                                                                                                             \
        }                                                                                                                                                                                                                                                      \
    }                                                                                                                                                                                                                                                          \
    extern "C" __global__ void OP##_fast_dim_no_include_##RUST_TYPE(Output##OP##_##RUST_TYPE *out, Output##OP##_##RUST_TYPE *buffer, IN_TYPE *in, int *finished, FastDivmod *shape, int *strides, size_t ndim, size_t reduce_size, size_t output_size)         \
    {                                                                                                                                                                                                                                                          \
        extern __shared__ Output##OP##_##RUST_TYPE OP##shared##RUST_TYPE[];                                                                                                                                                                                    \
        reduce_fast_dim_not_include<IN_TYPE, Output##OP##_##RUST_TYPE, REDUCE_OP, 32>(OP##shared##RUST_TYPE, out, buffer, in, finished, shape, strides, ndim, reduce_size, output_size);                                                                       \
    }

#define CONTIGUOUS_ARG_ALL_REDUCE(BLOCK_SIZE, IN_TYPE, REDUCE_OP)                                                                                                                     \
    case BLOCK_SIZE:                                                                                                                                                                  \
        all_reduce<                                                                                                                                                                   \
            ContiguousIndexCalculator<IN_TYPE>, IN_TYPE, REDUCE_OP##Result<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>(out, buffer, finished, size, ContiguousIndexCalculator<IN_TYPE>(in)); \
        break;

#define UNCONTIGUOUS_ARG_ALL_REDUCE(BLOCK_SIZE, IN_TYPE, REDUCE_OP)                                                                                                                                             \
    case BLOCK_SIZE:                                                                                                                                                                                            \
        all_reduce<                                                                                                                                                                                             \
            UncontiguousIndexCalculator<IN_TYPE>, IN_TYPE, REDUCE_OP##Result<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>(out, buffer, finished, size, UncontiguousIndexCalculator<IN_TYPE>(in, shape, strides, ndim)); \
        break;

#define CONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, IN_TYPE, REDUCE_OP)                                                         \
    case BLOCK_SIZE:                                                                                                              \
        reduce_fast_dim_only<IN_TYPE, REDUCE_OP##Result<IN_TYPE>, ContiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>( \
            out, in, buffer, finished, fast_dim_size, output_size, ContiguousIndexCalculator<IN_TYPE>(in), 1);                    \
        break;

#define UNCONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, IN_TYPE, REDUCE_OP)                                                                      \
    case BLOCK_SIZE:                                                                                                                             \
        reduce_fast_dim_only<IN_TYPE, REDUCE_OP##Result<IN_TYPE>, UncontiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>(              \
            out, in, buffer, finished, fast_dim_size, output_size, UncontiguousIndexCalculator<IN_TYPE>(in, shape, strides, ndim), last_stride); \
        break;

#define CONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, IN_TYPE, REDUCE_OP)                                                         \
    case BLOCK_SIZE:                                                                                                                    \
        reduce_small_fast_dim_only<IN_TYPE, REDUCE_OP##Result<IN_TYPE>, ContiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>( \
            out, in, fast_dim_size, output_size, ContiguousIndexCalculator<IN_TYPE>(in), 1);                                            \
        break;

#define UNCONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(BLOCK_SIZE, IN_TYPE, REDUCE_OP)                                                         \
    case BLOCK_SIZE:                                                                                                                      \
        reduce_small_fast_dim_only<IN_TYPE, REDUCE_OP##Result<IN_TYPE>, UncontiguousIndexCalculator<IN_TYPE>, REDUCE_OP, BLOCK_SIZE, 32>( \
            out, in, fast_dim_size, output_size, UncontiguousIndexCalculator<IN_TYPE>(in, shape, strides, ndim), last_stride);            \
        break;

#define DECLARE_ARG_KERNEL(IN_TYPE, RUST_TYPE, OP, REDUCE_OP)                                                                                                                                                                                                               \
    extern "C" __global__ void contiguous_##OP##_##RUST_TYPE(int64_t *out, REDUCE_OP##Result<IN_TYPE> *buffer, IN_TYPE *in, int32_t *finished, size_t size)                                                                                                                 \
    {                                                                                                                                                                                                                                                                       \
        switch (blockDim.x)                                                                                                                                                                                                                                                 \
        {                                                                                                                                                                                                                                                                   \
            CONTIGUOUS_ARG_ALL_REDUCE(32, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                              \
            CONTIGUOUS_ARG_ALL_REDUCE(64, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                              \
            CONTIGUOUS_ARG_ALL_REDUCE(128, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                             \
            CONTIGUOUS_ARG_ALL_REDUCE(256, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                             \
            CONTIGUOUS_ARG_ALL_REDUCE(512, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                             \
        default:                                                                                                                                                                                                                                                            \
            break;                                                                                                                                                                                                                                                          \
        }                                                                                                                                                                                                                                                                   \
    }                                                                                                                                                                                                                                                                       \
    extern "C" __global__ void uncontiguous_##OP##_##RUST_TYPE(int64_t *out, REDUCE_OP##Result<IN_TYPE> *buffer, IN_TYPE *in, int32_t *finished, FastDivmod *shape, int *strides, size_t ndim, size_t size)                                                                 \
    {                                                                                                                                                                                                                                                                       \
        switch (blockDim.x)                                                                                                                                                                                                                                                 \
        {                                                                                                                                                                                                                                                                   \
            UNCONTIGUOUS_ARG_ALL_REDUCE(32, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                            \
            UNCONTIGUOUS_ARG_ALL_REDUCE(64, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                            \
            UNCONTIGUOUS_ARG_ALL_REDUCE(128, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                           \
            UNCONTIGUOUS_ARG_ALL_REDUCE(256, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                           \
            UNCONTIGUOUS_ARG_ALL_REDUCE(512, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                           \
        default:                                                                                                                                                                                                                                                            \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                                 \
            {                                                                                                                                                                                                                                                               \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "uncontiguous_" #OP "_" #RUST_TYPE);                                                                                                                                                            \
            }                                                                                                                                                                                                                                                               \
            break;                                                                                                                                                                                                                                                          \
        }                                                                                                                                                                                                                                                                   \
    }                                                                                                                                                                                                                                                                       \
    extern "C" __global__ void contiguous_##OP##_fast_dim_only_##RUST_TYPE(int64_t *out, IN_TYPE *in, REDUCE_OP##Result<IN_TYPE> *buffer, int32_t *finished, size_t fast_dim_size, size_t output_size)                                                                      \
    {                                                                                                                                                                                                                                                                       \
        switch (blockDim.x)                                                                                                                                                                                                                                                 \
        {                                                                                                                                                                                                                                                                   \
            CONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(32, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                      \
            CONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(64, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                      \
            CONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(128, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                     \
            CONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(256, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                     \
            CONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(512, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                     \
        default:                                                                                                                                                                                                                                                            \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                                 \
            {                                                                                                                                                                                                                                                               \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "contiguous_" #OP "_fast_dim_only_" #RUST_TYPE);                                                                                                                                                \
            }                                                                                                                                                                                                                                                               \
            break;                                                                                                                                                                                                                                                          \
        }                                                                                                                                                                                                                                                                   \
    }                                                                                                                                                                                                                                                                       \
    extern "C" __global__ void uncontiguous_##OP##_fast_dim_only_##RUST_TYPE(int64_t *out, IN_TYPE *in, REDUCE_OP##Result<IN_TYPE> *buffer, int32_t *finished, FastDivmod *shape, int *strides, size_t ndim, size_t fast_dim_size, size_t output_size, int64_t last_stride) \
    {                                                                                                                                                                                                                                                                       \
        switch (blockDim.x)                                                                                                                                                                                                                                                 \
        {                                                                                                                                                                                                                                                                   \
            UNCONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(32, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                    \
            UNCONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(64, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                    \
            UNCONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(128, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                   \
            UNCONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(256, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                   \
            UNCONTIGUOUS_ARG_FAST_DIM_ONLY_ARMS(512, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                   \
        default:                                                                                                                                                                                                                                                            \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                                 \
            {                                                                                                                                                                                                                                                               \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "uncontiguous_" #OP "_fast_dim_only_" #RUST_TYPE);                                                                                                                                              \
            }                                                                                                                                                                                                                                                               \
            break;                                                                                                                                                                                                                                                          \
        }                                                                                                                                                                                                                                                                   \
    }                                                                                                                                                                                                                                                                       \
    extern "C" __global__ void contiguous_##OP##_small_fast_dim_only_##RUST_TYPE(int64_t *out, IN_TYPE *in, size_t fast_dim_size, size_t output_size)                                                                                                                       \
    {                                                                                                                                                                                                                                                                       \
        switch (blockDim.x)                                                                                                                                                                                                                                                 \
        {                                                                                                                                                                                                                                                                   \
            CONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(32, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                \
            CONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(64, IN_TYPE, REDUCE_OP);                                                                                                                                                                                                \
            CONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(128, IN_TYPE, REDUCE_OP);                                                                                                                                                                                               \
            CONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(256, IN_TYPE, REDUCE_OP);                                                                                                                                                                                               \
            CONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(512, IN_TYPE, REDUCE_OP);                                                                                                                                                                                               \
        default:                                                                                                                                                                                                                                                            \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                                 \
            {                                                                                                                                                                                                                                                               \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "contiguous_" #OP "_small_fast_dim_only_" #RUST_TYPE);                                                                                                                                          \
            }                                                                                                                                                                                                                                                               \
            break;                                                                                                                                                                                                                                                          \
        }                                                                                                                                                                                                                                                                   \
    }                                                                                                                                                                                                                                                                       \
    extern "C" __global__ void uncontiguous_##OP##_small_fast_dim_only_##RUST_TYPE(int64_t *out, IN_TYPE *in, FastDivmod *shape, int *strides, size_t ndim, size_t fast_dim_size, size_t output_size, int64_t last_stride)                                                  \
    {                                                                                                                                                                                                                                                                       \
        switch (blockDim.x)                                                                                                                                                                                                                                                 \
        {                                                                                                                                                                                                                                                                   \
            UNCONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(32, IN_TYPE, REDUCE_OP);                                                                                                                                                                                              \
            UNCONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(64, IN_TYPE, REDUCE_OP);                                                                                                                                                                                              \
            UNCONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(128, IN_TYPE, REDUCE_OP);                                                                                                                                                                                             \
            UNCONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(256, IN_TYPE, REDUCE_OP);                                                                                                                                                                                             \
            UNCONTIGUOUS_ARG_SMALL_FAST_DIM_ONLY_ARMS(512, IN_TYPE, REDUCE_OP);                                                                                                                                                                                             \
        default:                                                                                                                                                                                                                                                            \
            if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.y == 0)                                                                                                                                                                                 \
            {                                                                                                                                                                                                                                                               \
                printf("Error: Invalid block size %d for %s\n", blockDim.x, "uncontiguous_" #OP "_small_fast_dim_only_" #RUST_TYPE);                                                                                                                                        \
            }                                                                                                                                                                                                                                                               \
            break;                                                                                                                                                                                                                                                          \
        }                                                                                                                                                                                                                                                                   \
    }                                                                                                                                                                                                                                                                       \
    extern "C" __global__ void OP##_fast_dim_no_include_##RUST_TYPE(int64_t *out, REDUCE_OP##Result<IN_TYPE> *buffer, IN_TYPE *in, int *finished, FastDivmod *shape, int *strides, size_t ndim, size_t reduce_size, size_t output_size)                                     \
    {                                                                                                                                                                                                                                                                       \
        extern __shared__ REDUCE_OP##Result<IN_TYPE> OP##shared##RUST_TYPE[];                                                                                                                                                                                               \
        reduce_fast_dim_not_include<IN_TYPE, REDUCE_OP##Result<IN_TYPE>, REDUCE_OP, 32>(OP##shared##RUST_TYPE, out, buffer, in, finished, shape, strides, ndim, reduce_size, output_size);                                                                                  \
    }