import enum
import os.path
import shutil
import argparse
import platform
from library import *
from manifest import *
def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch=0):
cuda_version = [11, 0, 132]
if semantic_ver_string != "":
for i, x in enumerate([int(x) for x in semantic_ver_string.split(".")]):
if i < len(cuda_version):
cuda_version[i] = x
else:
cuda_version.append(x)
return cuda_version >= [major, minor, patch]
def CreateGemmOperator(
manifest,
layouts,
tile_descriptions,
data_type,
alignment_constraints,
complex_transforms=None,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
):
if complex_transforms is None:
complex_transforms = [(ComplexTransform.none, ComplexTransform.none)]
element_a, element_b, element_c, element_epilogue = data_type
operations = []
if manifest.args.kernels == "":
tile_descriptions = [tile_descriptions[0]]
alignment_constraints = [alignment_constraints[0]]
for layout in layouts:
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
for complex_transform in complex_transforms:
alignment_c = min(8, alignment)
A = TensorDescription(
element_a, layout[0], alignment, complex_transform[0]
)
B = TensorDescription(
element_b, layout[1], alignment, complex_transform[1]
)
C = TensorDescription(element_c, layout[2], alignment_c)
new_operation = GemmOperation(
GemmKind.Universal,
tile_description.minimum_compute_capability,
tile_description,
A,
B,
C,
element_epilogue,
epilogue_functor,
swizzling_functor,
)
manifest.append(new_operation)
operations.append(new_operation)
return operations
def CreateConv2dOperator(
manifest,
layout,
tile_descriptions,
data_type,
alignment,
conv_kinds=[ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad],
epilogue_functor=EpilogueFunctor.LinearCombination,
):
element_a, element_b, element_c, element_epilogue = data_type
alignment_c = min(8, alignment)
iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
if manifest.args.kernels == "":
tile_descriptions = [tile_descriptions[0]]
operations = []
for tile in tile_descriptions:
for conv_kind in conv_kinds:
for iterator_algorithm in iterator_algorithms:
A = TensorDescription(element_a, layout[0], alignment)
B = TensorDescription(element_b, layout[1], alignment)
C = TensorDescription(element_c, layout[2], alignment_c)
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
conv_kind == ConvKind.Dgrad
):
new_operation = Conv2dOperation(
conv_kind,
iterator_algorithm,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
StrideSupport.Unity,
epilogue_functor,
)
manifest.append(new_operation)
operations.append(new_operation)
if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
conv_kind == ConvKind.Dgrad
):
continue
new_operation = Conv2dOperation(
conv_kind,
iterator_algorithm,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
StrideSupport.Strided,
epilogue_functor,
)
manifest.append(new_operation)
operations.append(new_operation)
return operations
def GenerateConv2d_Simt(args):
operations = []
layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4)]
math_instructions = [
MathInstruction(
[1, 1, 4],
DataType.s8,
DataType.s8,
DataType.s32,
OpcodeClass.Simt,
MathOperation.multiply_add,
)
]
dst_layouts = [
LayoutType.TensorNC4HW4,
LayoutType.TensorNC32HW32,
LayoutType.TensorNHWC,
LayoutType.TensorNHWC,
LayoutType.TensorNCHW,
]
dst_types = [DataType.s8, DataType.s8, DataType.u4, DataType.s4, DataType.f32]
max_cc = 1024
for math_inst in math_instructions:
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
if dst_type == DataType.s4 or dst_type == DataType.u4:
min_cc = 75
use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt
else:
min_cc = 61
use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
tile_descriptions = [
TileDescription(
[128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
]
for tile in tile_descriptions:
if (
dst_layout == LayoutType.TensorNC32HW32
and tile.threadblock_shape[0] > 32
):
continue
if (
dst_layout == LayoutType.TensorNCHW
or dst_layout == LayoutType.TensorNHWC
) and tile.threadblock_shape[0] > 16:
continue
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
32,
32,
32,
use_special_optimization,
)
return operations
def GenerateConv2d_TensorOp_8816(args):
operations = []
layouts = [(LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32)]
math_instructions = [
MathInstruction(
[8, 8, 16],
DataType.s8,
DataType.s8,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
)
]
dst_layouts = [LayoutType.TensorNC32HW32, LayoutType.TensorNC4HW4]
dst_types = [DataType.s8, DataType.s8]
use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
min_cc = 75
max_cc = 1024
cuda_major = 10
cuda_minor = 2
for math_inst in math_instructions:
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
if dst_layout == LayoutType.TensorNC32HW32:
tile_descriptions = [
TileDescription(
[128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
]
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
tile_descriptions,
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
128,
128,
64,
use_special_optimization,
ImplicitGemmMode.GemmTN,
True,
cuda_major,
cuda_minor,
)
else:
assert dst_layout == LayoutType.TensorNC4HW4
tile_descriptions = [
TileDescription(
[64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc
),
]
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
tile_descriptions,
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
128,
128,
64,
use_special_optimization,
ImplicitGemmMode.GemmNT,
False,
cuda_major,
cuda_minor,
)
layouts_nhwc = [
(LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32),
(LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64),
(LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128),
]
dst_layouts_nhwc = [LayoutType.TensorNHWC]
for math_inst in math_instructions:
for layout in layouts_nhwc:
for dst_layout in dst_layouts_nhwc:
dst_type = math_inst.element_b
tile_descriptions = [
TileDescription(
[128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
]
for tile in tile_descriptions:
dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
False,
cuda_major,
cuda_minor,
)
if (
tile.threadblock_shape[1] == 16
or tile.threadblock_shape[1] == 32
):
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
True,
cuda_major,
cuda_minor,
)
out_dtypes = [DataType.s4, DataType.u4, DataType.f32]
for math_inst in math_instructions:
for layout in layouts_nhwc:
for dst_layout in dst_layouts_nhwc:
for out_dtype in out_dtypes:
tile_descriptions = [
TileDescription(
[128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
]
for tile in tile_descriptions:
dst_align = (
4 * DataTypeSize[out_dtype]
if tile.threadblock_shape[1] == 16
or out_dtype == DataType.f32
else 8 * DataTypeSize[out_dtype]
)
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
out_dtype,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
False,
cuda_major,
cuda_minor,
)
if tile.threadblock_shape[1] == 16 or (
tile.threadblock_shape[1] == 32
and out_dtype != DataType.f32
):
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
out_dtype,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
True,
cuda_major,
cuda_minor,
)
return operations
def GenerateConv2d_TensorOp_8832(args):
operations = []
layouts = [(LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64)]
math_instructions = [
MathInstruction(
[8, 8, 32],
DataType.s4,
DataType.s4,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
),
MathInstruction(
[8, 8, 32],
DataType.s4,
DataType.u4,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
),
]
dst_layouts = [LayoutType.TensorNC64HW64]
use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
min_cc = 75
max_cc = 1024
cuda_major = 10
cuda_minor = 2
for math_inst in math_instructions:
for layout in layouts:
for dst_layout in dst_layouts:
dst_type = math_inst.element_b
tile_descriptions = [
TileDescription(
[128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
]
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
tile_descriptions,
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
128,
128,
64,
use_special_optimization,
ImplicitGemmMode.GemmTN,
True,
cuda_major,
cuda_minor,
)
layouts_nhwc = [
(LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
(LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
(LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
]
dst_layouts_nhwc = [LayoutType.TensorNHWC]
for math_inst in math_instructions:
for layout in layouts_nhwc:
for dst_layout in dst_layouts_nhwc:
dst_type = math_inst.element_b
tile_descriptions = [
TileDescription(
[128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
]
for tile in tile_descriptions:
dst_align = 16 if tile.threadblock_shape[1] == 16 else 32
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
False,
cuda_major,
cuda_minor,
)
if (
tile.threadblock_shape[1] == 32
or tile.threadblock_shape[1] == 64
):
dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
True,
cuda_major,
cuda_minor,
)
for math_inst in math_instructions:
for layout in layouts_nhwc:
for dst_layout in dst_layouts_nhwc:
tile_descriptions = [
TileDescription(
[128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
]
for tile in tile_descriptions:
dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
DataType.s8,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
False,
cuda_major,
cuda_minor,
)
if (
tile.threadblock_shape[1] == 32
or tile.threadblock_shape[1] == 64
):
dst_align = 64 if tile.threadblock_shape[1] == 32 else 128
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Fprop,
[tile],
layout[0],
layout[1],
dst_layout,
DataType.s8,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
True,
cuda_major,
cuda_minor,
)
return operations
def GenerateDeconv_Simt(args):
operations = []
layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4)]
math_instructions = [
MathInstruction(
[1, 1, 4],
DataType.s8,
DataType.s8,
DataType.s32,
OpcodeClass.Simt,
MathOperation.multiply_add,
)
]
dst_layouts = [LayoutType.TensorNC4HW4]
dst_types = [DataType.s8]
use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
min_cc = 61
max_cc = 1024
for math_inst in math_instructions:
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
tile_descriptions = [
TileDescription(
[32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
]
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Dgrad,
tile_descriptions,
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
32,
32,
32,
use_special_optimization,
)
return operations
def GenerateDeconv_TensorOp_8816(args):
operations = []
layouts = [
(LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
(LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
(LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
]
math_instructions = [
MathInstruction(
[8, 8, 16],
DataType.s8,
DataType.s8,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
)
]
dst_layouts = [LayoutType.TensorNHWC]
dst_types = [DataType.s8]
use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
min_cc = 75
max_cc = 1024
cuda_major = 10
cuda_minor = 2
for math_inst in math_instructions:
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
tile_descriptions = [
TileDescription(
[128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
),
]
for tile in tile_descriptions:
dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
operations += GenerateConv2d(
ConvType.Convolution,
ConvKind.Dgrad,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
layout[2],
layout[2],
dst_align,
use_special_optimization,
ImplicitGemmMode.GemmTN,
False,
cuda_major,
cuda_minor,
)
return operations
warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
warpsPerThreadblockRatio = 2
warpsPerThreadblockMax = 16
warpShapeEdges = [8, 16, 32, 64, 128, 256]
warpShapeRatio = 4
warpShapeMax = 64 * 64
warpShapeMin = 8 * 8
threadblockEdgeMax = 256
precisions = {
"c": ["cutlass::complex<float>", 64, 64 * 128, [[64, 128], [64, 32]]],
"d": ["double", 64, 64 * 64, [[64, 64], [32, 32]]],
"h": ["cutlass::half_t", 16, 128 * 256, [[256, 128], [64, 128], [64, 32]]],
"i": ["int", 32, 128 * 128, [[128, 64], [16, 32]]],
"s": ["float", 32, 128 * 128, [[128, 256], [128, 128], [64, 64]]],
"z": ["cutlass::complex<double>", 128, 64 * 64, [[32, 64], [16, 32]]],
}
def GenerateGemm_Simt(args):
warpsPerThreadblocks = []
for warpsPerThreadblock0 in warpsPerThreadblockEdge:
for warpsPerThreadblock1 in warpsPerThreadblockEdge:
if (
warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
and warpsPerThreadblock1 / warpsPerThreadblock0
<= warpsPerThreadblockRatio
and warpsPerThreadblock0 * warpsPerThreadblock1
<= warpsPerThreadblockMax
):
warpsPerThreadblocks.append(
[warpsPerThreadblock0, warpsPerThreadblock1]
)
warpNumThreads = 32
warpShapes = []
for warp0 in warpShapeEdges:
for warp1 in warpShapeEdges:
if (
warp0 / warp1 <= warpShapeRatio
and warp1 / warp0 <= warpShapeRatio
and warp0 * warp1 <= warpShapeMax
and warp0 * warp1 > warpShapeMin
):
warpShapes.append([warp0, warp1])
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
"s"
]
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), ]
math_instructions = [
MathInstruction(
[1, 1, 1],
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.Simt,
MathOperation.multiply_add,
)
]
min_cc = 50
max_cc = 1024
operations = []
for math_inst in math_instructions:
for layout in layouts:
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
tile_descriptions = [
TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
]
for warpsPerThreadblock in warpsPerThreadblocks:
for warpShape in warpShapes:
warpThreadsM = 0
if warpShape[0] > warpShape[1]:
warpThreadsM = 8
else:
warpThreadsM = 4
warpThreadsN = warpNumThreads / warpThreadsM
blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
warpG = warpShape[0] > warpShape[1]
warpL = warpShape[0] < warpShape[1]
blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
warpG2 = warpShape[0] > warpShape[1] * 2
warpL2 = warpShape[0] * 2 < warpShape[1]
if blockG2 and warpL:
continue
if blockL2 and warpG:
continue
if warpG2 and blockL:
continue
if warpL2 and blockG:
continue
threadblockTile = [
warpShape[0] * warpsPerThreadblock[0],
warpShape[1] * warpsPerThreadblock[1],
]
if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
continue
if threadblockTile[0] > threadblockEdgeMax:
continue
if threadblockTile[1] > threadblockEdgeMax:
continue
totalThreads = (
warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
)
unrollMin = 8
unrollMin0 = totalThreads // threadblockTile[0]
unrollMin1 = totalThreads // threadblockTile[1]
unroll = max(unrollMin, unrollMin0, unrollMin1)
threadTileM = warpShape[0] // warpThreadsM
threadTileN = warpShape[1] // warpThreadsN
if threadTileM < 2 or threadTileN < 2:
continue
if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
continue
if threadblockTile[1] < warpNumThreads:
continue
smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
if smemKBytes > 48:
continue
tile = TileDescription(
[threadblockTile[0], threadblockTile[1], unroll],
2,
[
threadblockTile[0] // warpShape[0],
threadblockTile[1] // warpShape[1],
1,
],
math_inst,
min_cc,
max_cc,
)
def filter(t: TileDescription) -> bool:
nonlocal tile
return (
t.threadblock_shape[0] == tile.threadblock_shape[0]
and t.threadblock_shape[1] == tile.threadblock_shape[1]
and t.threadblock_shape[2] == tile.threadblock_shape[2]
and t.warp_count[0] == tile.warp_count[0]
and t.warp_count[1] == tile.warp_count[1]
and t.warp_count[2] == tile.warp_count[2]
and t.stages == tile.stages
)
if not any(t for t in tile_descriptions if filter(t)):
continue
operations += GeneratesGemm(
tile, data_type, layout[0], layout[1], layout[2], min_cc
)
return operations
def GenerateDwconv2d_Simt(args, conv_kind):
warpsPerThreadblocks = []
for warpsPerThreadblock0 in warpsPerThreadblockEdge:
for warpsPerThreadblock1 in warpsPerThreadblockEdge:
if (
warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
and warpsPerThreadblock1 / warpsPerThreadblock0
<= warpsPerThreadblockRatio
and warpsPerThreadblock0 * warpsPerThreadblock1
<= warpsPerThreadblockMax
):
warpsPerThreadblocks.append(
[warpsPerThreadblock0, warpsPerThreadblock1]
)
warpNumThreads = 32
warpShapes = []
for warp0 in warpShapeEdges:
for warp1 in warpShapeEdges:
if (
warp0 / warp1 <= warpShapeRatio
and warp1 / warp0 <= warpShapeRatio
and warp0 * warp1 <= warpShapeMax
and warp0 * warp1 > warpShapeMin
):
warpShapes.append([warp0, warp1])
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
"s"
]
layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
math_instructions = [
MathInstruction(
[1, 1, 1],
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.Simt,
MathOperation.multiply_add,
)
]
min_cc = 50
max_cc = 1024
dst_layouts = [LayoutType.TensorNCHW]
dst_types = [DataType.f32]
if conv_kind == ConvKind.Wgrad:
alignment_constraints = [32]
else:
alignment_constraints = [128, 32]
operations = []
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
]
for warpsPerThreadblock in warpsPerThreadblocks:
for warpShape in warpShapes:
warpThreadsM = 0
if warpShape[0] > warpShape[1]:
warpThreadsM = 8
else:
warpThreadsM = 4
warpThreadsN = warpNumThreads / warpThreadsM
blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
warpG = warpShape[0] > warpShape[1]
warpL = warpShape[0] < warpShape[1]
blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
warpG2 = warpShape[0] > warpShape[1] * 2
warpL2 = warpShape[0] * 2 < warpShape[1]
if blockG2 and warpL:
continue
if blockL2 and warpG:
continue
if warpG2 and blockL:
continue
if warpL2 and blockG:
continue
threadblockTile = [
warpShape[0] * warpsPerThreadblock[0],
warpShape[1] * warpsPerThreadblock[1],
]
if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
continue
if threadblockTile[0] > threadblockEdgeMax:
continue
if threadblockTile[1] > threadblockEdgeMax:
continue
totalThreads = (
warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
)
unrollMin = 8
unrollMin0 = totalThreads // threadblockTile[0]
unrollMin1 = totalThreads // threadblockTile[1]
unroll = max(unrollMin, unrollMin0, unrollMin1)
threadTileM = warpShape[0] // warpThreadsM
threadTileN = warpShape[1] // warpThreadsN
if threadTileM < 2 or threadTileN < 2:
continue
if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
continue
if threadblockTile[1] < warpNumThreads:
continue
smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
if smemKBytes > 48:
continue
tile = TileDescription(
[threadblockTile[0], threadblockTile[1], unroll],
2,
[
threadblockTile[0] // warpShape[0],
threadblockTile[1] // warpShape[1],
1,
],
math_inst,
min_cc,
max_cc,
)
def filter(t: TileDescription) -> bool:
nonlocal tile
return (
t.threadblock_shape[0] == tile.threadblock_shape[0]
and t.threadblock_shape[1] == tile.threadblock_shape[1]
and t.threadblock_shape[2] == tile.threadblock_shape[2]
and t.warp_count[0] == tile.warp_count[0]
and t.warp_count[1] == tile.warp_count[1]
and t.warp_count[2] == tile.warp_count[2]
and t.stages == tile.stages
)
if not any(t for t in tile_descriptions if filter(t)):
continue
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
for alignment_src in alignment_constraints:
operations += GenerateConv2d(
ConvType.DepthwiseConvolution,
conv_kind,
[tile],
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
alignment_src,
32,
32,
SpecialOptimizeDesc.NoneSpecialOpt,
ImplicitGemmMode.GemmNT
if conv_kind == ConvKind.Wgrad
else ImplicitGemmMode.GemmTN,
)
return operations
def GenerateDwconv2d_TensorOp_884(args, conv_kind):
layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
math_instructions = [
MathInstruction(
[8, 8, 4],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[8, 8, 4],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 70
max_cc = 75
dst_layouts = [LayoutType.TensorNCHW]
if conv_kind == ConvKind.Wgrad:
dst_types = [DataType.f32]
else:
dst_types = [DataType.f16]
alignment_constraints = [128, 32, 16]
cuda_major = 10
cuda_minor = 1
operations = []
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 32], 2, [4, 4, 1], math_inst, min_cc, max_cc),
TileDescription([64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
]
for layout in layouts:
for dst_type, dst_layout in zip(dst_types, dst_layouts):
for alignment_src in alignment_constraints:
if conv_kind == ConvKind.Wgrad:
if math_inst.element_accumulator == DataType.f16:
continue
for alignment_diff in alignment_constraints:
operations += GenerateConv2d(
ConvType.DepthwiseConvolution,
conv_kind,
tile_descriptions,
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
alignment_src,
alignment_diff,
32, SpecialOptimizeDesc.NoneSpecialOpt,
ImplicitGemmMode.GemmNT,
False,
cuda_major,
cuda_minor,
)
else:
operations += GenerateConv2d(
ConvType.DepthwiseConvolution,
conv_kind,
tile_descriptions,
layout[0],
layout[1],
dst_layout,
dst_type,
min_cc,
alignment_src,
16,
16,
SpecialOptimizeDesc.NoneSpecialOpt,
ImplicitGemmMode.GemmTN,
False,
cuda_major,
cuda_minor,
)
return operations
def GenerateGemv_Simt(args):
threadBlockShape_N = [128, 64, 32]
ldgBits_A = [128, 64, 32]
ldgBits_B = [128, 64, 32]
layouts = [(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor)]
math_instructions = [
MathInstruction(
[1, 1, 1],
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.Simt,
MathOperation.multiply_add,
)
]
min_cc = 50
operations = []
for math_inst in math_instructions:
for layout in layouts:
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
for threadblock_shape_n in threadBlockShape_N:
for align_a in ldgBits_A:
for align_b in ldgBits_B:
ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
threadblock_shape_k = (256 * ldg_elements_a) // (
threadblock_shape_n // ldg_elements_b
)
threadblock_shape = [
1,
threadblock_shape_n,
threadblock_shape_k,
]
thread_shape = [1, ldg_elements_b, ldg_elements_a]
operations.append(
GeneratesGemv(
math_inst,
threadblock_shape,
thread_shape,
data_type,
layout[0],
layout[1],
layout[2],
min_cc,
align_a,
align_b,
)
)
return operations
def GeneratesGemm_TensorOp_1688(args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), ]
math_instructions = [
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[16, 8, 8],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 75
max_cc = 1024
alignment_constraints = [
8,
4,
2,
]
cuda_major = 10
cuda_minor = 2
operations = []
for math_inst in math_instructions:
for layout in layouts:
for align in alignment_constraints:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
for tile in tile_descriptions:
operations += GeneratesGemm(
tile,
data_type,
layout[0],
layout[1],
layout[2],
min_cc,
align * 16,
align * 16,
align * 16,
cuda_major,
cuda_minor,
)
return operations
def GeneratesGemm_TensorOp_884(args):
layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), ]
math_instructions = [
MathInstruction(
[8, 8, 4],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[8, 8, 4],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 70
max_cc = 75
alignment_constraints = [
8,
4,
2,
]
cuda_major = 10
cuda_minor = 1
operations = []
for math_inst in math_instructions:
for layout in layouts:
for align in alignment_constraints:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
for tile in tile_descriptions:
operations += GeneratesGemm(
tile,
data_type,
layout[0],
layout[1],
layout[2],
min_cc,
align * 16,
align * 16,
align * 16,
cuda_major,
cuda_minor,
)
return operations
def GenerateConv2dOperations(args):
if args.type == "simt":
return GenerateConv2d_Simt(args)
elif args.type == "tensorop8816":
return GenerateConv2d_TensorOp_8816(args)
else:
assert args.type == "tensorop8832", (
"operation conv2d only support"
"simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
)
return GenerateConv2d_TensorOp_8832(args)
def GenerateDeconvOperations(args):
if args.type == "simt":
return GenerateDeconv_Simt(args)
else:
assert args.type == "tensorop8816", (
"operation deconv only support"
"simt and tensorop8816. (got:{})".format(args.type)
)
return GenerateDeconv_TensorOp_8816(args)
def GenerateDwconv2dFpropOperations(args):
if args.type == "simt":
return GenerateDwconv2d_Simt(args, ConvKind.Fprop)
else:
assert args.type == "tensorop884", (
"operation dwconv2d fprop only support"
"simt, tensorop884. (got:{})".format(args.type)
)
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Fprop)
def GenerateDwconv2dDgradOperations(args):
if args.type == "simt":
return GenerateDwconv2d_Simt(args, ConvKind.Dgrad)
else:
assert args.type == "tensorop884", (
"operation dwconv2d fprop only support"
"simt, tensorop884. (got:{})".format(args.type)
)
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad)
def GenerateDwconv2dWgradOperations(args):
if args.type == "simt":
return GenerateDwconv2d_Simt(args, ConvKind.Wgrad)
else:
assert args.type == "tensorop884", (
"operation dwconv2d fprop only support"
"simt, tensorop884. (got:{})".format(args.type)
)
return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad)
def GenerateGemmOperations(args):
if args.type == "tensorop884":
return GeneratesGemm_TensorOp_884(args)
elif args.type == "tensorop1688":
return GeneratesGemm_TensorOp_1688(args)
else:
assert (
args.type == "simt"
), "operation gemm only support" "simt. (got:{})".format(args.type)
return GenerateGemm_Simt(args)
def GenerateGemvOperations(args):
assert args.type == "simt", "operation gemv only support" "simt. (got:{})".format(
args.type
)
return GenerateGemv_Simt(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generates device kernel registration code for CUTLASS Kernels"
)
parser.add_argument(
"--operations",
type=str,
choices=[
"gemm",
"gemv",
"conv2d",
"deconv",
"dwconv2d_fprop",
"dwconv2d_dgrad",
"dwconv2d_wgrad",
],
required=True,
help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)",
)
parser.add_argument(
"output", type=str, help="output directory for CUTLASS kernel files"
)
parser.add_argument(
"--type",
type=str,
choices=["simt", "tensorop8816", "tensorop8832", "tensorop884", "tensorop1688"],
default="simt",
help="kernel type of CUTLASS kernel generator",
)
gemv_wrapper_path = (
"src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
)
short_path = (
platform.system() == "Windows" or platform.system().find("NT") >= 0
) and ("true" != os.getenv("CUTLASS_WITH_LONG_PATH", default="False").lower())
args = parser.parse_args()
if args.operations == "gemm":
operations = GenerateGemmOperations(args)
elif args.operations == "gemv":
operations = GenerateGemvOperations(args)
elif args.operations == "conv2d":
operations = GenerateConv2dOperations(args)
elif args.operations == "deconv":
operations = GenerateDeconvOperations(args)
elif args.operations == "dwconv2d_fprop":
operations = GenerateDwconv2dFpropOperations(args)
elif args.operations == "dwconv2d_dgrad":
operations = GenerateDwconv2dDgradOperations(args)
else:
assert args.operations == "dwconv2d_wgrad", "invalid operation"
operations = GenerateDwconv2dWgradOperations(args)
if (
args.operations == "conv2d"
or args.operations == "deconv"
or args.operations == "dwconv2d_fprop"
or args.operations == "dwconv2d_dgrad"
or args.operations == "dwconv2d_wgrad"
):
for operation in operations:
with EmitConvSingleKernelWrapper(
args.output, operation, short_path
) as emitter:
emitter.emit()
elif args.operations == "gemm":
for operation in operations:
with EmitGemmSingleKernelWrapper(
args.output, operation, short_path
) as emitter:
emitter.emit()
elif args.operations == "gemv":
for operation in operations:
with EmitGemvSingleKernelWrapper(
args.output, operation, gemv_wrapper_path, short_path
) as emitter:
emitter.emit()
if args.operations != "gemv":
GenerateManifest(args, operations, args.output)