1use anyhow::{Ok, Result};
2use core::{mem, ptr, slice};
3use itertools::{izip, Itertools};
4use memmap2::{Mmap, MmapOptions};
5use num_traits::{Float, FromPrimitive};
6use numpy::{
7 ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand},
8 Element,
9};
10use std::{
11 cell::UnsafeCell,
12 cmp::Ordering,
13 collections::HashMap,
14 fmt::{Debug, Display},
15 fs::File,
16 iter::zip,
17 marker::PhantomData,
18 ops::{AddAssign, MulAssign, SubAssign},
19 thread::available_parallelism,
20};
21
22#[derive(Debug)]
23pub struct ArrayError(String);
24impl ArrayError {
25 fn new(msg: &str) -> Self {
26 Self(msg.to_string())
27 }
28 pub fn data_not_contiguous<T>() -> Result<T> {
29 Err(ArrayError::new("data is not contiguous").into())
30 }
31}
32impl std::error::Error for ArrayError {}
33impl std::fmt::Display for ArrayError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "error occurred in `array` module: {}", self.0)
36 }
37}
38
39#[macro_export]
40macro_rules! as_data_slice_or_err {
41 ($data:expr) => {
42 match $data.as_slice() {
43 Some(data) => data,
44 None => return $crate::toolkit::array::ArrayError::data_not_contiguous(),
45 }
46 };
47}
48
49#[derive(Copy, Clone)]
50pub struct UnsafeSlice<'a, T> {
51 slice: &'a [UnsafeCell<T>],
52}
53unsafe impl<'a, T: Send + Sync> Send for UnsafeSlice<'a, T> {}
54unsafe impl<'a, T: Send + Sync> Sync for UnsafeSlice<'a, T> {}
55impl<'a, T> UnsafeSlice<'a, T> {
56 pub fn new(slice: &'a mut [T]) -> Self {
57 let ptr = slice as *mut [T] as *const [UnsafeCell<T>];
58 Self {
59 slice: unsafe { &*ptr },
60 }
61 }
62
63 pub fn shadow(&mut self) -> Self {
64 Self { slice: self.slice }
65 }
66
67 pub fn slice(&self, start: usize, end: usize) -> Self {
68 Self {
69 slice: &self.slice[start..end],
70 }
71 }
72
73 pub fn set(&mut self, i: usize, value: T) {
74 let ptr = self.slice[i].get();
75 unsafe {
76 ptr::write(ptr, value);
77 }
78 }
79
80 pub fn copy_from_slice(&mut self, i: usize, src: &[T])
81 where
82 T: Copy,
83 {
84 let ptr = self.slice[i].get();
85 unsafe {
86 ptr::copy_nonoverlapping(src.as_ptr(), ptr, src.len());
87 }
88 }
89}
90
91pub struct MmapArray1<T: Element>(Mmap, usize, PhantomData<T>);
92impl<T: Element> MmapArray1<T> {
93 pub unsafe fn new(path: &str) -> Result<Self> {
97 let file = File::open(path)?;
98 let mmap = unsafe { MmapOptions::new().map(&file)? };
99 let len = mmap.len() / mem::size_of::<T>();
100 Ok(Self(mmap, len, PhantomData))
101 }
102
103 pub fn len(&self) -> usize {
104 self.1
105 }
106 pub fn is_empty(&self) -> bool {
107 self.1 == 0
108 }
109
110 pub unsafe fn as_slice(&self) -> &[T] {
114 slice::from_raw_parts(self.0.as_ptr() as *const T, self.1)
115 }
116
117 pub unsafe fn as_array_view(&self) -> ArrayView1<T> {
121 ArrayView1::from_shape_ptr((self.1,), self.0.as_ptr() as *const T)
122 }
123}
124
125pub trait AFloat:
128 Float
129 + AddAssign
130 + SubAssign
131 + MulAssign
132 + FromPrimitive
133 + ScalarOperand
134 + Send
135 + Sync
136 + Debug
137 + Display
138{
139}
140impl<T> AFloat for T where
141 T: Float
142 + AddAssign
143 + SubAssign
144 + MulAssign
145 + FromPrimitive
146 + ScalarOperand
147 + Send
148 + Sync
149 + Debug
150 + Display
151{
152}
153
154const LANES: usize = 16;
157
158macro_rules! simd_unary_reduce {
159 ($a:expr, $a_dtype:ty, $func:expr) => {{
160 let chunks = $a.chunks_exact(LANES);
161 let remainder = chunks.remainder();
162
163 let sum = chunks.fold([T::zero(); LANES], |mut acc, chunk| {
164 let chunk: [$a_dtype; LANES] = chunk.try_into().unwrap();
165 (0..LANES).for_each(|i| acc[i] += $func(chunk[i]));
166 acc
167 });
168
169 let mut reduced = T::zero();
170 sum.iter().for_each(|&x| reduced += x);
171 remainder.iter().for_each(|&x| reduced += $func(x));
172 reduced
173 }};
174 ($a:expr, $func:expr) => {{
175 simd_unary_reduce!($a, T, $func)
176 }};
177}
178macro_rules! simd_binary_reduce {
179 ($a:expr, $b:expr, $b_dtype:ty, $func:expr) => {{
180 let a_chunks = $a.chunks_exact(LANES);
181 let b_chunks = $b.chunks_exact(LANES);
182 let remainder_a = a_chunks.remainder();
183 let remainder_b = b_chunks.remainder();
184 let zip_chunks = zip(a_chunks, b_chunks);
185
186 let sum = zip_chunks.fold([T::zero(); LANES], |mut acc, (a_chunk, b_chunk)| {
187 let a_chunk: [T; LANES] = a_chunk.try_into().unwrap();
188 let b_chunk: [$b_dtype; LANES] = b_chunk.try_into().unwrap();
189 (0..LANES).for_each(|i| acc[i] += $func(a_chunk[i], b_chunk[i]));
190 acc
191 });
192
193 let mut reduced = T::zero();
194 sum.iter().for_each(|&x| reduced += x);
195 zip(remainder_a, remainder_b).for_each(|(&x, &y)| reduced += $func(x, y));
196
197 reduced
198 }};
199 ($a:expr, $b:expr, $func:expr) => {{
200 simd_binary_reduce!($a, $b, T, $func)
201 }};
202}
203macro_rules! simd_ternary_reduce {
204 ($a:expr, $b:expr, $c:expr, $c_dtype:ty, $func:expr) => {{
205 let a_chunks = $a.chunks_exact(LANES);
206 let b_chunks = $b.chunks_exact(LANES);
207 let c_chunks = $c.chunks_exact(LANES);
208 let remainder_a = a_chunks.remainder();
209 let remainder_b = b_chunks.remainder();
210 let remainder_c = c_chunks.remainder();
211 let zip_chunks = izip!(a_chunks, b_chunks, c_chunks);
212
213 let sum = zip_chunks.fold(
214 [T::zero(); LANES],
215 |mut acc, (a_chunk, b_chunk, c_chunk)| {
216 let a_chunk: [T; LANES] = a_chunk.try_into().unwrap();
217 let b_chunk: [T; LANES] = b_chunk.try_into().unwrap();
218 let c_chunk: [$c_dtype; LANES] = c_chunk.try_into().unwrap();
219 (0..LANES).for_each(|i| acc[i] += $func(a_chunk[i], b_chunk[i], c_chunk[i]));
220 acc
221 },
222 );
223
224 let mut reduced = T::zero();
225 sum.iter().for_each(|&x| reduced += x);
226 izip!(remainder_a, remainder_b, remainder_c).for_each(|(&x, &y, &z)| {
227 reduced += $func(x, y, z);
228 });
229
230 reduced
231 }};
232 ($a:expr, $b:expr, $c:expr, $func:expr) => {{
233 simd_ternary_reduce!($a, $b, $c, T, $func)
234 }};
235}
236
237pub fn simd_sum<T: AFloat>(a: &[T]) -> T {
238 simd_unary_reduce!(a, |x| x)
239}
240pub fn simd_mean<T: AFloat>(a: &[T]) -> T {
241 simd_sum(a) / T::from_usize(a.len()).unwrap()
242}
243pub fn simd_nanmean<T: AFloat>(a: &[T]) -> T {
244 let sum = simd_unary_reduce!(a, |x: T| if x.is_nan() { T::zero() } else { x });
245 let num = simd_unary_reduce!(a, |x: T| if x.is_nan() { T::zero() } else { T::one() });
246 sum / num
247}
248pub fn simd_masked_mean<T: AFloat>(a: &[T], valid_mask: &[bool]) -> T {
249 let sum = simd_binary_reduce!(a, valid_mask, bool, |x, y| if y { x } else { T::zero() });
250 let num = simd_unary_reduce!(valid_mask, bool, |x| if x { T::one() } else { T::zero() });
251 sum / num
252}
253pub fn simd_subtract<T: AFloat>(a: &[T], n: T) -> Vec<T> {
254 a.iter().map(|&x| x - n).collect()
255}
256pub fn simd_dot<T: AFloat>(a: &[T], b: &[T]) -> T {
257 simd_binary_reduce!(a, b, |x, y| x * y)
258}
259pub fn simd_inner<T: AFloat>(a: &[T]) -> T {
260 simd_unary_reduce!(a, |x| x * x)
261}
262
263#[inline]
266fn get_valid_indices<T: AFloat>(a: ArrayView1<T>, b: ArrayView1<T>) -> Vec<usize> {
267 zip(a.iter(), b.iter())
268 .enumerate()
269 .filter_map(|(i, (&x, &y))| {
270 if x.is_nan() || y.is_nan() {
271 None
272 } else {
273 Some(i)
274 }
275 })
276 .collect()
277}
278#[inline]
279pub fn to_valid_indices(valid_mask: ArrayView1<bool>) -> Vec<usize> {
280 valid_mask
281 .iter()
282 .enumerate()
283 .filter_map(|(i, &valid)| if valid { Some(i) } else { None })
284 .collect()
285}
286
287#[inline]
288fn sorted<T: AFloat>(a: &[T]) -> Vec<&T> {
290 a.iter()
291 .sorted_by(|a, b| {
292 if a.is_nan() {
293 if b.is_nan() {
294 Ordering::Equal
295 } else {
296 Ordering::Greater
297 }
298 } else if b.is_nan() {
299 Ordering::Less
300 } else {
301 a.partial_cmp(b).unwrap()
302 }
303 })
304 .collect_vec()
305}
306#[inline]
307fn sorted_quantile<T: AFloat>(a: &[&T], q: T) -> T {
308 if a.is_empty() {
309 return T::nan();
310 }
311 let n = a.len() - 1;
312 let q = q * T::from_f64(n as f64).unwrap();
313 let i = q.floor().to_usize().unwrap();
314 if i == n {
315 return *a[n];
316 }
317 let q = q - T::from_usize(i).unwrap();
318 *a[i] * (T::one() - q) + *a[i + 1] * q
319}
320#[inline]
321fn sorted_median<T: AFloat>(a: &[&T]) -> T {
322 sorted_quantile(a, T::from_f64(0.5).unwrap())
323}
324
325#[inline]
326fn solve_2d<T: AFloat>(x: ArrayView2<T>, y: ArrayView1<T>) -> (T, T) {
327 let xtx = x.t().dot(&x);
328 let xty = x.t().dot(&y);
329 let xtx = xtx.into_raw_vec();
330 let (a, b, c, d) = (xtx[0], xtx[1], xtx[2], xtx[3]);
331 let xtx_inv = Array2::from_shape_vec((2, 2), vec![d, -b, -c, a]).unwrap();
332 let solution = xtx_inv.dot(&xty);
333 let solution = solution / (a * d - b * c).max(T::epsilon());
334 (solution[0], solution[1])
335}
336
337fn simd_corr<T: AFloat>(a: &[T], b: &[T]) -> T {
338 let a_mean = simd_mean(a);
339 let b_mean = simd_mean(b);
340 let a = simd_subtract(a, a_mean);
341 let b = simd_subtract(b, b_mean);
342 let a = a.as_slice();
343 let b = b.as_slice();
344 let cov = simd_dot(a, b);
345 let var1 = simd_inner(a);
346 let var2 = simd_inner(b);
347 cov / (var1.sqrt() * var2.sqrt())
348}
349fn simd_nancorr<T: AFloat>(a: &[T], b: &[T]) -> T {
350 let num = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
351 T::zero()
352 } else {
353 T::one()
354 });
355 if num == T::zero() || num == T::one() {
356 return T::nan();
357 }
358 let a_sum = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
359 T::zero()
360 } else {
361 x
362 });
363 let b_sum = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
364 T::zero()
365 } else {
366 y
367 });
368 let a_mean = a_sum / num;
369 let b_mean = b_sum / num;
370 let a = simd_subtract(a, a_mean);
371 let b = simd_subtract(b, b_mean);
372 let a = a.as_slice();
373 let b = b.as_slice();
374 let cov = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
375 T::zero()
376 } else {
377 x * y
378 });
379 let var1 = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
380 T::zero()
381 } else {
382 x * x
383 });
384 let var2 = simd_binary_reduce!(a, b, |x: T, y: T| if x.is_nan() || y.is_nan() {
385 T::zero()
386 } else {
387 y * y
388 });
389 cov / (var1.sqrt() * var2.sqrt())
390}
391fn simd_masked_corr<T: AFloat>(a: &[T], b: &[T], valid_mask: &[bool]) -> T {
392 let num = simd_unary_reduce!(valid_mask, bool, |x| if x { T::one() } else { T::zero() });
393 if num == T::zero() || num == T::one() {
394 return T::nan();
395 }
396 let a_sum = simd_binary_reduce!(a, valid_mask, bool, |x, y| if y { x } else { T::zero() });
397 let b_sum = simd_binary_reduce!(b, valid_mask, bool, |x, y| if y { x } else { T::zero() });
398 let a_mean = a_sum / num;
399 let b_mean = b_sum / num;
400 let a = simd_subtract(a, a_mean);
401 let b = simd_subtract(b, b_mean);
402 let a = a.as_slice();
403 let b = b.as_slice();
404 let cov = simd_ternary_reduce!(a, b, valid_mask, bool, |x, y, z| if z {
405 x * y
406 } else {
407 T::zero()
408 });
409 let var1 = simd_binary_reduce!(a, valid_mask, bool, |x, y| if y {
410 x * x
411 } else {
412 T::zero()
413 });
414 let var2 = simd_binary_reduce!(b, valid_mask, bool, |x, y| if y {
415 x * x
416 } else {
417 T::zero()
418 });
419 cov / (var1.sqrt() * var2.sqrt())
420}
421
422#[inline]
423fn coeff_with<T: AFloat>(
424 x: ArrayView1<T>,
425 y: ArrayView1<T>,
426 valid_indices: Vec<usize>,
427 q: Option<T>,
428) -> (T, T) {
429 if valid_indices.is_empty() {
430 return (T::nan(), T::nan());
431 }
432 let x = x.select(Axis(0), &valid_indices);
433 let mut y = y.select(Axis(0), &valid_indices);
434 let x_sorted = sorted(x.as_slice().unwrap());
435 let x_med = sorted_median(&x_sorted);
436 let x_mad = x_sorted.iter().map(|&x| (*x - x_med).abs()).collect_vec();
437 let x_mad = sorted_median(&sorted(&x_mad));
438 let hundred = T::from_f64(100.0).unwrap();
439 let x_floor = x_med - hundred * x_mad;
440 let x_ceil = x_med + hundred * x_mad;
441 let x = Array1::from_iter(x.iter().map(|&x| x.max(x_floor).min(x_ceil)));
442 let x_mean = x.mean().unwrap();
443 let x_std = x.std(T::zero()).max(T::epsilon());
444 let mut x = (x - x_mean) / x_std;
445 if let Some(q) = q {
446 if q > T::zero() {
447 let x_sorted = sorted(x.as_slice().unwrap());
448 let q_floor = sorted_quantile(&x_sorted, q);
449 let q_ceil = sorted_quantile(&x_sorted, T::one() - q);
450 let picked_indices: Vec<usize> = x
451 .iter()
452 .enumerate()
453 .filter_map(|(i, &x)| {
454 if x <= q_floor || x >= q_ceil {
455 Some(i)
456 } else {
457 None
458 }
459 })
460 .collect();
461 x = x.select(Axis(0), &picked_indices);
462 y = y.select(Axis(0), &picked_indices);
463 }
464 }
465 let x = stack![Axis(1), x, Array1::ones(x.len())];
466 solve_2d(x.view(), y.view())
467}
468fn coeff<T: AFloat>(x: ArrayView1<T>, y: ArrayView1<T>, q: Option<T>) -> (T, T) {
469 coeff_with(x, y, get_valid_indices(x, y), q)
470}
471fn masked_coeff<T: AFloat>(
472 x: ArrayView1<T>,
473 y: ArrayView1<T>,
474 valid_mask: ArrayView1<bool>,
475 q: Option<T>,
476) -> (T, T) {
477 coeff_with(x, y, to_valid_indices(valid_mask), q)
478}
479
480macro_rules! parallel_apply {
483 ($func:expr, $iter:expr, $slice:expr, $num_threads:expr) => {{
484 if $num_threads <= 1 {
485 $iter.enumerate().for_each(|(i, args)| {
486 $slice.set(i, $func(args));
487 });
488 } else {
489 let pool = rayon::ThreadPoolBuilder::new()
490 .num_threads($num_threads)
491 .build()
492 .unwrap();
493 pool.scope(|s| {
494 $iter.enumerate().for_each(|(i, args)| {
495 s.spawn(move |_| $slice.set(i, $func(args)));
496 });
497 });
498 }
499 }};
500}
501
502pub fn sum_axis1<T: AFloat>(a: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
505 let mut res: Vec<T> = vec![T::zero(); a.nrows()];
506 let mut slice = UnsafeSlice::new(&mut res);
507 parallel_apply!(
508 |row: ArrayView1<T>| simd_sum(row.as_slice().unwrap()),
509 a.rows().into_iter(),
510 slice,
511 num_threads
512 );
513 res
514}
515pub fn mean_axis1<T: AFloat>(a: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
516 let mut res: Vec<T> = vec![T::zero(); a.nrows()];
517 let mut slice = UnsafeSlice::new(&mut res);
518 parallel_apply!(
519 |row: ArrayView1<T>| simd_mean(row.as_slice().unwrap()),
520 a.rows().into_iter(),
521 slice,
522 num_threads
523 );
524 res
525}
526pub fn nanmean_axis1<T: AFloat>(a: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
527 let mut res: Vec<T> = vec![T::zero(); a.nrows()];
528 let mut slice = UnsafeSlice::new(&mut res);
529 parallel_apply!(
530 |row: ArrayView1<T>| simd_nanmean(row.as_slice().unwrap()),
531 a.rows().into_iter(),
532 slice,
533 num_threads
534 );
535 res
536}
537pub fn masked_mean_axis1<T: AFloat>(
538 a: &ArrayView2<T>,
539 valid_mask: &ArrayView2<bool>,
540 num_threads: usize,
541) -> Vec<T> {
542 let mut res: Vec<T> = vec![T::zero(); a.nrows()];
543 let mut slice = UnsafeSlice::new(&mut res);
544 parallel_apply!(
545 |(row, valid_mask): (ArrayView1<T>, ArrayView1<bool>)| simd_masked_mean(
546 row.as_slice().unwrap(),
547 valid_mask.as_slice().unwrap()
548 ),
549 zip(a.rows(), valid_mask.rows()),
550 slice,
551 num_threads
552 );
553 res
554}
555
556pub fn corr_axis1<T: AFloat>(a: &ArrayView2<T>, b: &ArrayView2<T>, num_threads: usize) -> Vec<T> {
557 let mut res: Vec<T> = vec![T::zero(); a.nrows()];
558 let mut slice = UnsafeSlice::new(&mut res);
559 parallel_apply!(
560 |(a, b): (ArrayView1<T>, ArrayView1<T>)| simd_corr(
561 a.as_slice().unwrap(),
562 b.as_slice().unwrap()
563 ),
564 zip(a.rows(), b.rows()),
565 slice,
566 num_threads
567 );
568 res
569}
570pub fn nancorr_axis1<T: AFloat>(
571 a: &ArrayView2<T>,
572 b: &ArrayView2<T>,
573 num_threads: usize,
574) -> Vec<T> {
575 let mut res: Vec<T> = vec![T::zero(); a.nrows()];
576 let mut slice = UnsafeSlice::new(&mut res);
577 parallel_apply!(
578 |(a, b): (ArrayView1<T>, ArrayView1<T>)| simd_nancorr(
579 a.as_slice().unwrap(),
580 b.as_slice().unwrap()
581 ),
582 zip(a.rows(), b.rows()),
583 slice,
584 num_threads
585 );
586 res
587}
588pub fn masked_corr_axis1<T: AFloat>(
589 a: &ArrayView2<T>,
590 b: &ArrayView2<T>,
591 valid_mask: &ArrayView2<bool>,
592 num_threads: usize,
593) -> Vec<T> {
594 let mut res: Vec<T> = vec![T::zero(); a.nrows()];
595 let mut slice = UnsafeSlice::new(&mut res);
596 parallel_apply!(
597 |(a, b, valid_mask): (ArrayView1<T>, ArrayView1<T>, ArrayView1<bool>)| simd_masked_corr(
598 a.as_slice().unwrap(),
599 b.as_slice().unwrap(),
600 valid_mask.as_slice().unwrap()
601 ),
602 izip!(a.rows(), b.rows(), valid_mask.rows()),
603 slice,
604 num_threads
605 );
606 res
607}
608
609pub fn coeff_axis1<T: AFloat>(
610 x: &ArrayView2<T>,
611 y: &ArrayView2<T>,
612 q: Option<T>,
613 num_threads: usize,
614) -> (Vec<T>, Vec<T>) {
615 let mut ws: Vec<T> = vec![T::zero(); x.nrows()];
616 let mut bs: Vec<T> = vec![T::zero(); x.nrows()];
617 let mut slice0 = UnsafeSlice::new(&mut ws);
618 let mut slice1 = UnsafeSlice::new(&mut bs);
619 if num_threads <= 1 {
620 izip!(x.rows(), y.rows())
621 .enumerate()
622 .for_each(|(i, (x, y))| {
623 let (w, b) = coeff(x, y, q);
624 slice0.set(i, w);
625 slice1.set(i, b);
626 });
627 } else {
628 let pool = rayon::ThreadPoolBuilder::new()
629 .num_threads(num_threads)
630 .build()
631 .unwrap();
632 pool.scope(move |s| {
633 izip!(x.rows(), y.rows())
634 .enumerate()
635 .for_each(|(i, (x, y))| {
636 s.spawn(move |_| {
637 let (w, b) = coeff(x, y, q);
638 slice0.set(i, w);
639 slice1.set(i, b);
640 });
641 });
642 });
643 }
644 (ws, bs)
645}
646pub fn masked_coeff_axis1<T: AFloat>(
647 x: &ArrayView2<T>,
648 y: &ArrayView2<T>,
649 valid_mask: &ArrayView2<bool>,
650 q: Option<T>,
651 num_threads: usize,
652) -> (Vec<T>, Vec<T>) {
653 let mut ws: Vec<T> = vec![T::zero(); x.nrows()];
654 let mut bs: Vec<T> = vec![T::zero(); x.nrows()];
655 let mut slice0 = UnsafeSlice::new(&mut ws);
656 let mut slice1 = UnsafeSlice::new(&mut bs);
657 if num_threads <= 1 {
658 izip!(x.rows(), y.rows(), valid_mask.rows())
659 .enumerate()
660 .for_each(|(i, (x, y, valid_mask))| {
661 let (w, b) = masked_coeff(x, y, valid_mask, q);
662 slice0.set(i, w);
663 slice1.set(i, b);
664 });
665 } else {
666 let pool = rayon::ThreadPoolBuilder::new()
667 .num_threads(num_threads)
668 .build()
669 .unwrap();
670 pool.scope(move |s| {
671 izip!(x.rows(), y.rows(), valid_mask.rows())
672 .enumerate()
673 .for_each(|(i, (x, y, valid_mask))| {
674 s.spawn(move |_| {
675 let (w, b) = masked_coeff(x, y, valid_mask, q);
676 slice0.set(i, w);
677 slice1.set(i, b);
678 });
679 });
680 });
681 }
682 (ws, bs)
683}
684
685pub fn unique(arr: &[i64]) -> (Array1<i64>, Array1<i64>) {
688 let mut counts = HashMap::new();
689
690 for &value in arr.iter() {
691 *counts.entry(value).or_insert(0) += 1;
692 }
693
694 let mut unique_values: Vec<i64> = counts.keys().cloned().collect();
695 unique_values.sort();
696
697 let counts: Vec<i64> = unique_values.iter().map(|&value| counts[&value]).collect();
698
699 (Array1::from(unique_values), Array1::from(counts))
700}
701
702pub fn searchsorted<T: Ord>(arr: &ArrayView1<T>, value: &T) -> usize {
703 arr.as_slice()
704 .unwrap()
705 .binary_search(value)
706 .unwrap_or_else(|x| x)
707}
708
709pub fn batch_searchsorted<T: Ord>(arr: &ArrayView1<T>, values: &ArrayView1<T>) -> Vec<usize> {
710 values
711 .iter()
712 .map(|value| searchsorted(arr, value))
713 .collect()
714}
715
716const CONCAT_GROUP_LIMIT: usize = 4 * 239 * 5000;
717type ConcatTask<'a, 'b, D> = (Vec<usize>, Vec<ArrayView2<'a, D>>, UnsafeSlice<'b, D>);
718#[inline]
719fn fill_concat<D: Copy>((offsets, arrays, mut out): ConcatTask<D>) {
720 offsets.iter().enumerate().for_each(|(i, &offset)| {
721 out.copy_from_slice(offset, arrays[i].as_slice().unwrap());
722 });
723}
724pub fn fast_concat_2d_axis0<D: Copy + Send + Sync>(
725 arrays: Vec<ArrayView2<D>>,
726 num_rows: Vec<usize>,
727 num_columns: usize,
728 limit_multiplier: usize,
729 mut out: UnsafeSlice<D>,
730) {
731 let mut cumsum: usize = 0;
732 let mut offsets: Vec<usize> = vec![0; num_rows.len()];
733 for i in 1..num_rows.len() {
734 cumsum += num_rows[i - 1];
735 offsets[i] = cumsum * num_columns;
736 }
737
738 let bumped_limit = CONCAT_GROUP_LIMIT * 16;
739 let total_bytes = offsets.last().unwrap() + num_rows.last().unwrap() * num_columns;
740 let (mut group_limit, mut tasks_divisor) = if total_bytes <= bumped_limit {
741 (CONCAT_GROUP_LIMIT, 8)
742 } else {
743 (bumped_limit, 1)
744 };
745 group_limit *= limit_multiplier;
746
747 let prior_num_tasks = total_bytes.div_ceil(group_limit);
748 let prior_num_threads = prior_num_tasks / tasks_divisor;
749 if prior_num_threads > 1 {
750 group_limit = total_bytes.div_ceil(prior_num_threads);
751 tasks_divisor = 1;
752 }
753
754 let nbytes = mem::size_of::<D>();
755
756 let mut tasks: Vec<ConcatTask<D>> = Vec::new();
757 let mut current_tasks: Option<ConcatTask<D>> = Some((Vec::new(), Vec::new(), out.shadow()));
758 let mut nbytes_cumsum = 0;
759 izip!(num_rows.iter(), offsets.into_iter(), arrays.into_iter()).for_each(
760 |(&num_row, offset, array)| {
761 nbytes_cumsum += nbytes * num_row * num_columns;
762 if let Some(ref mut current_tasks) = current_tasks {
763 current_tasks.0.push(offset);
764 current_tasks.1.push(array);
765 }
766 if nbytes_cumsum >= group_limit {
767 nbytes_cumsum = 0;
768 if let Some(current_tasks) = current_tasks.take() {
769 tasks.push(current_tasks);
770 }
771 current_tasks = Some((Vec::new(), Vec::new(), out.shadow()));
772 }
773 },
774 );
775 if let Some(current_tasks) = current_tasks.take() {
776 if !current_tasks.0.is_empty() {
777 tasks.push(current_tasks);
778 }
779 }
780
781 let max_threads = available_parallelism()
782 .expect("failed to get available parallelism")
783 .get();
784 let num_threads = (tasks.len() / tasks_divisor).min(max_threads * 8).min(512);
785 if num_threads <= 1 {
786 tasks.into_iter().for_each(fill_concat);
787 } else {
788 let pool = rayon::ThreadPoolBuilder::new()
789 .num_threads(num_threads)
790 .build()
791 .unwrap();
792
793 pool.scope(move |s| {
794 tasks.into_iter().for_each(|task| {
795 s.spawn(move |_| fill_concat(task));
796 });
797 });
798 }
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804 use crate::toolkit::convert::to_bytes;
805 use std::io::Write;
806 use tempfile::tempdir;
807
808 fn assert_allclose<T: AFloat>(a: &[T], b: &[T]) {
809 let atol = T::from_f64(1e-6).unwrap();
810 let rtol = T::from_f64(1e-6).unwrap();
811 a.iter().zip(b.iter()).for_each(|(&x, &y)| {
812 assert!(
813 (x - y).abs() <= atol + rtol * y.abs(),
814 "not close - a: {:?}, b: {:?}",
815 a,
816 b,
817 );
818 });
819 }
820
821 #[test]
822 fn test_mmap() {
823 let dir = tempdir().unwrap();
824 let file_path = dir.path().join("test.cfy");
825 let array = Array1::<f32>::from_shape_vec(3, vec![1., 2., 3.]).unwrap();
826 let bytes = unsafe { to_bytes(array.as_slice().unwrap()) };
827 let mut file = File::create(&file_path).unwrap();
828 file.write_all(bytes).unwrap();
829 let file_path = file_path.to_str().unwrap();
830 let mmap_array = unsafe { MmapArray1::<f32>::new(file_path).unwrap() };
831 assert_eq!(array.len(), mmap_array.len());
832 assert_allclose(array.as_slice().unwrap(), unsafe { mmap_array.as_slice() });
833 assert_allclose(
834 array.as_slice().unwrap(),
835 unsafe { mmap_array.as_array_view() }.as_slice().unwrap(),
836 );
837 }
838
839 macro_rules! test_fast_concat_2d_axis0 {
840 ($dtype:ty) => {
841 let array_2d_u = ArrayView2::<$dtype>::from_shape((1, 3), &[1., 2., 3.]).unwrap();
842 let array_2d_l =
843 ArrayView2::<$dtype>::from_shape((2, 3), &[4., 5., 6., 7., 8., 9.]).unwrap();
844 let arrays = vec![array_2d_u, array_2d_l];
845 let mut out: Vec<$dtype> = vec![0.; 3 * 3];
846 let out_slice = UnsafeSlice::new(&mut out);
847 fast_concat_2d_axis0(arrays, vec![1, 2], 3, 1, out_slice);
848 assert_eq!(out.as_slice(), &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
849 };
850 }
851
852 macro_rules! test_mean_axis1 {
853 ($dtype:ty) => {
854 let array =
855 ArrayView2::<$dtype>::from_shape((2, 3), &[1., 2., 3., 4., 5., 6.]).unwrap();
856 let out = nanmean_axis1(&array, 1);
857 assert_allclose(out.as_slice(), &[2., 5.]);
858 let out = nanmean_axis1(&array, 2);
859 assert_allclose(out.as_slice(), &[2., 5.]);
860 };
861 }
862
863 macro_rules! test_corr_axis1 {
864 ($dtype:ty) => {
865 let array =
866 ArrayView2::<$dtype>::from_shape((2, 3), &[1., 2., 3., 4., 5., 6.]).unwrap();
867 let out = nancorr_axis1(&array, &(&array + 1.).view(), 1);
868 assert_allclose(out.as_slice(), &[1., 1.]);
869 let out = nancorr_axis1(&array, &(&array + 1.).view(), 2);
870 assert_allclose(out.as_slice(), &[1., 1.]);
871 };
872 }
873
874 #[test]
875 fn test_fast_concat_2d_axis0_f32() {
876 test_fast_concat_2d_axis0!(f32);
877 }
878 #[test]
879 fn test_fast_concat_2d_axis0_f64() {
880 test_fast_concat_2d_axis0!(f64);
881 }
882
883 #[test]
884 fn test_mean_axis1_f32() {
885 test_mean_axis1!(f32);
886 }
887 #[test]
888 fn test_mean_axis1_f64() {
889 test_mean_axis1!(f64);
890 }
891
892 #[test]
893 fn test_corr_axis1_f32() {
894 test_corr_axis1!(f32);
895 }
896 #[test]
897 fn test_corr_axis1_f64() {
898 test_corr_axis1!(f64);
899 }
900
901 #[test]
902 fn test_coeff_axis1() {
903 let x = ArrayView2::<f64>::from_shape((2, 3), &[2., 1., 3., 6., 4., 5.]).unwrap();
904 let y = ArrayView2::<f64>::from_shape((2, 3), &[4., 2., 6., 12., 8., 10.]).unwrap();
905 let scale = 2. * (2. / 3.).sqrt();
906 let (ws, bs) = coeff_axis1(&x, &y, None, 1);
907 assert_allclose(ws.as_slice(), &[scale, scale]);
908 assert_allclose(bs.as_slice(), &[4., 10.]);
909 let (ws, bs) = coeff_axis1(&x, &y, None, 2);
910 assert_allclose(ws.as_slice(), &[scale, scale]);
911 assert_allclose(bs.as_slice(), &[4., 10.]);
912 }
913
914 #[test]
915 fn test_searchsorted() {
916 let array = ArrayView1::<i64>::from_shape(5, &[1, 2, 3, 5, 6]).unwrap();
917 assert_eq!(searchsorted(&array, &0), 0);
918 assert_eq!(searchsorted(&array, &1), 0);
919 assert_eq!(searchsorted(&array, &3), 2);
920 assert_eq!(searchsorted(&array, &4), 3);
921 assert_eq!(searchsorted(&array, &5), 3);
922 assert_eq!(searchsorted(&array, &6), 4);
923 assert_eq!(searchsorted(&array, &7), 5);
924 assert_eq!(batch_searchsorted(&array, &array), vec![0, 1, 2, 3, 4]);
925 }
926}