use crate::error::{ClusterError, ClusterResult};
use torsh_tensor::Tensor;
pub trait ClusterValidation {
fn validate(&self, data: &Tensor) -> ClusterResult<()>;
}
pub fn validate_cluster_input(data: &Tensor) -> ClusterResult<()> {
if data.shape().dims().is_empty() {
return Err(ClusterError::EmptyDataset);
}
let n_samples = data.shape().dims()[0];
let n_features = data.shape().dims().get(1).copied().unwrap_or(1);
if n_samples == 0 {
return Err(ClusterError::EmptyDataset);
}
if n_features == 0 {
return Err(ClusterError::InvalidInput(
"Data must have at least one feature".to_string(),
));
}
if let Ok(data_vec) = data.to_vec() {
for &val in &data_vec {
if val.is_nan() || val.is_infinite() {
return Err(ClusterError::InvalidInput(
"Data contains NaN or infinite values".to_string(),
));
}
}
}
Ok(())
}
pub fn validate_n_clusters(n_clusters: usize, n_samples: usize) -> ClusterResult<()> {
if n_clusters == 0 {
return Err(ClusterError::InvalidClusters(n_clusters));
}
if n_clusters > n_samples {
return Err(ClusterError::InvalidClusters(n_clusters));
}
Ok(())
}
pub fn validate_clustering_params(
data: &Tensor,
n_clusters: usize,
max_iters: Option<usize>,
tolerance: Option<f64>,
) -> ClusterResult<()> {
validate_cluster_input(data)?;
let n_samples = data.shape().dims()[0];
validate_n_clusters(n_clusters, n_samples)?;
if let Some(max_iters) = max_iters {
if max_iters == 0 {
return Err(ClusterError::ConfigError(
"max_iters must be positive".to_string(),
));
}
}
if let Some(tolerance) = tolerance {
if tolerance < 0.0 {
return Err(ClusterError::ConfigError(
"tolerance must be non-negative".to_string(),
));
}
}
Ok(())
}
pub fn validate_dbscan_params(eps: f64, min_samples: usize) -> ClusterResult<()> {
if eps <= 0.0 {
return Err(ClusterError::InvalidEpsilon(eps));
}
if min_samples == 0 {
return Err(ClusterError::InvalidMinSamples(min_samples));
}
Ok(())
}
pub fn validate_distance_metric(metric: &str) -> ClusterResult<()> {
match metric {
"euclidean" | "manhattan" | "cosine" | "hamming" | "jaccard" => Ok(()),
_ => Err(ClusterError::InvalidDistanceMetric(metric.to_string())),
}
}
pub fn validate_linkage(linkage: &str) -> ClusterResult<()> {
match linkage {
"ward" | "complete" | "average" | "single" => Ok(()),
_ => Err(ClusterError::InvalidLinkage(linkage.to_string())),
}
}
pub fn validate_affinity_matrix(affinity: &Tensor) -> ClusterResult<()> {
let shape = affinity.shape();
let dims = shape.dims();
if dims.len() != 2 || dims[0] != dims[1] {
return Err(ClusterError::InvalidAffinityMatrix(
"Affinity matrix must be square".to_string(),
));
}
if let Ok(affinity_vec) = affinity.to_vec() {
for &val in &affinity_vec {
if val.is_nan() || val.is_infinite() || val < 0.0 {
return Err(ClusterError::InvalidAffinityMatrix(
"Affinity matrix contains invalid values".to_string(),
));
}
}
}
Ok(())
}
pub fn validate_feature_dimensions(
data1: &Tensor,
data2: &Tensor,
_context: &str,
) -> ClusterResult<()> {
let shape1 = data1.shape();
let dims1 = shape1.dims();
let shape2 = data2.shape();
let dims2 = shape2.dims();
let n_features1 = dims1.get(1).copied().unwrap_or(1);
let n_features2 = dims2.get(1).copied().unwrap_or(1);
if n_features1 != n_features2 {
return Err(ClusterError::InvalidFeatureDimension {
expected: n_features1,
actual: n_features2,
});
}
Ok(())
}
pub fn validate_cluster_labels(labels: &Tensor, n_samples: usize) -> ClusterResult<()> {
let label_shape = labels.shape();
let label_dims = label_shape.dims();
if label_dims.len() != 1 || label_dims[0] != n_samples {
return Err(ClusterError::InvalidAssignment(format!(
"Labels must be 1D array with {} elements",
n_samples
)));
}
if let Ok(labels_vec) = labels.to_vec() {
for &label in &labels_vec {
if label.is_nan() || label.is_infinite() || label < -1.0 {
return Err(ClusterError::InvalidAssignment(
"Labels contain invalid values".to_string(),
));
}
}
}
Ok(())
}
pub fn check_numerical_stability(values: &[f64], context: &str) -> ClusterResult<()> {
for &val in values {
if val.is_nan() {
return Err(ClusterError::NumericalInstability(format!(
"NaN detected in {}",
context
)));
}
if val.is_infinite() {
return Err(ClusterError::NumericalInstability(format!(
"Infinite value detected in {}",
context
)));
}
if val.abs() < 1e-15 && val != 0.0 {
return Err(ClusterError::NumericalInstability(format!(
"Very small value detected in {}: {}",
context, val
)));
}
if val.abs() > 1e15 {
return Err(ClusterError::NumericalInstability(format!(
"Very large value detected in {}: {}",
context, val
)));
}
}
Ok(())
}