#include "megbrain_build_config.h"
#if MGB_CUDA
#pragma once
#include <memory>
#include "NvOF.h"
#include "cuda.h"
#include "nvOpticalFlowCommon.h"
#include "nvOpticalFlowCuda.h"
#define CUDA_DRVAPI_CALL(call) \
do { \
CUresult err__ = call; \
if (err__ != CUDA_SUCCESS) { \
const char* szErrName = NULL; \
cuGetErrorName(err__, &szErrName); \
std::ostringstream errorLog; \
errorLog << "CUDA driver API error " << szErrName; \
std::cout << "Exception: " << __FILE__ << ":" << __LINE__ << ":" \
<< errorLog.str() << std::endl; \
mgb_throw(MegBrainError, "CUDA_DRVAPI_CALL ERROR"); \
} \
} while (0)
class NvOFCudaAPI : public NvOFAPI {
public:
NvOFCudaAPI(
CUcontext cuContext, CUstream inputStream = nullptr,
CUstream outputStream = nullptr);
~NvOFCudaAPI();
NV_OF_CUDA_API_FUNCTION_LIST* GetAPI() {
std::lock_guard<std::mutex> lock(m_lock);
return m_ofAPI.get();
}
CUcontext GetCudaContext() { return m_cuContext; }
NvOFHandle GetHandle() { return m_hOF; }
CUstream GetCudaStream(NV_OF_BUFFER_USAGE usage);
private:
CUstream m_inputStream;
CUstream m_outputStream;
NvOFHandle m_hOF;
std::unique_ptr<NV_OF_CUDA_API_FUNCTION_LIST> m_ofAPI;
CUcontext m_cuContext;
};
class NvOFCuda : public NvOF {
public:
static NvOFObj Create(
CUcontext cuContext, uint32_t nWidth, uint32_t nHeight,
NV_OF_BUFFER_FORMAT eInBufFmt, NV_OF_CUDA_BUFFER_TYPE eInBufType,
NV_OF_CUDA_BUFFER_TYPE eOutBufType, NV_OF_MODE eMode,
NV_OF_PERF_LEVEL preset, CUstream inputStream = nullptr,
CUstream outputStream = nullptr);
~NvOFCuda(){};
private:
NvOFCuda(
CUcontext cuContext, uint32_t nWidth, uint32_t nHeight,
NV_OF_BUFFER_FORMAT eInBufFmt, NV_OF_CUDA_BUFFER_TYPE eInBufType,
NV_OF_CUDA_BUFFER_TYPE eOutBufType, NV_OF_MODE eMode,
NV_OF_PERF_LEVEL preset, CUstream inputStream = nullptr,
CUstream outputStream = nullptr);
virtual void DoGetOutputGridSizes(uint32_t* vals, uint32_t* size) override;
virtual void DoInit(const NV_OF_INIT_PARAMS& initParams) override;
virtual void DoExecute(
const NV_OF_EXECUTE_INPUT_PARAMS& executeInParams,
NV_OF_EXECUTE_OUTPUT_PARAMS& executeOutParams) override;
virtual std::vector<NvOFBufferObj> DoAllocBuffers(
NV_OF_BUFFER_DESCRIPTOR ofBufferDesc, uint32_t elementSize,
uint32_t numBuffers) override;
std::unique_ptr<NvOFBuffer> CreateOFBufferObject(
const NV_OF_BUFFER_DESCRIPTOR& desc, uint32_t elementSize,
NV_OF_CUDA_BUFFER_TYPE bufferType);
NV_OF_CUDA_BUFFER_TYPE GetBufferType(NV_OF_BUFFER_USAGE usage);
private:
CUcontext m_cuContext;
std::shared_ptr<NvOFCudaAPI> m_NvOFAPI;
NV_OF_CUDA_BUFFER_TYPE m_eInBufType;
NV_OF_CUDA_BUFFER_TYPE m_eOutBufType;
uint32_t _QuerySupportCaps(const NV_OF_CAPS& cap);
};
class NvOFBufferCudaDevicePtr : public NvOFBuffer {
public:
~NvOFBufferCudaDevicePtr();
CUdeviceptr getCudaDevicePtr() { return m_devPtr; }
virtual void UploadData(const void* pData, CUmemorytype mem_type) override;
virtual void DownloadData(void* pData, CUmemorytype mem_type) override;
NV_OF_CUDA_BUFFER_STRIDE_INFO getStrideInfo() { return m_strideInfo; }
private:
NvOFBufferCudaDevicePtr(
std::shared_ptr<NvOFCudaAPI> ofAPI, const NV_OF_BUFFER_DESCRIPTOR& desc,
uint32_t elementSize);
CUdeviceptr m_devPtr;
CUcontext m_cuContext;
NV_OF_CUDA_BUFFER_STRIDE_INFO m_strideInfo;
std::shared_ptr<NvOFCudaAPI> m_NvOFAPI;
friend class NvOFCuda;
};
class NvOFBufferCudaArray : public NvOFBuffer {
public:
~NvOFBufferCudaArray();
virtual void UploadData(const void* pData, CUmemorytype mem_type) override;
virtual void DownloadData(void* pData, CUmemorytype mem_type) override;
CUarray getCudaArray() { return m_cuArray; }
private:
NvOFBufferCudaArray(
std::shared_ptr<NvOFCudaAPI> ofAPI, const NV_OF_BUFFER_DESCRIPTOR& desc,
uint32_t elementSize);
CUarray m_cuArray;
CUcontext m_cuContext;
NV_OF_CUDA_BUFFER_STRIDE_INFO m_strideInfo;
std::shared_ptr<NvOFCudaAPI> m_NvOFAPI;
friend class NvOFCuda;
};
#endif