hdbscan 0.12.0

HDBSCAN clustering in pure Rust. A huge improvement on DBSCAN, capable of identifying clusters of varying densities.
Documentation
use crate::data_wrappers::MSTEdge;
use crate::DistanceMetric;
use num_traits::Float;

pub(crate) trait MinSpanningTree<'a, T> {
    fn compute(&self) -> Vec<MSTEdge<T>>;
}

#[derive(Clone, Debug)]
struct MinSpanningTreeCommon<'a, T> {
    data: &'a [Vec<T>],
    dist_metric: DistanceMetric,
    core_distances: &'a [T],
    n_samples: usize,
}

impl<'a, T: Float> MinSpanningTreeCommon<'a, T> {
    fn new(data: &'a [Vec<T>], dist_metric: DistanceMetric, core_distances: &'a [T]) -> Self {
        MinSpanningTreeCommon {
            data,
            dist_metric,
            core_distances,
            n_samples: data.len(),
        }
    }

    fn calc_mutual_reachability_dist(&self, a: usize, b: usize) -> T {
        let core_dist_a = self.core_distances[a];
        let core_dist_b = self.core_distances[b];
        let dist_a_b = if self.dist_metric == DistanceMetric::Precalculated {
            self.data[a][b]
        } else {
            self.dist_metric.calc_dist(&self.data[a], &self.data[b])
        };
        core_dist_a.max(core_dist_b).max(dist_a_b)
    }

    fn sort_mst_by_dist(&self, min_spanning_tree: &mut [MSTEdge<T>]) {
        min_spanning_tree
            .sort_by(|a, b| a.distance.partial_cmp(&b.distance).expect("Invalid floats"));
    }
}

#[cfg(feature = "serial")]
pub(crate) mod serial {
    use super::*;
    use crate::data_wrappers::MSTEdge;
    use num_traits::Float;

    #[derive(Clone, Debug)]
    pub(crate) struct PrimsMinSpanningTree<'a, T> {
        common: MinSpanningTreeCommon<'a, T>,
    }

    impl<'a, T: Float> PrimsMinSpanningTree<'a, T> {
        pub(crate) fn new(
            data: &'a [Vec<T>],
            dist_metric: DistanceMetric,
            core_distances: &'a [T],
        ) -> Self {
            let common = MinSpanningTreeCommon::new(data, dist_metric, core_distances);
            PrimsMinSpanningTree { common }
        }
    }

    impl<'a, T: Float> MinSpanningTree<'a, T> for PrimsMinSpanningTree<'a, T> {
        fn compute(&self) -> Vec<MSTEdge<T>> {
            let n_samples = self.common.n_samples;

            let mut in_tree = vec![false; n_samples];
            let mut distances = vec![T::infinity(); n_samples];
            distances[0] = T::zero();

            let mut mst = Vec::with_capacity(n_samples);

            let mut left_node_id = 0;
            let mut right_node_id = 0;

            for _ in 1..n_samples {
                in_tree[left_node_id] = true;
                let mut current_min_dist = T::infinity();

                for i in 0..n_samples {
                    if in_tree[i] {
                        continue;
                    }
                    let mrd = self.common.calc_mutual_reachability_dist(left_node_id, i);
                    if mrd < distances[i] {
                        distances[i] = mrd;
                    }
                    if distances[i] < current_min_dist {
                        right_node_id = i;
                        current_min_dist = distances[i];
                    }
                }
                mst.push(MSTEdge {
                    left_node_id,
                    right_node_id,
                    distance: current_min_dist,
                });
                left_node_id = right_node_id;
            }
            self.common.sort_mst_by_dist(&mut mst);
            mst
        }
    }
}

#[cfg(feature = "parallel")]
pub(crate) mod parallel {
    use super::*;
    use crate::data_wrappers::MSTEdge;
    use num_traits::Float;
    use rayon::prelude::*;

    #[derive(Clone, Debug)]
    pub(crate) struct PrimsMinSpanningTreePar<'a, T> {
        common: MinSpanningTreeCommon<'a, T>,
    }

    impl<'a, T: Float + Send + Sync> PrimsMinSpanningTreePar<'a, T> {
        pub(crate) fn new(
            data: &'a [Vec<T>],
            dist_metric: DistanceMetric,
            core_distances: &'a [T],
        ) -> Self {
            let common = MinSpanningTreeCommon::new(data, dist_metric, core_distances);
            PrimsMinSpanningTreePar { common }
        }
    }

    impl<'a, T: Float + Send + Sync> MinSpanningTree<'a, T> for PrimsMinSpanningTreePar<'a, T> {
        fn compute(&self) -> Vec<MSTEdge<T>> {
            let n_samples = self.common.n_samples;

            let mut in_tree = vec![false; n_samples];
            let mut distances = vec![T::infinity(); n_samples];
            distances[0] = T::zero();

            let mut mst = Vec::with_capacity(n_samples);

            let mut left_node_id = 0;

            for _ in 1..n_samples {
                in_tree[left_node_id] = true;

                let (min_idx, min_dist) = distances
                    .par_iter_mut()
                    .enumerate()
                    .filter_map(|(i, dist)| {
                        if in_tree[i] {
                            None
                        } else {
                            let mrd = self.common.calc_mutual_reachability_dist(left_node_id, i);
                            if mrd < *dist {
                                *dist = mrd;
                            }
                            Some((i, *dist))
                        }
                    })
                    .min_by(|(_, dist_a), (_, dist_b)| {
                        dist_a.partial_cmp(dist_b).expect("Invalid floats")
                    })
                    .expect("Malformed distance array");

                mst.push(MSTEdge {
                    left_node_id,
                    right_node_id: min_idx,
                    distance: min_dist,
                });
                left_node_id = min_idx;
            }

            self.common.sort_mst_by_dist(&mut mst);
            mst
        }
    }
}