from typing import Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..tensor import Tensor
from ..utils.tuple_function import _pair, _pair_nonzero
from .debug_param import get_execution_strategy
def conv_bias_activation(
inp: Tensor,
weight: Tensor,
bias: Tensor,
dtype=None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
nonlinear_mode="identity",
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.ConvBias(
stride_h=sh,
stride_w=sw,
pad_h=ph,
pad_w=pw,
dilate_h=dh,
dilate_w=dw,
dtype=dtype,
format=conv_format,
strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode,
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
(outputs,) = apply(op, inp, weight, bias)
return outputs
def batch_conv_bias_activation(
inp: Tensor,
weight: Tensor,
bias: Tensor,
dtype=None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
nonlinear_mode="identity",
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
op = builtin.BatchConvBias(
stride_h=sh,
stride_w=sw,
pad_h=ph,
pad_w=pw,
dilate_h=dh,
dilate_w=dw,
dtype=dtype,
format=conv_format,
strategy=get_execution_strategy(),
nonlineMode=nonlinear_mode,
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
(outputs,) = apply(op, inp, weight, bias)
return outputs
def conv_transpose2d(
inp: Tensor,
weight: Tensor,
bias: Tensor = None,
dtype=None,
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if groups != 1:
raise NotImplementedError(
"group quantized transposed conv2d is not supported yet."
)
if bias is not None:
raise NotImplementedError(
"bias of quantized transposed conv2d is not supported yet."
)
pad_h, pad_w = _pair(padding)
stride_h, stride_w = _pair_nonzero(stride)
dilate_h, dilate_w = _pair_nonzero(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
op = builtin.ConvolutionBackwardData(
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
dilate_h=dilate_h,
dilate_w=dilate_w,
strategy=get_execution_strategy(),
dtype=dtype,
compute_mode=compute_mode,
mode=conv_mode,
)
(output,) = apply(op, weight, inp)
return output