/**
* \file dnn/src/cuda/relayout/param_visitor.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/basic_types.h"
#include "src/cuda/int_fastdiv.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/utils.cuh"
#pragma once
namespace megdnn {
namespace cuda {
#define devfunc __device__ __forceinline__
/*!
* \brief contiguous type
* If the layout is contiguous, then the type is CONTIG_FULL, CONTIG_OTHER
* otherwise.
*/
enum ContigType { CONTIG_OTHER, CONTIG_FULL };
/* f{{{ ParamElemVisitor specialization */
/*!
* \brief visitor to access an element in a tensor at given logic index
* \tparam ctype plain element ctype (i.e. ctype in DTypeTrait)
* \tparam contig_mask bit mask for contig of params;
*
* host interface:
* void host_init(
* const TensorND &tensor, int grid_size, int block_size)
*
* device interface:
* void thread_init(uint32_t idx)
* called on thread entrance, with logical indexing; the index
y
* go beyond buffer range
*
* ctype* ptr()
* return buffer pointer; can be used by specialized OpCaller
*
* int offset(uint32_t idx)
* get physical offset from logical index
*
* ctype& at(uint32_t idx)
* ptr()[offset(idx)]
*
*/
template <int ndim, typename ctype, ContigType contig_type>
class ParamElemVisitor;
#define PARAM_ELEM_VISITOR_COMMON_DEV \
devfunc ctype* ptr() { return m_ptr; } \
devfunc ctype& at(uint32_t idx) { return m_ptr[offset(idx)]; }
//! specialization for CONTIG_OTHER
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, CONTIG_OTHER> {
ctype* __restrict m_ptr;
int m_stride[ndim];
//! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
#else
Uint32Fastdiv m_shape_highdim[ndim - 1];
#endif
public:
static const int NDIM = ndim;
void host_init(const TensorND& rv, int grid_size, int block_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {}
devfunc void next() {}
devfunc int offset(uint32_t idx) {
int offset = 0;
#pragma unroll
for (int i = ndim - 1; i >= 1; --i) {
Uint32Fastdiv& shp = m_shape_highdim[i - 1];
uint32_t idx_div = idx / shp;
offset += (idx - idx_div * shp.divisor()) * m_stride[i];
idx = idx_div;
}
offset += idx * m_stride[0];
return offset;
}
PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};
//! specialization for CONTIG_FULL
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, CONTIG_FULL> {
ctype* __restrict m_ptr;
public:
static const int NDIM = ndim;
void host_init(const TensorND& rv, int grid_size, int block_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {}
devfunc void next() {}
devfunc int offset(uint32_t idx) { return idx; }
PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};
#undef PARAM_ELEM_VISITOR_COMMON_DEV
template <int ndim>
class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> {
using Storage = uint8_t;
Storage* __restrict m_ptr;
int m_stride[ndim];
int m_shape[ndim];
bool m_is_contiguous;
bool m_is_physical_contiguous;
bool m_is_min_stride_2;
//! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
Uint32Fastdiv m_align_shape_highdim[ndim > 1 ? ndim - 1 : 1];
#else
Uint32Fastdiv m_shape_highdim[ndim];
Uint32Fastdiv m_align_shape_highdim[ndim];
#endif
public:
static const Storage kMask = 0xf;
static const Storage kBits = 4;
static const int NDIM = ndim;
void host_init(const TensorND& rv, int grid_size, int block_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {}
devfunc void next() {}
devfunc void get_shape_from_access(uint32_t access_idx, int (&shape_idx)[ndim]) {
#pragma unroll
for (int i = ndim - 1; i >= 1; --i) {
Uint32Fastdiv& align_shp = m_align_shape_highdim[i - 1];
uint32_t access_idx_div = access_idx / align_shp;
shape_idx[i] = access_idx - access_idx_div * align_shp.divisor();
access_idx = access_idx_div;
}
shape_idx[0] = access_idx;
}
devfunc int offset(uint32_t idx) {
int offset = 0;
#pragma unroll
for (int i = ndim - 1; i >= 1; --i) {
Uint32Fastdiv& shp = m_shape_highdim[i - 1];
uint32_t idx_div = idx / shp;
offset += (idx - idx_div * shp.divisor()) * m_stride[i];
idx = idx_div;
}
offset += idx * m_stride[0];
return offset;
}
devfunc int offset_from_access(uint32_t access_idx) {
int offset = 0;
if (m_is_contiguous) {
offset = access_idx;
} else {
int shape_idx[ndim];
get_shape_from_access(access_idx, shape_idx);
#pragma unroll
for (int i = ndim - 1; i >= 0; --i) {
offset += shape_idx[i] * m_stride[i];
}
}
return offset;
}
devfunc int idx(uint32_t access_idx) {
int idx = 0;
if (m_is_physical_contiguous) {
idx = access_idx;
} else if (!m_is_min_stride_2) {
int shape_idx[ndim];
bool valid = true;
get_shape_from_access(access_idx, shape_idx);
#pragma unroll
for (int i = 0; i < ndim; ++i) {
valid &= (shape_idx[i] < m_shape[i]);
}
for (int i = 0; i < ndim - 1; ++i) {
idx = (idx + shape_idx[i]) * m_shape[i + 1];
}
idx = valid ? idx + shape_idx[ndim - 1] : -1;
} else { // min_stride == 2
idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1;
}
return idx;
}
devfunc Storage* ptr() { return m_ptr; }
devfunc Storage at(uint32_t idx) {
int offset_ = offset(idx);
int vec_idx = offset_ >> 1;
int lane_idx = offset_ & 0x1;
Storage item = Storage(integer_subbyte::unpack_integer_4bits<false>(
*(Storage*)&m_ptr[vec_idx], lane_idx * 4));
return item;
}
using rwtype = typename elemwise_intl::VectTypeTrait<dt_quint4>::vect_type;
devfunc rwtype make_vector(Storage x, Storage y) {
return elemwise_intl::VectTypeTrait<dt_quint4>::make_vector(x, y);
}
#endif
};
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen