use crate::model::{ModelError, ModelResult};
use crate::settings::{ClusteringAlgorithmName, ClusteringSettings};
use crate::{
algorithms::ClusteringAlgorithm,
metrics::{ClusterMetrics, HCVScore},
};
use comfy_table::{
Attribute, Cell, Table, modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL,
};
use smartcore::linalg::basic::arrays::{Array1, Array2};
use smartcore::numbers::{basenum::Number, floatnum::FloatNumber, realnum::RealNumber};
use std::collections::BTreeSet;
use std::fmt::{Display, Formatter};
pub struct ClusteringModel<INPUT, CLUSTER, InputArray, ClusterArray>
where
INPUT: RealNumber + FloatNumber,
CLUSTER: Number + Ord,
InputArray: Array2<INPUT> + Clone,
ClusterArray: Array1<CLUSTER> + Clone + std::iter::FromIterator<CLUSTER>,
{
settings: ClusteringSettings,
x_train: InputArray,
trained_algorithms: Vec<TrainedClusteringAlgorithm<INPUT, CLUSTER, InputArray, ClusterArray>>,
}
impl<INPUT, CLUSTER, InputArray, ClusterArray>
ClusteringModel<INPUT, CLUSTER, InputArray, ClusterArray>
where
INPUT: RealNumber + FloatNumber,
CLUSTER: Number + Ord,
InputArray: Array2<INPUT> + Clone,
ClusterArray: Array1<CLUSTER> + Clone + std::iter::FromIterator<CLUSTER>,
{
pub fn new(x: InputArray, settings: ClusteringSettings) -> Self {
Self {
settings,
x_train: x,
trained_algorithms: Vec::new(),
}
}
pub fn train(&mut self) {
self.trained_algorithms.clear();
for algorithm_name in self.settings.selected_algorithms() {
let algorithm = ClusteringAlgorithm::from_name(algorithm_name);
let fitted = algorithm.fit(&self.x_train, &self.settings);
let mut trained = TrainedClusteringAlgorithm::new(algorithm_name, fitted);
trained.compute_baseline(&self.x_train, &self.settings);
self.trained_algorithms.push(trained);
}
}
#[must_use]
pub fn trained_algorithm_names(&self) -> Vec<ClusteringAlgorithmName> {
self.trained_algorithms
.iter()
.map(|entry| entry.algorithm_name)
.collect()
}
pub fn predict(&self, x: &InputArray) -> ModelResult<ClusterArray> {
let algorithm = self
.trained_algorithms
.first()
.ok_or(ModelError::NotTrained)?;
algorithm.predict(x, &self.settings)
}
pub fn predict_with(
&self,
algorithm: ClusteringAlgorithmName,
x: &InputArray,
) -> ModelResult<ClusterArray> {
let trained = self
.trained_algorithms
.iter()
.find(|entry| entry.algorithm_name == algorithm)
.ok_or(ModelError::NotTrained)?;
trained.predict(x, &self.settings)
}
pub fn evaluate(&mut self, truth: &ClusterArray) {
for trained in &mut self.trained_algorithms {
let predicted = trained
.predict(&self.x_train, &self.settings)
.expect("model must be trained before evaluation");
let mut scores = ClusterMetrics::<CLUSTER>::hcv_score();
scores.compute(truth, &predicted);
trained.metrics = Some(scores);
}
}
}
impl<INPUT, CLUSTER, InputArray, ClusterArray> Display
for ClusteringModel<INPUT, CLUSTER, InputArray, ClusterArray>
where
INPUT: RealNumber + FloatNumber,
CLUSTER: Number + Ord,
InputArray: Array2<INPUT> + Clone,
ClusterArray: Array1<CLUSTER> + Clone + std::iter::FromIterator<CLUSTER>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut table = Table::new();
table.load_preset(UTF8_FULL);
table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
table.set_header(vec![
Cell::new("Model").add_attribute(Attribute::Bold),
Cell::new("Clusters").add_attribute(Attribute::Bold),
Cell::new("Noise").add_attribute(Attribute::Bold),
Cell::new("Homogeneity").add_attribute(Attribute::Bold),
Cell::new("Completeness").add_attribute(Attribute::Bold),
Cell::new("V-Measure").add_attribute(Attribute::Bold),
]);
if self.trained_algorithms.is_empty() {
for algorithm_name in self.settings.selected_algorithms() {
table.add_row(vec![
format!("{algorithm_name} (untrained)"),
"-".to_string(),
"-".to_string(),
"-".to_string(),
"-".to_string(),
"-".to_string(),
]);
}
} else {
for entry in &self.trained_algorithms {
table.add_row(entry.display_row());
}
}
write!(f, "{table}")
}
}
#[derive(Debug, Clone, Copy)]
struct ClusterBaseline {
cluster_count: usize,
noise_count: usize,
}
impl ClusterBaseline {
const fn new(cluster_count: usize, noise_count: usize) -> Self {
Self {
cluster_count,
noise_count,
}
}
}
struct TrainedClusteringAlgorithm<INPUT, CLUSTER, InputArray, ClusterArray>
where
INPUT: RealNumber + FloatNumber,
CLUSTER: Number + Ord,
InputArray: Array2<INPUT> + Clone,
ClusterArray: Array1<CLUSTER> + Clone + std::iter::FromIterator<CLUSTER>,
{
algorithm_name: ClusteringAlgorithmName,
algorithm: ClusteringAlgorithm<INPUT, CLUSTER, InputArray, ClusterArray>,
metrics: Option<HCVScore<CLUSTER>>,
baseline: Option<ClusterBaseline>,
}
impl<INPUT, CLUSTER, InputArray, ClusterArray>
TrainedClusteringAlgorithm<INPUT, CLUSTER, InputArray, ClusterArray>
where
INPUT: RealNumber + FloatNumber,
CLUSTER: Number + Ord,
InputArray: Array2<INPUT> + Clone,
ClusterArray: Array1<CLUSTER> + Clone + std::iter::FromIterator<CLUSTER>,
{
fn new(
algorithm_name: ClusteringAlgorithmName,
algorithm: ClusteringAlgorithm<INPUT, CLUSTER, InputArray, ClusterArray>,
) -> Self {
Self {
algorithm_name,
algorithm,
metrics: None,
baseline: None,
}
}
fn predict(&self, x: &InputArray, settings: &ClusteringSettings) -> ModelResult<ClusterArray> {
self.algorithm.predict(x, settings)
}
fn compute_baseline(&mut self, x: &InputArray, settings: &ClusteringSettings) {
let Ok(predictions) = self.predict(x, settings) else {
self.baseline = None;
return;
};
let mut unique_clusters: BTreeSet<CLUSTER> = BTreeSet::new();
let mut noise_count = 0_usize;
for label in predictions.iterator(0) {
let value = *label;
if self.algorithm_name == ClusteringAlgorithmName::DBSCAN && value == CLUSTER::zero() {
noise_count += 1;
} else {
unique_clusters.insert(value);
}
}
self.baseline = Some(ClusterBaseline::new(unique_clusters.len(), noise_count));
}
fn display_row(&self) -> Vec<String> {
let (homogeneity, completeness, v_measure) = if let Some(scores) = &self.metrics {
let format_score = |s: Option<f64>| match s {
Some(val) => format!("{val:.2}"),
None => "-".to_string(),
};
(
format_score(scores.homogeneity()),
format_score(scores.completeness()),
format_score(scores.v_measure()),
)
} else {
("-".to_string(), "-".to_string(), "-".to_string())
};
let (clusters, noise) = if let Some(baseline) = &self.baseline {
(
baseline.cluster_count.to_string(),
baseline.noise_count.to_string(),
)
} else {
("-".to_string(), "-".to_string())
};
vec![
self.algorithm_name.to_string(),
clusters,
noise,
homogeneity,
completeness,
v_measure,
]
}
}