#include <ceed-impl.h>
#include <ceed.h>
#include <ceed/backend.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
static int CeedOperatorCheckField(Ceed ceed, CeedQFunctionField qf_field, CeedElemRestriction r, CeedBasis b) {
CeedInt dim = 1, num_comp = 1, q_comp = 1, rstr_num_comp = 1, size = qf_field->size;
CeedEvalMode eval_mode = qf_field->eval_mode;
CeedCheck((r == CEED_ELEMRESTRICTION_NONE) == (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_INCOMPATIBLE,
"CEED_ELEMRESTRICTION_NONE and CEED_EVAL_WEIGHT must be used together.");
if (r != CEED_ELEMRESTRICTION_NONE) {
CeedCall(CeedElemRestrictionGetNumComponents(r, &rstr_num_comp));
}
CeedCheck((b == CEED_BASIS_NONE) == (eval_mode == CEED_EVAL_NONE), ceed, CEED_ERROR_INCOMPATIBLE,
"CEED_BASIS_NONE and CEED_EVAL_NONE must be used together.");
if (b != CEED_BASIS_NONE) {
CeedCall(CeedBasisGetDimension(b, &dim));
CeedCall(CeedBasisGetNumComponents(b, &num_comp));
CeedCall(CeedBasisGetNumQuadratureComponents(b, eval_mode, &q_comp));
CeedCheck(r == CEED_ELEMRESTRICTION_NONE || rstr_num_comp == num_comp, ceed, CEED_ERROR_DIMENSION,
"Field '%s' of size %" CeedInt_FMT " and EvalMode %s: ElemRestriction has %" CeedInt_FMT " components, but Basis has %" CeedInt_FMT
" components",
qf_field->field_name, qf_field->size, CeedEvalModes[qf_field->eval_mode], rstr_num_comp, num_comp);
}
switch (eval_mode) {
case CEED_EVAL_NONE:
CeedCheck(size == rstr_num_comp, ceed, CEED_ERROR_DIMENSION,
"Field '%s' of size %" CeedInt_FMT " and EvalMode %s: ElemRestriction has %" CeedInt_FMT " components", qf_field->field_name,
qf_field->size, CeedEvalModes[qf_field->eval_mode], rstr_num_comp);
break;
case CEED_EVAL_INTERP:
case CEED_EVAL_GRAD:
case CEED_EVAL_DIV:
case CEED_EVAL_CURL:
CeedCheck(size == num_comp * q_comp, ceed, CEED_ERROR_DIMENSION,
"Field '%s' of size %" CeedInt_FMT " and EvalMode %s: ElemRestriction/Basis has %" CeedInt_FMT " components", qf_field->field_name,
qf_field->size, CeedEvalModes[qf_field->eval_mode], num_comp * q_comp);
break;
case CEED_EVAL_WEIGHT:
break;
}
return CEED_ERROR_SUCCESS;
}
static int CeedOperatorFieldView(CeedOperatorField field, CeedQFunctionField qf_field, CeedInt field_number, bool sub, bool input, FILE *stream) {
const char *pre = sub ? " " : "";
const char *in_out = input ? "Input" : "Output";
fprintf(stream,
"%s %s field %" CeedInt_FMT
":\n"
"%s Name: \"%s\"\n",
pre, in_out, field_number, pre, qf_field->field_name);
fprintf(stream, "%s Size: %" CeedInt_FMT "\n", pre, qf_field->size);
fprintf(stream, "%s EvalMode: %s\n", pre, CeedEvalModes[qf_field->eval_mode]);
if (field->basis == CEED_BASIS_NONE) fprintf(stream, "%s No basis\n", pre);
if (field->vec == CEED_VECTOR_ACTIVE) fprintf(stream, "%s Active vector\n", pre);
else if (field->vec == CEED_VECTOR_NONE) fprintf(stream, "%s No vector\n", pre);
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSingleView(CeedOperator op, bool sub, FILE *stream) {
const char *pre = sub ? " " : "";
CeedInt num_elem, num_qpts, total_fields = 0;
CeedCall(CeedOperatorGetNumElements(op, &num_elem));
CeedCall(CeedOperatorGetNumQuadraturePoints(op, &num_qpts));
CeedCall(CeedOperatorGetNumArgs(op, &total_fields));
fprintf(stream, "%s %" CeedInt_FMT " elements with %" CeedInt_FMT " quadrature points each\n", pre, num_elem, num_qpts);
fprintf(stream, "%s %" CeedInt_FMT " field%s\n", pre, total_fields, total_fields > 1 ? "s" : "");
fprintf(stream, "%s %" CeedInt_FMT " input field%s:\n", pre, op->qf->num_input_fields, op->qf->num_input_fields > 1 ? "s" : "");
for (CeedInt i = 0; i < op->qf->num_input_fields; i++) {
CeedCall(CeedOperatorFieldView(op->input_fields[i], op->qf->input_fields[i], i, sub, 1, stream));
}
fprintf(stream, "%s %" CeedInt_FMT " output field%s:\n", pre, op->qf->num_output_fields, op->qf->num_output_fields > 1 ? "s" : "");
for (CeedInt i = 0; i < op->qf->num_output_fields; i++) {
CeedCall(CeedOperatorFieldView(op->output_fields[i], op->qf->output_fields[i], i, sub, 0, stream));
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetActiveBasis(CeedOperator op, CeedBasis *active_basis) {
CeedCall(CeedOperatorGetActiveBases(op, active_basis, NULL));
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetActiveBases(CeedOperator op, CeedBasis *active_input_basis, CeedBasis *active_output_basis) {
Ceed ceed;
CeedCall(CeedOperatorGetCeed(op, &ceed));
if (active_input_basis) {
*active_input_basis = NULL;
if (!op->is_composite) {
for (CeedInt i = 0; i < op->qf->num_input_fields; i++) {
if (op->input_fields[i]->vec == CEED_VECTOR_ACTIVE) {
CeedCheck(!*active_input_basis || *active_input_basis == op->input_fields[i]->basis, ceed, CEED_ERROR_MINOR,
"Multiple active input CeedBases found");
*active_input_basis = op->input_fields[i]->basis;
}
}
CeedCheck(*active_input_basis, ceed, CEED_ERROR_INCOMPLETE, "No active input CeedBasis found");
}
}
if (active_output_basis) {
*active_output_basis = NULL;
if (!op->is_composite) {
for (CeedInt i = 0; i < op->qf->num_output_fields; i++) {
if (op->output_fields[i]->vec == CEED_VECTOR_ACTIVE) {
CeedCheck(!*active_output_basis || *active_output_basis == op->output_fields[i]->basis, ceed, CEED_ERROR_MINOR,
"Multiple active output CeedBases found");
*active_output_basis = op->output_fields[i]->basis;
}
}
CeedCheck(*active_output_basis, ceed, CEED_ERROR_INCOMPLETE, "No active output CeedBasis found");
}
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetActiveElemRestriction(CeedOperator op, CeedElemRestriction *active_rstr) {
CeedCall(CeedOperatorGetActiveElemRestrictions(op, active_rstr, NULL));
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetActiveElemRestrictions(CeedOperator op, CeedElemRestriction *active_input_rstr, CeedElemRestriction *active_output_rstr) {
Ceed ceed;
CeedCall(CeedOperatorGetCeed(op, &ceed));
if (active_input_rstr) {
*active_input_rstr = NULL;
if (!op->is_composite) {
for (CeedInt i = 0; i < op->qf->num_input_fields; i++) {
if (op->input_fields[i]->vec == CEED_VECTOR_ACTIVE) {
CeedCheck(!*active_input_rstr || *active_input_rstr == op->input_fields[i]->elem_rstr, ceed, CEED_ERROR_MINOR,
"Multiple active input CeedElemRestrictions found");
*active_input_rstr = op->input_fields[i]->elem_rstr;
}
}
CeedCheck(*active_input_rstr, ceed, CEED_ERROR_INCOMPLETE, "No active input CeedElemRestriction found");
}
}
if (active_output_rstr) {
*active_output_rstr = NULL;
if (!op->is_composite) {
for (CeedInt i = 0; i < op->qf->num_output_fields; i++) {
if (op->output_fields[i]->vec == CEED_VECTOR_ACTIVE) {
CeedCheck(!*active_output_rstr || *active_output_rstr == op->output_fields[i]->elem_rstr, ceed, CEED_ERROR_MINOR,
"Multiple active output CeedElemRestrictions found");
*active_output_rstr = op->output_fields[i]->elem_rstr;
}
}
CeedCheck(*active_output_rstr, ceed, CEED_ERROR_INCOMPLETE, "No active output CeedElemRestriction found");
}
}
return CEED_ERROR_SUCCESS;
}
static int CeedOperatorContextSetGeneric(CeedOperator op, CeedContextFieldLabel field_label, CeedContextFieldType field_type, void *values) {
bool is_composite = false;
CeedCheck(field_label, op->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
if (field_label->from_op) {
CeedInt index = -1;
for (CeedInt i = 0; i < op->num_context_labels; i++) {
if (op->context_labels[i] == field_label) index = i;
}
CeedCheck(index != -1, op->ceed, CEED_ERROR_UNSUPPORTED, "ContextFieldLabel does not correspond to the operator");
}
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite) {
CeedInt num_sub;
CeedOperator *sub_operators;
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_sub));
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
CeedCheck(num_sub == field_label->num_sub_labels, op->ceed, CEED_ERROR_UNSUPPORTED,
"Composite operator modified after ContextFieldLabel created");
for (CeedInt i = 0; i < num_sub; i++) {
if (field_label->sub_labels[i] && sub_operators[i]->qf->ctx) {
CeedCall(CeedQFunctionContextSetGeneric(sub_operators[i]->qf->ctx, field_label->sub_labels[i], field_type, values));
}
}
} else {
CeedCheck(op->qf->ctx, op->ceed, CEED_ERROR_UNSUPPORTED, "QFunction does not have context data");
CeedCall(CeedQFunctionContextSetGeneric(op->qf->ctx, field_label, field_type, values));
}
CeedCall(CeedOperatorSetQFunctionAssemblyDataUpdateNeeded(op, true));
return CEED_ERROR_SUCCESS;
}
static int CeedOperatorContextGetGenericRead(CeedOperator op, CeedContextFieldLabel field_label, CeedContextFieldType field_type, size_t *num_values,
void *values) {
bool is_composite = false;
CeedCheck(field_label, op->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
*(void **)values = NULL;
*num_values = 0;
if (field_label->from_op) {
CeedInt index = -1;
for (CeedInt i = 0; i < op->num_context_labels; i++) {
if (op->context_labels[i] == field_label) index = i;
}
CeedCheck(index != -1, op->ceed, CEED_ERROR_UNSUPPORTED, "ContextFieldLabel does not correspond to the operator");
}
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite) {
CeedInt num_sub;
CeedOperator *sub_operators;
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_sub));
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
CeedCheck(num_sub == field_label->num_sub_labels, op->ceed, CEED_ERROR_UNSUPPORTED,
"Composite operator modified after ContextFieldLabel created");
for (CeedInt i = 0; i < num_sub; i++) {
if (field_label->sub_labels[i] && sub_operators[i]->qf->ctx) {
CeedCall(CeedQFunctionContextGetGenericRead(sub_operators[i]->qf->ctx, field_label->sub_labels[i], field_type, num_values, values));
return CEED_ERROR_SUCCESS;
}
}
} else {
CeedCheck(op->qf->ctx, op->ceed, CEED_ERROR_UNSUPPORTED, "QFunction does not have context data");
CeedCall(CeedQFunctionContextGetGenericRead(op->qf->ctx, field_label, field_type, num_values, values));
}
return CEED_ERROR_SUCCESS;
}
static int CeedOperatorContextRestoreGenericRead(CeedOperator op, CeedContextFieldLabel field_label, CeedContextFieldType field_type, void *values) {
bool is_composite = false;
CeedCheck(field_label, op->ceed, CEED_ERROR_UNSUPPORTED, "Invalid field label");
if (field_label->from_op) {
CeedInt index = -1;
for (CeedInt i = 0; i < op->num_context_labels; i++) {
if (op->context_labels[i] == field_label) index = i;
}
CeedCheck(index != -1, op->ceed, CEED_ERROR_UNSUPPORTED, "ContextFieldLabel does not correspond to the operator");
}
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite) {
CeedInt num_sub;
CeedOperator *sub_operators;
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_sub));
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
CeedCheck(num_sub == field_label->num_sub_labels, op->ceed, CEED_ERROR_UNSUPPORTED,
"Composite operator modified after ContextFieldLabel created");
for (CeedInt i = 0; i < num_sub; i++) {
if (field_label->sub_labels[i] && sub_operators[i]->qf->ctx) {
CeedCall(CeedQFunctionContextRestoreGenericRead(sub_operators[i]->qf->ctx, field_label->sub_labels[i], field_type, values));
return CEED_ERROR_SUCCESS;
}
}
} else {
CeedCheck(op->qf->ctx, op->ceed, CEED_ERROR_UNSUPPORTED, "QFunction does not have context data");
CeedCall(CeedQFunctionContextRestoreGenericRead(op->qf->ctx, field_label, field_type, values));
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetNumArgs(CeedOperator op, CeedInt *num_args) {
CeedCheck(!op->is_composite, op->ceed, CEED_ERROR_MINOR, "Not defined for composite operators");
*num_args = op->num_fields;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorIsSetupDone(CeedOperator op, bool *is_setup_done) {
*is_setup_done = op->is_backend_setup;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetQFunction(CeedOperator op, CeedQFunction *qf) {
CeedCheck(!op->is_composite, op->ceed, CEED_ERROR_MINOR, "Not defined for composite operator");
*qf = op->qf;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorIsComposite(CeedOperator op, bool *is_composite) {
*is_composite = op->is_composite;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetData(CeedOperator op, void *data) {
*(void **)data = op->data;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSetData(CeedOperator op, void *data) {
op->data = data;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorReference(CeedOperator op) {
op->ref_count++;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSetSetupDone(CeedOperator op) {
op->is_backend_setup = true;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorCreate(Ceed ceed, CeedQFunction qf, CeedQFunction dqf, CeedQFunction dqfT, CeedOperator *op) {
if (!ceed->OperatorCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Operator"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support OperatorCreate");
CeedCall(CeedOperatorCreate(delegate, qf, dqf, dqfT, op));
return CEED_ERROR_SUCCESS;
}
CeedCheck(qf && qf != CEED_QFUNCTION_NONE, ceed, CEED_ERROR_MINOR, "Operator must have a valid QFunction.");
CeedCall(CeedCalloc(1, op));
CeedCall(CeedReferenceCopy(ceed, &(*op)->ceed));
(*op)->ref_count = 1;
(*op)->input_size = -1;
(*op)->output_size = -1;
CeedCall(CeedQFunctionReferenceCopy(qf, &(*op)->qf));
if (dqf && dqf != CEED_QFUNCTION_NONE) CeedCall(CeedQFunctionReferenceCopy(dqf, &(*op)->dqf));
if (dqfT && dqfT != CEED_QFUNCTION_NONE) CeedCall(CeedQFunctionReferenceCopy(dqfT, &(*op)->dqfT));
CeedCall(CeedQFunctionAssemblyDataCreate(ceed, &(*op)->qf_assembled));
CeedCall(CeedCalloc(CEED_FIELD_MAX, &(*op)->input_fields));
CeedCall(CeedCalloc(CEED_FIELD_MAX, &(*op)->output_fields));
CeedCall(ceed->OperatorCreate(*op));
return CEED_ERROR_SUCCESS;
}
int CeedCompositeOperatorCreate(Ceed ceed, CeedOperator *op) {
if (!ceed->CompositeOperatorCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Operator"));
if (delegate) {
CeedCall(CeedCompositeOperatorCreate(delegate, op));
return CEED_ERROR_SUCCESS;
}
}
CeedCall(CeedCalloc(1, op));
CeedCall(CeedReferenceCopy(ceed, &(*op)->ceed));
(*op)->ref_count = 1;
(*op)->is_composite = true;
CeedCall(CeedCalloc(CEED_COMPOSITE_MAX, &(*op)->sub_operators));
(*op)->input_size = -1;
(*op)->output_size = -1;
if (ceed->CompositeOperatorCreate) CeedCall(ceed->CompositeOperatorCreate(*op));
return CEED_ERROR_SUCCESS;
}
int CeedOperatorReferenceCopy(CeedOperator op, CeedOperator *op_copy) {
CeedCall(CeedOperatorReference(op));
CeedCall(CeedOperatorDestroy(op_copy));
*op_copy = op;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSetField(CeedOperator op, const char *field_name, CeedElemRestriction r, CeedBasis b, CeedVector v) {
bool is_input = true;
CeedInt num_elem = 0, num_qpts = 0;
CeedQFunctionField qf_field;
CeedOperatorField *op_field;
CeedCheck(!op->is_composite, op->ceed, CEED_ERROR_INCOMPATIBLE, "Cannot add field to composite operator.");
CeedCheck(!op->is_immutable, op->ceed, CEED_ERROR_MAJOR, "Operator cannot be changed after set as immutable");
CeedCheck(r, op->ceed, CEED_ERROR_INCOMPATIBLE, "ElemRestriction r for field \"%s\" must be non-NULL.", field_name);
CeedCheck(b, op->ceed, CEED_ERROR_INCOMPATIBLE, "Basis b for field \"%s\" must be non-NULL.", field_name);
CeedCheck(v, op->ceed, CEED_ERROR_INCOMPATIBLE, "Vector v for field \"%s\" must be non-NULL.", field_name);
CeedCall(CeedElemRestrictionGetNumElements(r, &num_elem));
CeedCheck(r == CEED_ELEMRESTRICTION_NONE || !op->has_restriction || op->num_elem == num_elem, op->ceed, CEED_ERROR_DIMENSION,
"ElemRestriction with %" CeedInt_FMT " elements incompatible with prior %" CeedInt_FMT " elements", num_elem, op->num_elem);
{
CeedRestrictionType rstr_type;
CeedCall(CeedElemRestrictionGetType(r, &rstr_type));
CeedCheck(rstr_type != CEED_RESTRICTION_POINTS, op->ceed, CEED_ERROR_UNSUPPORTED,
"CeedElemRestrictionAtPoints not supported for standard operator fields");
}
if (b == CEED_BASIS_NONE) CeedCall(CeedElemRestrictionGetElementSize(r, &num_qpts));
else CeedCall(CeedBasisGetNumQuadraturePoints(b, &num_qpts));
CeedCheck(op->num_qpts == 0 || op->num_qpts == num_qpts, op->ceed, CEED_ERROR_DIMENSION,
"%s must correspond to the same number of quadrature points as previously added Bases. Found %" CeedInt_FMT
" quadrature points but expected %" CeedInt_FMT " quadrature points.",
b == CEED_BASIS_NONE ? "ElemRestriction" : "Basis", num_qpts, op->num_qpts);
for (CeedInt i = 0; i < op->qf->num_input_fields; i++) {
if (!strcmp(field_name, (*op->qf->input_fields[i]).field_name)) {
qf_field = op->qf->input_fields[i];
op_field = &op->input_fields[i];
goto found;
}
}
is_input = false;
for (CeedInt i = 0; i < op->qf->num_output_fields; i++) {
if (!strcmp(field_name, (*op->qf->output_fields[i]).field_name)) {
qf_field = op->qf->output_fields[i];
op_field = &op->output_fields[i];
goto found;
}
}
return CeedError(op->ceed, CEED_ERROR_INCOMPLETE, "QFunction has no knowledge of field '%s'", field_name);
found:
CeedCall(CeedOperatorCheckField(op->ceed, qf_field, r, b));
CeedCall(CeedCalloc(1, op_field));
if (v == CEED_VECTOR_ACTIVE) {
CeedSize l_size;
CeedCall(CeedElemRestrictionGetLVectorSize(r, &l_size));
if (is_input) {
if (op->input_size == -1) op->input_size = l_size;
CeedCheck(l_size == op->input_size, op->ceed, CEED_ERROR_INCOMPATIBLE, "LVector size %td does not match previous size %td", l_size,
op->input_size);
} else {
if (op->output_size == -1) op->output_size = l_size;
CeedCheck(l_size == op->output_size, op->ceed, CEED_ERROR_INCOMPATIBLE, "LVector size %td does not match previous size %td", l_size,
op->output_size);
}
}
CeedCall(CeedVectorReferenceCopy(v, &(*op_field)->vec));
CeedCall(CeedElemRestrictionReferenceCopy(r, &(*op_field)->elem_rstr));
if (r != CEED_ELEMRESTRICTION_NONE && !op->has_restriction) {
op->num_elem = num_elem;
op->has_restriction = true; }
CeedCall(CeedBasisReferenceCopy(b, &(*op_field)->basis));
if (op->num_qpts == 0) op->num_qpts = num_qpts;
op->num_fields += 1;
CeedCall(CeedStringAllocCopy(field_name, (char **)&(*op_field)->field_name));
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetFields(CeedOperator op, CeedInt *num_input_fields, CeedOperatorField **input_fields, CeedInt *num_output_fields,
CeedOperatorField **output_fields) {
CeedCheck(!op->is_composite, op->ceed, CEED_ERROR_MINOR, "Not defined for composite operator");
CeedCall(CeedOperatorCheckReady(op));
if (num_input_fields) *num_input_fields = op->qf->num_input_fields;
if (input_fields) *input_fields = op->input_fields;
if (num_output_fields) *num_output_fields = op->qf->num_output_fields;
if (output_fields) *output_fields = op->output_fields;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetFieldByName(CeedOperator op, const char *field_name, CeedOperatorField *op_field) {
char *name;
CeedInt num_input_fields, num_output_fields;
CeedOperatorField *input_fields, *output_fields;
CeedCall(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
for (CeedInt i = 0; i < num_input_fields; i++) {
CeedCall(CeedOperatorFieldGetName(input_fields[i], &name));
if (!strcmp(name, field_name)) {
*op_field = input_fields[i];
return CEED_ERROR_SUCCESS;
}
}
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedCall(CeedOperatorFieldGetName(output_fields[i], &name));
if (!strcmp(name, field_name)) {
*op_field = output_fields[i];
return CEED_ERROR_SUCCESS;
}
}
bool has_name = op->name;
return CeedError(op->ceed, CEED_ERROR_MINOR, "The field \"%s\" not found in CeedOperator%s%s%s.\n", field_name, has_name ? " \"" : "",
has_name ? op->name : "", has_name ? "\"" : "");
}
int CeedOperatorFieldGetName(CeedOperatorField op_field, char **field_name) {
*field_name = (char *)op_field->field_name;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorFieldGetElemRestriction(CeedOperatorField op_field, CeedElemRestriction *rstr) {
*rstr = op_field->elem_rstr;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorFieldGetBasis(CeedOperatorField op_field, CeedBasis *basis) {
*basis = op_field->basis;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorFieldGetVector(CeedOperatorField op_field, CeedVector *vec) {
*vec = op_field->vec;
return CEED_ERROR_SUCCESS;
}
int CeedCompositeOperatorAddSub(CeedOperator composite_op, CeedOperator sub_op) {
CeedCheck(composite_op->is_composite, composite_op->ceed, CEED_ERROR_MINOR, "CeedOperator is not a composite operator");
CeedCheck(composite_op->num_suboperators < CEED_COMPOSITE_MAX, composite_op->ceed, CEED_ERROR_UNSUPPORTED, "Cannot add additional sub-operators");
CeedCheck(!composite_op->is_immutable, composite_op->ceed, CEED_ERROR_MAJOR, "Operator cannot be changed after set as immutable");
{
CeedSize input_size, output_size;
CeedCall(CeedOperatorGetActiveVectorLengths(sub_op, &input_size, &output_size));
if (composite_op->input_size == -1) composite_op->input_size = input_size;
if (composite_op->output_size == -1) composite_op->output_size = output_size;
CeedCheck((input_size == -1 || input_size == composite_op->input_size) && (output_size == -1 || output_size == composite_op->output_size),
composite_op->ceed, CEED_ERROR_MAJOR,
"Sub-operators must have compatible dimensions; composite operator of shape (%td, %td) not compatible with sub-operator of "
"shape (%td, %td)",
composite_op->input_size, composite_op->output_size, input_size, output_size);
}
composite_op->sub_operators[composite_op->num_suboperators] = sub_op;
CeedCall(CeedOperatorReference(sub_op));
composite_op->num_suboperators++;
return CEED_ERROR_SUCCESS;
}
int CeedCompositeOperatorGetNumSub(CeedOperator op, CeedInt *num_suboperators) {
CeedCheck(op->is_composite, op->ceed, CEED_ERROR_MINOR, "Only defined for a composite operator");
*num_suboperators = op->num_suboperators;
return CEED_ERROR_SUCCESS;
}
int CeedCompositeOperatorGetSubList(CeedOperator op, CeedOperator **sub_operators) {
CeedCheck(op->is_composite, op->ceed, CEED_ERROR_MINOR, "Only defined for a composite operator");
*sub_operators = op->sub_operators;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorCheckReady(CeedOperator op) {
Ceed ceed;
CeedCall(CeedOperatorGetCeed(op, &ceed));
if (op->is_interface_setup) return CEED_ERROR_SUCCESS;
CeedQFunction qf = op->qf;
if (op->is_composite) {
if (!op->num_suboperators) {
op->input_size = 0;
op->output_size = 0;
} else {
for (CeedInt i = 0; i < op->num_suboperators; i++) {
CeedCall(CeedOperatorCheckReady(op->sub_operators[i]));
}
CeedSize input_size, output_size;
CeedCall(CeedOperatorGetActiveVectorLengths(op, &input_size, &output_size));
}
} else {
CeedCheck(op->num_fields > 0, ceed, CEED_ERROR_INCOMPLETE, "No operator fields set");
CeedCheck(op->num_fields == qf->num_input_fields + qf->num_output_fields, ceed, CEED_ERROR_INCOMPLETE, "Not all operator fields set");
CeedCheck(op->has_restriction, ceed, CEED_ERROR_INCOMPLETE, "At least one restriction required");
CeedCheck(op->num_qpts > 0, ceed, CEED_ERROR_INCOMPLETE,
"At least one non-collocated basis is required or the number of quadrature points must be set");
}
op->is_interface_setup = true;
if (op->qf && op->qf != CEED_QFUNCTION_NONE) op->qf->is_immutable = true;
if (op->dqf && op->dqf != CEED_QFUNCTION_NONE) op->dqf->is_immutable = true;
if (op->dqfT && op->dqfT != CEED_QFUNCTION_NONE) op->dqfT->is_immutable = true;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetActiveVectorLengths(CeedOperator op, CeedSize *input_size, CeedSize *output_size) {
bool is_composite;
if (input_size) *input_size = op->input_size;
if (output_size) *output_size = op->output_size;
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite && (op->input_size == -1 || op->output_size == -1)) {
for (CeedInt i = 0; i < op->num_suboperators; i++) {
CeedSize sub_input_size, sub_output_size;
CeedCall(CeedOperatorGetActiveVectorLengths(op->sub_operators[i], &sub_input_size, &sub_output_size));
if (op->input_size == -1) op->input_size = sub_input_size;
if (op->output_size == -1) op->output_size = sub_output_size;
CeedCheck((sub_input_size == -1 || sub_input_size == op->input_size) && (sub_output_size == -1 || sub_output_size == op->output_size), op->ceed,
CEED_ERROR_MAJOR,
"Sub-operators must have compatible dimensions; composite operator of shape (%td, %td) not compatible with sub-operator of "
"shape (%td, %td)",
op->input_size, op->output_size, input_size, output_size);
}
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSetQFunctionAssemblyReuse(CeedOperator op, bool reuse_assembly_data) {
bool is_composite;
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite) {
for (CeedInt i = 0; i < op->num_suboperators; i++) {
CeedCall(CeedOperatorSetQFunctionAssemblyReuse(op->sub_operators[i], reuse_assembly_data));
}
} else {
CeedCall(CeedQFunctionAssemblyDataSetReuse(op->qf_assembled, reuse_assembly_data));
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSetQFunctionAssemblyDataUpdateNeeded(CeedOperator op, bool needs_data_update) {
bool is_composite;
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite) {
for (CeedInt i = 0; i < op->num_suboperators; i++) {
CeedCall(CeedOperatorSetQFunctionAssemblyDataUpdateNeeded(op->sub_operators[i], needs_data_update));
}
} else {
CeedCall(CeedQFunctionAssemblyDataSetUpdateNeeded(op->qf_assembled, needs_data_update));
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSetName(CeedOperator op, const char *name) {
char *name_copy;
size_t name_len = name ? strlen(name) : 0;
CeedCall(CeedFree(&op->name));
if (name_len > 0) {
CeedCall(CeedCalloc(name_len + 1, &name_copy));
memcpy(name_copy, name, name_len);
op->name = name_copy;
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorView(CeedOperator op, FILE *stream) {
bool has_name = op->name;
if (op->is_composite) {
fprintf(stream, "Composite CeedOperator%s%s\n", has_name ? " - " : "", has_name ? op->name : "");
for (CeedInt i = 0; i < op->num_suboperators; i++) {
has_name = op->sub_operators[i]->name;
fprintf(stream, " SubOperator %" CeedInt_FMT "%s%s:\n", i, has_name ? " - " : "", has_name ? op->sub_operators[i]->name : "");
CeedCall(CeedOperatorSingleView(op->sub_operators[i], 1, stream));
}
} else {
fprintf(stream, "CeedOperator%s%s\n", has_name ? " - " : "", has_name ? op->name : "");
CeedCall(CeedOperatorSingleView(op, 0, stream));
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetCeed(CeedOperator op, Ceed *ceed) {
*ceed = op->ceed;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetNumElements(CeedOperator op, CeedInt *num_elem) {
CeedCheck(!op->is_composite, op->ceed, CEED_ERROR_MINOR, "Not defined for composite operator");
*num_elem = op->num_elem;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetNumQuadraturePoints(CeedOperator op, CeedInt *num_qpts) {
CeedCheck(!op->is_composite, op->ceed, CEED_ERROR_MINOR, "Not defined for composite operator");
*num_qpts = op->num_qpts;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetFlopsEstimate(CeedOperator op, CeedSize *flops) {
bool is_composite;
CeedCall(CeedOperatorCheckReady(op));
*flops = 0;
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite) {
CeedInt num_suboperators;
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
CeedOperator *sub_operators;
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
for (CeedInt i = 0; i < num_suboperators; i++) {
CeedSize suboperator_flops;
CeedCall(CeedOperatorGetFlopsEstimate(sub_operators[i], &suboperator_flops));
*flops += suboperator_flops;
}
} else {
CeedInt num_input_fields, num_output_fields, num_elem = 0;
CeedOperatorField *input_fields, *output_fields;
CeedCall(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
CeedCall(CeedOperatorGetNumElements(op, &num_elem));
for (CeedInt i = 0; i < num_input_fields; i++) {
if (input_fields[i]->vec == CEED_VECTOR_ACTIVE) {
CeedSize rstr_flops, basis_flops;
CeedCall(CeedElemRestrictionGetFlopsEstimate(input_fields[i]->elem_rstr, CEED_NOTRANSPOSE, &rstr_flops));
*flops += rstr_flops;
CeedCall(CeedBasisGetFlopsEstimate(input_fields[i]->basis, CEED_NOTRANSPOSE, op->qf->input_fields[i]->eval_mode, &basis_flops));
*flops += basis_flops * num_elem;
}
}
{
CeedInt num_qpts;
CeedSize qf_flops;
CeedCall(CeedOperatorGetNumQuadraturePoints(op, &num_qpts));
CeedCall(CeedQFunctionGetFlopsEstimate(op->qf, &qf_flops));
*flops += num_elem * num_qpts * qf_flops;
}
for (CeedInt i = 0; i < num_output_fields; i++) {
if (output_fields[i]->vec == CEED_VECTOR_ACTIVE) {
CeedSize rstr_flops, basis_flops;
CeedCall(CeedElemRestrictionGetFlopsEstimate(output_fields[i]->elem_rstr, CEED_TRANSPOSE, &rstr_flops));
*flops += rstr_flops;
CeedCall(CeedBasisGetFlopsEstimate(output_fields[i]->basis, CEED_TRANSPOSE, op->qf->output_fields[i]->eval_mode, &basis_flops));
*flops += basis_flops * num_elem;
}
}
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetContext(CeedOperator op, CeedQFunctionContext *ctx) {
CeedCheck(!op->is_composite, op->ceed, CEED_ERROR_INCOMPATIBLE, "Cannot retrieve QFunctionContext for composite operator");
if (op->qf->ctx) CeedCall(CeedQFunctionContextReferenceCopy(op->qf->ctx, ctx));
else *ctx = NULL;
return CEED_ERROR_SUCCESS;
}
int CeedOperatorGetContextFieldLabel(CeedOperator op, const char *field_name, CeedContextFieldLabel *field_label) {
bool is_composite, field_found = false;
CeedCall(CeedOperatorIsComposite(op, &is_composite));
if (is_composite) {
for (CeedInt i = 0; i < op->num_context_labels; i++) {
if (!strcmp(op->context_labels[i]->name, field_name)) {
*field_label = op->context_labels[i];
return CEED_ERROR_SUCCESS;
}
}
CeedInt num_sub;
CeedOperator *sub_operators;
CeedContextFieldLabel new_field_label;
CeedCall(CeedCalloc(1, &new_field_label));
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_sub));
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
CeedCall(CeedCalloc(num_sub, &new_field_label->sub_labels));
new_field_label->num_sub_labels = num_sub;
for (CeedInt i = 0; i < num_sub; i++) {
if (sub_operators[i]->qf->ctx) {
CeedContextFieldLabel new_field_label_i;
CeedCall(CeedQFunctionContextGetFieldLabel(sub_operators[i]->qf->ctx, field_name, &new_field_label_i));
if (new_field_label_i) {
field_found = true;
new_field_label->sub_labels[i] = new_field_label_i;
new_field_label->name = new_field_label_i->name;
new_field_label->description = new_field_label_i->description;
if (new_field_label->type && new_field_label->type != new_field_label_i->type) {
CeedCall(CeedFree(&new_field_label));
return CeedError(op->ceed, CEED_ERROR_INCOMPATIBLE, "Incompatible field types on sub-operator contexts. %s != %s",
CeedContextFieldTypes[new_field_label->type], CeedContextFieldTypes[new_field_label_i->type]);
} else {
new_field_label->type = new_field_label_i->type;
}
if (new_field_label->num_values != 0 && new_field_label->num_values != new_field_label_i->num_values) {
CeedCall(CeedFree(&new_field_label));
return CeedError(op->ceed, CEED_ERROR_INCOMPATIBLE, "Incompatible field number of values on sub-operator contexts. %ld != %ld",
new_field_label->num_values, new_field_label_i->num_values);
} else {
new_field_label->num_values = new_field_label_i->num_values;
}
}
}
}
if (field_found) {
*field_label = new_field_label;
} else {
CeedCall(CeedFree(&new_field_label->sub_labels));
CeedCall(CeedFree(&new_field_label));
*field_label = NULL;
}
} else {
if (op->qf->ctx) {
CeedCall(CeedQFunctionContextGetFieldLabel(op->qf->ctx, field_name, field_label));
} else {
*field_label = NULL;
}
}
if (*field_label) {
(*field_label)->from_op = true;
if (op->num_context_labels == 0) {
CeedCall(CeedCalloc(1, &op->context_labels));
op->max_context_labels = 1;
} else if (op->num_context_labels == op->max_context_labels) {
CeedCall(CeedRealloc(2 * op->num_context_labels, &op->context_labels));
op->max_context_labels *= 2;
}
op->context_labels[op->num_context_labels] = *field_label;
op->num_context_labels++;
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorSetContextDouble(CeedOperator op, CeedContextFieldLabel field_label, double *values) {
return CeedOperatorContextSetGeneric(op, field_label, CEED_CONTEXT_FIELD_DOUBLE, values);
}
int CeedOperatorGetContextDoubleRead(CeedOperator op, CeedContextFieldLabel field_label, size_t *num_values, const double **values) {
return CeedOperatorContextGetGenericRead(op, field_label, CEED_CONTEXT_FIELD_DOUBLE, num_values, values);
}
int CeedOperatorRestoreContextDoubleRead(CeedOperator op, CeedContextFieldLabel field_label, const double **values) {
return CeedOperatorContextRestoreGenericRead(op, field_label, CEED_CONTEXT_FIELD_DOUBLE, values);
}
int CeedOperatorSetContextInt32(CeedOperator op, CeedContextFieldLabel field_label, int *values) {
return CeedOperatorContextSetGeneric(op, field_label, CEED_CONTEXT_FIELD_INT32, values);
}
int CeedOperatorGetContextInt32Read(CeedOperator op, CeedContextFieldLabel field_label, size_t *num_values, const int **values) {
return CeedOperatorContextGetGenericRead(op, field_label, CEED_CONTEXT_FIELD_INT32, num_values, values);
}
int CeedOperatorRestoreContextInt32Read(CeedOperator op, CeedContextFieldLabel field_label, const int **values) {
return CeedOperatorContextRestoreGenericRead(op, field_label, CEED_CONTEXT_FIELD_INT32, values);
}
int CeedOperatorApply(CeedOperator op, CeedVector in, CeedVector out, CeedRequest *request) {
CeedCall(CeedOperatorCheckReady(op));
if (op->is_composite) {
if (op->ApplyComposite) {
CeedCall(op->ApplyComposite(op, in, out, request));
} else {
CeedInt num_suboperators;
CeedOperator *sub_operators;
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
if (out != CEED_VECTOR_NONE) CeedCall(CeedVectorSetValue(out, 0.0));
for (CeedInt i = 0; i < num_suboperators; i++) {
for (CeedInt j = 0; j < sub_operators[i]->qf->num_output_fields; j++) {
CeedVector vec = sub_operators[i]->output_fields[j]->vec;
if (vec != CEED_VECTOR_ACTIVE && vec != CEED_VECTOR_NONE) {
CeedCall(CeedVectorSetValue(vec, 0.0));
}
}
}
for (CeedInt i = 0; i < num_suboperators; i++) {
CeedCall(CeedOperatorApplyAdd(sub_operators[i], in, out, request));
}
}
} else {
if (op->Apply) {
CeedCall(op->Apply(op, in, out, request));
} else {
CeedQFunction qf = op->qf;
for (CeedInt i = 0; i < qf->num_output_fields; i++) {
CeedVector vec = op->output_fields[i]->vec;
if (vec == CEED_VECTOR_ACTIVE) vec = out;
if (vec != CEED_VECTOR_NONE) CeedCall(CeedVectorSetValue(vec, 0.0));
}
if (op->num_elem) CeedCall(op->ApplyAdd(op, in, out, request));
}
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorApplyAdd(CeedOperator op, CeedVector in, CeedVector out, CeedRequest *request) {
CeedCall(CeedOperatorCheckReady(op));
if (op->is_composite) {
if (op->ApplyAddComposite) {
CeedCall(op->ApplyAddComposite(op, in, out, request));
} else {
CeedInt num_suboperators;
CeedOperator *sub_operators;
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
for (CeedInt i = 0; i < num_suboperators; i++) {
CeedCall(CeedOperatorApplyAdd(sub_operators[i], in, out, request));
}
}
} else if (op->num_elem) {
CeedCall(op->ApplyAdd(op, in, out, request));
}
return CEED_ERROR_SUCCESS;
}
int CeedOperatorDestroy(CeedOperator *op) {
if (!*op || --(*op)->ref_count > 0) {
*op = NULL;
return CEED_ERROR_SUCCESS;
}
if ((*op)->Destroy) CeedCall((*op)->Destroy(*op));
CeedCall(CeedDestroy(&(*op)->ceed));
for (CeedInt i = 0; i < (*op)->num_fields; i++) {
if ((*op)->input_fields[i]) {
if ((*op)->input_fields[i]->elem_rstr != CEED_ELEMRESTRICTION_NONE) {
CeedCall(CeedElemRestrictionDestroy(&(*op)->input_fields[i]->elem_rstr));
}
if ((*op)->input_fields[i]->basis != CEED_BASIS_NONE) {
CeedCall(CeedBasisDestroy(&(*op)->input_fields[i]->basis));
}
if ((*op)->input_fields[i]->vec != CEED_VECTOR_ACTIVE && (*op)->input_fields[i]->vec != CEED_VECTOR_NONE) {
CeedCall(CeedVectorDestroy(&(*op)->input_fields[i]->vec));
}
CeedCall(CeedFree(&(*op)->input_fields[i]->field_name));
CeedCall(CeedFree(&(*op)->input_fields[i]));
}
}
for (CeedInt i = 0; i < (*op)->num_fields; i++) {
if ((*op)->output_fields[i]) {
CeedCall(CeedElemRestrictionDestroy(&(*op)->output_fields[i]->elem_rstr));
if ((*op)->output_fields[i]->basis != CEED_BASIS_NONE) {
CeedCall(CeedBasisDestroy(&(*op)->output_fields[i]->basis));
}
if ((*op)->output_fields[i]->vec != CEED_VECTOR_ACTIVE && (*op)->output_fields[i]->vec != CEED_VECTOR_NONE) {
CeedCall(CeedVectorDestroy(&(*op)->output_fields[i]->vec));
}
CeedCall(CeedFree(&(*op)->output_fields[i]->field_name));
CeedCall(CeedFree(&(*op)->output_fields[i]));
}
}
for (CeedInt i = 0; i < (*op)->num_suboperators; i++) {
if ((*op)->sub_operators[i]) {
CeedCall(CeedOperatorDestroy(&(*op)->sub_operators[i]));
}
}
CeedCall(CeedQFunctionDestroy(&(*op)->qf));
CeedCall(CeedQFunctionDestroy(&(*op)->dqf));
CeedCall(CeedQFunctionDestroy(&(*op)->dqfT));
if ((*op)->is_composite) {
for (CeedInt i = 0; i < (*op)->num_context_labels; i++) {
CeedCall(CeedFree(&(*op)->context_labels[i]->sub_labels));
CeedCall(CeedFree(&(*op)->context_labels[i]));
}
}
CeedCall(CeedFree(&(*op)->context_labels));
CeedCall(CeedOperatorDestroy(&(*op)->op_fallback));
CeedCall(CeedQFunctionAssemblyDataDestroy(&(*op)->qf_assembled));
CeedCall(CeedOperatorAssemblyDataDestroy(&(*op)->op_assembled));
CeedCall(CeedFree(&(*op)->input_fields));
CeedCall(CeedFree(&(*op)->output_fields));
CeedCall(CeedFree(&(*op)->sub_operators));
CeedCall(CeedFree(&(*op)->name));
CeedCall(CeedFree(op));
return CEED_ERROR_SUCCESS;
}