#pragma once
#include <string>
#ifdef CT2_USE_HIP
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <thrust/execution_policy.h>
#include <hipcub/hipcub.hpp>
#ifdef CT2_WITH_TENSOR_PARALLEL
#include <cuda/mpi_stub.h>
#include <rccl/rccl.h>
#endif
#define cub hipcub
#define cudaError_t hipError_t
#define cudaSuccess hipSuccess
#define cudaGetErrorString hipGetErrorString
#define cublasStatus_t hipblasStatus_t
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define cudaStream_t hipStream_t
#define cublasHandle_t hipblasHandle_t
#define cudaDeviceProp hipDeviceProp_t
#define cudaGetDevice hipGetDevice
#define cudaSetDevice hipSetDevice
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaStreamSynchronize hipStreamSynchronize
#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <thrust/execution_policy.h>
#ifdef CT2_WITH_TENSOR_PARALLEL
# include <cuda/mpi_stub.h>
# include <nccl.h>
#endif
#ifdef CT2_WITH_CUDNN
# include <cudnn.h>
#endif
#endif
#include "ctranslate2/types.h"
#include "ctranslate2/utils.h"
namespace ctranslate2 {
namespace cuda {
#ifdef CT2_WITH_TENSOR_PARALLEL
#define MPI_CHECK(ans) \
{ \
int e = ans; \
if( e != MPI_SUCCESS ) \
THROW_RUNTIME_ERROR("MPI failed with error " \
+ std::to_string(e)); \
}
#define NCCL_CHECK(ans) \
{ \
ncclResult_t r = ans; \
if( r != ncclSuccess ) \
THROW_RUNTIME_ERROR("NCCL failed with error " \
+ std::to_string(r)); \
}
#endif
#define CUDA_CHECK(ans) \
{ \
cudaError_t code = (ans); \
if (code != cudaSuccess) \
THROW_RUNTIME_ERROR("CUDA failed with error " \
+ std::string(cudaGetErrorString(code))); \
}
#define CUBLAS_CHECK(ans) \
{ \
cublasStatus_t status = (ans); \
if (status != CUBLAS_STATUS_SUCCESS) \
THROW_RUNTIME_ERROR("cuBLAS failed with status " \
+ std::string(ctranslate2::cuda::cublasGetStatusName(status))); \
}
#define CUDNN_CHECK(ans) \
{ \
cudnnStatus_t status = (ans); \
if (status != CUDNN_STATUS_SUCCESS) \
THROW_RUNTIME_ERROR("cuDNN failed with status " \
+ std::string(cudnnGetErrorString(status))); \
}
#define TENSOR_CHECK(ans, message) \
{ \
if (!ans) \
THROW_RUNTIME_ERROR(message); \
}
const char* cublasGetStatusName(cublasStatus_t status);
cudaStream_t get_cuda_stream();
cublasHandle_t get_cublas_handle();
#ifdef CT2_WITH_CUDNN
cudnnHandle_t get_cudnn_handle();
cudnnDataType_t get_cudnn_data_type(DataType dtype);
#endif
int get_gpu_count();
bool has_gpu();
const cudaDeviceProp& get_device_properties(int device = -1);
bool gpu_supports_int8(int device = -1);
bool gpu_has_int8_tensor_cores(int device = -1);
bool gpu_has_fp16_tensor_cores(int device = -1);
bool have_same_compute_capability(const std::vector<int>& devices);
bool use_true_fp16_gemm();
void use_true_fp16_gemm(bool use);
class UseTrueFp16GemmInScope {
public:
UseTrueFp16GemmInScope(const bool use)
: _previous_value(use_true_fp16_gemm())
, _scope_value(use)
{
use_true_fp16_gemm(_scope_value);
}
~UseTrueFp16GemmInScope() {
use_true_fp16_gemm(_previous_value);
}
private:
const bool _previous_value;
const bool _scope_value;
};
#ifdef CT2_USE_HIP
#define THRUST_CALL(FUN, ...) FUN(thrust::hip::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__)
#else
#define THRUST_CALL(FUN, ...) FUN(thrust::cuda::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__)
#endif
}
}