use ndarray::prelude::*;
use ndarray::{s, Data, DataMut, RemoveAxis};
use ordered_float::{NotNan, OrderedFloat};
use std::mem;
pub type N32 = NotNan<f32>;
pub type N64 = NotNan<f64>;
#[inline]
pub fn n32(num: f32) -> N32 {
N32::new(num).expect("NaN")
}
#[inline]
pub fn n64(num: f64) -> N64 {
N64::new(num).expect("NaN")
}
pub type O32 = OrderedFloat<f32>;
pub type O64 = OrderedFloat<f64>;
#[must_use]
#[inline]
pub fn o32(num: f32) -> O32 {
OrderedFloat(num)
}
#[must_use]
#[inline]
pub fn o64(num: f64) -> O64 {
OrderedFloat(num)
}
pub trait MaybeNan: Sized {
type NotNan;
#[must_use]
fn is_nan(&self) -> bool;
fn try_as_not_nan(&self) -> Option<&Self::NotNan>;
#[must_use]
fn from_not_nan(_: Self::NotNan) -> Self;
#[must_use]
fn from_not_nan_opt(_: Option<Self::NotNan>) -> Self;
#[must_use]
fn from_not_nan_ref_opt(_: Option<&Self::NotNan>) -> &Self;
#[must_use]
fn remove_nan_mut(_: ArrayViewMut1<'_, Self>) -> ArrayViewMut1<'_, Self::NotNan>;
}
fn remove_nan_mut<A: MaybeNan>(mut view: ArrayViewMut1<'_, A>) -> ArrayViewMut1<'_, A> {
if view.is_empty() {
return view.slice_move(s![..0]);
}
let mut i = 0;
let mut j = view.len() - 1;
loop {
while i <= j && !view[i].is_nan() {
i += 1;
}
while j > i && view[j].is_nan() {
j -= 1;
}
if i >= j {
return view.slice_move(s![..i]);
} else {
view.swap(i, j);
i += 1;
j -= 1;
}
}
}
unsafe fn cast_view_mut<T, U>(mut view: ArrayViewMut1<'_, T>) -> ArrayViewMut1<'_, U> {
assert_eq!(mem::size_of::<T>(), mem::size_of::<U>());
assert_eq!(mem::align_of::<T>(), mem::align_of::<U>());
let ptr: *mut U = view.as_mut_ptr().cast();
let len: usize = view.len_of(Axis(0));
let stride: isize = view.stride_of(Axis(0));
if len <= 1 {
let stride = 0;
ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr)
} else if stride >= 0 {
let stride = stride as usize;
ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr)
} else {
let neg_stride = stride.checked_neg().unwrap() as usize;
let neg_ptr = ptr.offset((len - 1) as isize * stride);
let mut v = ArrayViewMut1::from_shape_ptr([len].strides([neg_stride]), neg_ptr);
v.invert_axis(Axis(0));
v
}
}
macro_rules! impl_maybenan_for_fxx {
($fxx:ident, $Nxx:ident) => {
impl MaybeNan for $fxx {
type NotNan = $Nxx;
#[inline]
fn is_nan(&self) -> bool {
$fxx::is_nan(*self)
}
#[inline]
fn try_as_not_nan(&self) -> Option<&$Nxx> {
(!self.is_nan()).then(|| unsafe { mem::transmute(self) })
}
#[inline]
fn from_not_nan(value: $Nxx) -> $fxx {
*value
}
#[inline]
fn from_not_nan_opt(value: Option<$Nxx>) -> $fxx {
match value {
None => ::std::$fxx::NAN,
Some(num) => *num,
}
}
#[inline]
fn from_not_nan_ref_opt(value: Option<&$Nxx>) -> &$fxx {
match value {
None => &::std::$fxx::NAN,
Some(num) => num.as_ref(),
}
}
#[inline]
fn remove_nan_mut(view: ArrayViewMut1<'_, $fxx>) -> ArrayViewMut1<'_, $Nxx> {
let not_nan = remove_nan_mut(view);
unsafe { cast_view_mut(not_nan) }
}
}
};
}
impl_maybenan_for_fxx!(f32, N32);
impl_maybenan_for_fxx!(f64, N64);
impl MaybeNan for O32 {
type NotNan = N32;
#[inline]
fn is_nan(&self) -> bool {
self.0.is_nan()
}
#[inline]
fn try_as_not_nan(&self) -> Option<&N32> {
(!self.is_nan()).then(|| unsafe { mem::transmute::<&O32, &N32>(self) })
}
#[inline]
fn from_not_nan(value: N32) -> O32 {
o32(*value)
}
#[inline]
fn from_not_nan_opt(value: Option<N32>) -> O32 {
match value {
None => o32(::std::f32::NAN),
Some(num) => o32(*num),
}
}
#[inline]
fn from_not_nan_ref_opt(value: Option<&N32>) -> &O32 {
match value {
None => unsafe { mem::transmute::<&f64, &O32>(&::std::f64::NAN) },
Some(num) => unsafe { mem::transmute::<&N32, &O32>(num) },
}
}
#[inline]
fn remove_nan_mut(view: ArrayViewMut1<'_, O32>) -> ArrayViewMut1<'_, N32> {
let not_nan = remove_nan_mut(view);
unsafe { cast_view_mut(not_nan) }
}
}
impl MaybeNan for O64 {
type NotNan = N64;
#[inline]
fn is_nan(&self) -> bool {
self.0.is_nan()
}
#[inline]
fn try_as_not_nan(&self) -> Option<&N64> {
(!self.is_nan()).then(|| unsafe { mem::transmute::<&O64, &N64>(self) })
}
#[inline]
fn from_not_nan(value: N64) -> O64 {
o64(*value)
}
#[inline]
fn from_not_nan_opt(value: Option<N64>) -> O64 {
match value {
None => o64(::std::f64::NAN),
Some(num) => o64(*num),
}
}
#[inline]
fn from_not_nan_ref_opt(value: Option<&N64>) -> &O64 {
match value {
None => unsafe { mem::transmute::<&f64, &O64>(&::std::f64::NAN) },
Some(num) => unsafe { mem::transmute::<&N64, &O64>(num) },
}
}
#[inline]
fn remove_nan_mut(view: ArrayViewMut1<'_, O64>) -> ArrayViewMut1<'_, N64> {
let not_nan = remove_nan_mut(view);
unsafe { cast_view_mut(not_nan) }
}
}
macro_rules! impl_maybenan_for_opt_never_nan {
($ty:ty) => {
impl MaybeNan for Option<$ty> {
type NotNan = NotNone<$ty>;
#[inline]
fn is_nan(&self) -> bool {
self.is_none()
}
#[inline]
fn try_as_not_nan(&self) -> Option<&NotNone<$ty>> {
if self.is_none() {
None
} else {
Some(unsafe { &*(self as *const Option<$ty> as *const NotNone<$ty>) })
}
}
#[inline]
fn from_not_nan(value: NotNone<$ty>) -> Option<$ty> {
value.into_inner()
}
#[inline]
fn from_not_nan_opt(value: Option<NotNone<$ty>>) -> Option<$ty> {
value.and_then(|v| v.into_inner())
}
#[inline]
fn from_not_nan_ref_opt(value: Option<&NotNone<$ty>>) -> &Option<$ty> {
match value {
None => &None,
Some(num) => unsafe { &*(num as *const NotNone<$ty> as *const Option<$ty>) },
}
}
#[inline]
fn remove_nan_mut(view: ArrayViewMut1<'_, Self>) -> ArrayViewMut1<'_, Self::NotNan> {
let not_nan = remove_nan_mut(view);
unsafe {
ArrayViewMut1::from_shape_ptr(
not_nan.dim(),
not_nan.as_ptr() as *mut NotNone<$ty>,
)
}
}
}
};
}
impl_maybenan_for_opt_never_nan!(u8);
impl_maybenan_for_opt_never_nan!(u16);
impl_maybenan_for_opt_never_nan!(u32);
impl_maybenan_for_opt_never_nan!(u64);
impl_maybenan_for_opt_never_nan!(u128);
impl_maybenan_for_opt_never_nan!(i8);
impl_maybenan_for_opt_never_nan!(i16);
impl_maybenan_for_opt_never_nan!(i32);
impl_maybenan_for_opt_never_nan!(i64);
impl_maybenan_for_opt_never_nan!(i128);
impl_maybenan_for_opt_never_nan!(N32);
impl_maybenan_for_opt_never_nan!(N64);
#[derive(Clone, Copy, Debug)]
#[repr(transparent)]
pub struct NotNone<T>(Option<T>);
impl<T> NotNone<T> {
#[inline]
pub fn new(value: T) -> NotNone<T> {
NotNone(Some(value))
}
#[inline]
pub fn try_new(value: Option<T>) -> Option<NotNone<T>> {
if value.is_some() {
Some(NotNone(value))
} else {
None
}
}
#[inline]
pub fn into_inner(self) -> Option<T> {
self.0
}
#[inline]
pub fn unwrap(self) -> T {
match self.0 {
Some(inner) => inner,
None => unsafe { ::std::hint::unreachable_unchecked() },
}
}
#[inline]
pub fn map<U, F>(self, f: F) -> NotNone<U>
where
F: FnOnce(T) -> U,
{
NotNone::new(f(self.unwrap()))
}
}
pub trait MaybeNanExt<A, S, D>
where
A: MaybeNan,
S: Data<Elem = A>,
D: Dimension,
{
fn fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
where
A: 'a,
F: FnMut(B, &'a A::NotNan) -> B;
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
where
A: 'a,
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B;
fn visit_skipnan<'a, F>(&'a self, f: F)
where
A: 'a,
F: FnMut(&'a A::NotNan);
fn fold_axis_skipnan<B, F>(&self, axis: Axis, init: B, fold: F) -> Array<B, D::Smaller>
where
D: RemoveAxis,
F: FnMut(&B, &A::NotNan) -> B,
B: Clone;
fn map_axis_skipnan_mut<'a, B, F>(&'a mut self, axis: Axis, mapping: F) -> Array<B, D::Smaller>
where
A: 'a,
S: DataMut,
D: RemoveAxis,
F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B;
private_decl! {}
}
impl<A, S, D> MaybeNanExt<A, S, D> for ArrayBase<S, D>
where
A: MaybeNan,
S: Data<Elem = A>,
D: Dimension,
{
fn fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
where
A: 'a,
F: FnMut(B, &'a A::NotNan) -> B,
{
self.fold(init, |acc, elem| {
if let Some(not_nan) = elem.try_as_not_nan() {
f(acc, not_nan)
} else {
acc
}
})
}
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
where
A: 'a,
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B,
{
self.indexed_iter().fold(init, |acc, (idx, elem)| {
if let Some(not_nan) = elem.try_as_not_nan() {
f(acc, (idx, not_nan))
} else {
acc
}
})
}
fn visit_skipnan<'a, F>(&'a self, mut f: F)
where
A: 'a,
F: FnMut(&'a A::NotNan),
{
self.for_each(|elem| {
if let Some(not_nan) = elem.try_as_not_nan() {
f(not_nan)
}
})
}
fn fold_axis_skipnan<B, F>(&self, axis: Axis, init: B, mut fold: F) -> Array<B, D::Smaller>
where
D: RemoveAxis,
F: FnMut(&B, &A::NotNan) -> B,
B: Clone,
{
self.fold_axis(axis, init, |acc, elem| {
if let Some(not_nan) = elem.try_as_not_nan() {
fold(acc, not_nan)
} else {
acc.clone()
}
})
}
fn map_axis_skipnan_mut<'a, B, F>(
&'a mut self,
axis: Axis,
mut mapping: F,
) -> Array<B, D::Smaller>
where
A: 'a,
S: DataMut,
D: RemoveAxis,
F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B,
{
self.map_axis_mut(axis, |lane| mapping(A::remove_nan_mut(lane)))
}
private_impl! {}
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck_macros::quickcheck;
#[quickcheck]
fn remove_nan_mut_idempotent(is_nan: Vec<bool>) -> bool {
let mut values: Vec<_> = is_nan
.into_iter()
.map(|is_nan| if is_nan { None } else { Some(1) })
.collect();
let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap();
let removed = remove_nan_mut(view);
removed == remove_nan_mut(removed.to_owned().view_mut())
}
#[quickcheck]
fn remove_nan_mut_only_nan_remaining(is_nan: Vec<bool>) -> bool {
let mut values: Vec<_> = is_nan
.into_iter()
.map(|is_nan| if is_nan { None } else { Some(1) })
.collect();
let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap();
remove_nan_mut(view).iter().all(|elem| !elem.is_nan())
}
#[quickcheck]
fn remove_nan_mut_keep_all_non_nan(is_nan: Vec<bool>) -> bool {
let non_nan_count = is_nan.iter().filter(|&&is_nan| !is_nan).count();
let mut values: Vec<_> = is_nan
.into_iter()
.map(|is_nan| if is_nan { None } else { Some(1) })
.collect();
let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap();
remove_nan_mut(view).len() == non_nan_count
}
}
mod impl_not_none;