#pragma once
#include "megdnn/dtype.h"
#if MEGDNN_CC_HOST
#include "megdnn/basic_types.h"
#endif
namespace megdnn {
namespace argmxx {
template <typename stype_, bool is_max>
struct ArgmxxOp {
struct wtype {
stype_ key;
dt_int32 val;
MEGDNN_HOST MEGDNN_DEVICE wtype() {}
MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val)
: key(key), val(val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile {
this->key = rhs.key;
this->val = rhs.val;
return *this;
}
};
MEGDNN_HOST MEGDNN_DEVICE
ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C)
: src(src),
dst(dst),
A(A),
B(B),
C(C),
INIT(wtype(
is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(),
0)) {}
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
wtype res;
res.key = src[idx];
res.val = idx / C % B;
return res;
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = val.val;
}
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
if (is_max) {
if (lhs.key > rhs.key)
return lhs;
else
return rhs;
} else {
if (lhs.key < rhs.key)
return lhs;
else
return rhs;
}
}
stype_* src;
dt_int32* dst;
uint32_t A, B, C;
const wtype INIT;
};
template <bool is_max>
struct ArgmxxOp<dt_float32, is_max> {
using stype_ = dt_float32;
struct wtype {
stype_ key;
dt_int32 val;
MEGDNN_HOST MEGDNN_DEVICE wtype() {}
MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val)
: key(key), val(val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile {
this->key = rhs.key;
this->val = rhs.val;
return *this;
}
};
MEGDNN_HOST MEGDNN_DEVICE
ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C)
: src(src),
dst(dst),
A(A),
B(B),
C(C),
INIT(wtype(
is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(),
0)) {}
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
wtype res;
res.key = src[idx];
res.val = idx / C % B;
return res;
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = val.val;
}
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
#if defined(__CUDA_ARCH__)
if (isnan(lhs.key))
#else
if (std::isnan(lhs.key))
#endif
return lhs;
if (is_max) {
if (lhs.key > rhs.key)
return lhs;
else
return rhs;
} else {
if (lhs.key < rhs.key)
return lhs;
else
return rhs;
}
}
stype_* src;
dt_int32* dst;
uint32_t A, B, C;
const wtype INIT;
};
} }