#include <ceed-impl.h>
#include <ceed.h>
#include <ceed/backend.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
int CeedPermutePadOffsets(const CeedInt *offsets, CeedInt *block_offsets, CeedInt num_block, CeedInt num_elem, CeedInt block_size,
CeedInt elem_size) {
for (CeedInt e = 0; e < num_block * block_size; e += block_size) {
for (CeedInt j = 0; j < block_size; j++) {
for (CeedInt k = 0; k < elem_size; k++) {
block_offsets[e * elem_size + k * block_size + j] = offsets[CeedIntMin(e + j, num_elem - 1) * elem_size + k];
}
}
}
return CEED_ERROR_SUCCESS;
}
int CeedPermutePadOrients(const bool *orients, bool *block_orients, CeedInt num_block, CeedInt num_elem, CeedInt block_size, CeedInt elem_size) {
for (CeedInt e = 0; e < num_block * block_size; e += block_size) {
for (CeedInt j = 0; j < block_size; j++) {
for (CeedInt k = 0; k < elem_size; k++) {
block_orients[e * elem_size + k * block_size + j] = orients[CeedIntMin(e + j, num_elem - 1) * elem_size + k];
}
}
}
return CEED_ERROR_SUCCESS;
}
int CeedPermutePadCurlOrients(const CeedInt8 *curl_orients, CeedInt8 *block_curl_orients, CeedInt num_block, CeedInt num_elem, CeedInt block_size,
CeedInt elem_size) {
for (CeedInt e = 0; e < num_block * block_size; e += block_size) {
for (CeedInt j = 0; j < block_size; j++) {
for (CeedInt k = 0; k < elem_size; k++) {
block_curl_orients[e * elem_size + k * block_size + j] = curl_orients[CeedIntMin(e + j, num_elem - 1) * elem_size + k];
}
}
}
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetType(CeedElemRestriction rstr, CeedRestrictionType *rstr_type) {
*rstr_type = rstr->rstr_type;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionIsStrided(CeedElemRestriction rstr, bool *is_strided) {
*is_strided = (rstr->rstr_type == CEED_RESTRICTION_STRIDED);
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionIsPoints(CeedElemRestriction rstr, bool *is_points) {
*is_points = (rstr->rstr_type == CEED_RESTRICTION_POINTS);
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetStrides(CeedElemRestriction rstr, CeedInt (*strides)[3]) {
CeedCheck(rstr->strides, rstr->ceed, CEED_ERROR_MINOR, "ElemRestriction has no stride data");
for (CeedInt i = 0; i < 3; i++) (*strides)[i] = rstr->strides[i];
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionHasBackendStrides(CeedElemRestriction rstr, bool *has_backend_strides) {
CeedCheck(rstr->strides, rstr->ceed, CEED_ERROR_MINOR, "ElemRestriction has no stride data");
*has_backend_strides = ((rstr->strides[0] == CEED_STRIDES_BACKEND[0]) && (rstr->strides[1] == CEED_STRIDES_BACKEND[1]) &&
(rstr->strides[2] == CEED_STRIDES_BACKEND[2]));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetOffsets(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
if (rstr->rstr_base) {
CeedCall(CeedElemRestrictionGetOffsets(rstr->rstr_base, mem_type, offsets));
} else {
CeedCheck(rstr->GetOffsets, rstr->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support GetOffsets");
CeedCall(rstr->GetOffsets(rstr, mem_type, offsets));
rstr->num_readers++;
}
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionRestoreOffsets(CeedElemRestriction rstr, const CeedInt **offsets) {
if (rstr->rstr_base) {
CeedCall(CeedElemRestrictionRestoreOffsets(rstr->rstr_base, offsets));
} else {
*offsets = NULL;
rstr->num_readers--;
}
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetOrientations(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) {
CeedCheck(rstr->GetOrientations, rstr->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support GetOrientations");
CeedCall(rstr->GetOrientations(rstr, mem_type, orients));
rstr->num_readers++;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionRestoreOrientations(CeedElemRestriction rstr, const bool **orients) {
*orients = NULL;
rstr->num_readers--;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetCurlOrientations(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) {
CeedCheck(rstr->GetCurlOrientations, rstr->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support GetCurlOrientations");
CeedCall(rstr->GetCurlOrientations(rstr, mem_type, curl_orients));
rstr->num_readers++;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionRestoreCurlOrientations(CeedElemRestriction rstr, const CeedInt8 **curl_orients) {
*curl_orients = NULL;
rstr->num_readers--;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetELayout(CeedElemRestriction rstr, CeedInt (*layout)[3]) {
CeedCheck(rstr->layout[0], rstr->ceed, CEED_ERROR_MINOR, "ElemRestriction has no layout data");
for (CeedInt i = 0; i < 3; i++) (*layout)[i] = rstr->layout[i];
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionSetELayout(CeedElemRestriction rstr, CeedInt layout[3]) {
for (CeedInt i = 0; i < 3; i++) rstr->layout[i] = layout[i];
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetData(CeedElemRestriction rstr, void *data) {
*(void **)data = rstr->data;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionSetData(CeedElemRestriction rstr, void *data) {
rstr->data = data;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionReference(CeedElemRestriction rstr) {
rstr->ref_count++;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetFlopsEstimate(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedSize *flops) {
CeedInt e_size = rstr->num_block * rstr->block_size * rstr->elem_size * rstr->num_comp, scale = 0;
CeedRestrictionType rstr_type;
CeedCall(CeedElemRestrictionGetType(rstr, &rstr_type));
if (rstr_type == CEED_RESTRICTION_POINTS) e_size = rstr->num_points * rstr->num_comp;
if (t_mode == CEED_TRANSPOSE) {
switch (rstr_type) {
case CEED_RESTRICTION_POINTS:
scale = 0;
break;
case CEED_RESTRICTION_STRIDED:
case CEED_RESTRICTION_STANDARD:
scale = 1;
break;
case CEED_RESTRICTION_ORIENTED:
scale = 2;
break;
case CEED_RESTRICTION_CURL_ORIENTED:
scale = 6;
break;
}
} else {
switch (rstr_type) {
case CEED_RESTRICTION_STRIDED:
case CEED_RESTRICTION_STANDARD:
case CEED_RESTRICTION_POINTS:
scale = 0;
break;
case CEED_RESTRICTION_ORIENTED:
scale = 1;
break;
case CEED_RESTRICTION_CURL_ORIENTED:
scale = 5;
break;
}
}
*flops = e_size * scale;
return CEED_ERROR_SUCCESS;
}
static struct CeedElemRestriction_private ceed_elemrestriction_none;
const CeedInt CEED_STRIDES_BACKEND[3] = {0};
const CeedElemRestriction CEED_ELEMRESTRICTION_NONE = &ceed_elemrestriction_none;
int CeedElemRestrictionCreate(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt num_comp, CeedInt comp_stride, CeedSize l_size,
CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, CeedElemRestriction *rstr) {
if (!ceed->ElemRestrictionCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreate");
CeedCall(CeedElemRestrictionCreate(delegate, num_elem, elem_size, num_comp, comp_stride, l_size, mem_type, copy_mode, offsets, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(num_comp == 1 || comp_stride > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction component stride must be at least 1");
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->comp_stride = comp_stride;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_elem * elem_size * num_comp;
(*rstr)->num_block = num_elem;
(*rstr)->block_size = 1;
(*rstr)->rstr_type = CEED_RESTRICTION_STANDARD;
CeedCall(ceed->ElemRestrictionCreate(mem_type, copy_mode, offsets, NULL, NULL, *rstr));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateOriented(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt num_comp, CeedInt comp_stride, CeedSize l_size,
CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const bool *orients,
CeedElemRestriction *rstr) {
if (!ceed->ElemRestrictionCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreate");
CeedCall(
CeedElemRestrictionCreateOriented(delegate, num_elem, elem_size, num_comp, comp_stride, l_size, mem_type, copy_mode, offsets, orients, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(num_comp == 1 || comp_stride > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction component stride must be at least 1");
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->comp_stride = comp_stride;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_elem * elem_size * num_comp;
(*rstr)->num_block = num_elem;
(*rstr)->block_size = 1;
(*rstr)->rstr_type = CEED_RESTRICTION_ORIENTED;
CeedCall(ceed->ElemRestrictionCreate(mem_type, copy_mode, offsets, orients, NULL, *rstr));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateCurlOriented(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt num_comp, CeedInt comp_stride, CeedSize l_size,
CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const CeedInt8 *curl_orients,
CeedElemRestriction *rstr) {
if (!ceed->ElemRestrictionCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreate");
CeedCall(CeedElemRestrictionCreateCurlOriented(delegate, num_elem, elem_size, num_comp, comp_stride, l_size, mem_type, copy_mode, offsets,
curl_orients, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(num_comp == 1 || comp_stride > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction component stride must be at least 1");
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->comp_stride = comp_stride;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_elem * elem_size * num_comp;
(*rstr)->num_block = num_elem;
(*rstr)->block_size = 1;
(*rstr)->rstr_type = CEED_RESTRICTION_CURL_ORIENTED;
CeedCall(ceed->ElemRestrictionCreate(mem_type, copy_mode, offsets, NULL, curl_orients, *rstr));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateStrided(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt num_comp, CeedSize l_size, const CeedInt strides[3],
CeedElemRestriction *rstr) {
if (!ceed->ElemRestrictionCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreate");
CeedCall(CeedElemRestrictionCreateStrided(delegate, num_elem, elem_size, num_comp, l_size, strides, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(l_size >= num_elem * elem_size * num_comp, ceed, CEED_ERROR_DIMENSION, "L-vector size must be at least num_elem * elem_size * num_comp");
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_elem * elem_size * num_comp;
(*rstr)->num_block = num_elem;
(*rstr)->block_size = 1;
(*rstr)->rstr_type = CEED_RESTRICTION_STRIDED;
CeedCall(CeedMalloc(3, &(*rstr)->strides));
for (CeedInt i = 0; i < 3; i++) (*rstr)->strides[i] = strides[i];
CeedCall(ceed->ElemRestrictionCreate(CEED_MEM_HOST, CEED_OWN_POINTER, NULL, NULL, NULL, *rstr));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateAtPoints(Ceed ceed, CeedInt num_elem, CeedInt num_points, CeedInt num_comp, CeedSize l_size, CeedMemType mem_type,
CeedCopyMode copy_mode, const CeedInt *offsets, CeedElemRestriction *rstr) {
if (!ceed->ElemRestrictionCreateAtPoints) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreateAtPoints");
CeedCall(CeedElemRestrictionCreateAtPoints(delegate, num_elem, num_points, num_comp, l_size, mem_type, copy_mode, offsets, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(num_points >= 0, ceed, CEED_ERROR_DIMENSION, "Number of points must be non-negative");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(l_size >= num_points * num_comp, ceed, CEED_ERROR_DIMENSION, "L-vector must be at least num_points * num_comp");
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->num_points = num_points;
(*rstr)->num_comp = num_comp;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_points * num_comp;
(*rstr)->num_block = num_elem;
(*rstr)->block_size = 1;
(*rstr)->rstr_type = CEED_RESTRICTION_POINTS;
CeedCall(ceed->ElemRestrictionCreateAtPoints(mem_type, copy_mode, offsets, NULL, NULL, *rstr));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateBlocked(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt block_size, CeedInt num_comp, CeedInt comp_stride,
CeedSize l_size, CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets,
CeedElemRestriction *rstr) {
CeedInt *block_offsets, num_block = (num_elem / block_size) + !!(num_elem % block_size);
if (!ceed->ElemRestrictionCreateBlocked) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreateBlocked");
CeedCall(CeedElemRestrictionCreateBlocked(delegate, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, mem_type, copy_mode, offsets,
rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(block_size > 0, ceed, CEED_ERROR_DIMENSION, "Block size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(num_comp == 1 || comp_stride > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction component stride must be at least 1");
CeedCall(CeedCalloc(num_block * block_size * elem_size, &block_offsets));
CeedCall(CeedPermutePadOffsets(offsets, block_offsets, num_block, num_elem, block_size, elem_size));
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->comp_stride = comp_stride;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_block * block_size * elem_size * num_comp;
(*rstr)->num_block = num_block;
(*rstr)->block_size = block_size;
(*rstr)->rstr_type = CEED_RESTRICTION_STANDARD;
CeedCall(ceed->ElemRestrictionCreateBlocked(CEED_MEM_HOST, CEED_OWN_POINTER, (const CeedInt *)block_offsets, NULL, NULL, *rstr));
if (copy_mode == CEED_OWN_POINTER) CeedCall(CeedFree(&offsets));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateBlockedOriented(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt block_size, CeedInt num_comp,
CeedInt comp_stride, CeedSize l_size, CeedMemType mem_type, CeedCopyMode copy_mode,
const CeedInt *offsets, const bool *orients, CeedElemRestriction *rstr) {
bool *block_orients;
CeedInt *block_offsets, num_block = (num_elem / block_size) + !!(num_elem % block_size);
if (!ceed->ElemRestrictionCreateBlocked) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreateBlocked");
CeedCall(CeedElemRestrictionCreateBlockedOriented(delegate, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, mem_type, copy_mode,
offsets, orients, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(block_size > 0, ceed, CEED_ERROR_DIMENSION, "Block size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(num_comp == 1 || comp_stride > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction component stride must be at least 1");
CeedCall(CeedCalloc(num_block * block_size * elem_size, &block_offsets));
CeedCall(CeedCalloc(num_block * block_size * elem_size, &block_orients));
CeedCall(CeedPermutePadOffsets(offsets, block_offsets, num_block, num_elem, block_size, elem_size));
CeedCall(CeedPermutePadOrients(orients, block_orients, num_block, num_elem, block_size, elem_size));
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->comp_stride = comp_stride;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_block * block_size * elem_size * num_comp;
(*rstr)->num_block = num_block;
(*rstr)->block_size = block_size;
(*rstr)->rstr_type = CEED_RESTRICTION_ORIENTED;
CeedCall(
ceed->ElemRestrictionCreateBlocked(CEED_MEM_HOST, CEED_OWN_POINTER, (const CeedInt *)block_offsets, (const bool *)block_orients, NULL, *rstr));
if (copy_mode == CEED_OWN_POINTER) CeedCall(CeedFree(&offsets));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateBlockedCurlOriented(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt block_size, CeedInt num_comp,
CeedInt comp_stride, CeedSize l_size, CeedMemType mem_type, CeedCopyMode copy_mode,
const CeedInt *offsets, const CeedInt8 *curl_orients, CeedElemRestriction *rstr) {
CeedInt8 *block_curl_orients;
CeedInt *block_offsets, num_block = (num_elem / block_size) + !!(num_elem % block_size);
if (!ceed->ElemRestrictionCreateBlocked) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreateBlocked");
CeedCall(CeedElemRestrictionCreateBlockedCurlOriented(delegate, num_elem, elem_size, block_size, num_comp, comp_stride, l_size, mem_type,
copy_mode, offsets, curl_orients, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(block_size > 0, ceed, CEED_ERROR_DIMENSION, "Block size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(num_comp == 1 || comp_stride > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction component stride must be at least 1");
CeedCall(CeedCalloc(num_block * block_size * elem_size, &block_offsets));
CeedCall(CeedCalloc(num_block * block_size * 3 * elem_size, &block_curl_orients));
CeedCall(CeedPermutePadOffsets(offsets, block_offsets, num_block, num_elem, block_size, elem_size));
CeedCall(CeedPermutePadCurlOrients(curl_orients, block_curl_orients, num_block, num_elem, block_size, 3 * elem_size));
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->comp_stride = comp_stride;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_block * block_size * elem_size * num_comp;
(*rstr)->num_block = num_block;
(*rstr)->block_size = block_size;
(*rstr)->rstr_type = CEED_RESTRICTION_CURL_ORIENTED;
CeedCall(ceed->ElemRestrictionCreateBlocked(CEED_MEM_HOST, CEED_OWN_POINTER, (const CeedInt *)block_offsets, NULL,
(const CeedInt8 *)block_curl_orients, *rstr));
if (copy_mode == CEED_OWN_POINTER) CeedCall(CeedFree(&offsets));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateBlockedStrided(Ceed ceed, CeedInt num_elem, CeedInt elem_size, CeedInt block_size, CeedInt num_comp, CeedSize l_size,
const CeedInt strides[3], CeedElemRestriction *rstr) {
CeedInt num_block = (num_elem / block_size) + !!(num_elem % block_size);
if (!ceed->ElemRestrictionCreateBlocked) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "ElemRestriction"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionCreateBlocked");
CeedCall(CeedElemRestrictionCreateBlockedStrided(delegate, num_elem, elem_size, block_size, num_comp, l_size, strides, rstr));
return CEED_ERROR_SUCCESS;
}
CeedCheck(num_elem >= 0, ceed, CEED_ERROR_DIMENSION, "Number of elements must be non-negative");
CeedCheck(elem_size > 0, ceed, CEED_ERROR_DIMENSION, "Element size must be at least 1");
CeedCheck(block_size > 0, ceed, CEED_ERROR_DIMENSION, "Block size must be at least 1");
CeedCheck(num_comp > 0, ceed, CEED_ERROR_DIMENSION, "ElemRestriction must have at least 1 component");
CeedCheck(l_size >= num_elem * elem_size * num_comp, ceed, CEED_ERROR_DIMENSION, "L-vector size must be at least num_elem * elem_size * num_comp");
CeedCall(CeedCalloc(1, rstr));
CeedCall(CeedReferenceCopy(ceed, &(*rstr)->ceed));
(*rstr)->ref_count = 1;
(*rstr)->num_elem = num_elem;
(*rstr)->elem_size = elem_size;
(*rstr)->num_comp = num_comp;
(*rstr)->l_size = l_size;
(*rstr)->e_size = num_block * block_size * elem_size * num_comp;
(*rstr)->num_block = num_block;
(*rstr)->block_size = block_size;
(*rstr)->rstr_type = CEED_RESTRICTION_STRIDED;
CeedCall(CeedMalloc(3, &(*rstr)->strides));
for (CeedInt i = 0; i < 3; i++) (*rstr)->strides[i] = strides[i];
CeedCall(ceed->ElemRestrictionCreateBlocked(CEED_MEM_HOST, CEED_OWN_POINTER, NULL, NULL, NULL, *rstr));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateUnsignedCopy(CeedElemRestriction rstr, CeedElemRestriction *rstr_unsigned) {
CeedCall(CeedCalloc(1, rstr_unsigned));
memcpy(*rstr_unsigned, rstr, sizeof(struct CeedElemRestriction_private));
(*rstr_unsigned)->ceed = NULL;
CeedCall(CeedReferenceCopy(rstr->ceed, &(*rstr_unsigned)->ceed));
(*rstr_unsigned)->ref_count = 1;
(*rstr_unsigned)->strides = NULL;
if (rstr->strides) {
CeedCall(CeedMalloc(3, &(*rstr_unsigned)->strides));
for (CeedInt i = 0; i < 3; i++) (*rstr_unsigned)->strides[i] = rstr->strides[i];
}
CeedCall(CeedElemRestrictionReferenceCopy(rstr, &(*rstr_unsigned)->rstr_base));
(*rstr_unsigned)->Apply = rstr->ApplyUnsigned;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateUnorientedCopy(CeedElemRestriction rstr, CeedElemRestriction *rstr_unoriented) {
CeedCall(CeedCalloc(1, rstr_unoriented));
memcpy(*rstr_unoriented, rstr, sizeof(struct CeedElemRestriction_private));
(*rstr_unoriented)->ceed = NULL;
CeedCall(CeedReferenceCopy(rstr->ceed, &(*rstr_unoriented)->ceed));
(*rstr_unoriented)->ref_count = 1;
(*rstr_unoriented)->strides = NULL;
if (rstr->strides) {
CeedCall(CeedMalloc(3, &(*rstr_unoriented)->strides));
for (CeedInt i = 0; i < 3; i++) (*rstr_unoriented)->strides[i] = rstr->strides[i];
}
CeedCall(CeedElemRestrictionReferenceCopy(rstr, &(*rstr_unoriented)->rstr_base));
(*rstr_unoriented)->Apply = rstr->ApplyUnoriented;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionReferenceCopy(CeedElemRestriction rstr, CeedElemRestriction *rstr_copy) {
if (rstr != CEED_ELEMRESTRICTION_NONE) CeedCall(CeedElemRestrictionReference(rstr));
CeedCall(CeedElemRestrictionDestroy(rstr_copy));
*rstr_copy = rstr;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionCreateVector(CeedElemRestriction rstr, CeedVector *l_vec, CeedVector *e_vec) {
CeedSize e_size, l_size;
l_size = rstr->l_size;
e_size = rstr->num_block * rstr->block_size * rstr->elem_size * rstr->num_comp;
if (l_vec) CeedCall(CeedVectorCreate(rstr->ceed, l_size, l_vec));
if (e_vec) CeedCall(CeedVectorCreate(rstr->ceed, e_size, e_vec));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionApply(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector ru, CeedRequest *request) {
CeedInt m, n;
if (t_mode == CEED_NOTRANSPOSE) {
m = rstr->e_size;
n = rstr->l_size;
} else {
m = rstr->l_size;
n = rstr->e_size;
}
CeedCheck(n <= u->length, rstr->ceed, CEED_ERROR_DIMENSION,
"Input vector size %" CeedInt_FMT " not compatible with element restriction (%" CeedInt_FMT ", %" CeedInt_FMT ")", u->length, m, n);
CeedCheck(m <= ru->length, rstr->ceed, CEED_ERROR_DIMENSION,
"Output vector size %" CeedInt_FMT " not compatible with element restriction (%" CeedInt_FMT ", %" CeedInt_FMT ")", ru->length, m, n);
if (rstr->num_elem > 0) CeedCall(rstr->Apply(rstr, t_mode, u, ru, request));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionApplyAtPointsInElement(CeedElemRestriction rstr, CeedInt elem, CeedTransposeMode t_mode, CeedVector u, CeedVector ru,
CeedRequest *request) {
CeedInt m, n;
if (t_mode == CEED_NOTRANSPOSE) {
CeedCall(CeedElemRestrictionGetNumPointsInElement(rstr, elem, &m));
n = rstr->l_size;
} else {
m = rstr->l_size;
CeedCall(CeedElemRestrictionGetNumPointsInElement(rstr, elem, &n));
}
CeedCheck(n <= u->length, rstr->ceed, CEED_ERROR_DIMENSION,
"Input vector size %" CeedInt_FMT " not compatible with element restriction (%" CeedInt_FMT ", %" CeedInt_FMT
") for element %" CeedInt_FMT,
u->length, m, n, elem);
CeedCheck(m <= ru->length, rstr->ceed, CEED_ERROR_DIMENSION,
"Output vector size %" CeedInt_FMT " not compatible with element restriction (%" CeedInt_FMT ", %" CeedInt_FMT
") for element %" CeedInt_FMT,
ru->length, m, n, elem);
CeedCheck(elem < rstr->num_elem, rstr->ceed, CEED_ERROR_DIMENSION,
"Cannot retrieve element %" CeedInt_FMT ", element %" CeedInt_FMT " > total elements %" CeedInt_FMT "", elem, elem, rstr->num_elem);
if (rstr->num_elem > 0) CeedCall(rstr->ApplyAtPointsInElement(rstr, elem, t_mode, u, ru, request));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionApplyBlock(CeedElemRestriction rstr, CeedInt block, CeedTransposeMode t_mode, CeedVector u, CeedVector ru,
CeedRequest *request) {
CeedInt m, n;
CeedCheck(rstr->ApplyBlock, rstr->ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement ElemRestrictionApplyBlock");
if (t_mode == CEED_NOTRANSPOSE) {
m = rstr->block_size * rstr->elem_size * rstr->num_comp;
n = rstr->l_size;
} else {
m = rstr->l_size;
n = rstr->block_size * rstr->elem_size * rstr->num_comp;
}
CeedCheck(n == u->length, rstr->ceed, CEED_ERROR_DIMENSION,
"Input vector size %" CeedInt_FMT " not compatible with element restriction (%" CeedInt_FMT ", %" CeedInt_FMT ")", u->length, m, n);
CeedCheck(m == ru->length, rstr->ceed, CEED_ERROR_DIMENSION,
"Output vector size %" CeedInt_FMT " not compatible with element restriction (%" CeedInt_FMT ", %" CeedInt_FMT ")", ru->length, m, n);
CeedCheck(rstr->block_size * block <= rstr->num_elem, rstr->ceed, CEED_ERROR_DIMENSION,
"Cannot retrieve block %" CeedInt_FMT ", element %" CeedInt_FMT " > total elements %" CeedInt_FMT "", block, rstr->block_size * block,
rstr->num_elem);
CeedCall(rstr->ApplyBlock(rstr, block, t_mode, u, ru, request));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetCeed(CeedElemRestriction rstr, Ceed *ceed) {
*ceed = rstr->ceed;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetCompStride(CeedElemRestriction rstr, CeedInt *comp_stride) {
*comp_stride = rstr->comp_stride;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetNumElements(CeedElemRestriction rstr, CeedInt *num_elem) {
*num_elem = rstr->num_elem;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetElementSize(CeedElemRestriction rstr, CeedInt *elem_size) {
*elem_size = rstr->elem_size;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetNumPoints(CeedElemRestriction rstr, CeedInt *num_points) {
Ceed ceed;
CeedCall(CeedElemRestrictionGetCeed(rstr, &ceed));
CeedCheck(rstr->rstr_type == CEED_RESTRICTION_POINTS, ceed, CEED_ERROR_INCOMPATIBLE,
"Can only retrieve the number of points for a points CeedElemRestriction");
*num_points = rstr->num_points;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetNumPointsInElement(CeedElemRestriction rstr, CeedInt elem, CeedInt *num_points) {
Ceed ceed;
const CeedInt *offsets;
CeedCall(CeedElemRestrictionGetCeed(rstr, &ceed));
CeedCheck(rstr->rstr_type == CEED_RESTRICTION_POINTS, ceed, CEED_ERROR_INCOMPATIBLE,
"Can only retrieve the number of points for a points CeedElemRestriction");
CeedCall(CeedElemRestrictionGetOffsets(rstr, CEED_MEM_HOST, &offsets));
*num_points = offsets[elem + 1] - offsets[elem];
CeedCall(CeedElemRestrictionRestoreOffsets(rstr, &offsets));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetMaxPointsInElement(CeedElemRestriction rstr, CeedInt *max_points) {
Ceed ceed;
CeedInt num_elem;
CeedRestrictionType rstr_type;
CeedCall(CeedElemRestrictionGetCeed(rstr, &ceed));
CeedCall(CeedElemRestrictionGetType(rstr, &rstr_type));
CeedCheck(rstr_type == CEED_RESTRICTION_POINTS, ceed, CEED_ERROR_INCOMPATIBLE,
"Cannot compute max points for a CeedElemRestriction that does not use points");
CeedCall(CeedElemRestrictionGetNumElements(rstr, &num_elem));
*max_points = 0;
for (CeedInt e = 0; e < num_elem; e++) {
CeedInt num_points;
CeedCall(CeedElemRestrictionGetNumPointsInElement(rstr, e, &num_points));
*max_points = CeedIntMax(num_points, *max_points);
}
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetLVectorSize(CeedElemRestriction rstr, CeedSize *l_size) {
*l_size = rstr->l_size;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetNumComponents(CeedElemRestriction rstr, CeedInt *num_comp) {
*num_comp = rstr->num_comp;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetNumBlocks(CeedElemRestriction rstr, CeedInt *num_block) {
*num_block = rstr->num_block;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetBlockSize(CeedElemRestriction rstr, CeedInt *block_size) {
*block_size = rstr->block_size;
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionGetMultiplicity(CeedElemRestriction rstr, CeedVector mult) {
CeedVector e_vec;
CeedCall(CeedElemRestrictionCreateVector(rstr, NULL, &e_vec));
CeedCall(CeedVectorSetValue(mult, 1.0));
CeedCall(CeedElemRestrictionApply(rstr, CEED_NOTRANSPOSE, mult, e_vec, CEED_REQUEST_IMMEDIATE));
CeedCall(CeedVectorSetValue(mult, 0.0));
CeedCall(CeedElemRestrictionApply(rstr, CEED_TRANSPOSE, e_vec, mult, CEED_REQUEST_IMMEDIATE));
CeedCall(CeedVectorDestroy(&e_vec));
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionView(CeedElemRestriction rstr, FILE *stream) {
CeedRestrictionType rstr_type;
CeedCall(CeedElemRestrictionGetType(rstr, &rstr_type));
if (rstr_type == CEED_RESTRICTION_POINTS) {
CeedInt max_points;
CeedCall(CeedElemRestrictionGetMaxPointsInElement(rstr, &max_points));
fprintf(stream,
"CeedElemRestriction at points from (%td, %" CeedInt_FMT ") to %" CeedInt_FMT " elements with a maximum of %" CeedInt_FMT
" points on an element\n",
rstr->l_size, rstr->num_comp, rstr->num_elem, max_points);
} else {
char stridesstr[500];
if (rstr->strides) {
sprintf(stridesstr, "[%" CeedInt_FMT ", %" CeedInt_FMT ", %" CeedInt_FMT "]", rstr->strides[0], rstr->strides[1], rstr->strides[2]);
} else {
sprintf(stridesstr, "%" CeedInt_FMT, rstr->comp_stride);
}
fprintf(stream, "%sCeedElemRestriction from (%td, %" CeedInt_FMT ") to %" CeedInt_FMT " elements with %" CeedInt_FMT " nodes each and %s %s\n",
rstr->block_size > 1 ? "Blocked " : "", rstr->l_size, rstr->num_comp, rstr->num_elem, rstr->elem_size,
rstr->strides ? "strides" : "component stride", stridesstr);
}
return CEED_ERROR_SUCCESS;
}
int CeedElemRestrictionDestroy(CeedElemRestriction *rstr) {
if (!*rstr || *rstr == CEED_ELEMRESTRICTION_NONE || --(*rstr)->ref_count > 0) {
*rstr = NULL;
return CEED_ERROR_SUCCESS;
}
CeedCheck((*rstr)->num_readers == 0, (*rstr)->ceed, CEED_ERROR_ACCESS,
"Cannot destroy CeedElemRestriction, a process has read access to the offset data");
if ((*rstr)->rstr_base) CeedCall(CeedElemRestrictionDestroy(&(*rstr)->rstr_base));
else if ((*rstr)->Destroy) CeedCall((*rstr)->Destroy(*rstr));
CeedCall(CeedFree(&(*rstr)->strides));
CeedCall(CeedDestroy(&(*rstr)->ceed));
CeedCall(CeedFree(rstr));
return CEED_ERROR_SUCCESS;
}