megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/cuda/relayout_format/cuda_post_process.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#pragma once
#include "src/cuda/relayout_format/relayout_format.cuh"

namespace megdnn {
namespace cuda {
namespace relayout_format {
namespace internal {
template <typename SrcType, typename DstType, bool same_scale>
struct CudaPostProcess;

template <>
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, true> {
    CudaPostProcess(float, uint8_t, float, uint8_t){};
    inline __device__ int8_t operator()(uint8_t val) { return val - 128; }
};

template <>
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, false> {
    CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
    CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) {
        m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
    };
    inline __device__ int8_t operator()(uint8_t val) {
        return m_dst_type_cvt.quantize((float)val - 128.f).as_int8();
    }
};

template <>
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, false> {
    CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
    CudaDTypeParamImpl<dt_quint8> m_src_type_cvt;
    CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, uint8_t) {
        m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
        m_src_type_cvt = CudaDTypeParamImpl<dt_quint8>(src_scale, src_zero_point);
    };
    inline __device__ int8_t operator()(uint8_t val) {
        float med_var = m_src_type_cvt.dequantize(dt_quint8(val));
        return m_dst_type_cvt.quantize(med_var).as_int8();
    }
};

template <>
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, true> {
    uint8_t m_src_zero_point = 0;
    CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) {
        m_src_zero_point = src_zero_point;
    };
    inline __device__ int8_t operator()(uint8_t val) { return val - m_src_zero_point; }
};

template <>
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, false> {
    CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
    CudaDTypeParamImpl<dt_qint8> m_src_type_cvt;
    CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) {
        m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
        m_src_type_cvt = CudaDTypeParamImpl<dt_qint8>(src_scale);
    };
    inline __device__ int8_t operator()(int8_t val) {
        float med_var = m_src_type_cvt.dequantize(dt_qint8(val));
        return m_dst_type_cvt.quantize(med_var).as_int8();
    }
};

template <>
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, true> {
    CudaPostProcess(){};
    CudaPostProcess(float, uint8_t, float, uint8_t){};
    inline __device__ int8_t operator()(int8_t val) { return val; }
};

template <>
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, false> {
    CudaDTypeParamImpl<dt_qint32> m_dst_type_cvt;
    CudaDTypeParamImpl<dt_qint32> m_src_type_cvt;
    CudaPostProcess(float src_scale, int, float dst_scale, int) {
        m_dst_type_cvt = CudaDTypeParamImpl<dt_qint32>(dst_scale);
        m_src_type_cvt = CudaDTypeParamImpl<dt_qint32>(src_scale);
    };
    inline __device__ int operator()(int val) {
        float med_var = m_src_type_cvt.dequantize(dt_qint32(val));
        return m_dst_type_cvt.quantize(med_var).as_int32();
    }
};
template <>
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, true> {
    CudaPostProcess(float, int, float, int){};
    inline __device__ int operator()(int val) { return val; }
};

template <>
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, false> {
    using SrcType = dtype::QuantizedS4;
    using DstType = dtype::QuantizedS4;
    CudaDTypeParamImpl<dt_qint4> m_dst_type_cvt;
    CudaDTypeParamImpl<dt_qint4> m_src_type_cvt;
    CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) {
        m_dst_type_cvt = CudaDTypeParamImpl<dt_qint4>(dst_scale);
        m_src_type_cvt = CudaDTypeParamImpl<dt_qint4>(src_scale);
    }
    inline __device__ int8_t operator()(int8_t val) {
        float intermediate = m_src_type_cvt.dequantize(dt_qint4(val));
        return m_dst_type_cvt.quantize(intermediate).as_int8();
    }
};

template <>
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, true> {
    using SrcType = dtype::QuantizedS4;
    using DstType = dtype::QuantizedS4;
    CudaPostProcess(float, uint8_t, float, uint8_t){};
    inline __device__ int8_t operator()(int8_t val) { return val; }
};

template <>
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, false> {
    using SrcType = dtype::Quantized4Asymm;
    using DstType = dtype::Quantized4Asymm;
    CudaDTypeParamImpl<dt_quint4> m_dst_type_cvt;
    CudaDTypeParamImpl<dt_quint4> m_src_type_cvt;
    CudaPostProcess(
            float src_scale, uint8_t src_zero_point, float dst_scale,
            uint8_t dst_zero_point) {
        m_dst_type_cvt = CudaDTypeParamImpl<dt_quint4>(dst_scale, dst_zero_point);
        m_src_type_cvt = CudaDTypeParamImpl<dt_quint4>(src_scale, src_zero_point);
    };
    inline __device__ uint8_t operator()(uint8_t val) {
        float intermediate = m_src_type_cvt.dequantize(dt_quint4(val));
        return m_dst_type_cvt.quantize(intermediate).as_uint8();
    }
};

template <>
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, true> {
    using SrcType = dtype::Quantized4Asymm;
    using DstType = dtype::Quantized4Asymm;
    uint8_t m_src_zero_point = 0;
    uint8_t m_dst_zero_point = 0;
    CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t dst_zero_point) {
        m_src_zero_point = src_zero_point;
        m_dst_zero_point = dst_zero_point;
    };
    inline __device__ uint8_t operator()(uint8_t val) {
        int result = val - m_src_zero_point + m_dst_zero_point;
        result = result >= 0 ? result : 0;
        result = result < 16 ? result : 15;
        return static_cast<uint8_t>(result);
    }
};

}  // namespace internal
}  // namespace relayout_format
}  // namespace cuda
}  // namespace megdnn