megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/rocm/reduce_helper.hipinl
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 */

#pragma once

#if MEGDNN_CC_CUDA

#include "./reduce_helper/column.hipinl"
#include "./reduce_helper/largeBC.hipinl"

namespace megdnn {
namespace rocm {

namespace reduce_intl {
static inline bool use_reduce_column(size_t A, size_t B, size_t C) {
    return C == 1 && (B <= A * 4 || B <= 32);
}
}  // namespace reduce_intl

template <class PublicOperator, bool sync_within_warp>
void run_reduce(typename PublicOperator::wtype* workspace, size_t A, size_t B,
                size_t C, hipStream_t stream, const PublicOperator& opr) {
    using namespace reduce_intl;
    if (use_reduce_column(A, B, C)) {
        reduce_intl::run_column<PublicOperator>::run(A, B, stream, opr);
    } else {
        reduce_intl::run_largeBC<PublicOperator, sync_within_warp>(workspace, A, B, C,
                                                      stream, opr);
    }
}

template <typename Op>
size_t get_reduce_workspace_in_bytes(size_t A, size_t B, size_t C) {
    using namespace reduce_intl;
    if (use_reduce_column(A, B, C))
        return 0;

    return get_workspace_largeBC<typename Op::wtype>(A, B, C);
}

}  // namespace rocm
}  // namespace megdnn

#endif

// vim: ft=cpp syntax=cpp.doxygen