#include "ceed-occa-operator.hpp"
#include "ceed-occa-basis.hpp"
#include "ceed-occa-cpu-operator.hpp"
#include "ceed-occa-elem-restriction.hpp"
#include "ceed-occa-gpu-operator.hpp"
#include "ceed-occa-qfunction.hpp"
namespace ceed {
namespace occa {
Operator::Operator() : ceedQ(0), ceedElementCount(0), qfunction(NULL), needsInitialSetup(true) {}
Operator::~Operator() {}
Operator *Operator::getOperator(CeedOperator op, const bool assertValid) {
if (!op) {
return NULL;
}
int ierr;
Operator *operator_ = NULL;
ierr = CeedOperatorGetData(op, (void **)&operator_);
if (assertValid) {
CeedOccaFromChk(ierr);
}
return operator_;
}
Operator *Operator::from(CeedOperator op) {
Operator *operator_ = getOperator(op);
if (!operator_) {
return NULL;
}
CeedCallOcca(CeedOperatorGetCeed(op, &operator_->ceed));
operator_->qfunction = QFunction::from(op);
if (!operator_->qfunction) {
return NULL;
}
CeedCallOcca(CeedOperatorGetNumQuadraturePoints(op, &operator_->ceedQ));
CeedCallOcca(CeedOperatorGetNumElements(op, &operator_->ceedElementCount));
operator_->args.setupArgs(op);
if (!operator_->args.isValid()) {
return NULL;
}
return operator_;
}
bool Operator::isApplyingIdentityFunction() { return qfunction->ceedIsIdentity; }
int Operator::applyAdd(Vector *in, Vector *out, CeedRequest *request) {
applyAddKernel = buildApplyAddKernel();
if (needsInitialSetup) {
initialSetup();
needsInitialSetup = false;
}
applyAdd(in, out);
return CEED_ERROR_SUCCESS;
}
void Operator::initialSetup() {}
int Operator::registerCeedFunction(Ceed ceed, CeedOperator op, const char *fname, ceed::occa::ceedFunction f) {
return CeedSetBackendFunction(ceed, "Operator", op, fname, f);
}
int Operator::ceedCreate(CeedOperator op) {
Ceed ceed;
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
#if 1
Operator *operator_ = new CpuOperator();
#else#endif
CeedCallBackend(CeedOperatorSetData(op, operator_));
CeedOccaRegisterFunction(op, "LinearAssembleQFunction", Operator::ceedLinearAssembleQFunction);
CeedOccaRegisterFunction(op, "LinearAssembleQFunctionUpdate", Operator::ceedLinearAssembleQFunction);
CeedOccaRegisterFunction(op, "LinearAssembleAddDiagonal", Operator::ceedLinearAssembleAddDiagonal);
CeedOccaRegisterFunction(op, "LinearAssembleAddPointBlockDiagonal", Operator::ceedLinearAssembleAddPointBlockDiagonal);
CeedOccaRegisterFunction(op, "CreateFDMElementInverse", Operator::ceedCreateFDMElementInverse);
CeedOccaRegisterFunction(op, "ApplyAdd", Operator::ceedApplyAdd);
CeedOccaRegisterFunction(op, "Destroy", Operator::ceedDestroy);
return CEED_ERROR_SUCCESS;
}
int Operator::ceedCreateComposite(CeedOperator op) {
Ceed ceed;
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedOccaRegisterFunction(op, "LinearAssembleAddDiagonal", Operator::ceedLinearAssembleAddDiagonal);
CeedOccaRegisterFunction(op, "LinearAssembleAddPointBlockDiagonal", Operator::ceedLinearAssembleAddPointBlockDiagonal);
return CEED_ERROR_SUCCESS;
}
int Operator::ceedLinearAssembleQFunction(CeedOperator op) { return staticCeedError("(OCCA) Backend does not implement LinearAssembleQFunction"); }
int Operator::ceedLinearAssembleQFunctionUpdate(CeedOperator op) {
return staticCeedError("(OCCA) Backend does not implement LinearAssembleQFunctionUpdate");
}
int Operator::ceedLinearAssembleAddDiagonal(CeedOperator op) { return staticCeedError("(OCCA) Backend does not implement LinearAssembleDiagonal"); }
int Operator::ceedLinearAssembleAddPointBlockDiagonal(CeedOperator op) {
return staticCeedError("(OCCA) Backend does not implement LinearAssemblePointBlockDiagonal");
}
int Operator::ceedCreateFDMElementInverse(CeedOperator op) { return staticCeedError("(OCCA) Backend does not implement CreateFDMElementInverse"); }
int Operator::ceedApplyAdd(CeedOperator op, CeedVector invec, CeedVector outvec, CeedRequest *request) {
Operator *operator_ = Operator::from(op);
Vector *in = Vector::from(invec);
Vector *out = Vector::from(outvec);
if (!operator_) {
return staticCeedError("Incorrect CeedOperator argument: op");
}
return operator_->applyAdd(in, out, request);
}
int Operator::ceedDestroy(CeedOperator op) {
delete getOperator(op, false);
return CEED_ERROR_SUCCESS;
}
} }