#include <inttypes.h>
#include <stdio.h>
#ifdef USE_INTEL_MKL
#include "mkl.h"
#define COMPLEX64 MKL_Complex16
#define FN_DGESV dgesv_
#define FN_ZGESV zgesv_
#define FN_DLANGE dlange_
#define FN_ZLANGE zlange_
#define FN_DPOTRF dpotrf_
#define FN_ZPOTRF zpotrf_
#define FN_DSYEV dsyev_
#define FN_DGEEV dgeev_
#define FN_ZGEEV zgeev_
#define FN_ZHEEV zheev_
#define FN_DGGEV dggev_
#define FN_ZGGEV zggev_
#define FN_DGESVD dgesvd_
#define FN_ZGESVD zgesvd_
#define FN_DGETRF dgetrf_
#define FN_DGETRI dgetri_
#define FN_ZGETRF zgetrf_
#define FN_ZGETRI zgetri_
#else
#include "cblas.h"
#include "lapack.h"
#define COMPLEX64 lapack_complex_double
#define FN_DGESV LAPACK_dgesv
#define FN_ZGESV LAPACK_zgesv
#define FN_DLANGE LAPACK_dlange
#define FN_ZLANGE LAPACK_zlange
#define FN_DPOTRF LAPACK_dpotrf
#define FN_ZPOTRF LAPACK_zpotrf
#define FN_DSYEV LAPACK_dsyev
#define FN_DGEEV LAPACK_dgeev
#define FN_ZGEEV LAPACK_zgeev
#define FN_ZHEEV LAPACK_zheev
#define FN_DGGEV LAPACK_dggev
#define FN_ZGGEV LAPACK_zggev
#define FN_DGESVD LAPACK_dgesvd
#define FN_ZGESVD LAPACK_zgesvd
#define FN_DGETRF LAPACK_dgetrf
#define FN_DGETRI LAPACK_dgetri
#define FN_ZGETRF LAPACK_zgetrf
#define FN_ZGETRI LAPACK_zgetri
#endif
#include "constants.h"
int32_t c_using_intel_mkl() {
#ifdef USE_INTEL_MKL
return 1;
#else
return 0;
#endif
}
void c_set_num_threads(int32_t n) {
#ifdef USE_INTEL_MKL
MKL_Set_Num_Threads(n);
#else
openblas_set_num_threads(n);
#endif
}
int32_t c_get_num_threads() {
#ifdef USE_INTEL_MKL
return mkl_get_max_threads();
#else
return openblas_get_num_threads();
#endif
}
void c_dgesv(const int32_t *n,
const int32_t *nrhs,
double *a,
const int32_t *lda,
int32_t *ipiv,
double *b,
const int32_t *ldb,
int32_t *info) {
FN_DGESV(n, nrhs, a, lda, ipiv, b, ldb, info);
}
void c_zgesv(const int32_t *n,
const int32_t *nrhs,
COMPLEX64 *a,
const int32_t *lda,
int32_t *ipiv,
COMPLEX64 *b,
const int32_t *ldb,
int32_t *info) {
FN_ZGESV(n, nrhs, a, lda, ipiv, b, ldb, info);
}
double c_dlange(int32_t norm_code,
const int32_t *m,
const int32_t *n,
const double *a,
const int32_t *lda,
double *work) {
const char *norm = norm_code == NORM_EUC || norm_code == NORM_FRO ? "F"
: norm_code == NORM_INF ? "I"
: norm_code == NORM_MAX ? "M"
: "O";
return FN_DLANGE(norm, m, n, a, lda, work);
}
double c_zlange(int32_t norm_code,
const int32_t *m,
const int32_t *n,
const COMPLEX64 *a,
const int32_t *lda,
double *work) {
const char *norm = norm_code == NORM_EUC || norm_code == NORM_FRO ? "F"
: norm_code == NORM_INF ? "I"
: norm_code == NORM_MAX ? "M"
: "O";
return FN_ZLANGE(norm, m, n, a, lda, work);
}
void c_dpotrf(C_BOOL upper,
const int32_t *n,
double *a,
const int32_t *lda,
int32_t *info) {
const char *uplo = upper == C_TRUE ? "U" : "L";
FN_DPOTRF(uplo, n, a, lda, info);
}
void c_zpotrf(C_BOOL upper,
const int32_t *n,
COMPLEX64 *a,
const int32_t *lda,
int32_t *info) {
const char *uplo = upper == C_TRUE ? "U" : "L";
FN_ZPOTRF(uplo, n, a, lda, info);
}
void c_dsyev(C_BOOL calc_v,
C_BOOL upper,
const int32_t *n,
double *a,
const int32_t *lda,
double *w,
double *work,
const int32_t *lwork,
int32_t *info) {
const char *jobz = calc_v == C_TRUE ? "V" : "N";
const char *uplo = upper == C_TRUE ? "U" : "L";
FN_DSYEV(jobz, uplo, n, a, lda, w, work, lwork, info);
}
void c_dgeev(C_BOOL calc_vl,
C_BOOL calc_vr,
const int32_t *n,
double *a,
const int32_t *lda,
double *wr,
double *wi,
double *vl,
const int32_t *ldvl,
double *vr,
const int32_t *ldvr,
double *work,
const int32_t *lwork,
int32_t *info) {
const char *jobvl = calc_vl == C_TRUE ? "V" : "N";
const char *jobvr = calc_vr == C_TRUE ? "V" : "N";
FN_DGEEV(jobvl, jobvr, n, a, lda, wr, wi, vl, ldvl, vr, ldvr, work, lwork, info);
}
void c_zgeev(
C_BOOL calc_vl,
C_BOOL calc_vr,
const int32_t *n,
COMPLEX64 *a,
const int32_t *lda,
COMPLEX64 *w,
COMPLEX64 *vl,
const int32_t *ldvl,
COMPLEX64 *vr,
const int32_t *ldvr,
COMPLEX64 *work,
const int32_t *lwork,
double *rwork,
int32_t *info) {
const char *jobvl = calc_vl == C_TRUE ? "V" : "N";
const char *jobvr = calc_vr == C_TRUE ? "V" : "N";
FN_ZGEEV(jobvl, jobvr, n, a, lda, w, vl, ldvl, vr, ldvr, work, lwork, rwork, info);
}
void c_zheev(
C_BOOL calc_v,
C_BOOL upper,
int32_t const *n,
COMPLEX64 *a,
int32_t const *lda,
double *w,
COMPLEX64 *work,
int32_t const *lwork,
double *rwork,
int32_t *info) {
const char *jobz = calc_v == C_TRUE ? "V" : "N";
const char *uplo = upper == C_TRUE ? "U" : "L";
FN_ZHEEV(jobz, uplo, n, a, lda, w, work, lwork, rwork, info);
}
void c_dggev(
C_BOOL calc_vl,
C_BOOL calc_vr,
const int32_t *n,
double *a,
const int32_t *lda,
double *b,
const int32_t *ldb,
double *alphar,
double *alphai,
double *beta,
double *vl,
const int32_t *ldvl,
double *vr,
const int32_t *ldvr,
double *work,
const int32_t *lwork,
int32_t *info) {
const char *jobvl = calc_vl == C_TRUE ? "V" : "N";
const char *jobvr = calc_vr == C_TRUE ? "V" : "N";
FN_DGGEV(jobvl, jobvr, n, a, lda, b, ldb, alphar, alphai, beta, vl, ldvl, vr, ldvr, work, lwork, info);
}
void c_zggev(
C_BOOL calc_vl,
C_BOOL calc_vr,
const int32_t *n,
COMPLEX64 *a,
const int32_t *lda,
COMPLEX64 *b,
const int32_t *ldb,
COMPLEX64 *alpha,
COMPLEX64 *beta,
COMPLEX64 *vl,
const int32_t *ldvl,
COMPLEX64 *vr,
const int32_t *ldvr,
COMPLEX64 *work,
const int32_t *lwork,
double *rwork,
int32_t *info) {
const char *jobvl = calc_vl == C_TRUE ? "V" : "N";
const char *jobvr = calc_vr == C_TRUE ? "V" : "N";
FN_ZGGEV(jobvl, jobvr, n, a, lda, b, ldb, alpha, beta, vl, ldvl, vr, ldvr, work, lwork, rwork, info);
}
void c_dgesvd(int32_t jobu_code,
int32_t jobvt_code,
const int32_t *m,
const int32_t *n,
double *a,
const int32_t *lda,
double *s,
double *u,
const int32_t *ldu,
double *vt,
const int32_t *ldvt,
double *work,
const int32_t *lwork,
int32_t *info) {
const char *jobu = jobu_code == SVD_CODE_A ? "A"
: jobu_code == SVD_CODE_S ? "S"
: jobu_code == SVD_CODE_O ? "O"
: "N";
const char *jobvt = jobvt_code == SVD_CODE_A ? "A"
: jobvt_code == SVD_CODE_S ? "S"
: jobvt_code == SVD_CODE_O ? "O"
: "N";
FN_DGESVD(jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, info);
}
void c_zgesvd(int32_t jobu_code,
int32_t jobvt_code,
const int32_t *m,
const int32_t *n,
COMPLEX64 *a,
const int32_t *lda,
double *s,
COMPLEX64 *u,
const int32_t *ldu,
COMPLEX64 *vh,
const int32_t *ldvt,
COMPLEX64 *work,
const int32_t *lwork,
double *rwork,
int32_t *info) {
const char *jobu = jobu_code == SVD_CODE_A ? "A"
: jobu_code == SVD_CODE_S ? "S"
: jobu_code == SVD_CODE_O ? "O"
: "N";
const char *jobvt = jobvt_code == SVD_CODE_A ? "A"
: jobvt_code == SVD_CODE_S ? "S"
: jobvt_code == SVD_CODE_O ? "O"
: "N";
FN_ZGESVD(jobu, jobvt, m, n, a, lda, s, u, ldu, vh, ldvt, work, lwork, rwork, info);
}
void c_dgetrf(const int32_t *m,
const int32_t *n,
double *a,
const int32_t *lda,
int32_t *ipiv,
int32_t *info) {
FN_DGETRF(m, n, a, lda, ipiv, info);
}
void c_dgetri(const int32_t *n,
double *a,
const int32_t *lda,
const int32_t *ipiv,
double *work,
const int32_t *lwork,
int32_t *info) {
FN_DGETRI(n, a, lda, ipiv, work, lwork, info);
}
void c_zgetrf(const int32_t *m,
const int32_t *n,
COMPLEX64 *a,
const int32_t *lda,
int32_t *ipiv,
int32_t *info) {
FN_ZGETRF(m, n, a, lda, ipiv, info);
}
void c_zgetri(const int32_t *n,
COMPLEX64 *a,
const int32_t *lda,
const int32_t *ipiv,
COMPLEX64 *work,
const int32_t *lwork,
int32_t *info) {
FN_ZGETRI(n, a, lda, ipiv, work, lwork, info);
}