forust_ml/
utils.rs

1use crate::constraints::Constraint;
2use crate::data::FloatData;
3use crate::errors::ForustError;
4use std::cmp::Ordering;
5use std::collections::VecDeque;
6use std::convert::TryInto;
7
8/// Create a string of all available items.
9pub fn items_to_strings(items: Vec<&str>) -> String {
10    let mut s = String::new();
11    for i in items {
12        s.push_str(i);
13        s.push_str(&String::from(", "));
14    }
15    s
16}
17
18pub fn fmt_vec_output<T: FloatData<T>>(v: &[T]) -> String {
19    let mut res = String::new();
20    if let Some(last) = v.len().checked_sub(1) {
21        if last == 0 {
22            return format!("{:.4}", v[0]);
23        }
24        for n in &v[..last] {
25            res.push_str(format!("{:.4}", n).as_str());
26            res.push_str(", ");
27        }
28        res.push_str(format!("{:.4}", &v[last]).as_str());
29    }
30    res
31}
32
33// Validation
34pub fn validate_positive_float_parameter<T: FloatData<T>>(
35    value: T,
36    parameter: &str,
37) -> Result<(), ForustError> {
38    validate_float_parameter(value, T::ZERO, T::INFINITY, parameter)
39}
40pub fn validate_float_parameter<T: FloatData<T>>(
41    value: T,
42    min: T,
43    max: T,
44    parameter: &str,
45) -> Result<(), ForustError> {
46    let mut msg = String::new();
47    if value.is_nan() || value < min || max < value {
48        msg.push_str(&value.to_string());
49        let ex_msg = format!("real value within rang {} and {}", min, max);
50        Err(ForustError::InvalidParameter(
51            parameter.to_string(),
52            ex_msg,
53            value.to_string(),
54        ))
55    } else {
56        Ok(())
57    }
58}
59
60pub fn validate_not_nan_vec(v: &[f64], name: String) -> Result<(), ForustError> {
61    if v.iter().any(|i| i.is_nan()) {
62        Err(ForustError::MissingValuesFound(name))
63    } else {
64        Ok(())
65    }
66}
67
68pub fn validate_positive_vec(v: &[f64], name: String) -> Result<(), ForustError> {
69    if v.iter().any(|i| i < &0.) {
70        Err(ForustError::NegativeValuesFound(name))
71    } else {
72        Ok(())
73    }
74}
75
76pub fn validate_positive_not_nan_vec(v: &[f64], name: String) -> Result<(), ForustError> {
77    let remainder = v
78        .iter()
79        .filter(|i| i.is_nan() || i < &&0.)
80        .copied()
81        .collect::<Vec<f64>>();
82    validate_positive_vec(&remainder, name.clone())?;
83    validate_not_nan_vec(&remainder, name)?;
84    Ok(())
85}
86
87macro_rules! validate_positive_float_field {
88    ($var: expr) => {
89        let var_name = stringify!($var).split(".").nth(1).unwrap();
90        crate::utils::validate_positive_float_parameter($var, var_name)?
91    };
92}
93
94pub(crate) use validate_positive_float_field;
95
96/// Calculate if a value is missing.
97#[inline]
98pub fn is_missing(value: &f64, missing: &f64) -> bool {
99    if missing.is_nan() {
100        value.is_nan()
101    } else if value.is_nan() {
102        panic!(
103            "Missing value is {}, however NAN value found in data.",
104            missing
105        )
106    } else {
107        value == missing
108    }
109}
110
111/// Calculate the constraint weight given bounds
112/// and a constraint.
113#[allow(clippy::too_many_arguments)]
114#[inline]
115pub fn constrained_weight(
116    l1: &f32,
117    l2: &f32,
118    max_delta_step: &f32,
119    gradient_sum: f32,
120    hessian_sum: f32,
121    lower_bound: f32,
122    upper_bound: f32,
123    constraint: Option<&Constraint>,
124) -> f32 {
125    let weight = weight(l1, l2, max_delta_step, gradient_sum, hessian_sum);
126    match constraint {
127        None | Some(Constraint::Unconstrained) => weight,
128        _ => {
129            if weight > upper_bound {
130                upper_bound
131            } else if weight < lower_bound {
132                lower_bound
133            } else {
134                weight
135            }
136        }
137    }
138}
139
140/// Test if v is contained within the range i and j
141#[inline]
142pub fn between(i: f32, j: f32, v: f32) -> bool {
143    if i > j {
144        (i >= v) && (v >= j)
145    } else {
146        (i <= v) && (v <= j)
147    }
148}
149
150#[inline]
151pub fn bound_to_parent(parent_weight: f32, left_weight: f32, right_weight: f32) -> (f32, f32) {
152    if between(left_weight, right_weight, parent_weight) {
153        (left_weight, right_weight)
154    } else {
155        // If we are here, we know, the parent weight is above or bellow
156        // the right and left weights range, because of the between check.
157        if left_weight > right_weight {
158            // Here is what it looks like on the number line if we are here
159            // right...left
160            // Is the parent above the range?
161            // i.e. right...left...parent?
162            if left_weight < parent_weight {
163                (parent_weight, right_weight)
164            } else {
165                // Otherwise if we are here, it must be outside of the range on the other side..
166                // i.e. parent...right...left
167                // In which case make parent equal right.
168                (left_weight, parent_weight)
169            }
170        } else {
171            // Here is what the number line looks like at this point...
172            // left_weight..right_weight
173            // Is the parent above the range?
174            // i.e. left...right...parent?
175            if right_weight < parent_weight {
176                // In which case set right equal to parent.
177                (left_weight, parent_weight)
178            } else {
179                // Is the parent bellow the range?
180                // i.e. parent...left...right...
181                // In which case set the left equal to the parent.
182                (parent_weight, right_weight)
183            }
184        }
185    }
186}
187
188/// Convert Log odds to probability
189#[inline]
190pub fn odds(v: f64) -> f64 {
191    1. / (1. + (-v).exp())
192}
193
194/// Calculate the gain given the gradient and hessian of the node.
195#[inline]
196pub fn gain(l2: &f32, gradient_sum: f32, hessian_sum: f32) -> f32 {
197    (gradient_sum * gradient_sum) / (hessian_sum + l2)
198}
199
200/// Calculate the gain of a split given a specific weight value.
201/// This is for if the weight has to be constrained, for example for
202/// monotonicity constraints.
203#[inline]
204pub fn gain_given_weight(l2: &f32, gradient_sum: f32, hessian_sum: f32, weight: f32) -> f32 {
205    -(2.0 * gradient_sum * weight + (hessian_sum + l2) * (weight * weight))
206}
207
208/// Cull gain, if it does not conform to constraints.
209#[inline]
210pub fn cull_gain(
211    gain: f32,
212    left_weight: f32,
213    right_weight: f32,
214    constraint: Option<&Constraint>,
215) -> f32 {
216    match constraint {
217        None | Some(Constraint::Unconstrained) => gain,
218        Some(Constraint::Negative) => {
219            if left_weight <= right_weight {
220                f32::NEG_INFINITY
221            } else {
222                gain
223            }
224        }
225        Some(Constraint::Positive) => {
226            if left_weight >= right_weight {
227                f32::NEG_INFINITY
228            } else {
229                gain
230            }
231        }
232    }
233}
234
235/// Calculate l1 regularization
236#[inline]
237pub fn l1_regularization(w: &f32, l1: &f32) -> f32 {
238    if l1 == &0. {
239        *w
240    } else if w > l1 {
241        w - l1
242    } else if w < &-l1 {
243        w + l1
244    } else {
245        0.0
246    }
247}
248
249/// Calculate the weight of a given node, given the sum
250/// of the gradients, and the hessians in a node.
251#[inline]
252pub fn weight(
253    l1: &f32,
254    l2: &f32,
255    max_delta_step: &f32,
256    gradient_sum: f32,
257    hessian_sum: f32,
258) -> f32 {
259    let w = -(l1_regularization(&gradient_sum, l1) / (hessian_sum + l2));
260    if (max_delta_step != &0.) && (&w.abs() > max_delta_step) {
261        return max_delta_step.copysign(w);
262    }
263    w
264}
265
266const LANES: usize = 16;
267
268/// Fast summation, ends up being roughly 8 to 10 times faster
269/// than values.iter().copied().sum().
270/// Shamelessly stolen from https://stackoverflow.com/a/67191480
271#[inline]
272pub fn fast_sum<T: FloatData<T>>(values: &[T]) -> T {
273    let chunks = values.chunks_exact(LANES);
274    let remainder = chunks.remainder();
275
276    let sum = chunks.fold([T::ZERO; LANES], |mut acc, chunk| {
277        let chunk: [T; LANES] = chunk.try_into().unwrap();
278        for i in 0..LANES {
279            acc[i] += chunk[i];
280        }
281        acc
282    });
283
284    let remainder: T = remainder.iter().copied().sum();
285
286    let mut reduced = T::ZERO;
287    for s in sum.iter().take(LANES) {
288        reduced += *s;
289    }
290    reduced + remainder
291}
292
293/// Fast summation, but using f64 as the internal representation so that
294/// we don't have issues with the precision.
295/// This way, we can still work with f32 values, but get the correct sum
296/// value.
297#[inline]
298pub fn fast_f64_sum(values: &[f32]) -> f32 {
299    let chunks = values.chunks_exact(LANES);
300    let remainder = chunks.remainder();
301
302    let sum = chunks.fold([f64::ZERO; LANES], |mut acc, chunk| {
303        let chunk: [f32; LANES] = chunk.try_into().unwrap();
304        for i in 0..LANES {
305            acc[i] += f64::from(chunk[i]);
306        }
307        acc
308    });
309
310    let remainder: f64 = remainder
311        .iter()
312        .fold(f64::ZERO, |acc, b| acc + f64::from(*b));
313
314    let mut reduced: f64 = 0.;
315    for s in sum.iter().take(LANES) {
316        reduced += *s;
317    }
318    (reduced + remainder) as f32
319}
320
321pub fn naive_sum<T: FloatData<T>>(values: &[T]) -> T {
322    values.iter().copied().sum()
323}
324
325/// Naive weighted percentiles calculation.
326///
327/// Currently this function does not support missing values.
328///   
329/// * `v` - A Vector of which to find percentiles for.
330/// * `sample_weight` - Sample weights for the instances of the vector.
331/// * `percentiles` - Percentiles to look for in the data. This should be
332///     values from 0 to 1, and in sorted order.
333pub fn percentiles<T>(v: &[T], sample_weight: &[T], percentiles: &[T]) -> Vec<T>
334where
335    T: FloatData<T>,
336{
337    let mut idx: Vec<usize> = (0..v.len()).collect();
338    idx.sort_unstable_by(|a, b| v[*a].partial_cmp(&v[*b]).unwrap());
339
340    // Setup percentiles
341    let mut pcts = VecDeque::from_iter(percentiles.iter());
342    let mut current_pct = *pcts.pop_front().expect("No percentiles were provided");
343
344    // Prepare a vector to put the percentiles in...
345    let mut p = Vec::new();
346    let mut cuml_pct = T::ZERO;
347    let mut current_value = v[idx[0]];
348    let total_values = fast_sum(sample_weight);
349
350    for i in idx.iter() {
351        if current_value != v[*i] {
352            current_value = v[*i];
353        }
354        cuml_pct += sample_weight[*i] / total_values;
355        if (current_pct == T::ZERO) || (cuml_pct >= current_pct) {
356            // We loop here, because the same number might be a valid
357            // value to make the percentile several times.
358            while cuml_pct >= current_pct {
359                p.push(current_value);
360                match pcts.pop_front() {
361                    Some(p_) => current_pct = *p_,
362                    None => return p,
363                }
364            }
365        } else if current_pct == T::ONE {
366            if let Some(i_) = idx.last() {
367                p.push(v[*i_]);
368                break;
369            }
370        }
371    }
372    p
373}
374
375// Return the index of the first value in a slice that
376// is less another number. This will return the first index for
377// missing values.
378/// Return the index of the first value in a sorted
379/// vector that is greater than a provided value.
380///
381/// * `x` - The sorted slice of values.
382/// * `v` - The value used to calculate the first
383///   value larger than it.
384#[inline]
385pub fn map_bin<T: FloatData<T>>(x: &[T], v: &T, missing: &T) -> Option<u16> {
386    if v.is_nan() || (v == missing) {
387        return Some(0);
388    }
389    let mut low = 0;
390    let mut high = x.len();
391    while low != high {
392        let mid = (low + high) / 2;
393        // This will always be false for NaNs.
394        // This it will force us to the bottom,
395        // and thus Zero.
396        if x[mid] <= *v {
397            low = mid + 1;
398        } else {
399            high = mid;
400        }
401    }
402    u16::try_from(low).ok()
403}
404
405/// Provided a list of index values, pivot those values
406/// around a specific split value so all of the values less
407/// than the split value are on one side, and then all of the
408/// values greater than or equal to the split value are above.
409///
410/// * `index` - The index values to sort.
411/// * `feature` - The feature vector to use to sort the index by.
412/// * `split_value` - the split value to use to pivot on.
413/// * `missing_right` - Should missing values go to the left, or
414///    to the right of the split value.
415#[inline]
416pub fn pivot_on_split(
417    index: &mut [usize],
418    feature: &[u16],
419    split_value: u16,
420    missing_right: bool,
421) -> usize {
422    // I think we can do this in O(n) time...
423    let mut low = 0;
424    let mut high = index.len() - 1;
425    let max_idx = high;
426    while low < high {
427        // Go until we find a low value that needs to
428        // be swapped, this will be the first value
429        // that our split value is less or equal to.
430        while low < max_idx {
431            let l = feature[index[low]];
432            match missing_compare(&split_value, l, missing_right) {
433                Ordering::Less | Ordering::Equal => break,
434                Ordering::Greater => low += 1,
435            }
436        }
437        while high > low {
438            let h = feature[index[high]];
439            // Go until we find a high value that needs to be
440            // swapped, this will be the first value that our
441            // split_value is greater than.
442            match missing_compare(&split_value, h, missing_right) {
443                Ordering::Less | Ordering::Equal => high -= 1,
444                Ordering::Greater => break,
445            }
446        }
447        if low < high {
448            index.swap(high, low);
449        }
450    }
451    low
452}
453
454/// Provided a list of index values, pivot those values
455/// around a specific split value so all of the values less
456/// than the split value are on one side, and then all of the
457/// values greater than or equal to the split value are above.
458/// Missing values, will be pushed to the bottom, a value of
459/// zero is missing in this case.
460/// Returns a tuple, the first is the first non-missing value
461/// index, the second is the first value that is greater than
462/// our provided split value.
463///
464/// WARNING!!! Currently, this function fails, if all the values are
465/// missing...
466///
467/// * `index` - The index values to sort.
468/// * `feature` - The feature vector to use to sort the index by.
469/// * `split_value` - the split value to use to pivot on.
470#[inline]
471pub fn pivot_on_split_exclude_missing(
472    index: &mut [usize],
473    feature: &[u16],
474    split_value: u16,
475) -> (usize, usize) {
476    // I think we can do this in O(n) time...
477    let mut low = 0;
478    let mut high = index.len() - 1;
479    // The index of the first value, that is not
480    // missing.
481    let mut missing = 0;
482    let max_idx = high;
483    while low < high {
484        // Go until we find a low value that needs to
485        // be swapped, this will be the first value
486        // that our split value is less or equal to.
487        while low < max_idx {
488            let l = feature[index[low]];
489            if l == 0 {
490                index.swap(missing, low);
491                missing += 1;
492            }
493            match &split_value.cmp(&l) {
494                Ordering::Less | Ordering::Equal => break,
495                Ordering::Greater => low += 1,
496            }
497        }
498        while high > low {
499            let h = feature[index[high]];
500            // If this is missing, we need to
501            // swap this value with missing, and
502            // then that value with low.
503            if h == 0 {
504                index.swap(missing, high);
505                missing += 1;
506                // Low must be at least equal to
507                // missing. Otherwise, we would get
508                // stuck, because low will be zero
509                // then...
510                if missing > low {
511                    low = missing;
512                }
513            }
514            // Go until we find a high value that needs to be
515            // swapped, this will be the first value that our
516            // split_value is greater than.
517            match &split_value.cmp(&h) {
518                Ordering::Less | Ordering::Equal => high -= 1,
519                Ordering::Greater => break,
520            }
521        }
522        if low < high {
523            index.swap(high, low);
524        }
525    }
526    (missing, low)
527}
528
529/// Function to compare a value to our split value.
530/// Our split value will _never_ be missing (0), thus we
531/// don't have to worry about that.
532#[inline]
533pub fn missing_compare(split_value: &u16, cmp_value: u16, missing_right: bool) -> Ordering {
534    if cmp_value == 0 {
535        if missing_right {
536            // If missing is right, then our split_value
537            // will always be considered less than missing.
538            Ordering::Less
539        } else {
540            // Otherwise less to send it left by considering
541            // our split value being always greater than missing
542            Ordering::Greater
543        }
544    } else {
545        split_value.cmp(&cmp_value)
546    }
547}
548
549#[inline]
550pub fn precision_round(n: f64, precision: i32) -> f64 {
551    let p = (10.0_f64).powi(precision);
552    (n * p).round() / p
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use rand::rngs::StdRng;
559    use rand::seq::SliceRandom;
560    use rand::Rng;
561    use rand::SeedableRng;
562    #[test]
563    fn test_round() {
564        assert_eq!(0.3, precision_round(0.3333, 1));
565        assert_eq!(0.2343, precision_round(0.2343123123123, 4));
566    }
567    #[test]
568    fn test_percentiles() {
569        let v = vec![4., 5., 6., 1., 2., 3., 7., 8., 9., 10.];
570        let w = vec![1.; v.len()];
571        let p = vec![0.3, 0.5, 0.75, 1.0];
572        let p = percentiles(&v, &w, &p);
573        assert_eq!(p, vec![3.0, 5.0, 8.0, 10.0]);
574    }
575
576    #[test]
577    fn test_percentiles_weighted() {
578        let v = vec![10., 8., 9., 1., 2., 3., 6., 7., 4., 5.];
579        let w = vec![1., 1., 1., 1., 1., 2., 1., 1., 5., 1.];
580        let p = vec![0.3, 0.5, 0.75, 1.0];
581        let p = percentiles(&v, &w, &p);
582        assert_eq!(p, vec![4.0, 4.0, 7.0, 10.0]);
583    }
584
585    #[test]
586    fn test_map_bin_or_equal() {
587        let v = vec![f64::MIN, 1., 4., 8., 9.];
588        assert_eq!(1, map_bin(&v, &0., &f64::NAN).unwrap());
589        assert_eq!(2, map_bin(&v, &1., &f64::NAN).unwrap());
590        // Less than the bin value of 2, means the value is less
591        // than 4...
592        assert_eq!(2, map_bin(&v, &2., &f64::NAN).unwrap());
593        assert_eq!(3, map_bin(&v, &4., &f64::NAN).unwrap());
594        assert_eq!(5, map_bin(&v, &9., &f64::NAN).unwrap());
595        assert_eq!(5, map_bin(&v, &10., &f64::NAN).unwrap());
596        assert_eq!(2, map_bin(&v, &1., &f64::NAN).unwrap());
597        assert_eq!(0, map_bin(&v, &f64::NAN, &f64::NAN).unwrap());
598    }
599
600    #[test]
601    fn test_map_bin_or_equal_num_miss() {
602        let v = vec![f64::MIN, 1., 4., 8., 9.];
603        assert_eq!(1, map_bin(&v, &0., &-99.).unwrap());
604        assert_eq!(2, map_bin(&v, &1., &-99.).unwrap());
605        // Less than the bin value of 2, means the value is less
606        // than 4...
607        assert_eq!(2, map_bin(&v, &2., &-99.).unwrap());
608        assert_eq!(3, map_bin(&v, &4., &-99.).unwrap());
609        assert_eq!(5, map_bin(&v, &9., &-99.).unwrap());
610        assert_eq!(5, map_bin(&v, &10., &-99.).unwrap());
611        assert_eq!(2, map_bin(&v, &1., &-99.).unwrap());
612        assert_eq!(0, map_bin(&v, &-99., &-99.).unwrap());
613    }
614
615    #[test]
616    fn test_missing_compare() {
617        assert_eq!(missing_compare(&10, 0, true), Ordering::Less);
618        assert_eq!(missing_compare(&10, 0, false), Ordering::Greater);
619        assert_eq!(missing_compare(&10, 11, true), Ordering::Less);
620        assert_eq!(missing_compare(&10, 1, true), Ordering::Greater);
621    }
622
623    #[test]
624    fn test_pivot() {
625        fn pivot_assert(
626            f: &[u16],
627            idx: &[usize],
628            split_i: usize,
629            missing_right: bool,
630            split_val: u16,
631        ) {
632            if missing_right {
633                for i in 0..split_i {
634                    assert!((f[idx[i]] < split_val) && f[idx[i]] != 0);
635                }
636                for i in idx[split_i..].iter() {
637                    assert!((f[*i] >= split_val) || (f[*i] == 0));
638                }
639            } else {
640                for i in 0..split_i {
641                    assert!((f[idx[i]] < split_val) || (f[idx[i]] == 0));
642                }
643                for i in idx[split_i..].iter() {
644                    assert!((f[*i] >= split_val) || (f[*i] != 0));
645                }
646            }
647        }
648
649        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
650        let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
651        let split_i = pivot_on_split(&mut idx, &f, 10, true);
652        pivot_assert(&f, &idx, split_i, true, 10);
653
654        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
655        let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
656        let split_i = pivot_on_split(&mut idx, &f, 10, false);
657        pivot_assert(&f, &idx, split_i, false, 10);
658
659        // Test Minimum value...
660        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
661        let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
662        let split_i = pivot_on_split(&mut idx, &f, 1, true);
663        pivot_assert(&f, &idx, split_i, true, 1);
664
665        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
666        let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
667        let split_i = pivot_on_split(&mut idx, &f, 1, false);
668        pivot_assert(&f, &idx, split_i, false, 1);
669
670        // Test Maximum value...
671        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
672        let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
673        let split_i = pivot_on_split(&mut idx, &f, 19, true);
674        pivot_assert(&f, &idx, split_i, true, 19);
675
676        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
677        let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14];
678        let split_i = pivot_on_split(&mut idx, &f, 19, false);
679        pivot_assert(&f, &idx, split_i, false, 19);
680
681        // Random tests... right...
682        // With missing
683        let index = (0..100).collect::<Vec<usize>>();
684        let mut rng = StdRng::seed_from_u64(0);
685        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
686        let mut idx = index
687            .choose_multiple(&mut rng, 73)
688            .copied()
689            .collect::<Vec<usize>>();
690        let split_i = pivot_on_split(&mut idx, &f, 7, true);
691        pivot_assert(&f, &idx, split_i, true, 7);
692
693        // Already sorted...
694        let split_i = pivot_on_split(&mut idx, &f, 7, true);
695        pivot_assert(&f, &idx, split_i, true, 7);
696
697        // Reversed
698        idx.reverse();
699        let split_i = pivot_on_split(&mut idx, &f, 7, true);
700        pivot_assert(&f, &idx, split_i, true, 7);
701
702        // Without missing...
703        let index = (0..100).collect::<Vec<usize>>();
704        let mut rng = StdRng::seed_from_u64(0);
705        let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
706        let mut idx = index
707            .choose_multiple(&mut rng, 73)
708            .copied()
709            .collect::<Vec<usize>>();
710        let split_i = pivot_on_split(&mut idx, &f, 5, true);
711        pivot_assert(&f, &idx, split_i, true, 5);
712
713        // Using max...
714        let index = (0..100).collect::<Vec<usize>>();
715        let mut rng = StdRng::seed_from_u64(0);
716        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
717        let mut idx = index
718            .choose_multiple(&mut rng, 73)
719            .copied()
720            .collect::<Vec<usize>>();
721        let sv = idx.iter().map(|i| f[*i]).max().unwrap();
722        let split_i = pivot_on_split(&mut idx, &f, sv, true);
723        pivot_assert(&f, &idx, split_i, true, sv);
724
725        // Using non-0 minimum...
726        let index = (0..100).collect::<Vec<usize>>();
727        let mut rng = StdRng::seed_from_u64(0);
728        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
729        let mut idx = index
730            .choose_multiple(&mut rng, 73)
731            .copied()
732            .collect::<Vec<usize>>();
733        let sv = idx
734            .iter()
735            .filter(|i| f[**i] > 0)
736            .map(|i| f[*i])
737            .min()
738            .unwrap();
739        let split_i = pivot_on_split(&mut idx, &f, sv, true);
740        pivot_assert(&f, &idx, split_i, true, sv);
741
742        // Using non-0 minimum with no missing...
743        let index = (0..100).collect::<Vec<usize>>();
744        let mut rng = StdRng::seed_from_u64(0);
745        let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
746        let mut idx = index
747            .choose_multiple(&mut rng, 73)
748            .copied()
749            .collect::<Vec<usize>>();
750        let sv = idx
751            .iter()
752            .filter(|i| f[**i] > 0)
753            .map(|i| f[*i])
754            .min()
755            .unwrap();
756        let split_i = pivot_on_split(&mut idx, &f, sv, true);
757        pivot_assert(&f, &idx, split_i, true, sv);
758
759        // Left
760        let index = (0..100).collect::<Vec<usize>>();
761        let mut rng = StdRng::seed_from_u64(0);
762        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
763        let mut idx = index
764            .choose_multiple(&mut rng, 73)
765            .copied()
766            .collect::<Vec<usize>>();
767        let split_i = pivot_on_split(&mut idx, &f, 7, false);
768        pivot_assert(&f, &idx, split_i, false, 7);
769
770        // Already sorted...
771        let split_i = pivot_on_split(&mut idx, &f, 7, false);
772        pivot_assert(&f, &idx, split_i, false, 7);
773
774        // Reversed
775        idx.reverse();
776        let split_i = pivot_on_split(&mut idx, &f, 7, false);
777        pivot_assert(&f, &idx, split_i, false, 7);
778
779        // Without missing...
780        let index = (0..100).collect::<Vec<usize>>();
781        let mut rng = StdRng::seed_from_u64(0);
782        let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
783        let mut idx = index
784            .choose_multiple(&mut rng, 73)
785            .copied()
786            .collect::<Vec<usize>>();
787        let split_i = pivot_on_split(&mut idx, &f, 5, false);
788        pivot_assert(&f, &idx, split_i, false, 5);
789
790        // Using max...
791        let index = (0..100).collect::<Vec<usize>>();
792        let mut rng = StdRng::seed_from_u64(0);
793        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
794        let mut idx = index
795            .choose_multiple(&mut rng, 73)
796            .copied()
797            .collect::<Vec<usize>>();
798        let sv = idx.iter().map(|i| f[*i]).max().unwrap();
799        let split_i = pivot_on_split(&mut idx, &f, sv, false);
800        pivot_assert(&f, &idx, split_i, false, sv);
801
802        // Using non-0 minimum...
803        let index = (0..100).collect::<Vec<usize>>();
804        let mut rng = StdRng::seed_from_u64(0);
805        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
806        let mut idx = index
807            .choose_multiple(&mut rng, 73)
808            .copied()
809            .collect::<Vec<usize>>();
810        let sv = idx
811            .iter()
812            .filter(|i| f[**i] > 0)
813            .map(|i| f[*i])
814            .min()
815            .unwrap();
816        let split_i = pivot_on_split(&mut idx, &f, sv, false);
817        pivot_assert(&f, &idx, split_i, false, sv);
818
819        // Using non-0 minimum with no missing...
820        let index = (0..100).collect::<Vec<usize>>();
821        let mut rng = StdRng::seed_from_u64(0);
822        let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
823        let mut idx = index
824            .choose_multiple(&mut rng, 73)
825            .copied()
826            .collect::<Vec<usize>>();
827        let sv = idx
828            .iter()
829            .filter(|i| f[**i] > 0)
830            .map(|i| f[*i])
831            .min()
832            .unwrap();
833        let split_i = pivot_on_split(&mut idx, &f, sv, false);
834        pivot_assert(&f, &idx, split_i, false, sv);
835    }
836
837    #[test]
838    fn test_pivot_missing() {
839        fn pivot_missing_assert(
840            split_value: u16,
841            idx: &[usize],
842            f: &[u16],
843            split_i: &(usize, usize),
844        ) {
845            // Check they are lower than..
846            for i in 0..split_i.1 {
847                assert!(f[idx[i]] < split_value);
848            }
849            // Check missing got moved
850            for i in 0..split_i.0 {
851                assert!(f[idx[i]] == 0);
852            }
853            // Check none are less than...
854            for i in split_i.1..(idx.len()) {
855                assert!(!(f[idx[i]] < split_value));
856            }
857            // Check none other are missing...
858            for i in split_i.0..(idx.len()) {
859                assert!(f[idx[i]] != 0);
860            }
861        }
862        // TODO: Add more tests for this...
863        // Using minimum value...
864        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
865        let f = vec![15, 10, 10, 0, 3, 0, 0, 9, 3, 5, 2, 6, 13, 19, 14];
866        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 1);
867        // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
868        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
869        pivot_missing_assert(1, &idx, &f, &split_i);
870
871        // Higher value...
872        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
873        let f = vec![15, 10, 10, 0, 3, 0, 0, 9, 3, 5, 2, 6, 13, 19, 14];
874        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
875        //let split_i = pivot_on_split(&mut idx, &f, 10, false);
876        // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
877        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
878        pivot_missing_assert(10, &idx, &f, &split_i);
879
880        // Run it again, and ensure it works on an already sorted list...
881        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
882        //let split_i = pivot_on_split(&mut idx, &f, 10, false);
883        // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
884        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
885        pivot_missing_assert(10, &idx, &f, &split_i);
886
887        // Run it again, and ensure it works on reversed list...
888        idx.reverse();
889        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
890        //let split_i = pivot_on_split(&mut idx, &f, 10, false);
891        // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
892        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
893        pivot_missing_assert(10, &idx, &f, &split_i);
894
895        // Small test done with python
896        let mut idx = vec![0, 1, 2, 3, 4, 5];
897        let f = vec![1, 0, 1, 3, 0, 4];
898        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 2);
899        // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
900        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
901        pivot_missing_assert(2, &idx, &f, &split_i);
902
903        // Ensure it works on all missing...
904        // let mut idx = vec![0, 1, 2, 3, 4, 5];
905        // let f: Vec<u16> = vec![3; idx.len()];
906        // let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 2);
907        // // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
908        // // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
909        // pivot_missing_assert(2, &idx, &f, &split_i);
910
911        // Check if none missing...
912        // TODO: Add more tests for this...
913        let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7];
914        let f = vec![15, 10, 10, 2, 3, 5, 7, 9, 3, 5, 2, 6, 13, 19, 14];
915        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
916        //let split_i = pivot_on_split(&mut idx, &f, 10, false);
917        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
918        // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
919        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
920        pivot_missing_assert(10, &idx, &f, &split_i);
921
922        // Random tests...
923        // With missing
924        let index = (0..100).collect::<Vec<usize>>();
925        let mut rng = StdRng::seed_from_u64(0);
926        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
927        let mut idx = index
928            .choose_multiple(&mut rng, 73)
929            .copied()
930            .collect::<Vec<usize>>();
931        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
932        pivot_missing_assert(10, &idx, &f, &split_i);
933
934        // Already sorted...
935        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 10);
936        pivot_missing_assert(10, &idx, &f, &split_i);
937
938        // Without missing...
939        let index = (0..100).collect::<Vec<usize>>();
940        let mut rng = StdRng::seed_from_u64(0);
941        let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
942        let mut idx = index
943            .choose_multiple(&mut rng, 73)
944            .copied()
945            .collect::<Vec<usize>>();
946        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, 5);
947        // let map_ = idx.iter().map(|i| f[*i]).collect::<Vec<u16>>();
948        // println!("{:?}, {:?}, {:?}", split_i, idx, map_);
949        pivot_missing_assert(5, &idx, &f, &split_i);
950
951        // Using max...
952        let index = (0..100).collect::<Vec<usize>>();
953        let mut rng = StdRng::seed_from_u64(0);
954        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
955        let mut idx = index
956            .choose_multiple(&mut rng, 73)
957            .copied()
958            .collect::<Vec<usize>>();
959        let sv = idx.iter().map(|i| f[*i]).max().unwrap();
960        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, sv);
961        pivot_missing_assert(sv, &idx, &f, &split_i);
962
963        // Using non-0 minimum...
964        let index = (0..100).collect::<Vec<usize>>();
965        let mut rng = StdRng::seed_from_u64(0);
966        let f = (0..100).map(|_| rng.gen_range(0..15)).collect::<Vec<u16>>();
967        let mut idx = index
968            .choose_multiple(&mut rng, 73)
969            .copied()
970            .collect::<Vec<usize>>();
971        let sv = idx
972            .iter()
973            .filter(|i| f[**i] > 0)
974            .map(|i| f[*i])
975            .min()
976            .unwrap();
977        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, sv);
978        pivot_missing_assert(sv, &idx, &f, &split_i);
979
980        // Using non-0 minimum with no missing...
981        let index = (0..100).collect::<Vec<usize>>();
982        let mut rng = StdRng::seed_from_u64(0);
983        let f = (0..100).map(|_| rng.gen_range(1..15)).collect::<Vec<u16>>();
984        let mut idx = index
985            .choose_multiple(&mut rng, 73)
986            .copied()
987            .collect::<Vec<usize>>();
988        let sv = idx
989            .iter()
990            .filter(|i| f[**i] > 0)
991            .map(|i| f[*i])
992            .min()
993            .unwrap();
994        let split_i = pivot_on_split_exclude_missing(&mut idx, &f, sv);
995        pivot_missing_assert(sv, &idx, &f, &split_i);
996    }
997
998    #[test]
999    fn test_fast_f64_sum() {
1000        let records = 300000;
1001        let vec = vec![0.23500371; records];
1002        assert_ne!(vec.iter().sum::<f32>(), vec[0] * (records as f32));
1003        assert_eq!(vec[0] * (records as f32), fast_f64_sum(&vec));
1004        // println!("Sum Result: {}", vec.iter().sum::<f32>());
1005        // println!("Multiplication Results {}", vec[0] * (records as f32));
1006        // println!("f64_sum Results {}", f64_sum(&vec));
1007    }
1008
1009    #[test]
1010    fn test_fmt_vec_output() {
1011        let v = Vec::<f32>::new();
1012        assert_eq!(fmt_vec_output(&v), String::from(""));
1013        let v: Vec<f32> = vec![0.1, 1.0];
1014        assert_eq!(fmt_vec_output(&v), String::from("0.1000, 1.0000"));
1015    }
1016}