megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/argmxx_helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once
#include "megdnn/dtype.h"

#if MEGDNN_CC_HOST
#include "megdnn/basic_types.h"
#endif

namespace megdnn {
namespace argmxx {

template <typename stype_, bool is_max>
struct ArgmxxOp {
    struct wtype {
        stype_ key;
        dt_int32 val;
        MEGDNN_HOST MEGDNN_DEVICE wtype() {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val)
                : key(key), val(val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs)
                : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs)
                : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs)
                : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile {
            this->key = rhs.key;
            this->val = rhs.val;
            return *this;
        }
    };
    MEGDNN_HOST MEGDNN_DEVICE
    ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C)
            : src(src),
              dst(dst),
              A(A),
              B(B),
              C(C),
              INIT(wtype(
                      is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(),
                      0)) {}
    MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
        wtype res;
        res.key = src[idx];
        res.val = idx / C % B;
        return res;
    }
    MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
        dst[idx] = val.val;
    }
    static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
        if (is_max) {
            if (lhs.key > rhs.key)
                return lhs;
            else
                return rhs;
        } else {
            if (lhs.key < rhs.key)
                return lhs;
            else
                return rhs;
        }
    }
    stype_* src;
    dt_int32* dst;
    uint32_t A, B, C;
    const wtype INIT;
};

template <bool is_max>
struct ArgmxxOp<dt_float32, is_max> {
    using stype_ = dt_float32;
    struct wtype {
        stype_ key;
        dt_int32 val;
        MEGDNN_HOST MEGDNN_DEVICE wtype() {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val)
                : key(key), val(val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs)
                : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs)
                : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs)
                : key(rhs.key), val(rhs.val) {}
        MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile {
            this->key = rhs.key;
            this->val = rhs.val;
            return *this;
        }
    };
    MEGDNN_HOST MEGDNN_DEVICE
    ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C)
            : src(src),
              dst(dst),
              A(A),
              B(B),
              C(C),
              INIT(wtype(
                      is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(),
                      0)) {}
    MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
        wtype res;
        res.key = src[idx];
        res.val = idx / C % B;
        return res;
    }
    MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
        dst[idx] = val.val;
    }
    static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
#if defined(__CUDA_ARCH__)
        if (isnan(lhs.key))
#else
        if (std::isnan(lhs.key))
#endif
            return lhs;
        if (is_max) {
            if (lhs.key > rhs.key)
                return lhs;
            else
                return rhs;
        } else {
            if (lhs.key < rhs.key)
                return lhs;
            else
                return rhs;
        }
    }
    stype_* src;
    dt_int32* dst;
    uint32_t A, B, C;
    const wtype INIT;
};

}  // namespace argmxx
}  // namespace megdnn
// vim: syntax=cpp.doxygen