use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone, Default)]
pub enum KMeansInit<R: Runtime<DType = DType>> {
#[default]
KMeansPlusPlus,
Random,
Points(Tensor<R>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum KMeansAlgorithm {
#[default]
Lloyd,
Elkan,
}
#[derive(Debug, Clone)]
pub struct KMeansOptions<R: Runtime<DType = DType>> {
pub n_clusters: usize,
pub max_iter: usize,
pub tol: f64,
pub n_init: usize,
pub init: KMeansInit<R>,
pub algorithm: KMeansAlgorithm,
}
impl<R: Runtime<DType = DType>> Default for KMeansOptions<R> {
fn default() -> Self {
Self {
n_clusters: 8,
max_iter: 300,
tol: 1e-4,
n_init: 10,
init: KMeansInit::KMeansPlusPlus,
algorithm: KMeansAlgorithm::Lloyd,
}
}
}
#[derive(Debug, Clone)]
pub struct KMeansResult<R: Runtime<DType = DType>> {
pub centroids: Tensor<R>,
pub labels: Tensor<R>,
pub inertia: Tensor<R>,
pub n_iter: usize,
}
pub trait KMeansAlgorithms<R: Runtime<DType = DType>> {
fn kmeans(&self, data: &Tensor<R>, options: &KMeansOptions<R>) -> Result<KMeansResult<R>>;
fn kmeans_predict(&self, centroids: &Tensor<R>, data: &Tensor<R>) -> Result<Tensor<R>>;
}