use ndarray::prelude::*;
use ndarray::{Data, RemoveAxis, Zip};
use rawpointer::PointerExt;
use std::cmp::Ordering;
use std::ptr::copy_nonoverlapping;
#[derive(Clone, Debug)]
pub struct Permutation {
indices: Vec<usize>,
}
impl Permutation {
pub fn from_indices(v: Vec<usize>) -> Result<Self, ()> {
let perm = Permutation { indices: v };
if perm.correct() {
Ok(perm)
} else {
Err(())
}
}
fn correct(&self) -> bool {
let axis_len = self.indices.len();
let mut seen = vec![false; axis_len];
for &i in &self.indices {
match seen.get_mut(i) {
None => return false,
Some(s) => {
if *s {
return false;
} else {
*s = true;
}
}
}
}
true
}
}
pub trait SortArray {
fn identity(&self, axis: Axis) -> Permutation;
fn sort_axis_by<F>(&self, axis: Axis, less_than: F) -> Permutation
where
F: FnMut(usize, usize) -> bool;
}
pub trait PermuteArray {
type Elem;
type Dim;
fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<Self::Elem, Self::Dim>
where
Self::Elem: Clone,
Self::Dim: RemoveAxis;
}
impl<A, S, D> SortArray for ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn identity(&self, axis: Axis) -> Permutation {
Permutation {
indices: (0..self.len_of(axis)).collect(),
}
}
fn sort_axis_by<F>(&self, axis: Axis, mut less_than: F) -> Permutation
where
F: FnMut(usize, usize) -> bool,
{
let mut perm = self.identity(axis);
perm.indices.sort_by(move |&a, &b| {
if less_than(a, b) {
Ordering::Less
} else if less_than(b, a) {
Ordering::Greater
} else {
Ordering::Equal
}
});
perm
}
}
impl<A, D> PermuteArray for Array<A, D>
where
D: Dimension,
{
type Elem = A;
type Dim = D;
fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<A, D>
where
D: RemoveAxis,
{
let axis_len = self.len_of(axis);
let axis_stride = self.stride_of(axis);
assert_eq!(axis_len, perm.indices.len());
debug_assert!(perm.correct());
if self.is_empty() {
return self;
}
let mut result = Array::uninit(self.dim());
unsafe {
let mut moved_elements = 0;
let source_0 = self.raw_view().index_axis_move(axis, 0);
Zip::from(&perm.indices)
.and(result.axis_iter_mut(axis))
.for_each(|&perm_i, result_pane| {
Zip::from(result_pane)
.and(source_0.clone())
.for_each(|to, from_0| {
let from = from_0.stride_offset(axis_stride, perm_i);
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
moved_elements += 1;
});
});
debug_assert_eq!(result.len(), moved_elements);
let mut old_storage = self.into_raw_vec();
old_storage.set_len(0);
result.assume_init()
}
}
}
#[cfg(feature = "std")]
fn main() {
let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
let strings = a.map(|x| x.to_string());
let perm = a.sort_axis_by(Axis(1), |i, j| a[[i, 0]] > a[[j, 0]]);
println!("{:?}", perm);
let b = a.permute_axis(Axis(0), &perm);
println!("{:?}", b);
println!("{:?}", strings);
let c = strings.permute_axis(Axis(1), &perm);
println!("{:?}", c);
}
#[cfg(not(feature = "std"))]
fn main() {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permute_axis() {
let a = array![
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[108000.33, 4.],
[107999.45, 5.],
[107999.57, 6.],
[108010.69, 7.],
[107999.81, 8.],
[107999.94, 9.],
[75600.09, 10.],
[75600.21, 11.],
[75601.33, 12.],
[75600.45, 13.],
[75600.58, 14.],
[109000.70, 15.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
];
let answer = array![
[75600.09, 10.],
[75600.21, 11.],
[75600.45, 13.],
[75600.58, 14.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
[75601.33, 12.],
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[107999.45, 5.],
[107999.57, 6.],
[107999.81, 8.],
[107999.94, 9.],
[108000.33, 4.],
[108010.69, 7.],
[109000.70, 15.],
];
let mut af = Array::zeros(a.dim().f());
af.assign(&a);
let at = a.t().to_owned();
let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);
let b = a.permute_axis(Axis(0), &perm);
assert_eq!(b, answer);
let bf = af.permute_axis(Axis(0), &perm);
assert_eq!(bf, answer);
let bt = at.permute_axis(Axis(1), &perm);
assert_eq!(bt, answer.t());
}
}