#include <ceed-impl.h>
#include <ceed.h>
#include <ceed/backend.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
int CeedQFunctionContextGetFieldIndex(CeedQFunctionContext ctx, const char *field_name, CeedInt *field_index) {
*field_index = -1;
for (CeedInt i = 0; i < ctx->num_fields; i++) {
if (!strcmp(ctx->field_labels[i]->name, field_name)) *field_index = i;
}
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextRegisterGeneric(CeedQFunctionContext ctx, const char *field_name, size_t field_offset, const char *field_description,
CeedContextFieldType field_type, size_t field_size, size_t num_values) {
CeedInt field_index = -1;
CeedCall(CeedQFunctionContextGetFieldIndex(ctx, field_name, &field_index));
CeedCheck(field_index == -1, ctx->ceed, CEED_ERROR_UNSUPPORTED, "QFunctionContext field with name \"%s\" already registered", field_name);
if (ctx->num_fields == 0) {
CeedCall(CeedCalloc(1, &ctx->field_labels));
ctx->max_fields = 1;
} else if (ctx->num_fields == ctx->max_fields) {
CeedCall(CeedRealloc(2 * ctx->max_fields, &ctx->field_labels));
ctx->max_fields *= 2;
}
CeedCall(CeedCalloc(1, &ctx->field_labels[ctx->num_fields]));
CeedCall(CeedStringAllocCopy(field_name, (char **)&ctx->field_labels[ctx->num_fields]->name));
CeedCall(CeedStringAllocCopy(field_description, (char **)&ctx->field_labels[ctx->num_fields]->description));
ctx->field_labels[ctx->num_fields]->type = field_type;
ctx->field_labels[ctx->num_fields]->offset = field_offset;
ctx->field_labels[ctx->num_fields]->size = field_size * num_values;
ctx->field_labels[ctx->num_fields]->num_values = num_values;
ctx->num_fields++;
return CEED_ERROR_SUCCESS;
}
static int CeedQFunctionContextDestroyData(CeedQFunctionContext ctx) {
if (ctx->DataDestroy) {
CeedCall(ctx->DataDestroy(ctx));
} else {
CeedMemType data_destroy_mem_type;
CeedQFunctionContextDataDestroyUser data_destroy_function;
CeedCall(CeedQFunctionContextGetDataDestroy(ctx, &data_destroy_mem_type, &data_destroy_function));
if (data_destroy_function) {
void *data;
CeedCall(CeedQFunctionContextGetData(ctx, data_destroy_mem_type, &data));
CeedCall(data_destroy_function(data));
CeedCall(CeedQFunctionContextRestoreData(ctx, &data));
}
}
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetCeed(CeedQFunctionContext ctx, Ceed *ceed) {
*ceed = ctx->ceed;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextHasValidData(CeedQFunctionContext ctx, bool *has_valid_data) {
CeedCheck(ctx->HasValidData, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support HasValidData");
CeedCall(ctx->HasValidData(ctx, has_valid_data));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextHasBorrowedDataOfType(CeedQFunctionContext ctx, CeedMemType mem_type, bool *has_borrowed_data_of_type) {
CeedCheck(ctx->HasBorrowedDataOfType, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support HasBorrowedDataOfType");
CeedCall(ctx->HasBorrowedDataOfType(ctx, mem_type, has_borrowed_data_of_type));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetState(CeedQFunctionContext ctx, uint64_t *state) {
*state = ctx->state;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetBackendData(CeedQFunctionContext ctx, void *data) {
*(void **)data = ctx->data;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextSetBackendData(CeedQFunctionContext ctx, void *data) {
ctx->data = data;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetFieldLabel(CeedQFunctionContext ctx, const char *field_name, CeedContextFieldLabel *field_label) {
CeedInt field_index;
CeedCall(CeedQFunctionContextGetFieldIndex(ctx, field_name, &field_index));
if (field_index != -1) {
*field_label = ctx->field_labels[field_index];
} else {
*field_label = NULL;
}
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextSetGeneric(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, CeedContextFieldType field_type, void *values) {
bool is_different;
char *data;
CeedCheck(field_label->type == field_type, ctx->ceed, CEED_ERROR_UNSUPPORTED,
"QFunctionContext field with name \"%s\" registered as %s, not registered as %s", field_label->name,
CeedContextFieldTypes[field_label->type], CeedContextFieldTypes[field_type]);
CeedCall(CeedQFunctionContextGetDataRead(ctx, CEED_MEM_HOST, &data));
is_different = memcmp(&data[field_label->offset], values, field_label->size);
CeedCall(CeedQFunctionContextRestoreDataRead(ctx, &data));
if (is_different) {
CeedCall(CeedQFunctionContextGetData(ctx, CEED_MEM_HOST, &data));
memcpy(&data[field_label->offset], values, field_label->size);
CeedCall(CeedQFunctionContextRestoreData(ctx, &data));
}
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetGenericRead(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, CeedContextFieldType field_type,
size_t *num_values, void *values) {
char *data;
CeedCheck(field_label->type == field_type, ctx->ceed, CEED_ERROR_UNSUPPORTED,
"QFunctionContext field with name \"%s\" registered as %s, not registered as %s", field_label->name,
CeedContextFieldTypes[field_label->type], CeedContextFieldTypes[field_type]);
CeedCall(CeedQFunctionContextGetDataRead(ctx, CEED_MEM_HOST, &data));
*(void **)values = &data[field_label->offset];
switch (field_type) {
case CEED_CONTEXT_FIELD_INT32:
*num_values = field_label->size / sizeof(int);
break;
case CEED_CONTEXT_FIELD_DOUBLE:
*num_values = field_label->size / sizeof(double);
break;
}
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextRestoreGenericRead(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, CeedContextFieldType field_type,
void *values) {
CeedCheck(field_label->type == field_type, ctx->ceed, CEED_ERROR_UNSUPPORTED,
"QFunctionContext field with name \"%s\" registered as %s, not registered as %s", field_label->name,
CeedContextFieldTypes[field_label->type], CeedContextFieldTypes[field_type]);
CeedCall(CeedQFunctionContextRestoreDataRead(ctx, values));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextSetDouble(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, double *values) {
CeedCheck(field_label, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
CeedCall(CeedQFunctionContextSetGeneric(ctx, field_label, CEED_CONTEXT_FIELD_DOUBLE, values));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetDoubleRead(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, size_t *num_values, const double **values) {
CeedCheck(field_label, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
CeedCall(CeedQFunctionContextGetGenericRead(ctx, field_label, CEED_CONTEXT_FIELD_DOUBLE, num_values, values));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextRestoreDoubleRead(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, const double **values) {
CeedCheck(field_label, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
CeedCall(CeedQFunctionContextRestoreGenericRead(ctx, field_label, CEED_CONTEXT_FIELD_DOUBLE, values));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextSetInt32(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, int *values) {
CeedCheck(field_label, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
CeedCall(CeedQFunctionContextSetGeneric(ctx, field_label, CEED_CONTEXT_FIELD_INT32, values));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetInt32Read(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, size_t *num_values, const int **values) {
CeedCheck(field_label, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
CeedCall(CeedQFunctionContextGetGenericRead(ctx, field_label, CEED_CONTEXT_FIELD_INT32, num_values, values));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextRestoreInt32Read(CeedQFunctionContext ctx, CeedContextFieldLabel field_label, const int **values) {
CeedCheck(field_label, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
CeedCall(CeedQFunctionContextRestoreGenericRead(ctx, field_label, CEED_CONTEXT_FIELD_INT32, values));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetDataDestroy(CeedQFunctionContext ctx, CeedMemType *f_mem_type, CeedQFunctionContextDataDestroyUser *f) {
if (f_mem_type) *f_mem_type = ctx->data_destroy_mem_type;
if (f) *f = ctx->data_destroy_function;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextReference(CeedQFunctionContext ctx) {
ctx->ref_count++;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextCreate(Ceed ceed, CeedQFunctionContext *ctx) {
if (!ceed->QFunctionContextCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Context"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support ContextCreate");
CeedCall(CeedQFunctionContextCreate(delegate, ctx));
return CEED_ERROR_SUCCESS;
}
CeedCall(CeedCalloc(1, ctx));
CeedCall(CeedReferenceCopy(ceed, &(*ctx)->ceed));
(*ctx)->ref_count = 1;
CeedCall(ceed->QFunctionContextCreate(*ctx));
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextReferenceCopy(CeedQFunctionContext ctx, CeedQFunctionContext *ctx_copy) {
CeedCall(CeedQFunctionContextReference(ctx));
CeedCall(CeedQFunctionContextDestroy(ctx_copy));
*ctx_copy = ctx;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextSetData(CeedQFunctionContext ctx, CeedMemType mem_type, CeedCopyMode copy_mode, size_t size, void *data) {
CeedCheck(ctx->SetData, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support ContextSetData");
CeedCheck(ctx->state % 2 == 0, ctx->ceed, 1, "Cannot grant CeedQFunctionContext data access, the access lock is already in use");
CeedCall(CeedQFunctionContextDestroyData(ctx));
ctx->ctx_size = size;
CeedCall(ctx->SetData(ctx, mem_type, copy_mode, data));
ctx->state += 2;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextTakeData(CeedQFunctionContext ctx, CeedMemType mem_type, void *data) {
void *temp_data = NULL;
bool has_valid_data = true, has_borrowed_data_of_type = true;
CeedCall(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
CeedCheck(has_valid_data, ctx->ceed, CEED_ERROR_BACKEND, "CeedQFunctionContext has no valid data to take, must set data");
CeedCheck(ctx->TakeData, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TakeData");
CeedCheck(ctx->state % 2 == 0, ctx->ceed, 1, "Cannot grant CeedQFunctionContext data access, the access lock is already in use");
CeedCall(CeedQFunctionContextHasBorrowedDataOfType(ctx, mem_type, &has_borrowed_data_of_type));
CeedCheck(has_borrowed_data_of_type, ctx->ceed, CEED_ERROR_BACKEND,
"CeedQFunctionContext has no borrowed %s data, must set data with CeedQFunctionContextSetData", CeedMemTypes[mem_type]);
CeedCall(ctx->TakeData(ctx, mem_type, &temp_data));
if (data) (*(void **)data) = temp_data;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetData(CeedQFunctionContext ctx, CeedMemType mem_type, void *data) {
bool has_valid_data = true;
CeedCheck(ctx->GetData, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support GetData");
CeedCheck(ctx->state % 2 == 0, ctx->ceed, 1, "Cannot grant CeedQFunctionContext data access, the access lock is already in use");
CeedCheck(ctx->num_readers == 0, ctx->ceed, 1, "Cannot grant CeedQFunctionContext data access, a process has read access");
CeedCall(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
CeedCheck(has_valid_data, ctx->ceed, CEED_ERROR_BACKEND, "CeedQFunctionContext has no valid data to get, must set data");
CeedCall(ctx->GetData(ctx, mem_type, data));
ctx->state++;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetDataRead(CeedQFunctionContext ctx, CeedMemType mem_type, void *data) {
bool has_valid_data = true;
CeedCheck(ctx->GetDataRead, ctx->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support GetDataRead");
CeedCheck(ctx->state % 2 == 0, ctx->ceed, 1, "Cannot grant CeedQFunctionContext data access, the access lock is already in use");
CeedCall(CeedQFunctionContextHasValidData(ctx, &has_valid_data));
CeedCheck(has_valid_data, ctx->ceed, CEED_ERROR_BACKEND, "CeedQFunctionContext has no valid data to get, must set data");
CeedCall(ctx->GetDataRead(ctx, mem_type, data));
ctx->num_readers++;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextRestoreData(CeedQFunctionContext ctx, void *data) {
CeedCheck(ctx->state % 2 == 1, ctx->ceed, 1, "Cannot restore CeedQFunctionContext array access, access was not granted");
if (ctx->RestoreData) CeedCall(ctx->RestoreData(ctx));
*(void **)data = NULL;
ctx->state++;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextRestoreDataRead(CeedQFunctionContext ctx, void *data) {
CeedCheck(ctx->num_readers > 0, ctx->ceed, 1, "Cannot restore CeedQFunctionContext array access, access was not granted");
ctx->num_readers--;
if (ctx->num_readers == 0 && ctx->RestoreDataRead) CeedCall(ctx->RestoreDataRead(ctx));
*(void **)data = NULL;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextRegisterDouble(CeedQFunctionContext ctx, const char *field_name, size_t field_offset, size_t num_values,
const char *field_description) {
return CeedQFunctionContextRegisterGeneric(ctx, field_name, field_offset, field_description, CEED_CONTEXT_FIELD_DOUBLE, sizeof(double), num_values);
}
int CeedQFunctionContextRegisterInt32(CeedQFunctionContext ctx, const char *field_name, size_t field_offset, size_t num_values,
const char *field_description) {
return CeedQFunctionContextRegisterGeneric(ctx, field_name, field_offset, field_description, CEED_CONTEXT_FIELD_INT32, sizeof(int), num_values);
}
int CeedQFunctionContextGetAllFieldLabels(CeedQFunctionContext ctx, const CeedContextFieldLabel **field_labels, CeedInt *num_fields) {
*field_labels = ctx->field_labels;
*num_fields = ctx->num_fields;
return CEED_ERROR_SUCCESS;
}
int CeedContextFieldLabelGetDescription(CeedContextFieldLabel label, const char **field_name, size_t *field_offset, size_t *num_values,
const char **field_description, CeedContextFieldType *field_type) {
if (field_name) *field_name = label->name;
if (field_offset) *field_offset = label->offset;
if (num_values) *num_values = label->num_values;
if (field_description) *field_description = label->description;
if (field_type) *field_type = label->type;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextGetContextSize(CeedQFunctionContext ctx, size_t *ctx_size) {
*ctx_size = ctx->ctx_size;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextView(CeedQFunctionContext ctx, FILE *stream) {
fprintf(stream, "CeedQFunctionContext\n");
fprintf(stream, " Context Data Size: %ld\n", ctx->ctx_size);
for (CeedInt i = 0; i < ctx->num_fields; i++) {
fprintf(stream, " Labeled %s field: %s\n", CeedContextFieldTypes[ctx->field_labels[i]->type], ctx->field_labels[i]->name);
}
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextSetDataDestroy(CeedQFunctionContext ctx, CeedMemType f_mem_type, CeedQFunctionContextDataDestroyUser f) {
CeedCheck(f, ctx->ceed, 1, "Must provide valid callback function for destroying user data");
ctx->data_destroy_mem_type = f_mem_type;
ctx->data_destroy_function = f;
return CEED_ERROR_SUCCESS;
}
int CeedQFunctionContextDestroy(CeedQFunctionContext *ctx) {
if (!*ctx || --(*ctx)->ref_count > 0) {
*ctx = NULL;
return CEED_ERROR_SUCCESS;
}
CeedCheck(((*ctx)->state % 2) == 0, (*ctx)->ceed, 1, "Cannot destroy CeedQFunctionContext, the access lock is in use");
CeedCall(CeedQFunctionContextDestroyData(*ctx));
if ((*ctx)->Destroy) CeedCall((*ctx)->Destroy(*ctx));
for (CeedInt i = 0; i < (*ctx)->num_fields; i++) {
CeedCall(CeedFree(&(*ctx)->field_labels[i]->name));
CeedCall(CeedFree(&(*ctx)->field_labels[i]->description));
CeedCall(CeedFree(&(*ctx)->field_labels[i]));
}
CeedCall(CeedFree(&(*ctx)->field_labels));
CeedCall(CeedDestroy(&(*ctx)->ceed));
CeedCall(CeedFree(ctx));
return CEED_ERROR_SUCCESS;
}