use crate::error::{ClusterError, ClusterResult};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use torsh_tensor::Tensor;
pub trait ClusteringResult: Clone + std::fmt::Debug {
fn labels(&self) -> &Tensor;
fn n_clusters(&self) -> usize;
fn centers(&self) -> Option<&Tensor> {
None
}
fn inertia(&self) -> Option<f64> {
None
}
fn n_iter(&self) -> Option<usize> {
None
}
fn converged(&self) -> bool {
true
}
fn metadata(&self) -> Option<&std::collections::HashMap<String, String>> {
None
}
}
pub trait Fit {
type Result: ClusteringResult;
fn fit(&self, data: &Tensor) -> ClusterResult<Self::Result>;
}
pub trait FitPredict {
type Result: ClusteringResult;
fn fit_predict(&self, data: &Tensor) -> ClusterResult<Self::Result>;
}
pub trait Transform {
fn transform(&self, data: &Tensor) -> ClusterResult<Tensor>;
fn predict_proba(&self, _data: &Tensor) -> ClusterResult<Tensor> {
Err(ClusterError::NotImplemented(
"predict_proba not implemented for this algorithm".to_string(),
))
}
}
pub trait Predict {
fn predict(&self, data: &Tensor) -> ClusterResult<Tensor>;
}
pub trait ClusteringAlgorithm: Fit + FitPredict {
fn name(&self) -> &str;
fn get_params(&self) -> std::collections::HashMap<String, String>;
fn set_params(
&mut self,
params: std::collections::HashMap<String, String>,
) -> ClusterResult<()>;
fn is_fitted(&self) -> bool {
false
}
fn validate_input(&self, data: &Tensor) -> ClusterResult<()> {
if data.shape().dims().is_empty() {
return Err(ClusterError::EmptyDataset);
}
if data.shape().dims()[0] == 0 {
return Err(ClusterError::EmptyDataset);
}
Ok(())
}
fn supported_distance_metrics(&self) -> Vec<&str> {
vec!["euclidean"]
}
fn complexity_info(&self) -> AlgorithmComplexity {
AlgorithmComplexity::default()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct AlgorithmComplexity {
pub time_complexity: String,
pub space_complexity: String,
pub deterministic: bool,
pub online_capable: bool,
pub memory_pattern: MemoryPattern,
}
impl Default for AlgorithmComplexity {
fn default() -> Self {
Self {
time_complexity: "O(n)".to_string(),
space_complexity: "O(n)".to_string(),
deterministic: false,
online_capable: false,
memory_pattern: MemoryPattern::Linear,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum MemoryPattern {
Linear,
Quadratic,
Constant,
Adaptive,
}
pub trait IncrementalClustering {
type Result: ClusteringResult;
fn partial_fit_one(&mut self, point: &Tensor) -> ClusterResult<()>;
fn partial_fit(&mut self, data: &Tensor) -> ClusterResult<()>;
fn current_result(&self) -> ClusterResult<Self::Result>;
fn reset(&mut self);
}
pub trait HierarchicalClustering {
type Tree;
fn get_tree(&self) -> Option<&Self::Tree> {
None
}
fn extract_flat_clustering(&self, n_clusters: usize) -> ClusterResult<Tensor>;
fn extract_clustering_by_distance(&self, threshold: f64) -> ClusterResult<Tensor>;
}
pub trait ProbabilisticClustering {
fn membership_probabilities(&self, data: &Tensor) -> ClusterResult<Tensor>;
fn cluster_parameters(&self) -> ClusterResult<Vec<std::collections::HashMap<String, Tensor>>>;
fn log_likelihood(&self, data: &Tensor) -> ClusterResult<f64>;
fn sample(&self, n_samples: usize) -> ClusterResult<Tensor>;
}
pub trait DensityBasedClustering {
fn core_points(&self) -> Option<&Tensor> {
None
}
fn noise_points(&self) -> Option<&Tensor> {
None
}
fn density_estimates(&self, data: &Tensor) -> ClusterResult<Tensor>;
}
pub trait ClusteringConfig: Clone + std::fmt::Debug {
fn validate(&self) -> ClusterResult<()>;
fn default() -> Self;
fn merge(&mut self, other: &Self);
}