use super::core::{
af_array, dim_t, AfError, Array, CovarianceComputable, HasAfEnum, MedianComputable,
RealFloating, RealNumber, TopkFn, VarianceBias, HANDLE_ERROR,
};
use libc::{c_double, c_int, c_uint};
extern "C" {
fn af_mean(out: *mut af_array, arr: af_array, dim: dim_t) -> c_int;
fn af_stdev(out: *mut af_array, arr: af_array, dim: dim_t) -> c_int;
fn af_median(out: *mut af_array, arr: af_array, dim: dim_t) -> c_int;
fn af_mean_weighted(out: *mut af_array, arr: af_array, wts: af_array, dim: dim_t) -> c_int;
fn af_var_weighted(out: *mut af_array, arr: af_array, wts: af_array, dim: dim_t) -> c_int;
fn af_var(out: *mut af_array, arr: af_array, isbiased: bool, dim: dim_t) -> c_int;
fn af_cov(out: *mut af_array, X: af_array, Y: af_array, isbiased: bool) -> c_int;
fn af_var_all(real: *mut c_double, imag: *mut c_double, arr: af_array, isbiased: bool)
-> c_int;
fn af_mean_all(real: *mut c_double, imag: *mut c_double, arr: af_array) -> c_int;
fn af_stdev_all(real: *mut c_double, imag: *mut c_double, arr: af_array) -> c_int;
fn af_median_all(real: *mut c_double, imag: *mut c_double, arr: af_array) -> c_int;
fn af_mean_all_weighted(
real: *mut c_double,
imag: *mut c_double,
arr: af_array,
wts: af_array,
) -> c_int;
fn af_var_all_weighted(
real: *mut c_double,
imag: *mut c_double,
arr: af_array,
wts: af_array,
) -> c_int;
fn af_corrcoef(real: *mut c_double, imag: *mut c_double, X: af_array, Y: af_array) -> c_int;
fn af_topk(
vals: *mut af_array,
idxs: *mut af_array,
arr: af_array,
k: c_int,
dim: c_int,
order: c_uint,
) -> c_int;
fn af_meanvar(
mean: *mut af_array,
var: *mut af_array,
input: af_array,
weights: af_array,
bias: c_uint,
dim: dim_t,
) -> c_int;
}
pub fn median<T>(input: &Array<T>, dim: i64) -> Array<T>
where
T: HasAfEnum + MedianComputable,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_median(&mut temp as *mut af_array, input.get(), dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
macro_rules! stat_func_def {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
#[doc=$doc_str]
pub fn $fn_name<T>(input: &Array<T>, dim: i64) -> Array<T::MeanOutType>
where
T: HasAfEnum,
T::MeanOutType: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_fn(&mut temp as *mut af_array, input.get(), dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
};
}
stat_func_def!("Mean along specified dimension", mean, af_mean);
stat_func_def!(
"Standard deviation along specified dimension",
stdev,
af_stdev
);
macro_rules! stat_wtd_func_def {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
#[doc=$doc_str]
pub fn $fn_name<T, W>(
input: &Array<T>,
weights: &Array<W>,
dim: i64,
) -> Array<T::MeanOutType>
where
T: HasAfEnum,
T::MeanOutType: HasAfEnum,
W: HasAfEnum + RealFloating,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_fn(&mut temp as *mut af_array,input.get(), weights.get(), dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
};
}
stat_wtd_func_def!(
"Weighted mean along specified dimension",
mean_weighted,
af_mean_weighted
);
stat_wtd_func_def!(
"Weight variance along specified dimension",
var_weighted,
af_var_weighted
);
pub fn var<T>(arr: &Array<T>, isbiased: bool, dim: i64) -> Array<T::MeanOutType>
where
T: HasAfEnum,
T::MeanOutType: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_var(&mut temp as *mut af_array, arr.get(), isbiased, dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn cov<T>(x: &Array<T>, y: &Array<T>, isbiased: bool) -> Array<T::MeanOutType>
where
T: HasAfEnum + CovarianceComputable,
T::MeanOutType: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_cov(&mut temp as *mut af_array, x.get(), y.get(), isbiased);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn var_all<T: HasAfEnum>(input: &Array<T>, isbiased: bool) -> (f64, f64) {
let mut real: f64 = 0.0;
let mut imag: f64 = 0.0;
unsafe {
let err_val = af_var_all(
&mut real as *mut c_double,
&mut imag as *mut c_double,
input.get(),
isbiased,
);
HANDLE_ERROR(AfError::from(err_val));
}
(real, imag)
}
macro_rules! stat_all_func_def {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
#[doc=$doc_str]
pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> (f64, f64) {
let mut real: f64 = 0.0;
let mut imag: f64 = 0.0;
unsafe {
let err_val = $ffi_fn(
&mut real as *mut c_double,
&mut imag as *mut c_double,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
(real, imag)
}
};
}
stat_all_func_def!("Compute mean of all data", mean_all, af_mean_all);
stat_all_func_def!(
"Compute standard deviation of all data",
stdev_all,
af_stdev_all
);
pub fn median_all<T>(input: &Array<T>) -> (f64, f64)
where
T: HasAfEnum + MedianComputable,
{
let mut real: f64 = 0.0;
let mut imag: f64 = 0.0;
unsafe {
let err_val = af_median_all(
&mut real as *mut c_double,
&mut imag as *mut c_double,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
(real, imag)
}
macro_rules! stat_wtd_all_func_def {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
#[doc=$doc_str]
pub fn $fn_name<T, W>(input: &Array<T>, weights: &Array<W>) -> (f64, f64)
where
T: HasAfEnum,
W: HasAfEnum + RealFloating,
{
let mut real: f64 = 0.0;
let mut imag: f64 = 0.0;
unsafe {
let err_val = $ffi_fn(
&mut real as *mut c_double,
&mut imag as *mut c_double,
input.get(),
weights.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
(real, imag)
}
};
}
stat_wtd_all_func_def!(
"Compute weighted mean of all data",
mean_all_weighted,
af_mean_all_weighted
);
stat_wtd_all_func_def!(
"Compute weighted variance of all data",
var_all_weighted,
af_var_all_weighted
);
pub fn corrcoef<T>(x: &Array<T>, y: &Array<T>) -> (f64, f64)
where
T: HasAfEnum + RealNumber,
{
let mut real: f64 = 0.0;
let mut imag: f64 = 0.0;
unsafe {
let err_val = af_corrcoef(
&mut real as *mut c_double,
&mut imag as *mut c_double,
x.get(),
y.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
(real, imag)
}
pub fn topk<T>(input: &Array<T>, k: u32, dim: i32, order: TopkFn) -> (Array<T>, Array<u32>)
where
T: HasAfEnum,
{
unsafe {
let mut t0: af_array = std::ptr::null_mut();
let mut t1: af_array = std::ptr::null_mut();
let err_val = af_topk(
&mut t0 as *mut af_array,
&mut t1 as *mut af_array,
input.get(),
k as c_int,
dim as c_int,
order as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
(t0.into(), t1.into())
}
}
pub fn meanvar<T, W>(
input: &Array<T>,
weights: &Array<W>,
bias: VarianceBias,
dim: i64,
) -> (Array<T::MeanOutType>, Array<T::MeanOutType>)
where
T: HasAfEnum,
T::MeanOutType: HasAfEnum,
W: HasAfEnum + RealFloating,
{
unsafe {
let mut mean: af_array = std::ptr::null_mut();
let mut var: af_array = std::ptr::null_mut();
let err_val = af_meanvar(
&mut mean as *mut af_array,
&mut var as *mut af_array,
input.get(),
weights.get(),
bias as c_uint,
dim,
);
HANDLE_ERROR(AfError::from(err_val));
(mean.into(), var.into())
}
}