use super::array::Array;
use super::defines::AfError;
use super::error::HANDLE_ERROR;
use super::seq::Seq;
use super::util::{af_array, af_index_t, dim_t, HasAfEnum, IndexableType};
use libc::{c_double, c_int, c_uint};
use std::default::Default;
use std::marker::PhantomData;
use std::mem;
extern "C" {
fn af_create_indexers(indexers: *mut af_index_t) -> c_int;
fn af_set_array_indexer(indexer: af_index_t, idx: af_array, dim: dim_t) -> c_int;
fn af_set_seq_indexer(
indexer: af_index_t,
idx: *const SeqInternal,
dim: dim_t,
is_batch: bool,
) -> c_int;
fn af_release_indexers(indexers: af_index_t) -> c_int;
fn af_index(
out: *mut af_array,
input: af_array,
ndims: c_uint,
index: *const SeqInternal,
) -> c_int;
fn af_lookup(out: *mut af_array, arr: af_array, indices: af_array, dim: c_uint) -> c_int;
fn af_assign_seq(
out: *mut af_array,
lhs: af_array,
ndims: c_uint,
indices: *const SeqInternal,
rhs: af_array,
) -> c_int;
fn af_index_gen(
out: *mut af_array,
input: af_array,
ndims: dim_t,
indices: af_index_t,
) -> c_int;
fn af_assign_gen(
out: *mut af_array,
lhs: af_array,
ndims: dim_t,
indices: af_index_t,
rhs: af_array,
) -> c_int;
}
pub struct Indexer<'object> {
handle: af_index_t,
count: usize,
marker: PhantomData<&'object ()>,
}
unsafe impl<'object> Send for Indexer<'object> {}
pub trait Indexable {
fn set(&self, idxr: &mut Indexer, dim: u32, is_batch: Option<bool>);
}
impl<T> Indexable for Array<T>
where
T: HasAfEnum + IndexableType,
{
fn set(&self, idxr: &mut Indexer, dim: u32, _is_batch: Option<bool>) {
unsafe {
let err_val = af_set_array_indexer(idxr.get(), self.get(), dim as dim_t);
HANDLE_ERROR(AfError::from(err_val));
}
}
}
impl<T> Indexable for Seq<T>
where
c_double: From<T>,
T: Copy + IndexableType,
{
fn set(&self, idxr: &mut Indexer, dim: u32, is_batch: Option<bool>) {
unsafe {
let err_val = af_set_seq_indexer(
idxr.get(),
&SeqInternal::from_seq(self) as *const SeqInternal,
dim as dim_t,
match is_batch {
Some(value) => value,
None => false,
},
);
HANDLE_ERROR(AfError::from(err_val));
}
}
}
impl<'object> Default for Indexer<'object> {
fn default() -> Self {
unsafe {
let mut temp: af_index_t = std::ptr::null_mut();
let err_val = af_create_indexers(&mut temp as *mut af_index_t);
HANDLE_ERROR(AfError::from(err_val));
Self {
handle: temp,
count: 0,
marker: PhantomData,
}
}
}
}
impl<'object> Indexer<'object> {
#[deprecated(since = "3.7.0", note = "Use Indexer::default() instead")]
pub fn new() -> Self {
unsafe {
let mut temp: af_index_t = std::ptr::null_mut();
let err_val = af_create_indexers(&mut temp as *mut af_index_t);
HANDLE_ERROR(AfError::from(err_val));
Self {
handle: temp,
count: 0,
marker: PhantomData,
}
}
}
pub fn set_index<'s, T>(&'s mut self, idx: &'object T, dim: u32, is_batch: Option<bool>)
where
T: Indexable + 'object,
{
idx.set(self, dim, is_batch);
self.count += 1;
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub unsafe fn get(&self) -> af_index_t {
self.handle
}
}
impl<'object> Drop for Indexer<'object> {
fn drop(&mut self) {
unsafe {
let ret_val = af_release_indexers(self.handle as af_index_t);
match ret_val {
0 => (),
_ => panic!("Failed to release indexers resource: {}", ret_val),
}
}
}
}
pub fn index<IO, T>(input: &Array<IO>, seqs: &[Seq<T>]) -> Array<IO>
where
c_double: From<T>,
IO: HasAfEnum,
T: Copy + HasAfEnum + IndexableType,
{
let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_index(
&mut temp as *mut af_array,
input.get(),
seqs.len() as u32,
seqs.as_ptr() as *const SeqInternal,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
where
T: HasAfEnum,
{
index(
input,
&[
Seq::new(row_num as f64, row_num as f64, 1.0),
Seq::default(),
],
)
}
pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: i64)
where
T: HasAfEnum,
{
let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)];
if inout.dims().ndims() > 1 {
seqs.push(Seq::default());
}
assign_seq(inout, &seqs, new_row)
}
pub fn rows<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
index(
input,
&[Seq::new(first as f64, last as f64, step), Seq::default()],
)
}
pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: i64, last: i64)
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()];
assign_seq(inout, &seqs, new_rows)
}
pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
where
T: HasAfEnum,
{
index(
input,
&[
Seq::default(),
Seq::new(col_num as f64, col_num as f64, 1.0),
],
)
}
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: i64)
where
T: HasAfEnum,
{
let seqs = [
Seq::default(),
Seq::new(col_num as f64, col_num as f64, 1.0),
];
assign_seq(inout, &seqs, new_col)
}
pub fn cols<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
index(
input,
&[Seq::default(), Seq::new(first as f64, last as f64, step)],
)
}
pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: i64, last: i64)
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)];
assign_seq(inout, &seqs, new_cols)
}
pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
where
T: HasAfEnum,
{
let seqs = [
Seq::default(),
Seq::default(),
Seq::new(slice_num as f64, slice_num as f64, 1.0),
];
index(input, &seqs)
}
pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
where
T: HasAfEnum,
{
let seqs = [
Seq::default(),
Seq::default(),
Seq::new(slice_num as f64, slice_num as f64, 1.0),
];
assign_seq(inout, &seqs, new_slice)
}
pub fn slices<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [
Seq::default(),
Seq::default(),
Seq::new(first as f64, last as f64, step),
];
index(input, &seqs)
}
pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: i64, last: i64)
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [
Seq::default(),
Seq::default(),
Seq::new(first as f64, last as f64, step),
];
assign_seq(inout, &seqs, new_slices)
}
pub fn lookup<T, I>(input: &Array<T>, indices: &Array<I>, seq_dim: i32) -> Array<T>
where
T: HasAfEnum,
I: HasAfEnum + IndexableType,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_lookup(
&mut temp as *mut af_array,
input.get() as af_array,
indices.get() as af_array,
seq_dim as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn assign_seq<T, I>(lhs: &mut Array<I>, seqs: &[Seq<T>], rhs: &Array<I>)
where
c_double: From<T>,
I: HasAfEnum,
T: Copy + IndexableType,
{
let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_assign_seq(
&mut temp as *mut af_array,
lhs.get() as af_array,
seqs.len() as c_uint,
seqs.as_ptr() as *const SeqInternal,
rhs.get() as af_array,
);
HANDLE_ERROR(AfError::from(err_val));
let modified = temp.into();
let _old_arr = mem::replace(lhs, modified);
}
}
pub fn index_gen<T>(input: &Array<T>, indices: Indexer) -> Array<T>
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_index_gen(
&mut temp as *mut af_array,
input.get() as af_array,
indices.len() as dim_t,
indices.get() as af_index_t,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
}
pub fn assign_gen<T>(lhs: &mut Array<T>, indices: &Indexer, rhs: &Array<T>)
where
T: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = af_assign_gen(
&mut temp as *mut af_array,
lhs.get() as af_array,
indices.len() as dim_t,
indices.get() as af_index_t,
rhs.get() as af_array,
);
HANDLE_ERROR(AfError::from(err_val));
let modified = temp.into();
let _old_arr = mem::replace(lhs, modified);
}
}
#[repr(C)]
struct SeqInternal {
begin: c_double,
end: c_double,
step: c_double,
}
impl SeqInternal {
fn from_seq<T>(s: &Seq<T>) -> Self
where
c_double: From<T>,
T: Copy + IndexableType,
{
Self {
begin: From::from(s.begin()),
end: From::from(s.end()),
step: From::from(s.step()),
}
}
}
#[cfg(test)]
mod tests {
use super::super::array::Array;
use super::super::data::constant;
use super::super::device::set_device;
use super::super::dim4::Dim4;
use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
use super::super::index::{cols, rows};
use super::super::random::randu;
use super::super::seq::Seq;
use crate::{dim4, seq, view};
#[test]
fn non_macro_seq_index() {
set_device(0);
let dims = Dim4::new(&[5, 5, 1, 1]);
let a = randu::<f32>(dims);
let seqs = &[Seq::new(1u32, 3, 1), Seq::default()];
let _sub = index(&a, seqs);
}
#[test]
fn seq_index() {
set_device(0);
let dims = dim4!(5, 5, 1, 1);
let a = randu::<f32>(dims);
let first3 = seq!(1:3:1);
let allindim2 = seq!();
let _sub = view!(a[first3, allindim2]);
}
#[test]
fn non_macro_seq_assign() {
set_device(0);
let mut a = constant(2.0 as f32, dim4!(5, 3));
let b = constant(1.0 as f32, dim4!(3, 3));
let seqs = [seq!(1:3:1), seq!()];
assign_seq(&mut a, &seqs, &b);
}
#[test]
fn non_macro_seq_array_index() {
set_device(0);
let values: [f32; 3] = [1.0, 2.0, 3.0];
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
let seq4gen = Seq::new(0.0, 2.0, 1.0);
let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
let mut idxrs = Indexer::default();
idxrs.set_index(&indices, 0, None); idxrs.set_index(&seq4gen, 1, Some(false));
let _sub2 = index_gen(&a, idxrs);
}
#[test]
fn seq_array_index() {
set_device(0);
let values: [f32; 3] = [1.0, 2.0, 3.0];
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
let seq4gen = seq!(0:2:1);
let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
let _sub2 = view!(a[indices, seq4gen]);
}
#[test]
fn non_macro_seq_array_assign() {
set_device(0);
let values: [f32; 3] = [1.0, 2.0, 3.0];
let indices = Array::new(&values, dim4!(3, 1, 1, 1));
let seq4gen = seq!(0:2:1);
let mut a = randu::<f32>(dim4!(5, 3, 1, 1));
let b = constant(2.0 as f32, dim4!(3, 3, 1, 1));
let mut idxrs = Indexer::default();
idxrs.set_index(&indices, 0, None); idxrs.set_index(&seq4gen, 1, Some(false));
let _sub2 = assign_gen(&mut a, &idxrs, &b);
}
#[test]
fn setrow() {
set_device(0);
let a = randu::<f32>(dim4!(5, 5, 1, 1));
let _r = row(&a, 4);
let _c = col(&a, 4);
}
#[test]
fn get_row() {
set_device(0);
let a = randu::<f32>(dim4!(5, 5));
let _r = row(&a, -1);
let _c = col(&a, -1);
}
#[test]
fn get_rows() {
set_device(0);
let a = randu::<f32>(dim4!(5, 5));
let _r = rows(&a, -1, -2);
let _c = cols(&a, -1, -3);
}
}