use super::array::Array;
use super::defines::{AfError, BorderType};
use super::dim4::Dim4;
use super::error::HANDLE_ERROR;
use super::util::{af_array, c32, c64, dim_t, u64_t, HasAfEnum};
use libc::{c_double, c_int, c_uint};
use std::option::Option;
use std::vec::Vec;
extern "C" {
fn af_constant(
out: *mut af_array,
val: c_double,
ndims: c_uint,
dims: *const dim_t,
afdtype: c_uint,
) -> c_int;
fn af_constant_complex(
out: *mut af_array,
real: c_double,
imag: c_double,
ndims: c_uint,
dims: *const dim_t,
afdtype: c_uint,
) -> c_int;
fn af_constant_long(out: *mut af_array, val: dim_t, ndims: c_uint, dims: *const dim_t)
-> c_int;
fn af_constant_ulong(
out: *mut af_array,
val: u64_t,
ndims: c_uint,
dims: *const dim_t,
) -> c_int;
fn af_range(
out: *mut af_array,
ndims: c_uint,
dims: *const dim_t,
seq_dim: c_int,
afdtype: c_uint,
) -> c_int;
fn af_iota(
out: *mut af_array,
ndims: c_uint,
dims: *const dim_t,
t_ndims: c_uint,
tdims: *const dim_t,
afdtype: c_uint,
) -> c_int;
fn af_identity(out: *mut af_array, ndims: c_uint, dims: *const dim_t, afdtype: c_uint)
-> c_int;
fn af_diag_create(out: *mut af_array, arr: af_array, num: c_int) -> c_int;
fn af_diag_extract(out: *mut af_array, arr: af_array, num: c_int) -> c_int;
fn af_join(out: *mut af_array, dim: c_int, first: af_array, second: af_array) -> c_int;
fn af_join_many(
out: *mut af_array,
dim: c_int,
n_arrays: c_uint,
inpts: *const af_array,
) -> c_int;
fn af_tile(
out: *mut af_array,
arr: af_array,
x: c_uint,
y: c_uint,
z: c_uint,
w: c_uint,
) -> c_int;
fn af_reorder(
o: *mut af_array,
a: af_array,
x: c_uint,
y: c_uint,
z: c_uint,
w: c_uint,
) -> c_int;
fn af_shift(o: *mut af_array, a: af_array, x: c_int, y: c_int, z: c_int, w: c_int) -> c_int;
fn af_moddims(out: *mut af_array, arr: af_array, ndims: c_uint, dims: *const dim_t) -> c_int;
fn af_flat(out: *mut af_array, arr: af_array) -> c_int;
fn af_flip(out: *mut af_array, arr: af_array, dim: c_uint) -> c_int;
fn af_lower(out: *mut af_array, arr: af_array, is_unit_diag: bool) -> c_int;
fn af_upper(out: *mut af_array, arr: af_array, is_unit_diag: bool) -> c_int;
fn af_select(out: *mut af_array, cond: af_array, a: af_array, b: af_array) -> c_int;
fn af_select_scalar_l(out: *mut af_array, cond: af_array, a: c_double, b: af_array) -> c_int;
fn af_select_scalar_r(out: *mut af_array, cond: af_array, a: af_array, b: c_double) -> c_int;
fn af_replace(a: *mut af_array, cond: af_array, b: af_array) -> c_int;
fn af_replace_scalar(a: *mut af_array, cond: af_array, b: c_double) -> c_int;
fn af_pad(
out: *mut af_array,
input: af_array,
begin_ndims: c_uint,
begin_dims: *const dim_t,
end_ndims: c_uint,
end_dims: *const dim_t,
pad_fill_type: c_uint,
) -> c_int;
}
pub trait ConstGenerator: HasAfEnum {
type OutType: HasAfEnum;
fn generate(&self, dims: Dim4) -> Array<Self::OutType>;
}
impl ConstGenerator for i64 {
type OutType = i64;
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_constant_long(
&mut temp as *mut af_array,
*self,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
}
impl ConstGenerator for u64 {
type OutType = u64;
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_constant_ulong(
&mut temp as *mut af_array,
*self,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
}
impl ConstGenerator for c32 {
type OutType = c32;
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_constant_complex(
&mut temp as *mut af_array,
(*self).re as c_double,
(*self).im as c_double,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
1,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
}
impl ConstGenerator for c64 {
type OutType = c64;
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_constant_complex(
&mut temp as *mut af_array,
(*self).re as c_double,
(*self).im as c_double,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
3,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
}
impl ConstGenerator for bool {
type OutType = bool;
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_constant(
&mut temp as *mut af_array,
*self as c_int as c_double,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
4,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
}
macro_rules! cnst {
($rust_type:ty, $ffi_type:expr) => {
impl ConstGenerator for $rust_type {
type OutType = $rust_type;
fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_constant(
&mut temp as *mut af_array,
*self as c_double,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
$ffi_type,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
}
};
}
cnst!(f32, 0);
cnst!(f64, 2);
cnst!(i32, 5);
cnst!(u32, 6);
cnst!(u8, 7);
cnst!(i16, 10);
cnst!(u16, 11);
pub fn constant<T>(cnst: T, dims: Dim4) -> Array<T>
where
T: ConstGenerator<OutType = T>,
{
cnst.generate(dims)
}
pub fn range<T: HasAfEnum>(dims: Dim4, seq_dim: i32) -> Array<T> {
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_range(
&mut temp as *mut af_array,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
seq_dim as c_int,
aftype as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn iota<T: HasAfEnum>(dims: Dim4, tdims: Dim4) -> Array<T> {
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_iota(
&mut temp as *mut af_array,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
tdims.ndims() as c_uint,
tdims.get().as_ptr() as *const dim_t,
aftype as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn identity<T: HasAfEnum>(dims: Dim4) -> Array<T> {
let aftype = T::get_af_dtype();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_identity(
&mut temp as *mut af_array,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
aftype as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn diag_create<T>(input: &Array<T>, dim: i32) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_diag_create(&mut temp as *mut af_array, input.get(), dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn diag_extract<T>(input: &Array<T>, dim: i32) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_diag_extract(&mut temp as *mut af_array, input.get(), dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn join<T>(dim: i32, first: &Array<T>, second: &Array<T>) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_join(&mut temp as *mut af_array, dim, first.get(), second.get());
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn join_many<T>(dim: i32, inputs: Vec<&Array<T>>) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut v = Vec::new();
for i in inputs {
v.push(i.get());
}
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_join_many(
&mut temp as *mut af_array,
dim,
v.len() as u32,
v.as_ptr() as *const af_array,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn tile<T>(input: &Array<T>, dims: Dim4) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_tile(
&mut temp as *mut af_array,
input.get() as af_array,
dims[0] as c_uint,
dims[1] as c_uint,
dims[2] as c_uint,
dims[3] as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn reorder_v2<T>(
input: &Array<T>,
new_axis0: u64,
new_axis1: u64,
next_axes: Option<Vec<u64>>,
) -> Array<T>
where
T: HasAfEnum,
{
let mut new_axes = [0, 1, 2, 3];
new_axes[0] = new_axis0;
new_axes[1] = new_axis1;
match next_axes {
Some(left_over_new_axes) => {
assert!(left_over_new_axes.len() <= 2);
for a_idx in 0..left_over_new_axes.len() {
new_axes[2 + a_idx] = left_over_new_axes[a_idx];
}
}
None => {
for a_idx in 2..4 {
new_axes[a_idx] = a_idx as u64;
}
}
};
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_reorder(
&mut temp as *mut af_array,
input.get() as af_array,
new_axes[0] as c_uint,
new_axes[1] as c_uint,
new_axes[2] as c_uint,
new_axes[3] as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
#[deprecated(since = "3.6.3", note = "Please use new reorder API")]
pub fn reorder<T>(input: &Array<T>, dims: Dim4) -> Array<T>
where
T: HasAfEnum,
{
reorder_v2(input, dims[0], dims[1], Some(vec![dims[2], dims[3]]))
}
pub fn shift<T>(input: &Array<T>, offsets: &[i32; 4]) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_shift(
&mut temp as *mut af_array,
input.get(),
offsets[0],
offsets[1],
offsets[2],
offsets[3],
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn moddims<T>(input: &Array<T>, dims: Dim4) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_moddims(
&mut temp as *mut af_array,
input.get(),
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn flat<T>(input: &Array<T>) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_flat(&mut temp as *mut af_array, input.get());
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn flip<T>(input: &Array<T>, dim: u32) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_flip(&mut temp as *mut af_array, input.get(), dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn lower<T>(input: &Array<T>, is_unit_diag: bool) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_lower(&mut temp as *mut af_array, input.get(), is_unit_diag);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn upper<T>(input: &Array<T>, is_unit_diag: bool) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_upper(&mut temp as *mut af_array, input.get(), is_unit_diag);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn select<T>(a: &Array<T>, cond: &Array<bool>, b: &Array<T>) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_select(&mut temp as *mut af_array, cond.get(), a.get(), b.get());
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn selectl<T>(a: f64, cond: &Array<bool>, b: &Array<T>) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_select_scalar_l(&mut temp as *mut af_array, cond.get(), a, b.get());
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn selectr<T>(a: &Array<T>, cond: &Array<bool>, b: f64) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_select_scalar_r(&mut temp as *mut af_array, cond.get(), a.get(), b);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn replace<T>(a: &mut Array<T>, cond: &Array<bool>, b: &Array<T>)
where
T: HasAfEnum,
{
unsafe {
let err_val = af_replace(a.get() as *mut af_array, cond.get(), b.get());
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn replace_scalar<T>(a: &mut Array<T>, cond: &Array<bool>, b: f64)
where
T: HasAfEnum,
{
unsafe {
let err_val = af_replace_scalar(a.get() as *mut af_array, cond.get(), b);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn pad<T: HasAfEnum>(
input: &Array<T>,
begin: Dim4,
end: Dim4,
fill_type: BorderType,
) -> Array<T> {
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_pad(
&mut temp as *mut af_array,
input.get(),
4,
begin.get().as_ptr() as *const dim_t,
4,
end.get().as_ptr() as *const dim_t,
fill_type as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
#[cfg(test)]
mod tests {
use super::reorder_v2;
use super::super::defines::BorderType;
use super::super::device::set_device;
use super::super::random::randu;
use super::pad;
use crate::dim4;
#[test]
fn check_reorder_api() {
set_device(0);
let a = randu::<f32>(dim4!(4, 5, 2, 3));
let _transposed = reorder_v2(&a, 1, 0, None);
let _swap_0_2 = reorder_v2(&a, 2, 1, Some(vec![0]));
let _swap_1_2 = reorder_v2(&a, 0, 2, Some(vec![1]));
let _swap_0_3 = reorder_v2(&a, 3, 1, Some(vec![2, 0]));
}
#[test]
fn check_pad_api() {
set_device(0);
let a = randu::<f32>(dim4![3, 3]);
let begin_dims = dim4!(0, 0, 0, 0);
let end_dims = dim4!(2, 2, 0, 0);
let _padded = pad(&a, begin_dims, end_dims, BorderType::ZERO);
}
}