#include "ceed-occa-qfunctioncontext.hpp"
#include <cstring>
namespace ceed {
namespace occa {
QFunctionContext::QFunctionContext() : ctxSize(0), hostBuffer(NULL), currentHostBuffer(NULL), syncState(SyncState::none) {}
QFunctionContext::~QFunctionContext() {
memory.free();
freeHostCtxBuffer();
}
QFunctionContext *QFunctionContext::getQFunctionContext(CeedQFunctionContext ctx, const bool assertValid) {
if (!ctx) {
return NULL;
}
int ierr;
QFunctionContext *ctx_ = NULL;
ierr = CeedQFunctionContextGetBackendData(ctx, &ctx_);
if (assertValid) {
CeedOccaFromChk(ierr);
}
return ctx_;
}
QFunctionContext *QFunctionContext::from(CeedQFunctionContext ctx) {
QFunctionContext *ctx_ = getQFunctionContext(ctx);
if (!ctx_) {
return NULL;
}
CeedCallOcca(CeedQFunctionContextGetContextSize(ctx, &ctx_->ctxSize));
if (ctx_ != NULL) {
CeedCallOcca(CeedQFunctionContextGetCeed(ctx, &ctx_->ceed));
}
return ctx_;
}
void QFunctionContext::resizeCtx(const size_t ctxSize_) { ctxSize = ctxSize_; }
void QFunctionContext::resizeCtxMemory(const size_t ctxSize_) { resizeCtxMemory(getDevice(), ctxSize_); }
void QFunctionContext::resizeCtxMemory(::occa::device device, const size_t ctxSize_) {
if (ctxSize_ != memory.size()) {
memory.free();
memory = device.malloc(ctxSize_);
}
}
void QFunctionContext::resizeHostCtxBuffer(const size_t ctxSize_) {
CeedFree(&hostBuffer);
CeedMallocArray(1, ctxSize, &hostBuffer);
}
void QFunctionContext::setCurrentCtxMemoryIfNeeded() {
if (!currentMemory.isInitialized()) {
resizeCtxMemory(ctxSize);
currentMemory = memory;
}
}
void QFunctionContext::setCurrentHostCtxBufferIfNeeded() {
if (!currentHostBuffer) {
resizeHostCtxBuffer(ctxSize);
currentHostBuffer = hostBuffer;
}
}
void QFunctionContext::freeHostCtxBuffer() {
if (hostBuffer) {
CeedFree(&hostBuffer);
}
}
int QFunctionContext::hasValidData(bool *has_valid_data) const {
(*has_valid_data) = (!!hostBuffer) || (!!currentHostBuffer) || (memory.isInitialized()) || (currentMemory.isInitialized());
return CEED_ERROR_SUCCESS;
}
int QFunctionContext::hasBorrowedDataOfType(CeedMemType mem_type, bool *has_borrowed_data_of_type) const {
switch (mem_type) {
case CEED_MEM_HOST:
(*has_borrowed_data_of_type) = !!currentHostBuffer;
break;
case CEED_MEM_DEVICE:
(*has_borrowed_data_of_type) = currentMemory.isInitialized();
break;
}
return CEED_ERROR_SUCCESS;
}
int QFunctionContext::setData(CeedMemType mtype, CeedCopyMode cmode, void *data) {
switch (cmode) {
case CEED_COPY_VALUES:
return copyDataValues(mtype, data);
case CEED_OWN_POINTER:
return ownDataPointer(mtype, data);
case CEED_USE_POINTER:
return useDataPointer(mtype, data);
}
return ceedError("Invalid CeedCopyMode passed");
}
int QFunctionContext::copyDataValues(CeedMemType mtype, void *data) {
switch (mtype) {
case CEED_MEM_HOST:
setCurrentHostCtxBufferIfNeeded();
std::memcpy(currentHostBuffer, data, ctxSize);
syncState = SyncState::host;
return CEED_ERROR_SUCCESS;
case CEED_MEM_DEVICE:
setCurrentCtxMemoryIfNeeded();
currentMemory.copyFrom(dataToMemory(data));
syncState = SyncState::device;
return CEED_ERROR_SUCCESS;
}
return ceedError("Invalid CeedMemType passed");
}
int QFunctionContext::ownDataPointer(CeedMemType mtype, void *data) {
switch (mtype) {
case CEED_MEM_HOST:
freeHostCtxBuffer();
hostBuffer = currentHostBuffer = data;
syncState = SyncState::host;
return CEED_ERROR_SUCCESS;
case CEED_MEM_DEVICE:
memory.free();
memory = currentMemory = dataToMemory(data);
syncState = SyncState::device;
return CEED_ERROR_SUCCESS;
}
return ceedError("Invalid CeedMemType passed");
}
int QFunctionContext::useDataPointer(CeedMemType mtype, void *data) {
switch (mtype) {
case CEED_MEM_HOST:
freeHostCtxBuffer();
currentHostBuffer = data;
syncState = SyncState::host;
return CEED_ERROR_SUCCESS;
case CEED_MEM_DEVICE:
memory.free();
currentMemory = dataToMemory(data);
syncState = SyncState::device;
return CEED_ERROR_SUCCESS;
}
return ceedError("Invalid CeedMemType passed");
}
int QFunctionContext::takeData(CeedMemType mtype, void *data) {
if (currentHostBuffer == NULL && currentMemory == ::occa::null) return ceedError("No context data set");
switch (mtype) {
case CEED_MEM_HOST:
setCurrentHostCtxBufferIfNeeded();
if (syncState == SyncState::device) {
setCurrentCtxMemoryIfNeeded();
currentMemory.copyTo(currentHostBuffer);
}
syncState = SyncState::host;
*(void **)data = currentHostBuffer;
hostBuffer = NULL;
currentHostBuffer = NULL;
return CEED_ERROR_SUCCESS;
case CEED_MEM_DEVICE:
setCurrentCtxMemoryIfNeeded();
if (syncState == SyncState::host) {
setCurrentHostCtxBufferIfNeeded();
currentMemory.copyFrom(currentHostBuffer);
}
syncState = SyncState::device;
*(void **)data = memoryToData(currentMemory);
memory = ::occa::null;
currentMemory = ::occa::null;
return CEED_ERROR_SUCCESS;
}
return ceedError("Invalid CeedMemType passed");
}
int QFunctionContext::getData(CeedMemType mtype, void *data) {
if (currentHostBuffer == NULL && currentMemory == ::occa::null) return ceedError("No context data set");
switch (mtype) {
case CEED_MEM_HOST:
setCurrentHostCtxBufferIfNeeded();
if (syncState == SyncState::device) {
setCurrentCtxMemoryIfNeeded();
currentMemory.copyTo(currentHostBuffer);
}
syncState = SyncState::host;
*(void **)data = currentHostBuffer;
return CEED_ERROR_SUCCESS;
case CEED_MEM_DEVICE:
setCurrentCtxMemoryIfNeeded();
if (syncState == SyncState::host) {
setCurrentHostCtxBufferIfNeeded();
currentMemory.copyFrom(currentHostBuffer);
}
syncState = SyncState::device;
*(void **)data = memoryToData(currentMemory);
return CEED_ERROR_SUCCESS;
}
return ceedError("Invalid CeedMemType passed");
}
int QFunctionContext::restoreData() { return CEED_ERROR_SUCCESS; }
::occa::memory QFunctionContext::getKernelArg() {
setCurrentCtxMemoryIfNeeded();
if (syncState == SyncState::host) {
setCurrentHostCtxBufferIfNeeded();
currentMemory.copyFrom(currentHostBuffer);
}
syncState = SyncState::device;
return currentMemory;
}
int QFunctionContext::registerCeedFunction(Ceed ceed, CeedQFunctionContext ctx, const char *fname, ceed::occa::ceedFunction f) {
return CeedSetBackendFunction(ceed, "QFunctionContext", ctx, fname, f);
}
int QFunctionContext::ceedCreate(CeedQFunctionContext ctx) {
Ceed ceed;
CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed));
CeedOccaRegisterFunction(ctx, "HasValidData", QFunctionContext::ceedHasValidData);
CeedOccaRegisterFunction(ctx, "HasBorrowedDataOfType", QFunctionContext::ceedHasBorrowedDataOfType);
CeedOccaRegisterFunction(ctx, "SetData", QFunctionContext::ceedSetData);
CeedOccaRegisterFunction(ctx, "TakeData", QFunctionContext::ceedTakeData);
CeedOccaRegisterFunction(ctx, "GetData", QFunctionContext::ceedGetData);
CeedOccaRegisterFunction(ctx, "GetDataRead", QFunctionContext::ceedGetDataRead);
CeedOccaRegisterFunction(ctx, "RestoreData", QFunctionContext::ceedRestoreData);
CeedOccaRegisterFunction(ctx, "Destroy", QFunctionContext::ceedDestroy);
QFunctionContext *ctx_ = new QFunctionContext();
CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, ctx_));
return CEED_ERROR_SUCCESS;
}
int QFunctionContext::ceedHasValidData(const CeedQFunctionContext ctx, bool *has_valid_data) {
QFunctionContext *ctx_ = QFunctionContext::from(ctx);
if (!ctx_) {
return staticCeedError("Invalid CeedQFunctionContext passed");
}
return ctx_->hasValidData(has_valid_data);
}
int QFunctionContext::ceedHasBorrowedDataOfType(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *has_borrowed_data_of_type) {
QFunctionContext *ctx_ = QFunctionContext::from(ctx);
if (!ctx_) {
return staticCeedError("Invalid CeedQFunctionContext passed");
}
return ctx_->hasBorrowedDataOfType(mem_type, has_borrowed_data_of_type);
}
int QFunctionContext::ceedSetData(CeedQFunctionContext ctx, CeedMemType mtype, CeedCopyMode cmode, void *data) {
QFunctionContext *ctx_ = QFunctionContext::from(ctx);
if (!ctx_) {
return staticCeedError("Invalid CeedQFunctionContext passed");
}
return ctx_->setData(mtype, cmode, data);
}
int QFunctionContext::ceedTakeData(CeedQFunctionContext ctx, CeedMemType mtype, void *data) {
QFunctionContext *ctx_ = QFunctionContext::from(ctx);
if (!ctx_) {
return staticCeedError("Invalid CeedQFunctionContext passed");
}
return ctx_->takeData(mtype, data);
}
int QFunctionContext::ceedGetData(CeedQFunctionContext ctx, CeedMemType mtype, void *data) {
QFunctionContext *ctx_ = QFunctionContext::from(ctx);
if (!ctx_) {
return staticCeedError("Invalid CeedQFunctionContext passed");
}
return ctx_->getData(mtype, data);
}
int QFunctionContext::ceedGetDataRead(CeedQFunctionContext ctx, CeedMemType mtype, void *data) {
QFunctionContext *ctx_ = QFunctionContext::from(ctx);
if (!ctx_) {
return staticCeedError("Invalid CeedQFunctionContext passed");
}
return ctx_->getData(mtype, data);
}
int QFunctionContext::ceedRestoreData(CeedQFunctionContext ctx) {
QFunctionContext *ctx_ = QFunctionContext::from(ctx);
if (!ctx_) {
return staticCeedError("Invalid CeedQFunctionContext passed");
}
return ctx_->restoreData();
}
int QFunctionContext::ceedDestroy(CeedQFunctionContext ctx) {
delete getQFunctionContext(ctx, false);
return CEED_ERROR_SUCCESS;
}
} }