megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/cuda/cond_take/kern.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include <cuda_runtime_api.h>
#include "megdnn/dtype.h"
#include "src/common/cond_take/predicate.cuh"

namespace megdnn {
namespace cuda {
namespace cond_take {

typedef dt_int32 IdxType;

/*!
 * \brief generate indices to take according to mask
 * \param dest_idx output index, must be size+1 long
 * \param size number of elements in mask
 * \return output size; i.e. number of elements taken
 */
template <typename T>
size_t gen_idx(
        void* workspace, size_t workspace_size, IdxType* dest_idx, const T* mask,
        size_t size, uint32_t mode, const megdnn::cond_take::KParam& kparam,
        cudaStream_t stream);

//! get workspace size in bytes for gen_idx()
size_t gen_idx_get_workspace_size(size_t size);

/*!
 * \brief copy to final output
 * \param[out] dest_data data output, size is returned by gen_idx()
 * \param[out] dest_idx index output, size is returned by gen_idx()
 * \param src_data data input
 * \param src_idx index input, must have been filled by gen_idx()
 * \param size size of original mask
 */
template <typename T>
void copy_output(
        T* dest_data, IdxType* dest_idx, const T* src_data, IdxType* src_idx,
        uint32_t size, cudaStream_t stream);

}  // namespace cond_take
}  // namespace cuda
}  // namespace megdnn

// vim: ft=cpp syntax=cpp.doxygen