cfpyo3_core/toolkit/
array.rs

1use anyhow::{Ok, Result};
2use core::{mem, ptr, slice};
3use itertools::{izip, Itertools};
4use memmap2::{Mmap, MmapOptions};
5use num_traits::{Float, FromPrimitive};
6use numpy::{
7    ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand},
8    Element,
9};
10use std::{
11    cell::UnsafeCell,
12    cmp::Ordering,
13    collections::HashMap,
14    fmt::{Debug, Display},
15    fs::File,
16    iter::zip,
17    marker::PhantomData,
18    ops::{AddAssign, MulAssign, SubAssign},
19    thread::available_parallelism,
20};
21
22#[derive(Debug)]
23pub struct ArrayError(String);
24impl ArrayError {
25    fn new(msg: &str) -> Self {
26        Self(msg.to_string())
27    }
28    pub fn data_not_contiguous<T>() -> Result<T> {
29        Err(ArrayError::new("data is not contiguous").into())
30    }
31}
32impl std::error::Error for ArrayError {}
33impl std::fmt::Display for ArrayError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "error occurred in `array` module: {}", self.0)
36    }
37}
38
39#[macro_export]
40macro_rules! as_data_slice_or_err {
41    ($data:expr) => {
42        match $data.as_slice() {
43            Some(data) => data,
44            None => return $crate::toolkit::array::ArrayError::data_not_contiguous(),
45        }
46    };
47}
48
49#[derive(Copy, Clone)]
50pub struct UnsafeSlice<'a, T> {
51    slice: &'a [UnsafeCell<T>],
52}
53unsafe impl<'a, T: Send + Sync> Send for UnsafeSlice<'a, T> {}
54unsafe impl<'a, T: Send + Sync> Sync for UnsafeSlice<'a, T> {}
55impl<'a, T> UnsafeSlice<'a, T> {
56    pub fn new(slice: &'a mut [T]) -> Self {
57        let ptr = slice as *mut [T] as *const [UnsafeCell<T>];
58        Self {
59            slice: unsafe { &*ptr },
60        }
61    }
62
63    pub fn shadow(&mut self) -> Self {
64        Self { slice: self.slice }
65    }
66
67    pub fn slice(&self, start: usize, end: usize) -> Self {
68        Self {
69            slice: &self.slice[start..end],
70        }
71    }
72
73    pub fn set(&mut self, i: usize, value: T) {
74        let ptr = self.slice[i].get();
75        unsafe {
76            ptr::write(ptr, value);
77        }
78    }
79
80    pub fn copy_from_slice(&mut self, i: usize, src: &[T])
81    where
82        T: Copy,
83    {
84        let ptr = self.slice[i].get();
85        unsafe {
86            ptr::copy_nonoverlapping(src.as_ptr(), ptr, src.len());
87        }
88    }
89}
90
91pub struct MmapArray1<T: Element>(Mmap, usize, PhantomData<T>);
92impl<T: Element> MmapArray1<T> {
93    /// # Safety
94    ///
95    /// The use of `mmap` is unsafe, see the documentation of [`MmapOptions`] for more details.
96    pub unsafe fn new(path: &str) -> Result<Self> {
97        let file = File::open(path)?;
98        let mmap = unsafe { MmapOptions::new().map(&file)? };
99        let len = mmap.len() / mem::size_of::<T>();
100        Ok(Self(mmap, len, PhantomData))
101    }
102
103    pub fn len(&self) -> usize {
104        self.1
105    }
106    pub fn is_empty(&self) -> bool {
107        self.1 == 0
108    }
109
110    /// # Safety
111    ///
112    /// The use of [`slice::from_raw_parts`] is unsafe, see its documentation for more details.
113    pub unsafe fn as_slice(&self) -> &[T] {
114        slice::from_raw_parts(self.0.as_ptr() as *const T, self.1)
115    }
116
117    /// # Safety
118    ///
119    /// The use of [`ArrayView1::from_shape_ptr`] is unsafe, see its documentation for more details.
120    pub unsafe fn as_array_view(&self) -> ArrayView1<T> {
121        ArrayView1::from_shape_ptr((self.1,), self.0.as_ptr() as *const T)
122    }
123}
124
125// float ops
126
127pub trait AFloat:
128    Float
129    + AddAssign
130    + SubAssign
131    + MulAssign
132    + FromPrimitive
133    + ScalarOperand
134    + Send
135    + Sync
136    + Debug
137    + Display
138{
139}
140impl<T> AFloat for T where
141    T: Float
142        + AddAssign
143        + SubAssign
144        + MulAssign
145        + FromPrimitive
146        + ScalarOperand
147        + Send
148        + Sync
149        + Debug
150        + Display
151{
152}
153
154// simd ops
155
156const LANES: usize = 16;
157
158macro_rules! simd_unary_reduce {
159    ($a:expr, $a_dtype:ty, $func:expr) => {{
160        let chunks = $a.chunks_exact(LANES);
161        let remainder = chunks.remainder();
162
163        let sum = chunks.fold([T::zero(); LANES], |mut acc, chunk| {
164            let chunk: [$a_dtype; LANES] = chunk.try_into().unwrap();
165            (0..LANES).for_each(|i| acc[i] += $func(chunk[i]));
166            acc
167        });
168
169        let mut reduced = T::zero();
170        sum.iter().for_each(|&x| reduced += x);
171        remainder.iter().for_each(|&x| reduced += $func(x));
172        reduced
173    }};
174    ($a:expr, $func:expr) => {{
175        simd_unary_reduce!($a, T, $func)
176    }};
177}
178macro_rules! simd_binary_reduce {
179    ($a:expr, $b:expr, $b_dtype:ty, $func:expr) => {{
180        let a_chunks = $a.chunks_exact(LANES);
181        let b_chunks = $b.chunks_exact(LANES);
182        let remainder_a = a_chunks.remainder();
183        let remainder_b = b_chunks.remainder();
184        let zip_chunks = zip(a_chunks, b_chunks);
185
186        let sum = zip_chunks.fold([T::zero(); LANES], |mut acc, (a_chunk, b_chunk)| {
187            let a_chunk: [T; LANES] = a_chunk.try_into().unwrap();
188            let b_chunk: [$b_dtype; LANES] = b_chunk.try_into().unwrap();
189            (0..LANES).for_each(|i| acc[i] += $func(a_chunk[i], b_chunk[i]));
190            acc
191        });
192
193        let mut reduced = T::zero();
194        sum.iter().for_each(|&x| reduced += x);
195        zip(remainder_a, remainder_b).for_each(|(&x, &y)| reduced += $func(x, y));
196
197        reduced
198    }};
199    ($a:expr, $b:expr, $func:expr) => {{
200        simd_binary_reduce!($a, $b, T, $func)
201    }};
202}
203macro_rules! simd_ternary_reduce {
204    ($a:expr, $b:expr, $c:expr, $c_dtype:ty, $func:expr) => {{
205        let a_chunks = $a.chunks_exact(LANES);
206        let b_chunks = $b.chunks_exact(LANES);
207        let c_chunks = $c.chunks_exact(LANES);
208        let remainder_a = a_chunks.remainder();
209        let remainder_b = b_chunks.remainder();
210        let remainder_c = c_chunks.remainder();
211        let zip_chunks = izip!(a_chunks, b_chunks, c_chunks);
212
213        let sum = zip_chunks.fold(
214            [T::zero(); LANES],
215            |mut acc, (a_chunk, b_chunk, c_chunk)| {
216                let a_chunk: [T; LANES] = a_chunk.try_into().unwrap();
217                let b_chunk: [T; LANES] = b_chunk.try_into().unwrap();
218                let c_chunk: [$c_dtype; LANES] = c_chunk.try_into().unwrap();
219                (0..LANES).for_each(|i| acc[i] += $func(a_chunk[i], b_chunk[i], c_chunk[i]));
220                acc
221            },
222        );
223
224        let mut reduced = T::zero();
225        sum.iter().for_each(|&x| reduced += x);
226        izip!(remainder_a, remainder_b, remainder_c).for_each(|(&x, &y, &z)| {
227            reduced += $func(x, y, z);
228        });
229
230        reduced
231    }};
232    ($a:expr, $b:expr, $c:expr, $func:expr) => {{
233        simd_ternary_reduce!($a, $b, $c, T, $func)
234    }};
235}
236
237pub fn simd_sum<T: AFloat>(a: &[T]) -> T {
238    simd_unary_reduce!(a, |x| x)
239}
240pub fn simd_mean<T: AFloat>(a: &[T]) -> T {
241    simd_sum(a) / T::from_usize(a.len()).unwrap()
242}
243pub fn simd_nanmean<T: AFloat>(a: &[T]) -> T {
244    let sum = simd_unary_reduce!(a, |x: T| if x.is_nan() { T::zero() } else { x });
245    let num = simd_unary_reduce!(a, |x: T| if x.is_nan() { T::zero() } else { T::one() });
246    sum / num
247}
248pub fn simd_masked_mean<T: AFloat>(a: &[T], valid_mask: &[bool]) -> T {
249    let sum = simd_binary_reduce!(a, valid_mask, bool, |x, y| if y { x } else { T::zero() });
250    let num = simd_unary_reduce!(valid_mask, bool, |x| if x { T::one() } else { T::zero() });
251    sum / num
252}
253pub fn simd_subtract<T: AFloat>(a: &[T], n: T) -> Vec<T> {
254    a.iter().map(|&x| x - n).collect()
255}
256pub fn simd_dot<T: AFloat>(a: &[T], b: &[T]) -> T {
257    simd_binary_reduce!(a, b, |x, y| x * y)
258}
259pub fn simd_inner<T: AFloat>(a: &[T]) -> T {
260    simd_unary_reduce!(a, |x| x * x)
261}
262
263// ops
264
265#[inline]
266fn get_valid_indices<T: AFloat>(a: ArrayView1<T>, b: ArrayView1<T>) -> Vec<usize> {
267    zip(a.iter(), b.iter())
268        .enumerate()
269        .filter_map(|(i, (&x, &y))| {
270            if x.is_nan() || y.is_nan() {
271                None
272            } else {
273                Some(i)
274            }
275        })
276        .collect()
277}
278#[inline]
279pub fn to_valid_indices(valid_mask: ArrayView1<bool>) -> Vec<usize> {
280    valid_mask
281        .iter()
282        .enumerate()
283        .filter_map(|(i, &valid)| if valid { Some(i) } else { None })
284        .collect()
285}
286
287#[inline]
288/// this function will put `NaN` at the end
289fn sorted<T: AFloat>(a: &[T]) -> Vec<&T> {
290    a.iter()
291        .sorted_by(|a, b| {
292            if a.is_nan() {
293                if b.is_nan() {
294                    Ordering::Equal
295                } else {
296                    Ordering::Greater
297                }
298            } else if b.is_nan() {
299                Ordering::Less
300            } else {
301                a.partial_cmp(b).unwrap()
302            }
303        })
304        .collect_vec()
305}
306#[inline]
307fn sorted_quantile<T: AFloat>(a: &[&T], q: T) -> T {
308    if a.is_empty() {
309        return T::nan();
310    }
311    let n = a.len() - 1;
312    let q = q * T::from_f64(n as f64).unwrap();
313    let i = q.floor().to_usize().unwrap();
314    if i == n {
315        return *a[n];
316    }
317    let q = q - T::from_usize(i).unwrap();
318    *a[i] * (T::one() - q) + *a[i + 1] * q
319}
320#[inline]
321fn sorted_median<T: AFloat>(a: &[&T]) -> T {
322    sorted_quantile(a, T::from_f64(0.5).unwrap())
323}
324
325#[inline]
326fn solve_2d<T: AFloat>(x: ArrayView2<T>, y: ArrayView1<T>) -> (T, T) {
327    let xtx = x.t().dot(&x);
328    let xty = x.t().dot(&y);
329    let xtx = xtx.into_raw_vec();
330    let (a, b, c, d) = (xtx[0], xtx[1], xtx[2], xtx[3]);
331    let xtx_inv = Array2::from_shape_vec((2, 2), vec![d, -b, -c, a]).unwrap();
332    let solution = xtx_inv.dot(&xty);
333    let solution = solution / (a * d - b * c).max(T::epsilon());
334    (solution[0], solution[1])
335}
336
337fn simd_corr<T: AFloat>(a: &[T], b: &[T]) -> T {
338    let a_mean = simd_mean(a);
339    let b_mean = simd_mean(b);
340    let a = simd_subtract(a, a_mean);
341    let b = simd_subtract(b, b_mean);
342    let a = a.as_slice();
343    let b = b.as_slice();
344    let cov = simd_dot(a, b);
345    let var1 = simd_inner(a);
346    let var2 = simd_inner(b);
347    cov / (var1.sqrt() * var2.sqrt())
348}
349fn simd_nancorr<T: AFloat>(a: &[T], b: &[T]) -> T {
350    let num = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
351        T::zero()
352    } else {
353        T::one()
354    });
355    if num == T::zero() || num == T::one() {
356        return T::nan();
357    }
358    let a_sum = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
359        T::zero()
360    } else {
361        x
362    });
363    let b_sum = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
364        T::zero()
365    } else {
366        y
367    });
368    let a_mean = a_sum / num;
369    let b_mean = b_sum / num;
370    let a = simd_subtract(a, a_mean);
371    let b = simd_subtract(b, b_mean);
372    let a = a.as_slice();
373    let b = b.as_slice();
374    let cov = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
375        T::zero()
376    } else {
377        x * y
378    });
379    let var1 = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
380        T::zero()
381    } else {
382        x * x
383    });
384    let var2 = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
385        T::zero()
386    } else {
387        y * y
388    });
389    cov / (var1.sqrt() * var2.sqrt())
390}
391fn simd_masked_corr<T: AFloat>(a: &[T], b: &[T], valid_mask: &[bool]) -> T {
392    let num = simd_unary_reduce!(valid_mask, bool, |x| if x { T::one() } else { T::zero() });
393    if num == T::zero() || num == T::one() {
394        return T::nan();
395    }
396    let a_sum = simd_binary_reduce!(a, valid_mask, bool, |x, y| if y { x } else { T::zero() });
397    let b_sum = simd_binary_reduce!(b, valid_mask, bool, |x, y| if y { x } else { T::zero() });
398    let a_mean = a_sum / num;
399    let b_mean = b_sum / num;
400    let a = simd_subtract(a, a_mean);
401    let b = simd_subtract(b, b_mean);
402    let a = a.as_slice();
403    let b = b.as_slice();
404    let cov = simd_ternary_reduce!(a, b, valid_mask, bool, |x, y, z| if z {
405        x * y
406    } else {
407        T::zero()
408    });
409    let var1 = simd_binary_reduce!(a, valid_mask, bool, |x, y| if y {
410        x * x
411    } else {
412        T::zero()
413    });
414    let var2 = simd_binary_reduce!(b, valid_mask, bool, |x, y| if y {
415        x * x
416    } else {
417        T::zero()
418    });
419    cov / (var1.sqrt() * var2.sqrt())
420}
421
422#[inline]
423fn coeff_with<T: AFloat>(
424    x: ArrayView1<T>,
425    y: ArrayView1<T>,
426    valid_indices: Vec<usize>,
427    q: Option<T>,
428) -> (T, T) {
429    if valid_indices.is_empty() {
430        return (T::nan(), T::nan());
431    }
432    let x = x.select(Axis(0), &valid_indices);
433    let mut y = y.select(Axis(0), &valid_indices);
434    let x_sorted = sorted(x.as_slice().unwrap());
435    let x_med = sorted_median(&x_sorted);
436    let x_mad = x_sorted.iter().map(|&x| (*x - x_med).abs()).collect_vec();
437    let x_mad = sorted_median(&sorted(&x_mad));
438    let hundred = T::from_f64(100.0).unwrap();
439    let x_floor = x_med - hundred * x_mad;
440    let x_ceil = x_med + hundred * x_mad;
441    let x = Array1::from_iter(x.iter().map(|&x| x.max(x_floor).min(x_ceil)));
442    let x_mean = x.mean().unwrap();
443    let x_std = x.std(T::zero()).max(T::epsilon());
444    let mut x = (x - x_mean) / x_std;
445    if let Some(q) = q {
446        if q > T::zero() {
447            let x_sorted = sorted(x.as_slice().unwrap());
448            let q_floor = sorted_quantile(&x_sorted, q);
449            let q_ceil = sorted_quantile(&x_sorted, T::one() - q);
450            let picked_indices: Vec<usize> = x
451                .iter()
452                .enumerate()
453                .filter_map(|(i, &x)| {
454                    if x <= q_floor || x >= q_ceil {
455                        Some(i)
456                    } else {
457                        None
458                    }
459                })
460                .collect();
461            x = x.select(Axis(0), &picked_indices);
462            y = y.select(Axis(0), &picked_indices);
463        }
464    }
465    let x = stack![Axis(1), x, Array1::ones(x.len())];
466    solve_2d(x.view(), y.view())
467}
468fn coeff<T: AFloat>(x: ArrayView1<T>, y: ArrayView1<T>, q: Option<T>) -> (T, T) {
469    coeff_with(x, y, get_valid_indices(x, y), q)
470}
471fn masked_coeff<T: AFloat>(
472    x: ArrayView1<T>,
473    y: ArrayView1<T>,
474    valid_mask: ArrayView1<bool>,
475    q: Option<T>,
476) -> (T, T) {
477    coeff_with(x, y, to_valid_indices(valid_mask), q)
478}
479
480// macros
481
482macro_rules! parallel_apply {
483    ($func:expr, $iter:expr, $slice:expr, $num_threads:expr) => {{
484        if $num_threads <= 1 {
485            $iter.enumerate().for_each(|(i, args)| {
486                $slice.set(i, $func(args));
487            });
488        } else {
489            let pool = rayon::ThreadPoolBuilder::new()
490                .num_threads($num_threads)
491                .build()
492                .unwrap();
493            pool.scope(|s| {
494                $iter.enumerate().for_each(|(i, args)| {
495                    s.spawn(move |_| $slice.set(i, $func(args)));
496                });
497            });
498        }
499    }};
500}
501
502// axis1 wrappers
503
504pub fn sum_axis1<T: AFloat>(a: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
505    let mut res: Vec<T> = vec![T::zero(); a.nrows()];
506    let mut slice = UnsafeSlice::new(&mut res);
507    parallel_apply!(
508        |row: ArrayView1<T>| simd_sum(row.as_slice().unwrap()),
509        a.rows().into_iter(),
510        slice,
511        num_threads
512    );
513    res
514}
515pub fn mean_axis1<T: AFloat>(a: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
516    let mut res: Vec<T> = vec![T::zero(); a.nrows()];
517    let mut slice = UnsafeSlice::new(&mut res);
518    parallel_apply!(
519        |row: ArrayView1<T>| simd_mean(row.as_slice().unwrap()),
520        a.rows().into_iter(),
521        slice,
522        num_threads
523    );
524    res
525}
526pub fn nanmean_axis1<T: AFloat>(a: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
527    let mut res: Vec<T> = vec![T::zero(); a.nrows()];
528    let mut slice = UnsafeSlice::new(&mut res);
529    parallel_apply!(
530        |row: ArrayView1<T>| simd_nanmean(row.as_slice().unwrap()),
531        a.rows().into_iter(),
532        slice,
533        num_threads
534    );
535    res
536}
537pub fn masked_mean_axis1<T: AFloat>(
538    a: &ArrayView2<T>,
539    valid_mask: &ArrayView2<bool>,
540    num_threads: usize,
541) -> Vec<T> {
542    let mut res: Vec<T> = vec![T::zero(); a.nrows()];
543    let mut slice = UnsafeSlice::new(&mut res);
544    parallel_apply!(
545        |(row, valid_mask): (ArrayView1<T>, ArrayView1<bool>)| simd_masked_mean(
546            row.as_slice().unwrap(),
547            valid_mask.as_slice().unwrap()
548        ),
549        zip(a.rows(), valid_mask.rows()),
550        slice,
551        num_threads
552    );
553    res
554}
555
556pub fn corr_axis1<T: AFloat>(a: &ArrayView2<T>, b: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
557    let mut res: Vec<T> = vec![T::zero(); a.nrows()];
558    let mut slice = UnsafeSlice::new(&mut res);
559    parallel_apply!(
560        |(a, b): (ArrayView1<T>, ArrayView1<T>)| simd_corr(
561            a.as_slice().unwrap(),
562            b.as_slice().unwrap()
563        ),
564        zip(a.rows(), b.rows()),
565        slice,
566        num_threads
567    );
568    res
569}
570pub fn nancorr_axis1<T: AFloat>(
571    a: &ArrayView2<T>,
572    b: &ArrayView2<T>,
573    num_threads: usize,
574) -> Vec<T> {
575    let mut res: Vec<T> = vec![T::zero(); a.nrows()];
576    let mut slice = UnsafeSlice::new(&mut res);
577    parallel_apply!(
578        |(a, b): (ArrayView1<T>, ArrayView1<T>)| simd_nancorr(
579            a.as_slice().unwrap(),
580            b.as_slice().unwrap()
581        ),
582        zip(a.rows(), b.rows()),
583        slice,
584        num_threads
585    );
586    res
587}
588pub fn masked_corr_axis1<T: AFloat>(
589    a: &ArrayView2<T>,
590    b: &ArrayView2<T>,
591    valid_mask: &ArrayView2<bool>,
592    num_threads: usize,
593) -> Vec<T> {
594    let mut res: Vec<T> = vec![T::zero(); a.nrows()];
595    let mut slice = UnsafeSlice::new(&mut res);
596    parallel_apply!(
597        |(a, b, valid_mask): (ArrayView1<T>, ArrayView1<T>, ArrayView1<bool>)| simd_masked_corr(
598            a.as_slice().unwrap(),
599            b.as_slice().unwrap(),
600            valid_mask.as_slice().unwrap()
601        ),
602        izip!(a.rows(), b.rows(), valid_mask.rows()),
603        slice,
604        num_threads
605    );
606    res
607}
608
609pub fn coeff_axis1<T: AFloat>(
610    x: &ArrayView2<T>,
611    y: &ArrayView2<T>,
612    q: Option<T>,
613    num_threads: usize,
614) -> (Vec<T>, Vec<T>) {
615    let mut ws: Vec<T> = vec![T::zero(); x.nrows()];
616    let mut bs: Vec<T> = vec![T::zero(); x.nrows()];
617    let mut slice0 = UnsafeSlice::new(&mut ws);
618    let mut slice1 = UnsafeSlice::new(&mut bs);
619    if num_threads <= 1 {
620        izip!(x.rows(), y.rows())
621            .enumerate()
622            .for_each(|(i, (x, y))| {
623                let (w, b) = coeff(x, y, q);
624                slice0.set(i, w);
625                slice1.set(i, b);
626            });
627    } else {
628        let pool = rayon::ThreadPoolBuilder::new()
629            .num_threads(num_threads)
630            .build()
631            .unwrap();
632        pool.scope(move |s| {
633            izip!(x.rows(), y.rows())
634                .enumerate()
635                .for_each(|(i, (x, y))| {
636                    s.spawn(move |_| {
637                        let (w, b) = coeff(x, y, q);
638                        slice0.set(i, w);
639                        slice1.set(i, b);
640                    });
641                });
642        });
643    }
644    (ws, bs)
645}
646pub fn masked_coeff_axis1<T: AFloat>(
647    x: &ArrayView2<T>,
648    y: &ArrayView2<T>,
649    valid_mask: &ArrayView2<bool>,
650    q: Option<T>,
651    num_threads: usize,
652) -> (Vec<T>, Vec<T>) {
653    let mut ws: Vec<T> = vec![T::zero(); x.nrows()];
654    let mut bs: Vec<T> = vec![T::zero(); x.nrows()];
655    let mut slice0 = UnsafeSlice::new(&mut ws);
656    let mut slice1 = UnsafeSlice::new(&mut bs);
657    if num_threads <= 1 {
658        izip!(x.rows(), y.rows(), valid_mask.rows())
659            .enumerate()
660            .for_each(|(i, (x, y, valid_mask))| {
661                let (w, b) = masked_coeff(x, y, valid_mask, q);
662                slice0.set(i, w);
663                slice1.set(i, b);
664            });
665    } else {
666        let pool = rayon::ThreadPoolBuilder::new()
667            .num_threads(num_threads)
668            .build()
669            .unwrap();
670        pool.scope(move |s| {
671            izip!(x.rows(), y.rows(), valid_mask.rows())
672                .enumerate()
673                .for_each(|(i, (x, y, valid_mask))| {
674                    s.spawn(move |_| {
675                        let (w, b) = masked_coeff(x, y, valid_mask, q);
676                        slice0.set(i, w);
677                        slice1.set(i, b);
678                    });
679                });
680        });
681    }
682    (ws, bs)
683}
684
685// misc
686
687pub fn unique(arr: &[i64]) -> (Array1<i64>, Array1<i64>) {
688    let mut counts = HashMap::new();
689
690    for &value in arr.iter() {
691        *counts.entry(value).or_insert(0) += 1;
692    }
693
694    let mut unique_values: Vec<i64> = counts.keys().cloned().collect();
695    unique_values.sort();
696
697    let counts: Vec<i64> = unique_values.iter().map(|&value| counts[&value]).collect();
698
699    (Array1::from(unique_values), Array1::from(counts))
700}
701
702pub fn searchsorted<T: Ord>(arr: &ArrayView1<T>, value: &T) -> usize {
703    arr.as_slice()
704        .unwrap()
705        .binary_search(value)
706        .unwrap_or_else(|x| x)
707}
708
709pub fn batch_searchsorted<T: Ord>(arr: &ArrayView1<T>, values: &ArrayView1<T>) -> Vec<usize> {
710    values
711        .iter()
712        .map(|value| searchsorted(arr, value))
713        .collect()
714}
715
716const CONCAT_GROUP_LIMIT: usize = 4 * 239 * 5000;
717type ConcatTask<'a, 'b, D> = (Vec<usize>, Vec<ArrayView2<'a, D>>, UnsafeSlice<'b, D>);
718#[inline]
719fn fill_concat<D: Copy>((offsets, arrays, mut out): ConcatTask<D>) {
720    offsets.iter().enumerate().for_each(|(i, &offset)| {
721        out.copy_from_slice(offset, arrays[i].as_slice().unwrap());
722    });
723}
724pub fn fast_concat_2d_axis0<D: Copy + Send + Sync>(
725    arrays: Vec<ArrayView2<D>>,
726    num_rows: Vec<usize>,
727    num_columns: usize,
728    limit_multiplier: usize,
729    mut out: UnsafeSlice<D>,
730) {
731    let mut cumsum: usize = 0;
732    let mut offsets: Vec<usize> = vec![0; num_rows.len()];
733    for i in 1..num_rows.len() {
734        cumsum += num_rows[i - 1];
735        offsets[i] = cumsum * num_columns;
736    }
737
738    let bumped_limit = CONCAT_GROUP_LIMIT * 16;
739    let total_bytes = offsets.last().unwrap() + num_rows.last().unwrap() * num_columns;
740    let (mut group_limit, mut tasks_divisor) = if total_bytes <= bumped_limit {
741        (CONCAT_GROUP_LIMIT, 8)
742    } else {
743        (bumped_limit, 1)
744    };
745    group_limit *= limit_multiplier;
746
747    let prior_num_tasks = total_bytes.div_ceil(group_limit);
748    let prior_num_threads = prior_num_tasks / tasks_divisor;
749    if prior_num_threads > 1 {
750        group_limit = total_bytes.div_ceil(prior_num_threads);
751        tasks_divisor = 1;
752    }
753
754    let nbytes = mem::size_of::<D>();
755
756    let mut tasks: Vec<ConcatTask<D>> = Vec::new();
757    let mut current_tasks: Option<ConcatTask<D>> = Some((Vec::new(), Vec::new(), out.shadow()));
758    let mut nbytes_cumsum = 0;
759    izip!(num_rows.iter(), offsets.into_iter(), arrays.into_iter()).for_each(
760        |(&num_row, offset, array)| {
761            nbytes_cumsum += nbytes * num_row * num_columns;
762            if let Some(ref mut current_tasks) = current_tasks {
763                current_tasks.0.push(offset);
764                current_tasks.1.push(array);
765            }
766            if nbytes_cumsum >= group_limit {
767                nbytes_cumsum = 0;
768                if let Some(current_tasks) = current_tasks.take() {
769                    tasks.push(current_tasks);
770                }
771                current_tasks = Some((Vec::new(), Vec::new(), out.shadow()));
772            }
773        },
774    );
775    if let Some(current_tasks) = current_tasks.take() {
776        if !current_tasks.0.is_empty() {
777            tasks.push(current_tasks);
778        }
779    }
780
781    let max_threads = available_parallelism()
782        .expect("failed to get available parallelism")
783        .get();
784    let num_threads = (tasks.len() / tasks_divisor).min(max_threads * 8).min(512);
785    if num_threads <= 1 {
786        tasks.into_iter().for_each(fill_concat);
787    } else {
788        let pool = rayon::ThreadPoolBuilder::new()
789            .num_threads(num_threads)
790            .build()
791            .unwrap();
792
793        pool.scope(move |s| {
794            tasks.into_iter().for_each(|task| {
795                s.spawn(move |_| fill_concat(task));
796            });
797        });
798    }
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804    use crate::toolkit::convert::to_bytes;
805    use std::io::Write;
806    use tempfile::tempdir;
807
808    fn assert_allclose<T: AFloat>(a: &[T], b: &[T]) {
809        let atol = T::from_f64(1e-6).unwrap();
810        let rtol = T::from_f64(1e-6).unwrap();
811        a.iter().zip(b.iter()).for_each(|(&x, &y)| {
812            assert!(
813                (x - y).abs() <= atol + rtol * y.abs(),
814                "not close - a: {:?}, b: {:?}",
815                a,
816                b,
817            );
818        });
819    }
820
821    #[test]
822    fn test_mmap() {
823        let dir = tempdir().unwrap();
824        let file_path = dir.path().join("test.cfy");
825        let array = Array1::<f32>::from_shape_vec(3, vec![1., 2., 3.]).unwrap();
826        let bytes = unsafe { to_bytes(array.as_slice().unwrap()) };
827        let mut file = File::create(&file_path).unwrap();
828        file.write_all(bytes).unwrap();
829        let file_path = file_path.to_str().unwrap();
830        let mmap_array = unsafe { MmapArray1::<f32>::new(file_path).unwrap() };
831        assert_eq!(array.len(), mmap_array.len());
832        assert_allclose(array.as_slice().unwrap(), unsafe { mmap_array.as_slice() });
833        assert_allclose(
834            array.as_slice().unwrap(),
835            unsafe { mmap_array.as_array_view() }.as_slice().unwrap(),
836        );
837    }
838
839    macro_rules! test_fast_concat_2d_axis0 {
840        ($dtype:ty) => {
841            let array_2d_u = ArrayView2::<$dtype>::from_shape((1, 3), &[1., 2., 3.]).unwrap();
842            let array_2d_l =
843                ArrayView2::<$dtype>::from_shape((2, 3), &[4., 5., 6., 7., 8., 9.]).unwrap();
844            let arrays = vec![array_2d_u, array_2d_l];
845            let mut out: Vec<$dtype> = vec![0.; 3 * 3];
846            let out_slice = UnsafeSlice::new(&mut out);
847            fast_concat_2d_axis0(arrays, vec![1, 2], 3, 1, out_slice);
848            assert_eq!(out.as_slice(), &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
849        };
850    }
851
852    macro_rules! test_mean_axis1 {
853        ($dtype:ty) => {
854            let array =
855                ArrayView2::<$dtype>::from_shape((2, 3), &[1., 2., 3., 4., 5., 6.]).unwrap();
856            let out = nanmean_axis1(&array, 1);
857            assert_allclose(out.as_slice(), &[2., 5.]);
858            let out = nanmean_axis1(&array, 2);
859            assert_allclose(out.as_slice(), &[2., 5.]);
860        };
861    }
862
863    macro_rules! test_corr_axis1 {
864        ($dtype:ty) => {
865            let array =
866                ArrayView2::<$dtype>::from_shape((2, 3), &[1., 2., 3., 4., 5., 6.]).unwrap();
867            let out = nancorr_axis1(&array, &(&array + 1.).view(), 1);
868            assert_allclose(out.as_slice(), &[1., 1.]);
869            let out = nancorr_axis1(&array, &(&array + 1.).view(), 2);
870            assert_allclose(out.as_slice(), &[1., 1.]);
871        };
872    }
873
874    #[test]
875    fn test_fast_concat_2d_axis0_f32() {
876        test_fast_concat_2d_axis0!(f32);
877    }
878    #[test]
879    fn test_fast_concat_2d_axis0_f64() {
880        test_fast_concat_2d_axis0!(f64);
881    }
882
883    #[test]
884    fn test_mean_axis1_f32() {
885        test_mean_axis1!(f32);
886    }
887    #[test]
888    fn test_mean_axis1_f64() {
889        test_mean_axis1!(f64);
890    }
891
892    #[test]
893    fn test_corr_axis1_f32() {
894        test_corr_axis1!(f32);
895    }
896    #[test]
897    fn test_corr_axis1_f64() {
898        test_corr_axis1!(f64);
899    }
900
901    #[test]
902    fn test_coeff_axis1() {
903        let x = ArrayView2::<f64>::from_shape((2, 3), &[2., 1., 3., 6., 4., 5.]).unwrap();
904        let y = ArrayView2::<f64>::from_shape((2, 3), &[4., 2., 6., 12., 8., 10.]).unwrap();
905        let scale = 2. * (2. / 3.).sqrt();
906        let (ws, bs) = coeff_axis1(&x, &y, None, 1);
907        assert_allclose(ws.as_slice(), &[scale, scale]);
908        assert_allclose(bs.as_slice(), &[4., 10.]);
909        let (ws, bs) = coeff_axis1(&x, &y, None, 2);
910        assert_allclose(ws.as_slice(), &[scale, scale]);
911        assert_allclose(bs.as_slice(), &[4., 10.]);
912    }
913
914    #[test]
915    fn test_searchsorted() {
916        let array = ArrayView1::<i64>::from_shape(5, &[1, 2, 3, 5, 6]).unwrap();
917        assert_eq!(searchsorted(&array, &0), 0);
918        assert_eq!(searchsorted(&array, &1), 0);
919        assert_eq!(searchsorted(&array, &3), 2);
920        assert_eq!(searchsorted(&array, &4), 3);
921        assert_eq!(searchsorted(&array, &5), 3);
922        assert_eq!(searchsorted(&array, &6), 4);
923        assert_eq!(searchsorted(&array, &7), 5);
924        assert_eq!(batch_searchsorted(&array, &array), vec![0, 1, 2, 3, 4]);
925    }
926}