megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/rounding_converter.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once
#include "megdnn/dtype.h"

namespace megdnn {
namespace rounding {

template <typename T>
struct RoundingConverter;

template <>
struct RoundingConverter<float> {
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()(float x) const {
        return x;
    }
};

#ifndef MEGDNN_DISABLE_FLOAT16

template <>
struct RoundingConverter<half_float::half> {
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_float::half operator()(
            float x) const {
        return static_cast<half_float::half>(x);
    }
};

template <>
struct RoundingConverter<half_bfloat16::bfloat16> {
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16 operator()(
            float x) const {
        return static_cast<half_bfloat16::bfloat16>(x);
    }
};

#endif  // #ifdef MEGDNN_DISABLE_FLOAT16

template <>
struct RoundingConverter<int8_t> {
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t operator()(float x) const {
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<int8_t>(round(x));
    }
};

template <>
struct RoundingConverter<uint8_t> {
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t operator()(float x) const {
#if MEGDNN_CC_HOST
        using std::max;
        using std::min;
        using std::round;
#endif
        x = min(255.0f, max(0.0f, x));  //! FIXME!!! check other places
        return static_cast<uint8_t>(round(x));
    }
};

template <>
struct RoundingConverter<dt_qint4> {
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4 operator()(float x) const {
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<dt_qint4>(round(x));
    }
};

template <>
struct RoundingConverter<dt_quint4> {
    MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4 operator()(float x) const {
#if MEGDNN_CC_HOST
        using std::round;
#endif
        return static_cast<dt_quint4>(round(x));
    }
};

}  // namespace rounding
}  // namespace megdnn

/* vim: set ft=cpp: */