use crate::core_distances::NnAlgorithm;
use crate::distance::DistanceMetric;
use num_traits::Num;
use std::fmt::Display;
const MIN_CLUSTER_SIZE_DEFAULT: usize = 5;
const MAX_CLUSTER_SIZE_DEFAULT: usize = usize::MAX; const ALLOW_SINGLE_CLUSTER_DEFAULT: bool = false;
const EPSILON_DEFAULT: f64 = 0.0;
const DISTANCE_METRIC_DEFAULT: DistanceMetric = DistanceMetric::Euclidean;
const NN_ALGORITHM_DEFAULT: NnAlgorithm = NnAlgorithm::Auto;
const MIN_CLUSTER_SIZE_MINIMUM: usize = 2;
const MAX_CLUSTER_SIZE_MINIMUM: usize = 2;
const MIN_SAMPLES_MINIMUM: usize = 1;
const EPSILON_MINIMUM: f64 = 0.0;
#[derive(Debug, Clone, PartialEq)]
pub struct HdbscanHyperParams {
pub(crate) min_cluster_size: usize,
pub(crate) max_cluster_size: usize,
pub(crate) allow_single_cluster: bool,
pub(crate) min_samples: usize,
pub(crate) epsilon: f64,
pub(crate) dist_metric: DistanceMetric,
pub(crate) nn_algo: NnAlgorithm,
}
#[derive(Debug, Clone, PartialEq)]
pub struct HyperParamBuilder {
min_cluster_size: Option<usize>,
max_cluster_size: Option<usize>,
allow_single_cluster: Option<bool>,
min_samples: Option<usize>,
epsilon: Option<f64>,
dist_metric: Option<DistanceMetric>,
nn_algo: Option<NnAlgorithm>,
}
impl HdbscanHyperParams {
pub(crate) fn default() -> HdbscanHyperParams {
Self::builder().build()
}
pub fn builder() -> HyperParamBuilder {
HyperParamBuilder {
min_cluster_size: None,
max_cluster_size: None,
allow_single_cluster: None,
min_samples: None,
epsilon: None,
dist_metric: None,
nn_algo: None,
}
}
}
impl HyperParamBuilder {
pub fn min_cluster_size(mut self, min_cluster_size: usize) -> HyperParamBuilder {
let valid_min_cluster_size = HyperParamBuilder::validate_input_left_bound(
min_cluster_size,
MIN_CLUSTER_SIZE_MINIMUM,
"min_cluster_size",
);
self.min_cluster_size = Some(valid_min_cluster_size);
self
}
pub fn max_cluster_size(mut self, max_cluster_size: usize) -> HyperParamBuilder {
let valid_max_cluster_size = HyperParamBuilder::validate_input_left_bound(
max_cluster_size,
MAX_CLUSTER_SIZE_MINIMUM,
"max_cluster_size",
);
self.max_cluster_size = Some(valid_max_cluster_size);
self
}
pub fn allow_single_cluster(mut self, allow_single_cluster: bool) -> HyperParamBuilder {
self.allow_single_cluster = Some(allow_single_cluster);
self
}
pub fn min_samples(mut self, min_samples: usize) -> HyperParamBuilder {
let valid_min_samples = HyperParamBuilder::validate_input_left_bound(
min_samples,
MIN_SAMPLES_MINIMUM,
"min_samples",
);
self.min_samples = Some(valid_min_samples);
self
}
pub fn epsilon(mut self, epsilon: f64) -> HyperParamBuilder {
let valid_epsilon =
HyperParamBuilder::validate_input_left_bound(epsilon, EPSILON_MINIMUM, "epsilon");
self.epsilon = Some(valid_epsilon);
self
}
pub fn dist_metric(mut self, dist_metric: DistanceMetric) -> HyperParamBuilder {
self.dist_metric = Some(dist_metric);
self
}
pub fn nn_algorithm(mut self, nn_algorithm: NnAlgorithm) -> HyperParamBuilder {
self.nn_algo = Some(nn_algorithm);
self
}
pub fn build(self) -> HdbscanHyperParams {
let min_cluster_size = self.min_cluster_size.unwrap_or(MIN_CLUSTER_SIZE_DEFAULT);
HdbscanHyperParams {
min_cluster_size,
max_cluster_size: self.max_cluster_size.unwrap_or(MAX_CLUSTER_SIZE_DEFAULT),
allow_single_cluster: self
.allow_single_cluster
.unwrap_or(ALLOW_SINGLE_CLUSTER_DEFAULT),
min_samples: self.min_samples.unwrap_or(min_cluster_size),
epsilon: self.epsilon.unwrap_or(EPSILON_DEFAULT),
dist_metric: self.dist_metric.unwrap_or(DISTANCE_METRIC_DEFAULT),
nn_algo: self.nn_algo.unwrap_or(NN_ALGORITHM_DEFAULT),
}
}
fn validate_input_left_bound<N>(input_param: N, left_bound: N, param: &str) -> N
where
N: Num + PartialOrd + Display,
{
if input_param < left_bound {
println!(
"HDBSCAN_WARNING: {param} ({input_param}) cannot be lower \
than {left_bound}. Set to {left_bound}."
);
left_bound
} else {
input_param
}
}
}