/**
* \file dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/arch.h"
#include "src/rocm/int_fastdiv.h.hip"
#include "src/rocm/error_info.h.hip"
namespace megdnn {
namespace rocm {
namespace indexing_multi_axis_vec {
//! AxisIndexer equiv in kernel
template <int idx_ndim>
struct KAxisIndexer {
int stride[idx_ndim];
#ifdef WIN32
Uint32Fastdiv shape[idx_ndim];
#else
Uint32Fastdiv shape[idx_ndim - 1];
#endif
const int *ptr;
};
//! param for gen_offset_base
template<int nidx, int idx_ndim>
struct GenOffsetBaseParam {
uint32_t size; //!< number of outputs; also size of each index
int *output; //!< output ptr
KAxisIndexer<idx_ndim> indexer[nidx];
uint32_t data_shape[nidx];
int data_stride[nidx];
void* error_tracker;
megcore::AsyncErrorInfo* error_info;
};
//! tensor layout for fast offset computing
template<int ndim>
struct FastLayout {
int stride[ndim];
#ifdef WIN32
Uint32Fastdiv shape[ndim];
#else
Uint32Fastdiv shape[ndim - 1];
#endif
};
//! param for apply_opr
template<typename ctype, int ndim>
struct ApplyOprParam {
uint32_t tot_size; //!< total output size
//! offset array generated by gen_offset_base for first output axis
const int *offset_base;
ctype *data, *value;
int idx_axis;
int idx_axis_end;
int idx_nelems;
int value_stride;
//! iterate on value, with strides from corresponding axes on data
FastLayout<ndim> value_ly_on_data;
};
//! generate offset bases for first axis in the output
template<int nidx, int idx_ndim>
void gen_offset_base(const GenOffsetBaseParam<nidx, idx_ndim> ¶m,
hipStream_t stream);
struct OprAtomicIncr {
#if MEGDNN_CC_CUDA
template<typename ctype>
__device__ static void apply(ctype &data, ctype value) {
atomicAdd(&data, value);
}
#endif
};
/*!
* \brief forward kernel: copy data to value
* \tparam ndim numer of axes except axis_0 in data,
* range from 0 to max_ndim - 1
*/
template<typename ctype, int ndim, class Opr>
void apply_opr(const ApplyOprParam<ctype, ndim> ¶m,
hipStream_t stream);
} // namespace indexing_multi_axis_vec
} // namespace rocm
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen