#include "ceed-occa-simplex-basis.hpp"
#include "ceed-occa-kernels.hpp"
namespace ceed {
namespace occa {
SimplexBasis::SimplexBasis(CeedBasis basis, CeedInt dim_, CeedInt P_, CeedInt Q_, const CeedScalar *interp_, const CeedScalar *grad_,
const CeedScalar *qWeight_) {
setCeedFields(basis);
dim = dim_;
P = P_;
Q = Q_;
::occa::device device = getDevice();
interp = device.malloc<CeedScalar>(P * Q, interp_);
grad = device.malloc<CeedScalar>(P * Q * dim, grad_);
qWeight = device.malloc<CeedScalar>(Q, qWeight_);
setKernelProperties();
}
SimplexBasis::~SimplexBasis() {}
bool SimplexBasis::isTensorBasis() const { return false; }
const char *SimplexBasis::getFunctionSource() const {
return occa_simplex_basis_cpu_function_source;
}
void SimplexBasis::setKernelProperties() {
kernelProperties["defines/CeedInt"] = ::occa::dtype::get<CeedInt>().name();
kernelProperties["defines/CeedScalar"] = ::occa::dtype::get<CeedScalar>().name();
kernelProperties["defines/DIM"] = dim;
kernelProperties["defines/Q"] = Q;
kernelProperties["defines/P"] = P;
kernelProperties["defines/MAX_PQ"] = P > Q ? P : Q;
kernelProperties["defines/BASIS_COMPONENT_COUNT"] = ceedComponentCount;
if (usingGpuDevice()) {
kernelProperties["defines/ELEMENTS_PER_BLOCK"] = (Q <= 1024) ? (1024 / Q) : 1;
}
}
::occa::kernel SimplexBasis::buildKernel(const std::string &kernelName) {
std::string kernelSource;
if (usingGpuDevice()) {
kernelSource = occa_simplex_basis_gpu_source;
} else {
kernelSource = occa_simplex_basis_cpu_function_source;
kernelSource += '\n';
kernelSource += occa_simplex_basis_cpu_kernel_source;
}
return getDevice().buildKernelFromString(kernelSource, kernelName, kernelProperties);
}
int SimplexBasis::applyInterp(const CeedInt elementCount, const bool transpose, Vector &U, Vector &V) {
if (transpose) {
if (!interpTKernel.isInitialized()) {
kernelProperties["defines/TRANSPOSE"] = transpose;
interpTKernel = buildKernel("interp");
}
interpTKernel(elementCount, interp, U.getConstKernelArg(), V.getKernelArg());
} else {
if (!interpKernel.isInitialized()) {
kernelProperties["defines/TRANSPOSE"] = transpose;
interpKernel = buildKernel("interp");
}
interpKernel(elementCount, interp, U.getConstKernelArg(), V.getKernelArg());
}
return CEED_ERROR_SUCCESS;
}
int SimplexBasis::applyGrad(const CeedInt elementCount, const bool transpose, Vector &U, Vector &V) {
if (transpose) {
if (!gradTKernel.isInitialized()) {
kernelProperties["defines/TRANSPOSE"] = transpose;
gradTKernel = buildKernel("grad");
}
gradTKernel(elementCount, grad, U.getConstKernelArg(), V.getKernelArg());
} else {
if (!gradKernel.isInitialized()) {
kernelProperties["defines/TRANSPOSE"] = transpose;
gradKernel = buildKernel("grad");
}
gradKernel(elementCount, grad, U.getConstKernelArg(), V.getKernelArg());
}
return CEED_ERROR_SUCCESS;
}
int SimplexBasis::applyWeight(const CeedInt elementCount, Vector &W) {
if (!weightKernel.isInitialized()) {
weightKernel = buildKernel("weight");
}
weightKernel(elementCount, qWeight, W.getKernelArg());
return CEED_ERROR_SUCCESS;
}
int SimplexBasis::apply(const CeedInt elementCount, CeedTransposeMode tmode, CeedEvalMode emode, Vector *U, Vector *V) {
const bool transpose = tmode == CEED_TRANSPOSE;
if ((dim < 1) || (3 < dim)) {
return ceedError("Backend only supports dimensions: 1, 2, and 3");
}
if (emode != CEED_EVAL_WEIGHT) {
if (!U) {
return ceedError("Incorrect CeedVector input: U");
}
}
if (!V) {
return ceedError("Incorrect CeedVector input: V");
}
try {
switch (emode) {
case CEED_EVAL_INTERP:
return applyInterp(elementCount, transpose, *U, *V);
case CEED_EVAL_GRAD:
return applyGrad(elementCount, transpose, *U, *V);
case CEED_EVAL_WEIGHT:
return applyWeight(elementCount, *V);
default:
return ceedError("Backend does not support given simplex eval mode");
}
} catch (::occa::exception &exc) {
CeedHandleOccaException(exc);
}
return CEED_ERROR_SUCCESS;
}
int SimplexBasis::ceedCreate(CeedElemTopology topology, CeedInt dim, CeedInt ndof, CeedInt nquad, const CeedScalar *interp, const CeedScalar *grad,
const CeedScalar *qref, const CeedScalar *qWeight, CeedBasis basis) {
Ceed ceed;
CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
SimplexBasis *basis_ = new SimplexBasis(basis, dim, ndof, nquad, interp, grad, qWeight);
CeedCallBackend(CeedBasisSetData(basis, basis_));
CeedOccaRegisterFunction(basis, "Apply", Basis::ceedApply);
CeedOccaRegisterFunction(basis, "Destroy", Basis::ceedDestroy);
return CEED_ERROR_SUCCESS;
}
} }