1use ndarray::prelude::*;
2use ndarray::{s, RemoveAxis};
3use noisy_float::types::{N32, N64};
4use std::mem;
5
6pub trait MaybeNan: Sized {
8 type NotNan;
10
11 fn is_nan(&self) -> bool;
13
14 fn try_as_not_nan(&self) -> Option<&Self::NotNan>;
18
19 fn from_not_nan(_: Self::NotNan) -> Self;
23
24 fn from_not_nan_opt(_: Option<Self::NotNan>) -> Self;
28
29 fn from_not_nan_ref_opt(_: Option<&Self::NotNan>) -> &Self;
33
34 fn remove_nan_mut(_: ArrayViewMut1<'_, Self>) -> ArrayViewMut1<'_, Self::NotNan>;
41}
42
43fn remove_nan_mut<A: MaybeNan>(mut view: ArrayViewMut1<'_, A>) -> ArrayViewMut1<'_, A> {
47 if view.is_empty() {
48 return view.slice_move(s![..0]);
49 }
50 let mut i = 0;
51 let mut j = view.len() - 1;
52 loop {
53 while i <= j && !view[i].is_nan() {
56 i += 1;
57 }
58 while j > i && view[j].is_nan() {
60 j -= 1;
61 }
62 if i >= j {
64 return view.slice_move(s![..i]);
65 } else {
66 view.swap(i, j);
67 i += 1;
68 j -= 1;
69 }
70 }
71}
72
73unsafe fn cast_view_mut<T, U>(mut view: ArrayViewMut1<'_, T>) -> ArrayViewMut1<'_, U> {
83 assert_eq!(mem::size_of::<T>(), mem::size_of::<U>());
84 assert_eq!(mem::align_of::<T>(), mem::align_of::<U>());
85 let ptr: *mut U = view.as_mut_ptr().cast();
86 let len: usize = view.len_of(Axis(0));
87 let stride: isize = view.stride_of(Axis(0));
88 if len <= 1 {
89 let stride = 0;
91 ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr)
92 } else if stride >= 0 {
93 let stride = stride as usize;
94 ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr)
95 } else {
96 let neg_stride = stride.checked_neg().unwrap() as usize;
100 let neg_ptr = ptr.offset((len - 1) as isize * stride);
103 let mut v = ArrayViewMut1::from_shape_ptr([len].strides([neg_stride]), neg_ptr);
104 v.invert_axis(Axis(0));
105 v
106 }
107}
108
109macro_rules! impl_maybenan_for_fxx {
110 ($fxx:ident, $Nxx:ident) => {
111 impl MaybeNan for $fxx {
112 type NotNan = $Nxx;
113
114 fn is_nan(&self) -> bool {
115 $fxx::is_nan(*self)
116 }
117
118 fn try_as_not_nan(&self) -> Option<&$Nxx> {
119 $Nxx::try_borrowed(self)
120 }
121
122 fn from_not_nan(value: $Nxx) -> $fxx {
123 value.raw()
124 }
125
126 fn from_not_nan_opt(value: Option<$Nxx>) -> $fxx {
127 match value {
128 None => ::std::$fxx::NAN,
129 Some(num) => num.raw(),
130 }
131 }
132
133 fn from_not_nan_ref_opt(value: Option<&$Nxx>) -> &$fxx {
134 match value {
135 None => &::std::$fxx::NAN,
136 Some(num) => num.as_ref(),
137 }
138 }
139
140 fn remove_nan_mut(view: ArrayViewMut1<'_, $fxx>) -> ArrayViewMut1<'_, $Nxx> {
141 let not_nan = remove_nan_mut(view);
142 unsafe { cast_view_mut(not_nan) }
145 }
146 }
147 };
148}
149impl_maybenan_for_fxx!(f32, N32);
150impl_maybenan_for_fxx!(f64, N64);
151
152macro_rules! impl_maybenan_for_opt_never_nan {
153 ($ty:ty) => {
154 impl MaybeNan for Option<$ty> {
155 type NotNan = NotNone<$ty>;
156
157 fn is_nan(&self) -> bool {
158 self.is_none()
159 }
160
161 fn try_as_not_nan(&self) -> Option<&NotNone<$ty>> {
162 if self.is_none() {
163 None
164 } else {
165 Some(unsafe { &*(self as *const Option<$ty> as *const NotNone<$ty>) })
168 }
169 }
170
171 fn from_not_nan(value: NotNone<$ty>) -> Option<$ty> {
172 value.into_inner()
173 }
174
175 fn from_not_nan_opt(value: Option<NotNone<$ty>>) -> Option<$ty> {
176 value.and_then(|v| v.into_inner())
177 }
178
179 fn from_not_nan_ref_opt(value: Option<&NotNone<$ty>>) -> &Option<$ty> {
180 match value {
181 None => &None,
182 Some(num) => unsafe { &*(num as *const NotNone<$ty> as *const Option<$ty>) },
185 }
186 }
187
188 fn remove_nan_mut(view: ArrayViewMut1<'_, Self>) -> ArrayViewMut1<'_, Self::NotNan> {
189 let not_nan = remove_nan_mut(view);
190 unsafe {
193 ArrayViewMut1::from_shape_ptr(
194 not_nan.dim(),
195 not_nan.as_ptr() as *mut NotNone<$ty>,
196 )
197 }
198 }
199 }
200 };
201}
202impl_maybenan_for_opt_never_nan!(u8);
203impl_maybenan_for_opt_never_nan!(u16);
204impl_maybenan_for_opt_never_nan!(u32);
205impl_maybenan_for_opt_never_nan!(u64);
206impl_maybenan_for_opt_never_nan!(u128);
207impl_maybenan_for_opt_never_nan!(i8);
208impl_maybenan_for_opt_never_nan!(i16);
209impl_maybenan_for_opt_never_nan!(i32);
210impl_maybenan_for_opt_never_nan!(i64);
211impl_maybenan_for_opt_never_nan!(i128);
212impl_maybenan_for_opt_never_nan!(N32);
213impl_maybenan_for_opt_never_nan!(N64);
214
215#[derive(Clone, Copy, Debug)]
218#[repr(transparent)]
219pub struct NotNone<T>(Option<T>);
220
221impl<T> NotNone<T> {
222 pub fn new(value: T) -> NotNone<T> {
224 NotNone(Some(value))
225 }
226
227 pub fn try_new(value: Option<T>) -> Option<NotNone<T>> {
231 if value.is_some() {
232 Some(NotNone(value))
233 } else {
234 None
235 }
236 }
237
238 pub fn into_inner(self) -> Option<T> {
240 self.0
241 }
242
243 pub fn unwrap(self) -> T {
247 match self.0 {
248 Some(inner) => inner,
249 None => unsafe { ::std::hint::unreachable_unchecked() },
250 }
251 }
252
253 pub fn map<U, F>(self, f: F) -> NotNone<U>
256 where
257 F: FnOnce(T) -> U,
258 {
259 NotNone::new(f(self.unwrap()))
260 }
261}
262
263pub trait MaybeNanExt<A, D>
265where
266 A: MaybeNan,
267 D: Dimension,
268{
269 fn fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
274 where
275 A: 'a,
276 F: FnMut(B, &'a A::NotNan) -> B;
277
278 fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
283 where
284 A: 'a,
285 F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B;
286
287 fn visit_skipnan<'a, F>(&'a self, f: F)
291 where
292 A: 'a,
293 F: FnMut(&'a A::NotNan);
294
295 fn fold_axis_skipnan<B, F>(&self, axis: Axis, init: B, fold: F) -> Array<B, D::Smaller>
300 where
301 D: RemoveAxis,
302 F: FnMut(&B, &A::NotNan) -> B,
303 B: Clone;
304
305 fn map_axis_skipnan_mut<'a, B, F>(&'a mut self, axis: Axis, mapping: F) -> Array<B, D::Smaller>
321 where
322 A: 'a,
323 D: RemoveAxis,
324 F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B;
325
326 private_decl! {}
327}
328
329impl<A, D> MaybeNanExt<A, D> for ArrayRef<A, D>
330where
331 A: MaybeNan,
332 D: Dimension,
333{
334 fn fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
335 where
336 A: 'a,
337 F: FnMut(B, &'a A::NotNan) -> B,
338 {
339 self.fold(init, |acc, elem| {
340 if let Some(not_nan) = elem.try_as_not_nan() {
341 f(acc, not_nan)
342 } else {
343 acc
344 }
345 })
346 }
347
348 fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
349 where
350 A: 'a,
351 F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B,
352 {
353 self.indexed_iter().fold(init, |acc, (idx, elem)| {
354 if let Some(not_nan) = elem.try_as_not_nan() {
355 f(acc, (idx, not_nan))
356 } else {
357 acc
358 }
359 })
360 }
361
362 fn visit_skipnan<'a, F>(&'a self, mut f: F)
363 where
364 A: 'a,
365 F: FnMut(&'a A::NotNan),
366 {
367 self.for_each(|elem| {
368 if let Some(not_nan) = elem.try_as_not_nan() {
369 f(not_nan)
370 }
371 })
372 }
373
374 fn fold_axis_skipnan<B, F>(&self, axis: Axis, init: B, mut fold: F) -> Array<B, D::Smaller>
375 where
376 D: RemoveAxis,
377 F: FnMut(&B, &A::NotNan) -> B,
378 B: Clone,
379 {
380 self.fold_axis(axis, init, |acc, elem| {
381 if let Some(not_nan) = elem.try_as_not_nan() {
382 fold(acc, not_nan)
383 } else {
384 acc.clone()
385 }
386 })
387 }
388
389 fn map_axis_skipnan_mut<'a, B, F>(
390 &'a mut self,
391 axis: Axis,
392 mut mapping: F,
393 ) -> Array<B, D::Smaller>
394 where
395 A: 'a,
396 D: RemoveAxis,
397 F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B,
398 {
399 self.map_axis_mut(axis, |lane| mapping(A::remove_nan_mut(lane)))
400 }
401
402 private_impl! {}
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use quickcheck_macros::quickcheck;
409
410 #[quickcheck]
411 fn remove_nan_mut_idempotent(is_nan: Vec<bool>) -> bool {
412 let mut values: Vec<_> = is_nan
413 .into_iter()
414 .map(|is_nan| if is_nan { None } else { Some(1) })
415 .collect();
416 let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap();
417 let removed = remove_nan_mut(view);
418 removed == remove_nan_mut(removed.to_owned().view_mut())
419 }
420
421 #[quickcheck]
422 fn remove_nan_mut_only_nan_remaining(is_nan: Vec<bool>) -> bool {
423 let mut values: Vec<_> = is_nan
424 .into_iter()
425 .map(|is_nan| if is_nan { None } else { Some(1) })
426 .collect();
427 let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap();
428 remove_nan_mut(view).iter().all(|elem| !elem.is_nan())
429 }
430
431 #[quickcheck]
432 fn remove_nan_mut_keep_all_non_nan(is_nan: Vec<bool>) -> bool {
433 let non_nan_count = is_nan.iter().filter(|&&is_nan| !is_nan).count();
434 let mut values: Vec<_> = is_nan
435 .into_iter()
436 .map(|is_nan| if is_nan { None } else { Some(1) })
437 .collect();
438 let view = ArrayViewMut1::from_shape(values.len(), &mut values).unwrap();
439 remove_nan_mut(view).len() == non_nan_count
440 }
441}
442
443mod impl_not_none;