#include "../blasinfra.h"
int enzyme_dup;
int enzyme_out;
int enzyme_const;
template <typename... T> void __enzyme_autodiff(void *, T...);
void my_dscal_v2(cublasHandle_t *handle, int N, double alpha,
double *__restrict__ X, int incx) {
cublasDscal_v2(handle, N, &alpha, X, incx);
inDerivative = true;
}
void my_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
double alpha, double *__restrict__ A, int lda,
double *__restrict__ X, int incx, double beta,
double *__restrict__ Y, int incy) {
cublasDgemv(handle, trans, M, N, &alpha, A, lda, X, incx, &beta, Y, incy);
inDerivative = true;
}
void ow_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
double alpha, double *A, int lda, double *X, int incx,
double beta, double *Y, int incy) {
cublasDgemv(handle, trans, M, N, &alpha, A, lda, X, incx, &beta, Y, incy);
inDerivative = true;
}
double my_ddot(cublasHandle_t *handle, int N, double *__restrict__ X, int incx,
double *__restrict__ Y, int incy) {
double res = cublasDdot(handle, N, X, incx, Y, incy);
inDerivative = true;
return res;
}
double my_ddot2(cublasHandle_t *handle, int N, double *__restrict__ X, int incx,
double *__restrict__ Y, int incy) {
double res = 0.0;
cublasDdot_v2(handle, N, X, incx, Y, incy, &res);
inDerivative = true;
return res;
}
void my_dgemm(cublasHandle_t *handle, cublasOperation_t transA,
cublasOperation_t transB, int M, int N, int K, double alpha,
double *__restrict__ A, int lda, double *__restrict__ B, int ldb,
double beta, double *__restrict__ C, int ldc) {
cublasDgemm(handle, transA, transB, M, N, K, &alpha, A, lda, B, ldb, &beta, C,
ldc);
inDerivative = true;
}
static void scal2Tests() {
std::string Test = "SCAL2 active both ";
cublasHandle_t *handle = DEFAULT_CUBLAS_HANDLE;
BlasInfo inputs[6] = {
BlasInfo(A, N, incA),
BlasInfo(),
BlasInfo(),
BlasInfo(),
BlasInfo(),
BlasInfo(),
};
init();
double alpha = 3.14;
my_dscal_v2(handle, N, alpha, A, incA);
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void *)my_dscal_v2, enzyme_const, handle, enzyme_const, N,
enzyme_out, alpha, enzyme_dup, A, dA, enzyme_const, incA);
foundCalls = calls;
init();
my_dscal_v2(handle, N, alpha, A, incA);
inDerivative = true;
double *dalpha = (double *)foundCalls[1].pout_arg1;
inputs[3] = BlasInfo(dalpha, 1, 1);
cublasDdot_v2(handle, N, A, incA, dA, incA, dalpha);
cublasDscal_v2(handle, N, &alpha, dA, incA);
checkTest(Test);
checkMemoryTrace(inputs, "Expected " + Test, calls);
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
static void dotTests() {
std::string Test = "DOT active both ";
cublasHandle_t *handle = DEFAULT_CUBLAS_HANDLE;
BlasInfo inputs[6] = {
BlasInfo(A, N, incA),
BlasInfo(B, N, incB),
BlasInfo(C, M, incC),
BlasInfo(),
BlasInfo(),
BlasInfo(),
};
init();
my_ddot(handle, N, A, incA, B, incB);
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void *)my_ddot, enzyme_const, handle, enzyme_const, N,
enzyme_dup, A, dA, enzyme_const, incA, enzyme_dup, B, dB,
enzyme_const, incB);
foundCalls = calls;
init();
my_ddot(handle, N, A, incA, B, incB);
inDerivative = true;
cublasDaxpy(handle, N, 1.0, B, incB, dA, incA);
cublasDaxpy(handle, N, 1.0, A, incA, dB, incB);
checkTest(Test);
checkMemoryTrace(inputs, "Expected " + Test, calls);
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
static void dot2Tests() {
std::string Test = "DOTv2 active both ";
cublasHandle_t *handle = DEFAULT_CUBLAS_HANDLE;
BlasInfo inputs[6] = {
BlasInfo(A, N, incA),
BlasInfo(B, N, incB),
BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
my_ddot2(handle, N, A, incA, B, incB);
{
auto primal_stack_ret = (double *)calls[0].pout_arg1;
inputs[3] = BlasInfo(primal_stack_ret, 1, 1);
}
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void *)my_ddot2, enzyme_const, handle, enzyme_const, N,
enzyme_dup, A, dA, enzyme_const, incA, enzyme_dup, B, dB,
enzyme_const, incB);
{
auto primal_stack_ret = (double *)calls[0].pout_arg1;
inputs[3] = BlasInfo(primal_stack_ret, 1, 1);
}
foundCalls = calls;
auto stack_ret = (double*)foundCalls[1].pin_arg2;
inputs[4] = BlasInfo(stack_ret, 1, 1);
init();
my_ddot2(handle, N, A, incA, B, incB);
calls[0].pout_arg1 = (double*)foundCalls[0].pout_arg1;
inDerivative = true;
cublasDaxpy_v2(handle, N, stack_ret, B, incB, dA, incA);
cublasDaxpy_v2(handle, N, stack_ret, A, incA, dB, incB);
cudaMemset(stack_ret, 0, sizeof(double));
checkTest(Test);
checkMemoryTrace(inputs, "Expected " + Test, calls);
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
static void gemvTests() {
for (cublasOperation_t transA :
{cublasOperation_t::CUBLAS_OP_N, cublasOperation_t::CUBLAS_OP_T}) {
{
bool trans = !is_normal(transA);
auto handle = DEFAULT_CUBLAS_HANDLE;
std::string Test = "GEMV active A, C ";
BlasInfo inputs[6] = { BlasInfo(A, CUBLAS_LAYOUT, M, N, lda),
BlasInfo(B, trans ? M : N, incB),
BlasInfo(C, trans ? N : M, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()};
init();
my_dgemv(handle, transA, M, N, alpha, A, lda, B, incB, beta, C, incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::GEMV);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].handle == DEFAULT_CUBLAS_HANDLE);
assert(calls[0].targ1 == (char)transA);
assert(calls[0].targ2 == (char)UNUSED_TRANS);
assert(calls[0].iarg1 == M);
assert(calls[0].iarg2 == N);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void *)my_dgemv, enzyme_const, handle, enzyme_const,
transA, enzyme_const, M, enzyme_const, N, enzyme_const,
alpha, enzyme_dup, A, dA, enzyme_const, lda,
enzyme_const, B, enzyme_const, incB, enzyme_const, beta,
enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
my_dgemv(handle, transA, M, N, alpha, A, lda, B, incB, beta, C, incC);
inDerivative = true;
cublasDger(handle, M, N, &alpha, trans ? B : dC, trans ? incB : incC,
trans ? dC : B, trans ? incC : incB, dA, lda);
cublasDscal(handle, trans ? N : M, &beta, dC, incC);
checkTest(Test);
checkMemoryTrace(inputs, "Expected " + Test, calls);
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "GEMV active A, B, C ";
init();
__enzyme_autodiff((void *)my_dgemv, enzyme_const, handle, enzyme_const,
transA, enzyme_const, M, enzyme_const, N, enzyme_const,
alpha, enzyme_dup, A, dA, enzyme_const, lda, enzyme_dup,
B, dB, enzyme_const, incB, enzyme_const, beta,
enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
my_dgemv(handle, transA, M, N, alpha, A, lda, B, incB, beta, C, incC);
inDerivative = true;
cublasDger(handle, M, N, &alpha, trans ? B : dC, trans ? incB : incC,
trans ? dC : B, trans ? incC : incB, dA, lda);
double c1 = 1.0;
cublasDgemv(handle, transpose(transA), M, N, &alpha, A, lda, dC, incC,
&c1, dB, incB);
cublasDscal(handle, trans ? N : M, &beta, dC, incC);
checkTest(Test);
checkMemoryTrace(inputs, "Expected " + Test, calls);
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
static void gemmTests() {
auto handle = DEFAULT_CUBLAS_HANDLE;
for (auto transA :
{cublasOperation_t::CUBLAS_OP_N, cublasOperation_t::CUBLAS_OP_T}) {
for (auto transB :
{cublasOperation_t::CUBLAS_OP_N, cublasOperation_t::CUBLAS_OP_T}) {
{
bool transA_bool = !is_normal(transA);
bool transB_bool = !is_normal(transB);
std::string Test = "GEMM";
BlasInfo inputs[6] = {
BlasInfo(A, CUBLAS_LAYOUT, transA_bool ? K : M, transA_bool ? M : K,
lda),
BlasInfo(B, CUBLAS_LAYOUT, transB_bool ? N : K, transB_bool ? K : N, incB),
BlasInfo(C, CUBLAS_LAYOUT, M, N, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()};
init();
my_dgemm(handle, transA, transB, M, N, K, alpha, A, lda, B, incB, beta,
C, incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::GEMM);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].handle == DEFAULT_CUBLAS_HANDLE);
assert(calls[0].targ1 == (char)transA);
assert(calls[0].targ2 == (char)transB);
assert(calls[0].iarg1 == M);
assert(calls[0].iarg2 == N);
assert(calls[0].iarg3 == K);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void *)my_dgemm, enzyme_const, handle, enzyme_const,
transA, enzyme_const, transB, enzyme_const, M,
enzyme_const, N, enzyme_const, K, enzyme_const, alpha,
enzyme_dup, A, dA, enzyme_const, lda, enzyme_dup, B,
dB, enzyme_const, incB, enzyme_const, beta,
enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
my_dgemm(handle, transA, transB, M, N, K, alpha, A, lda, B, incB, beta,
C, incC);
inDerivative = true;
my_dgemm(handle, transA_bool ? transB : cublasOperation_t::CUBLAS_OP_N,
transA_bool ? cublasOperation_t::CUBLAS_OP_T
: transpose(transB),
transA_bool ? K : M, transA_bool ? M : K, N, alpha,
transA_bool ? B : dC, transA_bool ? incB : incC,
transA_bool ? dC : B, transA_bool ? incC : incB, 1.0, dA, lda);
my_dgemm(
handle,
transB_bool ? cublasOperation_t::CUBLAS_OP_T : transpose(transA),
transB_bool ? transA : cublasOperation_t::CUBLAS_OP_N, transB_bool ? N : K, transB_bool ? K : N, M, alpha,
transB_bool ? dC : A, transB_bool ? incC : lda,
transB_bool ? A : dC, transB_bool ? lda : incC, 1.0, dB, incB);
double c10 = 1.0;
cublasDlascl(handle, (cublasOperation_t)2, 0, 0, &c10, &beta, M, N,
dC, incC, 0);
checkTest(Test);
checkMemoryTrace(inputs, "Expected " + Test, calls);
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}
int main() {
gemmTests();
gemvTests();
dotTests();
dot2Tests();
scal2Tests();
}