#[test]
fn test_update_centroids_division() {
let data = Matrix::from_vec(
4,
1,
vec![
0.0, 2.0, 10.0, 12.0, ],
)
.expect("Matrix creation should succeed");
let mut kmeans = KMeans::new(2).with_random_state(42);
kmeans.fit(&data).expect("KMeans fit should succeed");
let centroids = kmeans.centroids();
let c0 = centroids.get(0, 0);
let c1 = centroids.get(1, 0);
let has_low = (c0 - 1.0).abs() < 1.0 || (c1 - 1.0).abs() < 1.0;
let has_high = (c0 - 11.0).abs() < 1.0 || (c1 - 11.0).abs() < 1.0;
assert!(has_low, "Should have centroid near 1.0, got {c0} and {c1}");
assert!(
has_high,
"Should have centroid near 11.0, got {c0} and {c1}"
);
}
#[test]
fn test_centroids_converged_squaring_not_division() {
let kmeans = KMeans::new(1).with_tol(1.0);
let old = Matrix::from_vec(1, 1, vec![0.0_f32]).expect("Matrix creation should succeed");
let new = Matrix::from_vec(1, 1, vec![0.5_f32]).expect("Matrix creation should succeed");
assert!(
kmeans.centroids_converged(&old, &new),
"With diff=0.5, diff²=0.25 < tol²=1.0, should converge"
);
let new2 = Matrix::from_vec(1, 1, vec![2.0_f32]).expect("Matrix creation should succeed");
assert!(
!kmeans.centroids_converged(&old, &new2),
"With diff=2.0, diff²=4.0 > tol²=1.0, must NOT converge. If using diff/diff=1.0, test fails."
);
}
#[test]
fn test_centroids_converged_sum_not_multiply() {
let kmeans = KMeans::new(1).with_tol(0.6);
let old = Matrix::from_vec(1, 2, vec![0.0_f32, 0.0]).expect("Matrix creation should succeed");
let new = Matrix::from_vec(1, 2, vec![0.3_f32, 0.4]).expect("Matrix creation should succeed");
assert!(
kmeans.centroids_converged(&old, &new),
"dist²=0.25 < tol²=0.36, should converge"
);
let new2 = Matrix::from_vec(1, 2, vec![0.5_f32, 0.5]).expect("Matrix creation should succeed");
assert!(
!kmeans.centroids_converged(&old, &new2),
"dist²=0.5 > tol²=0.36, must NOT converge"
);
}
#[test]
fn test_centroids_converged_addition_not_squaring() {
let kmeans = KMeans::new(1).with_tol(1.0);
let old = Matrix::from_vec(1, 1, vec![0.0_f32]).expect("Matrix creation should succeed");
let new = Matrix::from_vec(1, 1, vec![0.6_f32]).expect("Matrix creation should succeed");
assert!(
kmeans.centroids_converged(&old, &new),
"diff²=0.36 < tol²=1.0, must converge. If using diff+diff=1.2, test fails."
);
}
#[test]
fn test_centroids_converged_greater_not_equal() {
let kmeans = KMeans::new(1).with_tol(1.0);
let old = Matrix::from_vec(1, 1, vec![0.0_f32]).expect("Matrix creation should succeed");
let new = Matrix::from_vec(1, 1, vec![1.3_f32]).expect("Matrix creation should succeed");
assert!(
!kmeans.centroids_converged(&old, &new),
"dist²=1.69 > tol²=1.0, must NOT converge"
);
}
#[test]
fn test_centroids_converged_greater_not_less() {
let kmeans = KMeans::new(1).with_tol(1.0);
let old = Matrix::from_vec(1, 1, vec![0.0_f32]).expect("Matrix creation should succeed");
let new = Matrix::from_vec(1, 1, vec![0.7_f32]).expect("Matrix creation should succeed");
assert!(
kmeans.centroids_converged(&old, &new),
"dist²=0.49 < tol²=1.0, must converge. If using <, test fails."
);
}
#[test]
fn test_kmeans_save_safetensors_roundtrip() {
use std::fs;
use tempfile::tempdir;
let data = Matrix::from_vec(
6,
2,
vec![1.0, 2.0, 1.5, 1.8, 5.0, 8.0, 8.0, 8.0, 1.0, 0.6, 9.0, 11.0],
)
.expect("Matrix creation should succeed");
let mut kmeans = KMeans::new(2).with_random_state(42).with_max_iter(100);
kmeans.fit(&data).expect("KMeans fit should succeed");
let dir = tempdir().expect("tempdir creation should succeed");
let path = dir.path().join("kmeans_model.safetensors");
kmeans
.save_safetensors(&path)
.expect("SafeTensors save should succeed");
assert!(path.exists(), "SafeTensors file should exist");
let loaded = KMeans::load_safetensors(&path).expect("SafeTensors load should succeed");
assert_eq!(loaded.n_clusters(), kmeans.n_clusters());
assert_eq!(loaded.max_iter(), kmeans.max_iter());
assert!((loaded.tol() - kmeans.tol()).abs() < 1e-6);
assert_eq!(loaded.random_state(), kmeans.random_state());
assert!((loaded.inertia() - kmeans.inertia()).abs() < 1e-3);
assert_eq!(loaded.n_iter(), kmeans.n_iter());
let orig_centroids = kmeans.centroids();
let loaded_centroids = loaded.centroids();
assert_eq!(orig_centroids.shape(), loaded_centroids.shape());
for i in 0..orig_centroids.n_rows() {
for j in 0..orig_centroids.n_cols() {
assert!(
(orig_centroids.get(i, j) - loaded_centroids.get(i, j)).abs() < 1e-6,
"Centroid mismatch at ({}, {})",
i,
j
);
}
}
let orig_labels = kmeans.predict(&data);
let loaded_labels = loaded.predict(&data);
assert_eq!(orig_labels, loaded_labels);
fs::remove_file(&path).ok();
}
#[test]
fn test_kmeans_save_safetensors_without_random_state() {
use std::fs;
use tempfile::tempdir;
let data = Matrix::from_vec(
6,
2,
vec![1.0, 2.0, 1.5, 1.8, 5.0, 8.0, 8.0, 8.0, 1.0, 0.6, 9.0, 11.0],
)
.expect("Matrix creation should succeed");
let mut kmeans = KMeans::new(2);
kmeans.fit(&data).expect("KMeans fit should succeed");
let dir = tempdir().expect("tempdir creation should succeed");
let path = dir.path().join("kmeans_no_random.safetensors");
kmeans
.save_safetensors(&path)
.expect("SafeTensors save should succeed");
let loaded = KMeans::load_safetensors(&path).expect("SafeTensors load should succeed");
assert!(loaded.random_state().is_none());
fs::remove_file(&path).ok();
}
#[test]
fn test_kmeans_save_safetensors_unfitted_error() {
use tempfile::tempdir;
let kmeans = KMeans::new(2);
let dir = tempdir().expect("tempdir creation should succeed");
let path = dir.path().join("kmeans_unfitted.safetensors");
let result = kmeans.save_safetensors(&path);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Cannot save unfitted model"));
}
#[test]
fn test_kmeans_load_safetensors_missing_file() {
let result = KMeans::load_safetensors("/nonexistent/path/model.safetensors");
assert!(result.is_err());
}
#[test]
fn test_kmeans_load_safetensors_invalid_format() {
use std::fs;
use tempfile::tempdir;
let dir = tempdir().expect("tempdir creation should succeed");
let path = dir.path().join("invalid.safetensors");
fs::write(&path, b"not a valid safetensors file").expect("write should succeed");
let result = KMeans::load_safetensors(&path);
assert!(result.is_err());
fs::remove_file(&path).ok();
}
#[test]
fn test_kmeans_load_safetensors_invalid_centroids_shape() {
use crate::serialization::safetensors;
use std::collections::BTreeMap;
use std::fs;
use tempfile::tempdir;
let dir = tempdir().expect("tempdir creation should succeed");
let path = dir.path().join("invalid_shape.safetensors");
let mut tensors = BTreeMap::new();
tensors.insert(
"centroids".to_string(),
(vec![1.0_f32, 2.0, 3.0, 4.0], vec![4]), );
tensors.insert("n_clusters".to_string(), (vec![2.0_f32], vec![1]));
tensors.insert("max_iter".to_string(), (vec![100.0_f32], vec![1]));
tensors.insert("tol".to_string(), (vec![1e-4_f32], vec![1]));
tensors.insert("random_state".to_string(), (vec![-1.0_f32], vec![1]));
tensors.insert("inertia".to_string(), (vec![0.0_f32], vec![1]));
tensors.insert("n_iter".to_string(), (vec![10.0_f32], vec![1]));
safetensors::save_safetensors(&path, &tensors).expect("save should succeed");
let result = KMeans::load_safetensors(&path);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("Invalid centroids tensor shape"),
"Error message should indicate invalid shape"
);
fs::remove_file(&path).ok();
}