1use crate::constraints::Constraint;
2use crate::data::FloatData;
3use crate::errors::ForustError;
4use std::cmp::Ordering;
5use std::collections::VecDeque;
6use std::convert::TryInto;
7
8pub fn items_to_strings(items: Vec<&str>) -> String {
10 let mut s = String::new();
11 for i in items {
12 s.push_str(i);
13 s.push_str(&String::from(", "));
14 }
15 s
16}
17
18pub fn fmt_vec_output<T: FloatData<T>>(v: &[T]) -> String {
19 let mut res = String::new();
20 if let Some(last) = v.len().checked_sub(1) {
21 if last == 0 {
22 return format!("{:.4}", v[0]);
23 }
24 for n in &v[..last] {
25 res.push_str(format!("{:.4}", n).as_str());
26 res.push_str(", ");
27 }
28 res.push_str(format!("{:.4}", &v[last]).as_str());
29 }
30 res
31}
32
33pub fn validate_positive_float_parameter<T: FloatData<T>>(
35 value: T,
36 parameter: &str,
37) -> Result<(), ForustError> {
38 validate_float_parameter(value, T::ZERO, T::INFINITY, parameter)
39}
40pub fn validate_float_parameter<T: FloatData<T>>(
41 value: T,
42 min: T,
43 max: T,
44 parameter: &str,
45) -> Result<(), ForustError> {
46 let mut msg = String::new();
47 if value.is_nan() || value < min || max < value {
48 msg.push_str(&value.to_string());
49 let ex_msg = format!("real value within rang {} and {}", min, max);
50 Err(ForustError::InvalidParameter(
51 parameter.to_string(),
52 ex_msg,
53 value.to_string(),
54 ))
55 } else {
56 Ok(())
57 }
58}
59
60pub fn validate_not_nan_vec(v: &[f64], name: String) -> Result<(), ForustError> {
61 if v.iter().any(|i| i.is_nan()) {
62 Err(ForustError::MissingValuesFound(name))
63 } else {
64 Ok(())
65 }
66}
67
68pub fn validate_positive_vec(v: &[f64], name: String) -> Result<(), ForustError> {
69 if v.iter().any(|i| i < &0.) {
70 Err(ForustError::NegativeValuesFound(name))
71 } else {
72 Ok(())
73 }
74}
75
76pub fn validate_positive_not_nan_vec(v: &[f64], name: String) -> Result<(), ForustError> {
77 let remainder = v
78 .iter()
79 .filter(|i| i.is_nan() || i < &&0.)
80 .copied()
81 .collect::<Vec<f64>>();
82 validate_positive_vec(&remainder, name.clone())?;
83 validate_not_nan_vec(&remainder, name)?;
84 Ok(())
85}
86
87macro_rules! validate_positive_float_field {
88 ($var: expr) => {
89 let var_name = stringify!($var).split(".").nth(1).unwrap();
90 crate::utils::validate_positive_float_parameter($var, var_name)?
91 };
92}
93
94pub(crate) use validate_positive_float_field;
95
96#[inline]
98pub fn is_missing(value: &f64, missing: &f64) -> bool {
99 if missing.is_nan() {
100 value.is_nan()
101 } else if value.is_nan() {
102 panic!(
103 "Missing value is {}, however NAN value found in data.",
104 missing
105 )
106 } else {
107 value == missing
108 }
109}
110
111#[allow(clippy::too_many_arguments)]
114#[inline]
115pub fn constrained_weight(
116 l1: &f32,
117 l2: &f32,
118 max_delta_step: &f32,
119 gradient_sum: f32,
120 hessian_sum: f32,
121 lower_bound: f32,
122 upper_bound: f32,
123 constraint: Option<&Constraint>,
124) -> f32 {
125 let weight = weight(l1, l2, max_delta_step, gradient_sum, hessian_sum);
126 match constraint {
127 None | Some(Constraint::Unconstrained) => weight,
128 _ => {
129 if weight > upper_bound {
130 upper_bound
131 } else if weight < lower_bound {
132 lower_bound
133 } else {
134 weight
135 }
136 }
137 }
138}
139
140#[inline]
142pub fn between(i: f32, j: f32, v: f32) -> bool {
143 if i > j {
144 (i >= v) && (v >= j)
145 } else {
146 (i <= v) && (v <= j)
147 }
148}
149
150#[inline]
151pub fn bound_to_parent(parent_weight: f32, left_weight: f32, right_weight: f32) -> (f32, f32) {
152 if between(left_weight, right_weight, parent_weight) {
153 (left_weight, right_weight)
154 } else {
155 if left_weight > right_weight {
158 if left_weight < parent_weight {
163 (parent_weight, right_weight)
164 } else {
165 (left_weight, parent_weight)
169 }
170 } else {
171 if right_weight < parent_weight {
176 (left_weight, parent_weight)
178 } else {
179 (parent_weight, right_weight)
183 }
184 }
185 }
186}
187
188#[inline]
190pub fn odds(v: f64) -> f64 {
191 1. / (1. + (-v).exp())
192}
193
194#[inline]
196pub fn gain(l2: &f32, gradient_sum: f32, hessian_sum: f32) -> f32 {
197 (gradient_sum * gradient_sum) / (hessian_sum + l2)
198}
199
200#[inline]
204pub fn gain_given_weight(l2: &f32, gradient_sum: f32, hessian_sum: f32, weight: f32) -> f32 {
205 -(2.0 * gradient_sum * weight + (hessian_sum + l2) * (weight * weight))
206}
207
208#[inline]
210pub fn cull_gain(
211 gain: f32,
212 left_weight: f32,
213 right_weight: f32,
214 constraint: Option<&Constraint>,
215) -> f32 {
216 match constraint {
217 None | Some(Constraint::Unconstrained) => gain,
218 Some(Constraint::Negative) => {
219 if left_weight <= right_weight {
220 f32::NEG_INFINITY
221 } else {
222 gain
223 }
224 }
225 Some(Constraint::Positive) => {
226 if left_weight >= right_weight {
227 f32::NEG_INFINITY
228 } else {
229 gain
230 }
231 }
232 }
233}
234
235#[inline]
237pub fn l1_regularization(w: &f32, l1: &f32) -> f32 {
238 if l1 == &0. {
239 *w
240 } else if w > l1 {
241 w - l1
242 } else if w < &-l1 {
243 w + l1
244 } else {
245 0.0
246 }
247}
248
249#[inline]
252pub fn weight(
253 l1: &f32,
254 l2: &f32,
255 max_delta_step: &f32,
256 gradient_sum: f32,
257 hessian_sum: f32,
258) -> f32 {
259 let w = -(l1_regularization(&gradient_sum, l1) / (hessian_sum + l2));
260 if (max_delta_step != &0.) && (&w.abs() > max_delta_step) {
261 return max_delta_step.copysign(w);
262 }
263 w
264}
265
266const LANES: usize = 16;
267
268#[inline]
272pub fn fast_sum<T: FloatData<T>>(values: &[T]) -> T {
273 let chunks = values.chunks_exact(LANES);
274 let remainder = chunks.remainder();
275
276 let sum = chunks.fold([T::ZERO; LANES], |mut acc, chunk| {
277 let chunk: [T; LANES] = chunk.try_into().unwrap();
278 for i in 0..LANES {
279 acc[i] += chunk[i];
280 }
281 acc
282 });
283
284 let remainder: T = remainder.iter().copied().sum();
285
286 let mut reduced = T::ZERO;
287 for s in sum.iter().take(LANES) {
288 reduced += *s;
289 }
290 reduced + remainder
291}
292
293#[inline]
298pub fn fast_f64_sum(values: &[f32]) -> f32 {
299 let chunks = values.chunks_exact(LANES);
300 let remainder = chunks.remainder();
301
302 let sum = chunks.fold([f64::ZERO; LANES], |mut acc, chunk| {
303 let chunk: [f32; LANES] = chunk.try_into().unwrap();
304 for i in 0..LANES {
305 acc[i] += f64::from(chunk[i]);
306 }
307 acc
308 });
309
310 let remainder: f64 = remainder
311 .iter()
312 .fold(f64::ZERO, |acc, b| acc + f64::from(*b));
313
314 let mut reduced: f64 = 0.;
315 for s in sum.iter().take(LANES) {
316 reduced += *s;
317 }
318 (reduced + remainder) as f32
319}
320
321pub fn naive_sum<T: FloatData<T>>(values: &[T]) -> T {
322 values.iter().copied().sum()
323}
324
325pub fn percentiles<T>(v: &[T], sample_weight: &[T], percentiles: &[T]) -> Vec<T>
334where
335 T: FloatData<T>,
336{
337 let mut idx: Vec<usize> = (0..v.len()).collect();
338 idx.sort_unstable_by(|a, b| v[*a].partial_cmp(&v[*b]).unwrap());
339
340 let mut pcts = VecDeque::from_iter(percentiles.iter());
342 let mut current_pct = *pcts.pop_front().expect("No percentiles were provided");
343
344 let mut p = Vec::new();
346 let mut cuml_pct = T::ZERO;
347 let mut current_value = v[idx[0]];
348 let total_values = fast_sum(sample_weight);
349
350 for i in idx.iter() {
351 if current_value != v[*i] {
352 current_value = v[*i];
353 }
354 cuml_pct += sample_weight[*i] / total_values;
355 if (current_pct == T::ZERO) || (cuml_pct >= current_pct) {
356 while cuml_pct >= current_pct {
359 p.push(current_value);
360 match pcts.pop_front() {
361 Some(p_) => current_pct = *p_,
362 None => return p,
363 }
364 }
365 } else if current_pct == T::ONE {
366 if let Some(i_) = idx.last() {
367 p.push(v[*i_]);
368 break;
369 }
370 }
371 }
372 p
373}
374
375#[inline]
385pub fn map_bin<T: FloatData<T>>(x: &[T], v: &T, missing: &T) -> Option<u16> {
386 if v.is_nan() || (v == missing) {
387 return Some(0);
388 }
389 let mut low = 0;
390 let mut high = x.len();
391 while low != high {
392 let mid = (low + high) / 2;
393 if x[mid] <= *v {
397 low = mid + 1;
398 } else {
399 high = mid;
400 }
401 }
402 u16::try_from(low).ok()
403}
404
405#[inline]
416pub fn pivot_on_split(
417 index: &mut [usize],
418 feature: &[u16],
419 split_value: u16,
420 missing_right: bool,
421) -> usize {
422 let mut low = 0;
424 let mut high = index.len() - 1;
425 let max_idx = high;
426 while low < high {
427 while low < max_idx {
431 let l = feature[index[low]];
432 match missing_compare(&split_value, l, missing_right) {
433 Ordering::Less | Ordering::Equal => break,
434 Ordering::Greater => low += 1,
435 }
436 }
437 while high > low {
438 let h = feature[index[high]];
439 match missing_compare(&split_value, h, missing_right) {
443 Ordering::Less | Ordering::Equal => high -= 1,
444 Ordering::Greater => break,
445 }
446 }
447 if low < high {
448 index.swap(high, low);
449 }
450 }
451 low
452}
453
454#[inline]
471pub fn pivot_on_split_exclude_missing(
472 index: &mut [usize],
473 feature: &[u16],
474 split_value: u16,
475) -> (usize, usize) {
476 let mut low = 0;
478 let mut high = index.len() - 1;
479 let mut missing = 0;
482 let max_idx = high;
483 while low < high {
484 while low < max_idx {
488 let l = feature[index[low]];
489 if l == 0 {
490 index.swap(missing, low);
491 missing += 1;
492 }
493 match &split_value.cmp(&l) {
494 Ordering::Less | Ordering::Equal => break,
495 Ordering::Greater => low += 1,
496 }
497 }
498 while high > low {
499 let h = feature[index[high]];
500 if h == 0 {
504 index.swap(missing, high);
505 missing += 1;
506 if missing > low {
511 low = missing;
512 }
513 }
514 match &split_value.cmp(&h) {
518 Ordering::Less | Ordering::Equal => high -= 1,
519 Ordering::Greater => break,
520 }
521 }
522 if low < high {
523 index.swap(high, low);
524 }
525 }
526 (missing, low)
527}
528
529#[inline]
533pub fn missing_compare(split_value: &u16, cmp_value: u16, missing_right: bool) -> Ordering {
534 if cmp_value == 0 {
535 if missing_right {
536 Ordering::Less
539 } else {
540 Ordering::Greater
543 }
544 } else {
545 split_value.cmp(&cmp_value)
546 }
547}
548
549#[inline]
550pub fn precision_round(n: f64, precision: i32) -> f64 {
551 let p = (10.0_f64).powi(precision);
552 (n * p).round() / p
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use rand::rngs::StdRng;
559 use rand::seq::SliceRandom;
560 use rand::Rng;
561 use rand::SeedableRng;
562 #[test]
563 fn test_round() {
564 assert_eq!(0.3, precision_round(0.3333, 1));
565 assert_eq!(0.2343, precision_round(0.2343123123123, 4));
566 }
567 #[test]
568 fn test_percentiles() {
569 let v = vec![4., 5., 6., 1., 2., 3., 7., 8., 9., 10.];
570 let w = vec![1.; v.len()];
571 let p = vec![0.3, 0.5, 0.75, 1.0];
572 let p = percentiles(&v, &w, &p);
573 assert_eq!(p, vec![3.0, 5.0, 8.0, 10.0]);
574 }
575
576 #[test]
577 fn test_percentiles_weighted() {
578 let v = vec![10., 8., 9., 1., 2., 3., 6., 7., 4., 5.];
579 let w = vec![1., 1., 1., 1., 1., 2., 1., 1., 5., 1.];
580 let p = vec![0.3, 0.5, 0.75, 1.0];
581 let p = percentiles(&v, &w, &p);
582 assert_eq!(p, vec![4.0, 4.0, 7.0, 10.0]);
583 }
584
585 #[test]
586 fn test_map_bin_or_equal() {
587 let v = vec![f64::MIN, 1., 4., 8., 9.];
588 assert_eq!(1, map_bin(&v, &0., &f64::NAN).unwrap());
589 assert_eq!(2, map_bin(&v, &1., &f64::NAN).unwrap());
590 assert_eq!(2, map_bin(&v, &2., &f64::NAN).unwrap());
593 assert_eq!(3, map_bin(&v, &4., &f64::NAN).unwrap());
594 assert_eq!(5, map_bin(&v, &9., &f64::NAN).unwrap());
595 assert_eq!(5, map_bin(&v, &10., &f64::NAN).unwrap());
596 assert_eq!(2, map_bin(&v, &1., &f64::NAN).unwrap());
597 assert_eq!(0, map_bin(&v, &f64::NAN, &f64::NAN).unwrap());
598 }
599
600 #[test]
601 fn test_map_bin_or_equal_num_miss() {
602 let v = vec![f64::MIN, 1., 4., 8., 9.];
603 assert_eq!(1, map_bin(&v, &0., &-99.).unwrap());
604 assert_eq!(2, map_bin(&v, &1., &-99.).unwrap());
605 assert_eq!(2, map_bin(&v, &2., &-99.).unwrap());
608 assert_eq!(3, map_bin(&v, &4., &-99.).unwrap());
609 assert_eq!(5, map_bin(&v, &9., &-99.).unwrap());
610 assert_eq!(5, map_bin(&v, &10., &-99.).unwrap());
611 assert_eq!(2, map_bin(&v, &1., &-99.).unwrap());
612 assert_eq!(0, map_bin(&v, &-99., &-99.).unwrap());
613 }
614
615 #[test]
616 fn test_missing_compare() {
617 assert_eq!(missing_compare(&10, 0, true), Ordering::Less);
618 assert_eq!(missing_compare(&10, 0, false), Ordering::Greater);
619 assert_eq!(missing_compare(&10, 11, true), Ordering::Less);
620 assert_eq!(missing_compare(&10, 1, true), Ordering::Greater);
621 }
622
623 #[test]
624 fn test_pivot() {
625 fn pivot_assert(
626 f: &[u16],
627 idx: &[usize],
628 split_i: usize,
629 missing_right: bool,
630 split_val: u16,
631 ) {
632 if missing_right {
633 for i in 0..split_i {
634 assert!((f[idx[i]] < split_val) && f[idx[i]] != 0);
635 }
636 for i in idx[split_i..].iter() {
637 assert!((f[*i] >= split_val) || (f[*i] == 0));
638 }
639 } else {
640 for i in 0..split_i {
641 assert!((f[idx[i]] < split_val) || (f[idx[i]] == 0));
642 }
643 for i in idx[split_i..].iter() {
644 assert!((f[*i] >= split_val) || (f[*i] != 0));
645 }
646 }
647 }
648
649 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
650 let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
651 let split_i = pivot_on_split(&mut idx, &f, 10, true);
652 pivot_assert(&f, &idx, split_i, true, 10);
653
654 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
655 let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
656 let split_i = pivot_on_split(&mut idx, &f, 10, false);
657 pivot_assert(&f, &idx, split_i, false, 10);
658
659 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
661 let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
662 let split_i = pivot_on_split(&mut idx, &f, 1, true);
663 pivot_assert(&f, &idx, split_i, true, 1);
664
665 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
666 let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
667 let split_i = pivot_on_split(&mut idx, &f, 1, false);
668 pivot_assert(&f, &idx, split_i, false, 1);
669
670 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
672 let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
673 let split_i = pivot_on_split(&mut idx, &f, 19, true);
674 pivot_assert(&f, &idx, split_i, true, 19);
675
676 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
677 let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
678 let split_i = pivot_on_split(&mut idx, &f, 19, false);
679 pivot_assert(&f, &idx, split_i, false, 19);
680
681 let index = (0..100).collect::<Vec<usize>>();
684 let mut rng = StdRng::seed_from_u64(0);
685 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
686 let mut idx = index
687 .choose_multiple(&mut rng, 73)
688 .copied()
689 .collect::<Vec<usize>>();
690 let split_i = pivot_on_split(&mut idx, &f, 7, true);
691 pivot_assert(&f, &idx, split_i, true, 7);
692
693 let split_i = pivot_on_split(&mut idx, &f, 7, true);
695 pivot_assert(&f, &idx, split_i, true, 7);
696
697 idx.reverse();
699 let split_i = pivot_on_split(&mut idx, &f, 7, true);
700 pivot_assert(&f, &idx, split_i, true, 7);
701
702 let index = (0..100).collect::<Vec<usize>>();
704 let mut rng = StdRng::seed_from_u64(0);
705 let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
706 let mut idx = index
707 .choose_multiple(&mut rng, 73)
708 .copied()
709 .collect::<Vec<usize>>();
710 let split_i = pivot_on_split(&mut idx, &f, 5, true);
711 pivot_assert(&f, &idx, split_i, true, 5);
712
713 let index = (0..100).collect::<Vec<usize>>();
715 let mut rng = StdRng::seed_from_u64(0);
716 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
717 let mut idx = index
718 .choose_multiple(&mut rng, 73)
719 .copied()
720 .collect::<Vec<usize>>();
721 let sv = idx.iter().map(|i| f[*i]).max().unwrap();
722 let split_i = pivot_on_split(&mut idx, &f, sv, true);
723 pivot_assert(&f, &idx, split_i, true, sv);
724
725 let index = (0..100).collect::<Vec<usize>>();
727 let mut rng = StdRng::seed_from_u64(0);
728 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
729 let mut idx = index
730 .choose_multiple(&mut rng, 73)
731 .copied()
732 .collect::<Vec<usize>>();
733 let sv = idx
734 .iter()
735 .filter(|i| f[**i] > 0)
736 .map(|i| f[*i])
737 .min()
738 .unwrap();
739 let split_i = pivot_on_split(&mut idx, &f, sv, true);
740 pivot_assert(&f, &idx, split_i, true, sv);
741
742 let index = (0..100).collect::<Vec<usize>>();
744 let mut rng = StdRng::seed_from_u64(0);
745 let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
746 let mut idx = index
747 .choose_multiple(&mut rng, 73)
748 .copied()
749 .collect::<Vec<usize>>();
750 let sv = idx
751 .iter()
752 .filter(|i| f[**i] > 0)
753 .map(|i| f[*i])
754 .min()
755 .unwrap();
756 let split_i = pivot_on_split(&mut idx, &f, sv, true);
757 pivot_assert(&f, &idx, split_i, true, sv);
758
759 let index = (0..100).collect::<Vec<usize>>();
761 let mut rng = StdRng::seed_from_u64(0);
762 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
763 let mut idx = index
764 .choose_multiple(&mut rng, 73)
765 .copied()
766 .collect::<Vec<usize>>();
767 let split_i = pivot_on_split(&mut idx, &f, 7, false);
768 pivot_assert(&f, &idx, split_i, false, 7);
769
770 let split_i = pivot_on_split(&mut idx, &f, 7, false);
772 pivot_assert(&f, &idx, split_i, false, 7);
773
774 idx.reverse();
776 let split_i = pivot_on_split(&mut idx, &f, 7, false);
777 pivot_assert(&f, &idx, split_i, false, 7);
778
779 let index = (0..100).collect::<Vec<usize>>();
781 let mut rng = StdRng::seed_from_u64(0);
782 let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
783 let mut idx = index
784 .choose_multiple(&mut rng, 73)
785 .copied()
786 .collect::<Vec<usize>>();
787 let split_i = pivot_on_split(&mut idx, &f, 5, false);
788 pivot_assert(&f, &idx, split_i, false, 5);
789
790 let index = (0..100).collect::<Vec<usize>>();
792 let mut rng = StdRng::seed_from_u64(0);
793 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
794 let mut idx = index
795 .choose_multiple(&mut rng, 73)
796 .copied()
797 .collect::<Vec<usize>>();
798 let sv = idx.iter().map(|i| f[*i]).max().unwrap();
799 let split_i = pivot_on_split(&mut idx, &f, sv, false);
800 pivot_assert(&f, &idx, split_i, false, sv);
801
802 let index = (0..100).collect::<Vec<usize>>();
804 let mut rng = StdRng::seed_from_u64(0);
805 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
806 let mut idx = index
807 .choose_multiple(&mut rng, 73)
808 .copied()
809 .collect::<Vec<usize>>();
810 let sv = idx
811 .iter()
812 .filter(|i| f[**i] > 0)
813 .map(|i| f[*i])
814 .min()
815 .unwrap();
816 let split_i = pivot_on_split(&mut idx, &f, sv, false);
817 pivot_assert(&f, &idx, split_i, false, sv);
818
819 let index = (0..100).collect::<Vec<usize>>();
821 let mut rng = StdRng::seed_from_u64(0);
822 let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
823 let mut idx = index
824 .choose_multiple(&mut rng, 73)
825 .copied()
826 .collect::<Vec<usize>>();
827 let sv = idx
828 .iter()
829 .filter(|i| f[**i] > 0)
830 .map(|i| f[*i])
831 .min()
832 .unwrap();
833 let split_i = pivot_on_split(&mut idx, &f, sv, false);
834 pivot_assert(&f, &idx, split_i, false, sv);
835 }
836
837 #[test]
838 fn test_pivot_missing() {
839 fn pivot_missing_assert(
840 split_value: u16,
841 idx: &[usize],
842 f: &[u16],
843 split_i: &(usize, usize),
844 ) {
845 for i in 0..split_i.1 {
847 assert!(f[idx[i]] < split_value);
848 }
849 for i in 0..split_i.0 {
851 assert!(f[idx[i]] == 0);
852 }
853 for i in split_i.1..(idx.len()) {
855 assert!(!(f[idx[i]] < split_value));
856 }
857 for i in split_i.0..(idx.len()) {
859 assert!(f[idx[i]] != 0);
860 }
861 }
862 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
865 let f = vec![15, 10, 10, 0, 3, 0, 0, 9, 3, 5, 2, 6, 13, 19, 14];
866 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 1);
867 pivot_missing_assert(1, &idx, &f, &split_i);
870
871 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
873 let f = vec![15, 10, 10, 0, 3, 0, 0, 9, 3, 5, 2, 6, 13, 19, 14];
874 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
875 pivot_missing_assert(10, &idx, &f, &split_i);
879
880 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
882 pivot_missing_assert(10, &idx, &f, &split_i);
886
887 idx.reverse();
889 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
890 pivot_missing_assert(10, &idx, &f, &split_i);
894
895 let mut idx = vec![0, 1, 2, 3, 4, 5];
897 let f = vec![1, 0, 1, 3, 0, 4];
898 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 2);
899 pivot_missing_assert(2, &idx, &f, &split_i);
902
903 let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
914 let f = vec![15, 10, 10, 2, 3, 5, 7, 9, 3, 5, 2, 6, 13, 19, 14];
915 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
916 pivot_missing_assert(10, &idx, &f, &split_i);
921
922 let index = (0..100).collect::<Vec<usize>>();
925 let mut rng = StdRng::seed_from_u64(0);
926 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
927 let mut idx = index
928 .choose_multiple(&mut rng, 73)
929 .copied()
930 .collect::<Vec<usize>>();
931 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
932 pivot_missing_assert(10, &idx, &f, &split_i);
933
934 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
936 pivot_missing_assert(10, &idx, &f, &split_i);
937
938 let index = (0..100).collect::<Vec<usize>>();
940 let mut rng = StdRng::seed_from_u64(0);
941 let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
942 let mut idx = index
943 .choose_multiple(&mut rng, 73)
944 .copied()
945 .collect::<Vec<usize>>();
946 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 5);
947 pivot_missing_assert(5, &idx, &f, &split_i);
950
951 let index = (0..100).collect::<Vec<usize>>();
953 let mut rng = StdRng::seed_from_u64(0);
954 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
955 let mut idx = index
956 .choose_multiple(&mut rng, 73)
957 .copied()
958 .collect::<Vec<usize>>();
959 let sv = idx.iter().map(|i| f[*i]).max().unwrap();
960 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, sv);
961 pivot_missing_assert(sv, &idx, &f, &split_i);
962
963 let index = (0..100).collect::<Vec<usize>>();
965 let mut rng = StdRng::seed_from_u64(0);
966 let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
967 let mut idx = index
968 .choose_multiple(&mut rng, 73)
969 .copied()
970 .collect::<Vec<usize>>();
971 let sv = idx
972 .iter()
973 .filter(|i| f[**i] > 0)
974 .map(|i| f[*i])
975 .min()
976 .unwrap();
977 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, sv);
978 pivot_missing_assert(sv, &idx, &f, &split_i);
979
980 let index = (0..100).collect::<Vec<usize>>();
982 let mut rng = StdRng::seed_from_u64(0);
983 let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
984 let mut idx = index
985 .choose_multiple(&mut rng, 73)
986 .copied()
987 .collect::<Vec<usize>>();
988 let sv = idx
989 .iter()
990 .filter(|i| f[**i] > 0)
991 .map(|i| f[*i])
992 .min()
993 .unwrap();
994 let split_i = pivot_on_split_exclude_missing(&mut idx, &f, sv);
995 pivot_missing_assert(sv, &idx, &f, &split_i);
996 }
997
998 #[test]
999 fn test_fast_f64_sum() {
1000 let records = 300000;
1001 let vec = vec![0.23500371; records];
1002 assert_ne!(vec.iter().sum::<f32>(), vec[0] * (records as f32));
1003 assert_eq!(vec[0] * (records as f32), fast_f64_sum(&vec));
1004 }
1008
1009 #[test]
1010 fn test_fmt_vec_output() {
1011 let v = Vec::<f32>::new();
1012 assert_eq!(fmt_vec_output(&v), String::from(""));
1013 let v: Vec<f32> = vec![0.1, 1.0];
1014 assert_eq!(fmt_vec_output(&v), String::from("0.1000, 1.0000"));
1015 }
1016}