#include <complex>
#include <iostream>
#include "blis.hh"
#include "test.hh"
using namespace blis;
using namespace std;
#define M 5
#define N 6
template< typename T >
void ref_trsv(int64_t n,
T *A,
T *X
)
{
obj_t obj_a, obj_x;
num_t dt;
if(is_same<T, float>::value)
dt = BLIS_FLOAT;
else if(is_same<T, double>::value)
dt = BLIS_DOUBLE;
else if(is_same<T, complex<float>>::value)
dt = BLIS_SCOMPLEX;
else if(is_same<T, complex<double>>::value)
dt = BLIS_DCOMPLEX;
bli_obj_create_with_attached_buffer( dt, n, n, A, 1,n,&obj_a );
bli_obj_create_with_attached_buffer( dt, n, 1, X, 1,n,&obj_x );
bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a );
bli_obj_set_uplo( BLIS_LOWER, &obj_a );
bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &obj_a );
bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a );
bli_trsv( &BLIS_ONE,
&obj_a,
&obj_x
);
}
template< typename T >
void test_trsv( )
{
T *A, *X, *X_ref;
int n;
int lda, incx, incx_ref;
n = N;
lda = n;
incx = 1;
incx_ref = 1;
srand (time(NULL));
allocate_init_buffer(A , n , n);
allocate_init_buffer(X , n , 1);
copy_buffer(X, X_ref , n ,1);
#ifdef PRINT
printmatrix(A, lda ,n,n,(char *) "A");
printvector(X, n,(char *) "X");
#endif
blis::trsv(
CblasColMajor,
CblasLower,
CblasNoTrans,
CblasNonUnit,
n,
A,
lda,
X,
incx
);
#ifdef PRINT
printvector(X, n,(char *) "X output");
#endif
ref_trsv(n, A, X_ref);
#ifdef PRINT
printvector(X_ref, n,(char *) "X ref output");
#endif
if(computeErrorV(incx, incx_ref, n, X, X_ref )==1)
printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__);
else
printf("%s TEST PASS\n" , __PRETTY_FUNCTION__);
delete[]( A );
delete[]( X );
delete[]( X_ref );
}
int main( int argc, char** argv )
{
test_trsv<double>( );
test_trsv<float>( );
test_trsv<complex<float>>( );
test_trsv<complex<double>>( );
return 0;
}