#warning "libCEED OCCA backend is experimental; for best performance, use device native backends"
#include <map>
#include <occa.hpp>
#include <vector>
#include "ceed-occa-context.hpp"
#include "ceed-occa-elem-restriction.hpp"
#include "ceed-occa-operator.hpp"
#include "ceed-occa-qfunction.hpp"
#include "ceed-occa-qfunctioncontext.hpp"
#include "ceed-occa-simplex-basis.hpp"
#include "ceed-occa-tensor-basis.hpp"
#include "ceed-occa-types.hpp"
#include "ceed-occa-vector.hpp"
namespace ceed {
namespace occa {
typedef std::map<std::string, std::string> StringMap;
typedef std::vector<std::string> StringVector;
enum ResourceParserStep { RESOURCE, QUERY_KEY, QUERY_VALUE };
static const char RESOURCE_DELIMITER = '/';
static const char QUERY_DELIMITER = ':';
static const char QUERY_KEY_VALUE_DELIMITER = '=';
static const char QUERY_ARG_DELIMITER = ',';
static std::string getDefaultDeviceMode(const bool cpuMode, const bool gpuMode) {
if (gpuMode) {
if (::occa::modeIsEnabled("CUDA")) {
return "CUDA";
}
if (::occa::modeIsEnabled("HIP")) {
return "HIP";
}
if (::occa::modeIsEnabled("dpcpp")) {
return "dpcpp";
}
if (::occa::modeIsEnabled("OpenCL")) {
return "OpenCL";
}
}
if (cpuMode) {
if (::occa::modeIsEnabled("OpenMP")) {
return "OpenMP";
}
return "Serial";
}
return "";
}
static int getDeviceMode(const std::string &match, std::string &mode) {
if (match == "cuda") {
mode = "CUDA";
return CEED_ERROR_SUCCESS;
}
if (match == "hip") {
mode = "HIP";
return CEED_ERROR_SUCCESS;
}
if (match == "dpcpp") {
mode = "dpcpp";
return CEED_ERROR_SUCCESS;
}
if (match == "opencl") {
mode = "OpenCL";
return CEED_ERROR_SUCCESS;
}
if (match == "openmp") {
mode = "OpenMP";
return CEED_ERROR_SUCCESS;
}
if (match == "serial") {
mode = "Serial";
return CEED_ERROR_SUCCESS;
}
const bool autoMode = match == "*";
const bool cpuMode = match == "cpu";
const bool gpuMode = match == "gpu";
mode = getDefaultDeviceMode(cpuMode || autoMode, gpuMode || autoMode);
return !mode.size();
}
static int splitCeedResource(const std::string &resource, std::string &match, StringMap &query) {
const int charCount = (int)resource.size();
const char *c_resource = resource.c_str();
StringVector resourceVector;
ResourceParserStep parsingStep = RESOURCE;
int wordStart = 1;
std::string queryKey;
if (resource == "/gpu/cuda/occa") {
match = "cuda";
return CEED_ERROR_SUCCESS;
}
if (resource == "/gpu/hip/occa") {
match = "hip";
return CEED_ERROR_SUCCESS;
}
if (resource == "/gpu/dpcpp/occa") {
match = "dpcpp";
return CEED_ERROR_SUCCESS;
}
if (resource == "/gpu/opencl/occa") {
match = "opencl";
return CEED_ERROR_SUCCESS;
}
if (resource == "/cpu/openmp/occa") {
match = "openmp";
return CEED_ERROR_SUCCESS;
}
if (resource == "/cpu/self/occa") {
match = "serial";
return CEED_ERROR_SUCCESS;
}
for (int i = 1; i <= charCount; ++i) {
const char c = c_resource[i];
if (parsingStep == RESOURCE) {
if (c == RESOURCE_DELIMITER || c == QUERY_DELIMITER || c == '\0') {
resourceVector.push_back(resource.substr(wordStart, i - wordStart));
wordStart = i + 1;
if (c == QUERY_DELIMITER) {
parsingStep = QUERY_KEY;
}
}
} else if (parsingStep == QUERY_KEY) {
if (c == QUERY_KEY_VALUE_DELIMITER) {
queryKey = resource.substr(wordStart, i - wordStart);
wordStart = i + 1;
parsingStep = QUERY_VALUE;
}
} else if (parsingStep == QUERY_VALUE) {
if (c == QUERY_ARG_DELIMITER || c == '\0') {
query[queryKey] = resource.substr(wordStart, i - wordStart);
wordStart = i + 1;
parsingStep = QUERY_KEY;
queryKey = "";
}
}
}
if (resourceVector.size() != 2 || resourceVector[1] != "occa") {
return 1;
}
match = resourceVector[0];
return CEED_ERROR_SUCCESS;
}
void setDefaultProps(::occa::properties &deviceProps, const std::string &defaultMode) {
std::string mode;
if (deviceProps.has("mode")) {
mode = (std::string)deviceProps["mode"];
} else {
mode = defaultMode;
deviceProps.set("mode", mode);
}
if ((mode == "CUDA") || (mode == "HIP") || (mode == "dpcpp") || (mode == "OpenCL")) {
if (!deviceProps.has("device_id")) {
deviceProps["device_id"] = 0;
}
}
if ((mode == "dpcpp") || (mode == "OpenCL")) {
if (!deviceProps.has("platform_id")) {
deviceProps["platform_id"] = 0;
}
}
}
static int initCeed(const char *c_resource, Ceed ceed) {
int ierr;
std::string match;
StringMap query;
ierr = splitCeedResource(c_resource, match, query);
if (ierr) {
return CeedError(ceed, CEED_ERROR_BACKEND, "(OCCA) Backend cannot use resource: %s", c_resource);
}
std::string mode;
ierr = getDeviceMode(match, mode);
if (ierr) {
return CeedError(ceed, CEED_ERROR_BACKEND, "(OCCA) Backend cannot use resource: %s", c_resource);
}
std::string devicePropsStr = "{\n";
StringMap::const_iterator it;
for (it = query.begin(); it != query.end(); ++it) {
devicePropsStr += " \"";
devicePropsStr += it->first;
devicePropsStr += "\": ";
devicePropsStr += it->second;
devicePropsStr += ",\n";
}
devicePropsStr += '}';
::occa::properties deviceProps(devicePropsStr);
setDefaultProps(deviceProps, mode);
ceed::occa::Context *context = new Context(::occa::device(deviceProps));
CeedCallBackend(CeedSetData(ceed, context));
return CEED_ERROR_SUCCESS;
}
static int destroyCeed(Ceed ceed) {
delete Context::from(ceed);
return CEED_ERROR_SUCCESS;
}
static int registerCeedFunction(Ceed ceed, const char *fname, ceed::occa::ceedFunction f) {
return CeedSetBackendFunction(ceed, "Ceed", ceed, fname, f);
}
static int preferHostMemType(CeedMemType *type) {
*type = CEED_MEM_HOST;
return CEED_ERROR_SUCCESS;
}
static int preferDeviceMemType(CeedMemType *type) {
*type = CEED_MEM_DEVICE;
return CEED_ERROR_SUCCESS;
}
static ceed::occa::ceedFunction getPreferredMemType(Ceed ceed) {
if (Context::from(ceed)->device.hasSeparateMemorySpace()) {
return (ceed::occa::ceedFunction)(void *)preferDeviceMemType;
}
return (ceed::occa::ceedFunction)(void *)preferHostMemType;
}
static int registerMethods(Ceed ceed) {
CeedOccaRegisterBaseFunction("Destroy", ceed::occa::destroyCeed);
CeedOccaRegisterBaseFunction("GetPreferredMemType", getPreferredMemType(ceed));
CeedOccaRegisterBaseFunction("VectorCreate", ceed::occa::Vector::ceedCreate);
CeedOccaRegisterBaseFunction("BasisCreateTensorH1", ceed::occa::TensorBasis::ceedCreate);
CeedOccaRegisterBaseFunction("BasisCreateH1", ceed::occa::SimplexBasis::ceedCreate);
CeedOccaRegisterBaseFunction("ElemRestrictionCreate", ceed::occa::ElemRestriction::ceedCreate);
CeedOccaRegisterBaseFunction("QFunctionCreate", ceed::occa::QFunction::ceedCreate);
CeedOccaRegisterBaseFunction("QFunctionContextCreate", ceed::occa::QFunctionContext::ceedCreate);
CeedOccaRegisterBaseFunction("OperatorCreate", ceed::occa::Operator::ceedCreate);
CeedOccaRegisterBaseFunction("CompositeOperatorCreate", ceed::occa::Operator::ceedCreateComposite);
return CEED_ERROR_SUCCESS;
}
static int registerBackend(const char *resource, Ceed ceed) {
try {
CeedCallBackend(ceed::occa::initCeed(resource, ceed));
} catch (const ::occa::exception &e) {
CeedHandleOccaException(e);
}
try {
CeedCallBackend(ceed::occa::registerMethods(ceed));
} catch (const ::occa::exception &e) {
CeedHandleOccaException(e);
}
return CEED_ERROR_SUCCESS;
}
} }
CEED_INTERN int CeedRegister_Occa(void) {
CeedCallBackend(CeedRegister("/*/occa", ceed::occa::registerBackend, 270));
CeedCallBackend(CeedRegister("/cpu/self/occa", ceed::occa::registerBackend, 260));
CeedCallBackend(CeedRegister("/cpu/openmp/occa", ceed::occa::registerBackend, 250));
CeedCallBackend(CeedRegister("/gpu/dpcpp/occa", ceed::occa::registerBackend, 240));
CeedCallBackend(CeedRegister("/gpu/opencl/occa", ceed::occa::registerBackend, 230));
CeedCallBackend(CeedRegister("/gpu/hip/occa", ceed::occa::registerBackend, 220));
CeedCallBackend(CeedRegister("/gpu/cuda/occa", ceed::occa::registerBackend, 210));
return CEED_ERROR_SUCCESS;
}