/**
* \file dnn/src/cuda/reduce_helper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#if MEGDNN_CC_CUDA
#include "./reduce_helper/largeBC.cuinl"
#include "./reduce_helper/column.cuinl"
namespace megdnn {
namespace cuda {
namespace reduce_intl {
static inline bool use_reduce_column(size_t A, size_t B, size_t C) {
return C == 1 && (B <= A * 4 || B <= 32);
}
} // namespace reduce_intl
template <class PublicOperator, bool sync_within_warp>
void run_reduce(typename PublicOperator::wtype *workspace,
size_t A, size_t B, size_t C,
cudaStream_t stream, const PublicOperator &opr)
{
using namespace reduce_intl;
if (use_reduce_column(A, B, C)) {
run_column<PublicOperator>::run(A, B, stream, opr);
} else {
run_largeBC<PublicOperator, sync_within_warp>(
workspace, A, B, C, stream, opr);
}
}
template <typename Op>
size_t get_reduce_workspace_in_bytes(size_t A, size_t B, size_t C)
{
using namespace reduce_intl;
if (use_reduce_column(A, B, C))
return 0;
return get_workspace_largeBC<typename Op::wtype>(A, B, C);
}
} // namespace cuda
} // namespace megdnn
#endif
// vim: ft=cpp syntax=cpp.doxygen