#include <ceed-impl.h>
#include <ceed.h>
#include <ceed/backend.h>
#include <stddef.h>
int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract) {
if (!ceed->TensorContractCreate) {
Ceed delegate;
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "TensorContract"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not support TensorContractCreate");
CeedCall(CeedTensorContractCreate(delegate, contract));
return CEED_ERROR_SUCCESS;
}
CeedCall(CeedCalloc(1, contract));
CeedCall(CeedReferenceCopy(ceed, &(*contract)->ceed));
CeedCall(ceed->TensorContractCreate(*contract));
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
CeedCall(contract->Apply(contract, A, B, C, J, t, t_mode, add, u, v));
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractStridedApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt D, CeedInt J, const CeedScalar *restrict t,
CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
if (t_mode == CEED_TRANSPOSE) {
for (CeedInt d = 0; d < D; d++) {
CeedCall(contract->Apply(contract, A, J, C, B, t + d * B * J, t_mode, add, u + d * A * J * C, v));
}
} else {
for (CeedInt d = 0; d < D; d++) {
CeedCall(contract->Apply(contract, A, B, C, J, t + d * B * J, t_mode, add, u, v + d * A * J * C));
}
}
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractGetCeed(CeedTensorContract contract, Ceed *ceed) {
*ceed = contract->ceed;
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractGetData(CeedTensorContract contract, void *data) {
*(void **)data = contract->data;
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractSetData(CeedTensorContract contract, void *data) {
contract->data = data;
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractReference(CeedTensorContract contract) {
contract->ref_count++;
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractReferenceCopy(CeedTensorContract tensor, CeedTensorContract *tensor_copy) {
CeedCall(CeedTensorContractReference(tensor));
CeedCall(CeedTensorContractDestroy(tensor_copy));
*tensor_copy = tensor;
return CEED_ERROR_SUCCESS;
}
int CeedTensorContractDestroy(CeedTensorContract *contract) {
if (!*contract || --(*contract)->ref_count > 0) {
*contract = NULL;
return CEED_ERROR_SUCCESS;
}
if ((*contract)->Destroy) {
CeedCall((*contract)->Destroy(*contract));
}
CeedCall(CeedDestroy(&(*contract)->ceed));
CeedCall(CeedFree(contract));
return CEED_ERROR_SUCCESS;
}