use super::core::{
af_array, AfError, Array, FloatingPoint, HasAfEnum, MatProp, NormType, HANDLE_ERROR,
};
use libc::{c_double, c_int, c_uint};
extern "C" {
fn af_svd(u: *mut af_array, s: *mut af_array, vt: *mut af_array, input: af_array) -> c_int;
fn af_svd_inplace(
u: *mut af_array,
s: *mut af_array,
vt: *mut af_array,
input: af_array,
) -> c_int;
fn af_lu(
lower: *mut af_array,
upper: *mut af_array,
pivot: *mut af_array,
input: af_array,
) -> c_int;
fn af_lu_inplace(pivot: *mut af_array, input: af_array, is_lapack_piv: bool) -> c_int;
fn af_qr(q: *mut af_array, r: *mut af_array, tau: *mut af_array, input: af_array) -> c_int;
fn af_qr_inplace(tau: *mut af_array, input: af_array) -> c_int;
fn af_cholesky(out: *mut af_array, info: *mut c_int, input: af_array, is_upper: bool) -> c_int;
fn af_cholesky_inplace(info: *mut c_int, input: af_array, is_upper: bool) -> c_int;
fn af_solve(x: *mut af_array, a: af_array, b: af_array, options: c_uint) -> c_int;
fn af_solve_lu(
x: *mut af_array,
a: af_array,
piv: af_array,
b: af_array,
options: c_uint,
) -> c_int;
fn af_inverse(out: *mut af_array, input: af_array, options: c_uint) -> c_int;
fn af_rank(rank: *mut c_uint, input: af_array, tol: c_double) -> c_int;
fn af_det(det_real: *mut c_double, det_imag: *mut c_double, input: af_array) -> c_int;
fn af_norm(
out: *mut c_double,
input: af_array,
ntype: c_uint,
p: c_double,
q: c_double,
) -> c_int;
fn af_is_lapack_available(out: *mut bool) -> c_int;
fn af_pinverse(out: *mut af_array, input: af_array, tol: c_double, options: c_uint) -> c_int;
}
pub fn svd<T>(input: &Array<T>) -> (Array<T>, Array<T::BaseType>, Array<T>)
where
T: HasAfEnum + FloatingPoint,
T::BaseType: HasAfEnum,
{
unsafe {
let mut u: af_array = std::ptr::null_mut();
let mut s: af_array = std::ptr::null_mut();
let mut vt: af_array = std::ptr::null_mut();
let err_val = af_svd(
&mut u as *mut af_array,
&mut s as *mut af_array,
&mut vt as *mut af_array,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
(u.into(), s.into(), vt.into())
}
}
pub fn svd_inplace<T>(input: &mut Array<T>) -> (Array<T>, Array<T::BaseType>, Array<T>)
where
T: HasAfEnum + FloatingPoint,
T::BaseType: HasAfEnum,
{
unsafe {
let mut u: af_array = std::ptr::null_mut();
let mut s: af_array = std::ptr::null_mut();
let mut vt: af_array = std::ptr::null_mut();
let err_val = af_svd_inplace(
&mut u as *mut af_array,
&mut s as *mut af_array,
&mut vt as *mut af_array,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
(u.into(), s.into(), vt.into())
}
}
pub fn lu<T>(input: &Array<T>) -> (Array<T>, Array<T>, Array<i32>)
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut lower: af_array = std::ptr::null_mut();
let mut upper: af_array = std::ptr::null_mut();
let mut pivot: af_array = std::ptr::null_mut();
let err_val = af_lu(
&mut lower as *mut af_array,
&mut upper as *mut af_array,
&mut pivot as *mut af_array,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
(lower.into(), upper.into(), pivot.into())
}
}
pub fn lu_inplace<T>(input: &mut Array<T>, is_lapack_piv: bool) -> Array<i32>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut pivot: af_array = std::ptr::null_mut();
let err_val = af_lu_inplace(&mut pivot as *mut af_array, input.get(), is_lapack_piv);
HANDLE_ERROR(AfError::from(err_val));
pivot.into()
}
}
pub fn qr<T>(input: &Array<T>) -> (Array<T>, Array<T>, Array<T>)
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut q: af_array = std::ptr::null_mut();
let mut r: af_array = std::ptr::null_mut();
let mut tau: af_array = std::ptr::null_mut();
let err_val = af_qr(
&mut q as *mut af_array,
&mut r as *mut af_array,
&mut tau as *mut af_array,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
(q.into(), r.into(), tau.into())
}
}
pub fn qr_inplace<T>(input: &mut Array<T>) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut tau: af_array = std::ptr::null_mut();
let err_val = af_qr_inplace(&mut tau as *mut af_array, input.get());
HANDLE_ERROR(AfError::from(err_val));
tau.into()
}
}
pub fn cholesky<T>(input: &Array<T>, is_upper: bool) -> (Array<T>, i32)
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let mut info: i32 = 0;
let err_val = af_cholesky(
&mut temp as *mut af_array,
&mut info as *mut c_int,
input.get(),
is_upper,
);
HANDLE_ERROR(AfError::from(err_val));
(temp.into(), info)
}
}
pub fn cholesky_inplace<T>(input: &mut Array<T>, is_upper: bool) -> i32
where
T: HasAfEnum + FloatingPoint,
{
let mut info: i32 = 0;
unsafe {
let err_val = af_cholesky_inplace(&mut info as *mut c_int, input.get(), is_upper);
HANDLE_ERROR(AfError::from(err_val));
}
info
}
pub fn solve<T>(a: &Array<T>, b: &Array<T>, options: MatProp) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_solve(
&mut temp as *mut af_array,
a.get(),
b.get(),
options as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn solve_lu<T>(a: &Array<T>, piv: &Array<i32>, b: &Array<T>, options: MatProp) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_solve_lu(
&mut temp as *mut af_array,
a.get(),
piv.get(),
b.get(),
options as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn inverse<T>(input: &Array<T>, options: MatProp) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_inverse(&mut temp as *mut af_array, input.get(), options as c_uint);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn rank<T>(input: &Array<T>, tol: f64) -> u32
where
T: HasAfEnum + FloatingPoint,
{
let mut temp: u32 = 0;
unsafe {
let err_val = af_rank(&mut temp as *mut c_uint, input.get(), tol);
HANDLE_ERROR(AfError::from(err_val));
}
temp
}
pub fn det<T>(input: &Array<T>) -> (f64, f64)
where
T: HasAfEnum + FloatingPoint,
{
let mut real: f64 = 0.0;
let mut imag: f64 = 0.0;
unsafe {
let err_val = af_det(
&mut real as *mut c_double,
&mut imag as *mut c_double,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
(real, imag)
}
pub fn norm<T>(input: &Array<T>, ntype: NormType, p: f64, q: f64) -> f64
where
T: HasAfEnum + FloatingPoint,
{
let mut out: f64 = 0.0;
unsafe {
let err_val = af_norm(
&mut out as *mut c_double,
input.get(),
ntype as c_uint,
p,
q,
);
HANDLE_ERROR(AfError::from(err_val));
}
out
}
pub fn is_lapack_available() -> bool {
let mut temp: bool = false;
unsafe {
af_is_lapack_available(&mut temp as *mut bool);
}
temp
}
pub fn pinverse<T>(input: &Array<T>, tolerance: f64, option: MatProp) -> Array<T>
where
T: HasAfEnum + FloatingPoint,
{
unsafe {
let mut out: af_array = std::ptr::null_mut();
let err_val = af_pinverse(
&mut out as *mut af_array,
input.get(),
tolerance,
option as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
out.into()
}
}