#include <complex>
#include <iostream>
#include "blis.hh"
#include "test.hh"
using namespace blis;
using namespace std;
#define ALPHA 1.0
#define M 5
#define N 4
template< typename T >
void ref_trsm(int64_t m, int64_t n,
T * alpha,
T *A,
T *B
)
{
obj_t obj_a, obj_b;
obj_t obj_alpha;
num_t dt;
if(is_same<T, float>::value)
dt = BLIS_FLOAT;
else if(is_same<T, double>::value)
dt = BLIS_DOUBLE;
else if(is_same<T, complex<float>>::value)
dt = BLIS_SCOMPLEX;
else if(is_same<T, complex<double>>::value)
dt = BLIS_DCOMPLEX;
bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha );
bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a );
bli_obj_create_with_attached_buffer( dt, m, n, B, 1,m,&obj_b );
bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a );
bli_obj_set_uplo( BLIS_LOWER, &obj_a );
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a );
bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a );
bli_trsm( BLIS_LEFT,
&obj_alpha,
&obj_a,
&obj_b
);
}
template< typename T >
void test_trsm( )
{
T *A, *B, *B_ref;
T alpha;
int m,n;
int lda, ldb, ldb_ref;
alpha = ALPHA;
m = M;
n = N;
lda = m;
ldb = m;
ldb_ref = m;
srand (time(NULL));
allocate_init_buffer(A , m , m);
allocate_init_buffer(B , m , n);
copy_buffer(B, B_ref , m ,n);
#ifdef PRINT
printmatrix(A, lda ,m,m, (char *)"A");
printmatrix(B, ldb ,m,n, (char *)"B");
#endif
blis::trsm(
CblasColMajor,
CblasLeft,
CblasLower,
CblasNoTrans,
CblasNonUnit,
m,
n,
alpha,
A,
lda,
B,
ldb
);
#ifdef PRINT
printmatrix(B, ldb ,m,n, (char *)"B output");
#endif
ref_trsm(m, n, &alpha, A, B_ref);
#ifdef PRINT
printmatrix(B_ref, ldb_ref ,m,n, (char *)"B ref output");
#endif
if(computeErrorM(ldb, ldb_ref, m, n, B, B_ref )==1)
printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__);
else
printf("%s TEST PASS\n" , __PRETTY_FUNCTION__);
delete[]( A );
delete[]( B );
delete[]( B_ref );
}
int main( int argc, char** argv )
{
test_trsm<double>( );
test_trsm<float>( );
test_trsm<complex<float>>( );
test_trsm<complex<double>>( );
return 0;
}