from functools import lru_cache
from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import Const, apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core.ops import builtin
from ..core.ops.builtin import (
BatchNorm,
Dimshuffle,
Dropout,
Elemwise,
GetVarShape,
Identity,
Reduce,
Reshape,
TypeCvt,
)
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import (
astensor1d,
cast_tensors,
convert_single_value,
make_shape_tuple,
subgraph,
subgraph_fn,
)
from ..device import get_default_device
from ..distributed import WORLD, is_distributed
from ..jit import exclude_from_trace
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, log, log1p, maximum, minimum
from .math import matmul, max, sum
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
__all__ = [
"adaptive_avg_pool2d",
"adaptive_max_pool2d",
"avg_pool2d",
"batch_norm",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose2d",
"conv_transpose3d",
"deformable_conv2d",
"deformable_psroi_pooling",
"dropout",
"embedding",
"gelu",
"hsigmoid",
"hswish",
"indexing_one_hot",
"leaky_relu",
"linear",
"local_conv2d",
"local_response_norm",
"logsigmoid",
"logsumexp",
"logsoftmax",
"max_pool2d",
"one_hot",
"prelu",
"pad",
"relu",
"relu6",
"remap",
"sigmoid",
"sliding_window",
"sliding_window_transpose",
"silu",
"softmax",
"softplus",
"sync_batch_norm",
"warp_affine",
"warp_perspective",
"pixel_shuffle",
]
def expand_hw(x):
if isinstance(x, Sequence):
return int(x[0]), int(x[1])
return int(x), int(x)
def expand_dhw(x):
if isinstance(x, Sequence):
return int(x[0]), int(x[1]), int(x[2])
return int(x), int(x), int(x)
def linear(
inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
) -> Tensor:
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
if bias is not None:
if amp._enabled:
bias = bias.astype("float16")
ret += bias
return ret
def conv1d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
if bias is not None:
assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
stride_h = stride
pad_h = padding
dilate_h = dilation
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution(
stride_h=stride_h,
stride_w=1,
pad_h=pad_h,
pad_w=0,
dilate_h=dilate_h,
dilate_w=1,
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
format=conv_format,
)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
return output
def conv2d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Convolution(
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
format=conv_format,
)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
return output
def conv3d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
conv_mode: str = "cross_correlation",
) -> Tensor:
assert conv_mode.lower() == "cross_correlation"
D, H, W = 0, 1, 2
pad = expand_dhw(padding)
stride = expand_dhw(stride)
dilate = expand_dhw(dilation)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D(
pad_d=pad[D],
pad_h=pad[H],
pad_w=pad[W],
stride_d=stride[D],
stride_h=stride[H],
stride_w=stride[W],
dilate_d=dilate[D],
dilate_h=dilate[H],
dilate_w=dilate[W],
strategy=get_execution_strategy(),
mode=conv_mode,
sparse=sparse_type,
)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
return output
def conv_transpose2d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.ConvolutionBackwardData(
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_execution_strategy(),
compute_mode=compute_mode,
sparse=sparse_type,
)
(output,) = apply(op, weight, inp)
if bias is not None:
if amp._enabled:
bias = cast_tensors(bias)
output += bias
return output
def deformable_conv2d(
inp: Tensor,
weight: Tensor,
offset: Tensor,
mask: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
if amp._enabled:
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else:
offset = offset.astype("float32")
mask = mask.astype("float32")
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.DeformableConv(
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
(output,) = apply(op, inp, weight, offset, mask)
if bias is not None:
output += bias
return output
def local_conv2d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
conv_mode="cross_correlation",
):
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
op = builtin.GroupLocal(
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
mode=conv_mode,
sparse="dense",
)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
return output
def conv_transpose3d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
) -> Tensor:
D, H, W = 0, 1, 2
pad = expand_dhw(padding)
stride = expand_dhw(stride)
dilate = expand_dhw(dilation)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData(
pad_d=pad[D],
pad_h=pad[H],
pad_w=pad[W],
stride_d=stride[D],
stride_h=stride[H],
stride_w=stride[W],
dilate_d=dilate[D],
dilate_h=dilate[H],
dilate_w=dilate[W],
strategy=get_execution_strategy(),
sparse=sparse_type,
)
(output,) = apply(op, weight, inp)
if bias is not None:
output += bias
return output
def max_pool2d(
inp: Tensor,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Union[int, Tuple[int, int]] = 0,
) -> Tensor:
if stride is None:
stride = kernel_size
window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = expand_hw(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling(
window_h=window_h,
window_w=window_w,
stride_h=stride_h,
stride_w=stride_w,
pad_h=padding_h,
pad_w=padding_w,
mode="max",
format=conv_format,
)
(output,) = apply(op, inp)
return output
def avg_pool2d(
inp: Tensor,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Union[int, Tuple[int, int]] = 0,
mode: str = "average_count_exclude_padding",
) -> Tensor:
if stride is None:
stride = kernel_size
window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = expand_hw(padding)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.Pooling(
window_h=window_h,
window_w=window_w,
stride_h=stride_h,
stride_w=stride_w,
pad_h=padding_h,
pad_w=padding_w,
mode=mode,
format=conv_format,
)
(output,) = apply(op, inp)
return output
def adaptive_max_pool2d(
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
if isinstance(oshp, int):
oshp = (oshp, oshp)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.AdaptivePooling(mode="max", format=conv_format,)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
def adaptive_avg_pool2d(
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
if isinstance(oshp, int):
oshp = (oshp, oshp)
op = builtin.AdaptivePooling(mode="average", format="NCHW",)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
def deformable_psroi_pooling(
inp: Tensor,
rois: Tensor,
trans: Tensor,
no_trans: bool,
part_size: int,
pooled_h: int,
pooled_w: int,
sample_per_part: int,
spatial_scale: float,
trans_std: float = 0.1,
):
op = builtin.DeformablePSROIPooling(
no_trans=no_trans,
part_size=part_size,
pooled_h=pooled_h,
pooled_w=pooled_w,
sample_per_part=sample_per_part,
spatial_scale=spatial_scale,
trans_std=trans_std,
)
output, _ = apply(op, inp, rois, trans)
return output
def hswish(x):
return _elwise(x, mode=Elemwise.Mode.H_SWISH)
def sigmoid(x):
return _elwise(x, mode=Elemwise.Mode.SIGMOID)
@lru_cache(maxsize=None)
def _get_hsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"Hsigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def hsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
inp = f("+", inp, c(3))
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = f("/", min_6, c(6))
(oup_grad,) = yield (oup,)
inp_grad = f("/", oup_grad, c(6))
inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)
return hsigmoid
def hsigmoid(x):
hsigmoid = _get_hsigmoid_op(x.dtype, x.device)
(x,) = hsigmoid(x)
return x
def relu(x):
return _elwise(x, mode=Elemwise.Mode.RELU)
@lru_cache(maxsize=None)
def _get_relu6_op(dtype=None, device=None):
@subgraph_fn(
"ReLU6",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def relu6(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = min_6
(oup_grad,) = yield (oup,)
inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)
return relu6
def relu6(x):
relu6 = _get_relu6_op(x.dtype, x.device)
(x,) = relu6(x)
return x
@lru_cache(maxsize=None)
def _get_prelu_op(dtype=None, device=None):
@subgraph_fn(
"PReLU",
dtype=dtype,
device=device,
nr_inputs=2,
jit_fusion=True,
custom_grad=True,
)
def prelu(inputs, f, c):
(inp, weight) = inputs[0:2]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("fma3", min_0, weight, max_0)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, weight)
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
weight_grad = f("*", oup_grad, min_0)
yield (inp_grad, weight_grad)
return prelu
def prelu(inp: Tensor, weight: Tensor) -> Tensor:
prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device)
(oup,) = prelu(inp, broadcast_to(weight, inp.shape))
return oup
@lru_cache(maxsize=None)
def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None):
@subgraph_fn(
"LeakyReLU",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def leakyReLU(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("+", max_0, f("*", min_0, c(negative_slope)))
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, c(negative_slope))
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return leakyReLU
def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device)
(oup,) = leakyReLU(inp)
return oup
def silu(x):
return _elwise(x, mode=Elemwise.Mode.SILU)
def gelu(x):
return _elwise(x, mode=Elemwise.Mode.GELU)
@lru_cache(maxsize=None)
def _get_softplus_op(dtype=None, device=None):
@subgraph_fn(
"Softplus",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def softplus(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup0 = f("log1p", exp)
oup1 = f("relu", inp)
oup = f("+", oup0, oup1)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return softplus
def softplus(inp: Tensor) -> Tensor:
softplus = _get_softplus_op(inp.dtype, inp.device)
(oup,) = softplus(inp)
return oup
def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return inp - logsumexp(inp, axis, keepdims=True)
@lru_cache(maxsize=None)
def _get_logsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"LogSigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def logsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup0 = f("log1p", exp)
oup1 = f("relu", f("-", inp))
oup = f("+", oup0, oup1)
oup = f("-", oup)
(oup_grad,) = yield (oup,)
oup_grad = f("-", oup_grad)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_0 = f("-", inp_grad_0)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return logsigmoid
def logsigmoid(inp: Tensor) -> Tensor:
logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device)
(oup,) = logsigmoid(inp)
return oup
def logsumexp(
inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
) -> Tensor:
max_value = max(inp.detach(), axis, keepdims=True)
if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else:
return squeeze(max_value, axis=None) + log(
sum(exp(inp - max_value), axis, keepdims)
)
def _get_softmax_axis(ndim: int) -> int:
if ndim in (0, 1, 3):
return 0
return 1
def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
if axis is None:
axis = _get_softmax_axis(len(inp.shape))
if isinstance(axis, list):
offset = inp.max(axis=axis, keepdims=True).detach()
cached = exp(inp - offset)
down = sum(cached, axis=axis, keepdims=True)
return cached / down
else:
op = builtin.Softmax(axis=axis,)
(output,) = apply(op, inp)
return output
def layer_norm(
inp: Tensor,
normalized_shape: tuple,
affine: bool,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
):
if amp._enabled:
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True)
if isinstance(normalized_shape, int):
normalized_shape = [normalized_shape]
normalized_dim = len(normalized_shape)
assert normalized_dim > 0
normalized_size = 1
for i in range(normalized_dim):
normalized_size = normalized_size * normalized_shape[i]
op = builtin.LayerNorm(
affine=affine,
eps=eps,
normalized_dim=normalized_dim,
normalized_size=normalized_size,
)
if affine:
assert weight is not None and bias is not None
return apply(op, inp, weight, bias)[0]
else:
return apply(op, inp)[0]
def batch_norm(
inp: Tensor,
running_mean: Tensor = None,
running_var: Tensor = None,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
*,
training: bool = False,
momentum: float = 0.9,
eps: float = 1e-5,
inplace: bool = True,
compute_mode="default",
param_dim="dim_1c11"
):
def make_full_if_none(x, value):
x_ndim = None if x is None else x.ndim
if x_ndim is not None and x_ndim != 1:
return x
if param_dim == "dim_1c11":
C = inp.shape[1]
pshape = (1, C, 1, 1)
elif param_dim == "dim_111c":
C = inp.shape[3]
pshape = (1, 1, 1, C)
else:
raise ValueError("Invalid param_dim {}".format(param_dim))
if x is None:
x = Const(value, inp.dtype, inp.device, None)
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape)
return result
else:
assert x_ndim == 1
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Reshape(), x, shape)
return result
has_mean = running_mean is not None
has_var = running_var is not None
if not training:
assert has_mean, "running_mean must be provided in inference mode"
assert has_var, "running_var must be provided in inference mode"
weight = make_full_if_none(weight, 1)
bias = make_full_if_none(bias, 0)
if not training:
op = builtin.BatchNorm(
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim=param_dim
)
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret
else:
op = builtin.BatchNorm(
avg_factor=1 - momentum, epsilon=eps, param_dim=param_dim
)
if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0)
running_var = make_full_if_none(running_var, 1)
new_mean, new_var, *_, inp = apply(
op, inp, weight, bias, running_mean, running_var
)
if not has_mean:
new_mean = None
if not has_var:
new_var = None
if inplace:
if has_mean:
running_mean[...] = new_mean
if has_var:
running_var[...] = new_var
return inp
else:
return inp, new_mean, new_var
else:
inp = apply(op, inp, weight, bias)[-1]
return inp
@lru_cache(maxsize=None)
def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
@subgraph("SyncBnStage0", dtype, device, 1)
def syncbn_stage0(inputs, f, c):
input = inputs[0]
reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device)
input_shape = f(GetVarShape(), input)
input_elems = f(Reduce(mode="product", axis=0), input_shape)
reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
reduce_size = f("//", input_elems, reduce_elems)
channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
@subgraph("SyncBnStage1", dtype, device, 7)
def syncbn_stage1(inputs, f, c):
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
weight, bias = inputs[5:7]
channel_mean = f("/", channel_x1s, reduce_size)
channel_var =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("-", f("*", reduce_size, reduce_size))),
f("/", channel_x2s, reduce_size))
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("fma3", input, inv_var_wt,
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar, channel_mean, channel_var), (True, True, True)
@subgraph("SyncBnStage1Inference", dtype, device, 6)
def syncbn_stage1_inference(inputs, f, c):
input, channel_mean, channel_var, eps = inputs[0:4]
weight, bias = inputs[4:6]
invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean)
outvar =\
f("+", f("*", input, inv_var_wt),
f("+", f("*", neg_channel_mean, inv_var_wt),
bias))
return (outvar,), (True,)
@subgraph("SyncBnStage2", dtype, device, 7)
def syncbn_stage2(inputs, f, c):
running_mean, running_var, momentum = inputs[0:3]
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
c1_minus_momentum = f("-", c(1), momentum)
reduce_size_minus_c1 = f("-", reduce_size, c(1))
running_mean = f("fma4",
running_mean, momentum,
c1_minus_momentum, channel_mean,
)
channel_variance_unbiased =\
f("+", f("/", f("**", channel_x1s, c(2)),
f("*", f("-", reduce_size),
reduce_size_minus_c1)),
f("/", channel_x2s,
reduce_size_minus_c1))
running_var = f("fma4",
running_var, momentum,
c1_minus_momentum, channel_variance_unbiased
)
return (running_mean, running_var), (True, True)
@subgraph("SyncBnConcatStats", dtype, device, 3)
def syncbn_concat_stats(inputs, f, c):
reduce_size, channel_x1s, channel_x2s = inputs[0:3]
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
return (stats,), (True,)
@subgraph("SyncBnSplitStats", dtype, device, 1)
def syncbn_split_stats(inputs, f, c):
stats = inputs[0]
c_1 = c(1, dtype="int32")
channel_x1s_end = c(channels+1, dtype="int32")
def _subtensor(src, axis, begin, end):
items = (axis, (begin is not None), (end is not None), False, False),
args = ()
if begin is not None:
args += begin,
if end is not None:
args += end,
return f(builtin.Subtensor(items=items), src, *args)
reduce_size = _subtensor(stats, 1, None, c_1)
channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end)
channel_x2s = _subtensor(stats, 1, channel_x1s_end, None)
reduce_size = f(builtin.Reshape(), reduce_size, c_1)
return (reduce_size, channel_x1s, channel_x2s), (False, True, True)
return (
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
)
def sync_batch_norm(
inp: Tensor,
running_mean: Tensor,
running_var: Tensor,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
training: bool = False,
momentum: Union[float, Tensor] = 0.9,
eps: float = 1e-5,
eps_mode="additive",
group=WORLD,
) -> Tensor:
_eps_mode = eps_mode.lower()
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
if _eps_mode == "additive" and not (is_distributed() or training):
return batch_norm(
inp,
running_mean,
running_var,
weight,
bias,
training=training,
momentum=momentum,
eps=eps,
)
if amp._enabled:
inp, weight, bias, running_mean, running_var = cast_tensors(
inp, weight, bias, running_mean, running_var, promote=True
)
_channels = make_shape_tuple(inp.shape)[1]
_ndim = inp.ndim
_device = inp.device
_dtype = inp.dtype
if _ndim != 4:
raise NotImplementedError("sync_batch_norm for ndim != 4")
def _make_full_if_none(x, value):
if x is None:
x = Const(value, inp.dtype, _device, None)
(result,) = apply(builtin.Broadcast(), x, reduce_shape)
return result
elif x.ndim == 1:
(result,) = apply(builtin.Reshape(), x, reduce_shape)
return result
return x
(
syncbn_stage0,
syncbn_stage1,
syncbn_stage1_inference,
syncbn_stage2,
syncbn_concat_stats,
syncbn_split_stats,
) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)
reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp)
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
weight = _make_full_if_none(weight, 1)
bias = _make_full_if_none(bias, 0)
if training:
if is_distributed():
(stat,) = apply(
syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s
)
stat = all_reduce_sum(stat, group)
reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat)
outvar, channel_mean, *_ = apply(
syncbn_stage1(),
inp,
reduce_size,
channel_x1s,
channel_x2s,
eps,
weight,
bias,
)
else:
assert running_var is not None and running_mean is not None
channel_mean = running_mean
channel_var = running_var
outvar, *_ = apply(
syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias
)
if training and running_var is not None and running_mean is not None:
momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
running_mean[...], running_var[...] = apply(
syncbn_stage2(),
running_mean,
running_var,
momentum,
reduce_size,
channel_x1s,
channel_x2s,
channel_mean,
)
if amp._enabled:
outvar = outvar.astype("float16")
return outvar
def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
assert 0 <= drop_prob < 1
if not training or drop_prob == 0:
return inp
op = Dropout(drop_prob=drop_prob, seed=_get_global_rng_seed(), handle=0)
outputs = apply(op, inp)
return outputs[0]
def one_hot(inp: Tensor, num_classes: int) -> Tensor:
zeros_tensor = zeros(
list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device
)
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
op = builtin.IndexingSetOneHot(axis=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor)
return result
def embedding(
inp: Tensor,
weight: Tensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: Optional[float] = None,
):
if padding_idx is not None:
raise ValueError("Not support padding_idx Now!")
if max_norm is not None or norm_type is not None:
raise ValueError("Not support weight normlization Now!")
dest_shp = list(inp.shape) + [weight.shape[-1]]
return weight[inp.reshape(-1)].reshape(dest_shp)
def indexing_one_hot(
src: Tensor, index: Tensor, axis: int = 1, keepdims=False
) -> Tensor:
assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
index = convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = squeeze(result, axis)
return result
def sliding_window(
inp: Tensor,
kernel_size: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]] = 0,
stride: Union[int, Tuple[int, int]] = 1,
dilation: Union[int, Tuple[int, int]] = 1,
) -> Tensor:
padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = expand_hw(kernel_size)
op = builtin.Images2Neibs(
pad_h=padding_h,
pad_w=padding_w,
stride_h=stride_h,
stride_w=stride_w,
dilate_h=dilation_h,
dilate_w=dilation_w,
window_h=window_h,
window_w=window_w,
)
(output,) = apply(op, inp)
return output
def sliding_window_transpose(
inp: Tensor,
output_size: Union[int, Tuple[int, int]],
kernel_size: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]] = 0,
stride: Union[int, Tuple[int, int]] = 1,
dilation: Union[int, Tuple[int, int]] = 1,
) -> Tensor:
output_h, output_w = expand_hw(output_size)
padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = expand_hw(kernel_size)
expected_h = (
output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
) // stride_h + 1
expected_w = (
output_w + 2 * padding_w - dilation_w * (window_w - 1) - 1
) // stride_w + 1
assert inp.ndim == 6, "the input dimension of sliding_window_transpose should be 6"
assert (
inp.shape[2] == expected_h and inp.shape[3] == expected_w
), "the input shape and output size do not match"
op = builtin.SlidingWindowTranspose(
out_h=output_h,
out_w=output_w,
pad_h=padding_h,
pad_w=padding_w,
stride_h=stride_h,
stride_w=stride_w,
dilate_h=dilation_h,
dilate_w=dilation_w,
window_h=window_h,
window_w=window_w,
)
(output,) = apply(op, inp)
return output
def pad(
src: Tensor,
pad_width: Tuple[Tuple[int, int], ...],
mode: str = "constant",
constant_value: float = 0.0,
) -> Tensor:
p_offsets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
assert mode.lower() in ["constant", "edge", "replicate", "reflect"]
if mode.lower() == "edge":
mode = "replicate"
for i in range(0, len(pad_width)):
p_offsets[i * 2] = pad_width[i][0]
p_offsets[i * 2 + 1] = pad_width[i][1]
op = builtin.Padding(
front_offset_dim0=p_offsets[0],
front_offset_dim1=p_offsets[2],
front_offset_dim2=p_offsets[4],
front_offset_dim3=p_offsets[6],
front_offset_dim4=p_offsets[8],
front_offset_dim5=p_offsets[10],
front_offset_dim6=p_offsets[12],
back_offset_dim0=p_offsets[1],
back_offset_dim1=p_offsets[3],
back_offset_dim2=p_offsets[5],
back_offset_dim3=p_offsets[7],
back_offset_dim4=p_offsets[9],
back_offset_dim5=p_offsets[11],
back_offset_dim6=p_offsets[13],
padding_val=constant_value,
padding_mode=mode.upper(),
)
(output,) = apply(op, src)
return output
def local_response_norm(
inp: Tensor,
kernel_size: int = 5,
k: float = 2.0,
alpha: float = 1e-4,
beta: float = 0.75,
) -> Tensor:
op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,)
(output,) = apply(op, inp)
return output
@lru_cache(maxsize=None)
def _get_layerPixelShuffle(device, dtype, dim_order):
@subgraph("LayerPixelShuffle", dtype, device, 3)
def layerPixelShuffle(inputs, f, c):
inp, shape_0, shape_1 = inputs
inp = f(Reshape(), inp, shape_0)
inp = f(Dimshuffle(dim_order), inp)
oup = f(Reshape(), inp, shape_1)
return (oup,), (True,)
return layerPixelShuffle
def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
assert upscale_factor > 0, "upscale_factor should larger than 0"
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
assert (
inp.shape[-3] % (upscale_factor ** 2) == 0
), "the -3 dimension should be divided by (upscale_factor ** 2)"
_device = inp.device
_dtype = inp.dtype
shape_ori = inp.shape
high_dim = shape_ori[:-3]
square = upscale_factor ** 2
n = 1
for item in high_dim:
n *= item
shape_0 = (
n,
int(shape_ori[-3] / square),
upscale_factor,
upscale_factor,
shape_ori[-2],
shape_ori[-1],
)
shape_1 = (
*high_dim,
int(shape_ori[-3] / square),
shape_ori[-2] * upscale_factor,
shape_ori[-1] * upscale_factor,
)
dim_order = (0, 1, 4, 2, 5, 3)
layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)
shape_0 = convert_single_value(shape_0, device=inp.device)
shape_1 = convert_single_value(shape_1, device=inp.device)
outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)
return outvar
from .quantized import conv_bias_activation from .loss import * from .metric import * from .vision import *