use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
use scirs2_core::random::{Distribution, Normal};
use std::fmt::Debug;
use std::str::FromStr;
use super::{euclidean_distance, vq};
use crate::error::{ClusteringError, Result};
use scirs2_core::validation::{clustering::*, parameters::*};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MinitMethod {
Random,
Points,
PlusPlus,
}
impl MinitMethod {
pub fn parse_method(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"random" => Ok(MinitMethod::Random),
"points" => Ok(MinitMethod::Points),
"k-means++" | "kmeans++" | "plusplus" => Ok(MinitMethod::PlusPlus),
_ => Err(ClusteringError::InvalidInput(format!(
"Unknown initialization method: '{}'. Valid options are: 'random', 'points', 'k-means++'",
s
))),
}
}
}
impl FromStr for MinitMethod {
type Err = ClusteringError;
fn from_str(s: &str) -> Result<Self> {
Self::parse_method(s)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MissingMethod {
Warn,
Raise,
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn kmeans2<F>(
data: ArrayView2<F>,
k: usize,
iter: Option<usize>,
thresh: Option<F>,
minit: Option<MinitMethod>,
missing: Option<MissingMethod>,
check_finite: Option<bool>,
randomseed: Option<u64>,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let iterations = iter.unwrap_or(10);
let threshold = thresh.unwrap_or(F::from(1e-5).expect("Failed to convert constant to float"));
let missing_method = missing.unwrap_or(MissingMethod::Warn);
let check_finite_flag = check_finite.unwrap_or(true);
validate_clustering_data(&data, "K-means", check_finite_flag, Some(k))
.map_err(|e| ClusteringError::InvalidInput(format!("K-means: {}", e)))?;
check_n_clusters_bounds(&data, k, "K-means")
.map_err(|e| ClusteringError::InvalidInput(format!("{}", e)))?;
check_iteration_params(iterations, threshold, "K-means")
.map_err(|e| ClusteringError::InvalidInput(format!("{}", e)))?;
let init_method = minit.unwrap_or(MinitMethod::PlusPlus); let mut centroids = match init_method {
MinitMethod::Random => krandinit(data, k, randomseed)?,
MinitMethod::Points => kpoints(data, k, randomseed)?,
MinitMethod::PlusPlus => kmeans_plus_plus(data, k, randomseed)?,
};
let mut labels;
for _iteration in 0..iterations {
let prev_centroids = centroids.clone();
let (new_labels, _distances) = vq(data, centroids.view())?;
labels = new_labels;
let mut new_centroids = Array2::zeros((k, n_features));
let mut counts = Array1::zeros(k);
for i in 0..n_samples {
let cluster = labels[i];
let point = data.slice(s![i, ..]);
for j in 0..n_features {
new_centroids[[cluster, j]] = new_centroids[[cluster, j]] + point[j];
}
counts[cluster] += 1;
}
for i in 0..k {
if counts[i] == 0 {
match missing_method {
MissingMethod::Warn => {
eprintln!("One of the clusters is empty. Re-run kmeans with a different initialization.");
let mut max_dist = F::zero();
let mut far_idx = 0;
for j in 0..n_samples {
let cluster_j = labels[j];
let dist = euclidean_distance(
data.slice(s![j, ..]),
centroids.slice(s![cluster_j, ..]),
);
if dist > max_dist {
max_dist = dist;
far_idx = j;
}
}
for j in 0..n_features {
new_centroids[[i, j]] = data[[far_idx, j]];
}
counts[i] = 1;
}
MissingMethod::Raise => {
return Err(ClusteringError::EmptyCluster(
"One of the clusters is empty. Re-run kmeans with a different initialization.".to_string()
));
}
}
} else {
for j in 0..n_features {
new_centroids[[i, j]] = new_centroids[[i, j]]
/ F::from(counts[i]).expect("Failed to convert to float");
}
}
}
centroids = new_centroids;
let mut max_centroid_shift = F::zero();
for i in 0..k {
for j in 0..n_features {
let shift = (centroids[[i, j]] - prev_centroids[[i, j]]).abs();
if shift > max_centroid_shift {
max_centroid_shift = shift;
}
}
}
if max_centroid_shift < threshold {
break;
}
}
let (final_labels, _distances) = vq(data, centroids.view())?;
Ok((centroids, final_labels))
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn kmeans2_str<F>(
data: ArrayView2<F>,
k: usize,
iter: Option<usize>,
thresh: Option<F>,
minit: Option<&str>,
missing: Option<&str>,
check_finite: Option<bool>,
randomseed: Option<u64>,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
{
let minit_method = if let Some(method_str) = minit {
Some(MinitMethod::from_str(method_str)?)
} else {
Some(MinitMethod::PlusPlus) };
let missing_method = if let Some(missing_str) = missing {
match missing_str.to_lowercase().as_str() {
"warn" => Some(MissingMethod::Warn),
"raise" => Some(MissingMethod::Raise),
_ => {
return Err(ClusteringError::InvalidInput(format!(
"Unknown missing method: '{}'. Valid options are: 'warn', 'raise'",
missing_str
)))
}
}
} else {
Some(MissingMethod::Warn) };
kmeans2(
data,
k,
iter,
thresh,
minit_method,
missing_method,
check_finite,
randomseed,
)
}
#[allow(dead_code)]
fn krandinit<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let mut means = Array1::<F>::zeros(n_features);
let mut vars = Array1::<F>::zeros(n_features);
for j in 0..n_features {
let mut sum = F::zero();
for i in 0..n_samples {
sum = sum + data[[i, j]];
}
means[j] = sum / F::from(n_samples).expect("Failed to convert to float");
let mut var_sum = F::zero();
for i in 0..n_samples {
let diff = data[[i, j]] - means[j];
var_sum = var_sum + diff * diff;
}
vars[j] = var_sum / F::from(n_samples).expect("Failed to convert to float");
}
let mut centroids = Array2::<F>::zeros((k, n_features));
let mut rng: Box<dyn Rng> = if let Some(_seed) = randomseed {
Box::new(StdRng::seed_from_u64(_seed))
} else {
Box::new(scirs2_core::random::rng())
};
for i in 0..k {
for j in 0..n_features {
let mean = means[j].to_f64().expect("Operation failed");
let std = vars[j].sqrt().to_f64().expect("Operation failed");
if std > 0.0 {
let normal = Normal::new(mean, std).expect("Operation failed");
let value = normal.sample(&mut rng);
centroids[[i, j]] = F::from(value).expect("Failed to convert to float");
} else {
centroids[[i, j]] = means[j];
}
}
}
Ok(centroids)
}
#[allow(dead_code)]
fn kpoints<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let mut rng: Box<dyn Rng> = if let Some(_seed) = randomseed {
Box::new(StdRng::seed_from_u64(_seed))
} else {
Box::new(scirs2_core::random::rng())
};
let mut indices: Vec<usize> = (0..n_samples).collect();
for i in 0..k {
let j = rng.random_range(i..n_samples);
indices.swap(i, j);
}
let mut centroids = Array2::zeros((k, n_features));
for i in 0..k {
let idx = indices[i];
for j in 0..n_features {
centroids[[i, j]] = data[[idx, j]];
}
}
Ok(centroids)
}
#[allow(dead_code)]
fn kmeans_plus_plus<F>(data: ArrayView2<F>, k: usize, randomseed: Option<u64>) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let mut rng: Box<dyn Rng> = if let Some(_seed) = randomseed {
Box::new(StdRng::seed_from_u64(_seed))
} else {
Box::new(scirs2_core::random::rng())
};
let mut centroids = Array2::zeros((k, n_features));
let first_idx = rng.random_range(0..n_samples);
for j in 0..n_features {
centroids[[0, j]] = data[[first_idx, j]];
}
for i in 1..k {
let mut distances = Array1::<F>::zeros(n_samples);
for j in 0..n_samples {
let mut min_dist = F::infinity();
for c in 0..i {
let dist = euclidean_distance(data.slice(s![j, ..]), centroids.slice(s![c, ..]));
if dist < min_dist {
min_dist = dist;
}
}
distances[j] = min_dist * min_dist;
}
let total = distances.iter().fold(F::zero(), |a, &b| a + b);
let mut probabilities = Array1::<F>::zeros(n_samples);
for j in 0..n_samples {
probabilities[j] = distances[j] / total;
}
let mut cumsum = F::zero();
let r = F::from(rng.random::<f64>()).expect("Operation failed");
let mut next_idx = n_samples - 1;
for j in 0..n_samples {
cumsum = cumsum + probabilities[j];
if cumsum > r {
next_idx = j;
break;
}
}
for j in 0..n_features {
centroids[[i, j]] = data[[next_idx, j]];
}
}
Ok(centroids)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_kmeans2_basic_functionality() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[0.8, 0.9],
[8.0, 8.0],
[8.2, 8.1],
[7.8, 7.9],
];
let (centroids, labels) = kmeans2(
data.view(),
2,
Some(50),
Some(1e-6),
Some(MinitMethod::PlusPlus),
Some(MissingMethod::Warn),
Some(true),
Some(42),
)
.expect("Test: operation failed");
assert_eq!(centroids.shape(), [2, 2]);
assert_eq!(labels.len(), 6);
assert!(labels.iter().all(|&l| l == 0 || l == 1));
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 2);
}
#[test]
fn test_kmeans2_parameter_validation() {
let data = array![[1.0, 1.0], [2.0, 2.0]];
let result = kmeans2(
data.view(),
0,
None,
None,
Some(MinitMethod::Random),
None,
None,
None,
);
assert!(result.is_err());
let result = kmeans2(
data.view(),
5,
None,
None,
Some(MinitMethod::Random),
None,
None,
None,
);
assert!(result.is_err());
}
#[test]
fn test_kmeans2_initialization_methods() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[0.8, 0.9],
[8.0, 8.0],
[8.2, 8.1],
[7.8, 7.9],
];
let methods = vec![
MinitMethod::Random,
MinitMethod::Points,
MinitMethod::PlusPlus,
];
for method in methods {
let result = kmeans2(
data.view(),
2,
Some(10),
None,
Some(method),
Some(MissingMethod::Warn),
None,
Some(42),
);
assert!(result.is_ok(), "Failed with method: {:?}", method);
let (centroids, labels) = result.expect("Test: operation failed");
assert_eq!(centroids.shape(), [2, 2]);
assert_eq!(labels.len(), 6);
}
}
#[test]
fn test_kmeans2_reproducibility_with_seed() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[0.8, 0.9],
[8.0, 8.0],
[8.2, 8.1],
[7.8, 7.9],
];
let (centroids1, labels1) = kmeans2(
data.view(),
2,
Some(10),
None,
Some(MinitMethod::Random),
None,
None,
Some(42),
)
.expect("Test: operation failed");
let (centroids2, labels2) = kmeans2(
data.view(),
2,
Some(10),
None,
Some(MinitMethod::Random),
None,
None,
Some(42),
)
.expect("Test: operation failed");
assert_eq!(labels1, labels2);
for i in 0..centroids1.shape()[0] {
for j in 0..centroids1.shape()[1] {
assert_abs_diff_eq!(centroids1[[i, j]], centroids2[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_kmeans2_single_cluster() {
let data = array![[1.0, 1.0], [1.1, 1.1], [0.9, 0.9],];
let (centroids, labels) = kmeans2(
data.view(),
1,
Some(10),
None,
Some(MinitMethod::Points),
None,
None,
Some(42),
)
.expect("Test: operation failed");
assert_eq!(centroids.shape(), [1, 2]);
assert!(labels.iter().all(|&l| l == 0));
}
#[test]
fn test_kmeans2_identical_points() {
let data = array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0],];
let (centroids, labels) = kmeans2(
data.view(),
2,
Some(10),
None,
Some(MinitMethod::Points),
Some(MissingMethod::Warn),
None,
Some(42),
)
.expect("Test: operation failed");
assert_eq!(centroids.shape(), [2, 2]);
assert_eq!(labels.len(), 4);
assert!(labels.iter().all(|&l| l == 0 || l == 1));
}
#[test]
fn test_kmeans2_missing_method_warn() {
let data = array![[0.0, 0.0], [0.1, 0.1], [10.0, 10.0],];
let result = kmeans2(
data.view(),
2,
Some(5),
None,
Some(MinitMethod::Random),
Some(MissingMethod::Warn),
None,
Some(123),
);
assert!(result.is_ok());
}
#[test]
fn test_kmeans2_convergence_behavior() {
let data = array![
[1.0, 1.0],
[1.1, 1.1],
[0.9, 0.9],
[10.0, 10.0],
[10.1, 10.1],
[9.9, 9.9],
];
let (centroids_few_) = kmeans2(
data.view(),
2,
Some(1),
None,
Some(MinitMethod::PlusPlus),
None,
None,
Some(42),
)
.expect("Test: operation failed");
let (centroids_many_) = kmeans2(
data.view(),
2,
Some(100),
None,
Some(MinitMethod::PlusPlus),
None,
None,
Some(42),
)
.expect("Test: operation failed");
assert_eq!(centroids_few_.0.shape(), [2, 2]);
assert_eq!(centroids_many_.0.shape(), [2, 2]);
}
#[test]
fn test_kmeans2_high_k() {
let data = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0],];
let (centroids, labels) = kmeans2(
data.view(),
5,
Some(10),
None,
Some(MinitMethod::Points),
None,
None,
Some(42),
)
.expect("Test: operation failed");
assert_eq!(centroids.shape(), [5, 2]);
assert_eq!(labels.len(), 5);
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 5);
}
#[test]
fn test_kmeans2_different_thresholds() {
let data = array![[1.0, 1.0], [1.5, 1.5], [8.0, 8.0], [8.5, 8.5],];
let result1 = kmeans2(
data.view(),
2,
Some(100),
Some(1e-10), Some(MinitMethod::PlusPlus),
None,
None,
Some(42),
);
let result2 = kmeans2(
data.view(),
2,
Some(100),
Some(1e-1), Some(MinitMethod::PlusPlus),
None,
None,
Some(42),
);
assert!(result1.is_ok());
assert!(result2.is_ok());
}
#[test]
fn test_kmeans2_convergence_threshold() {
let data = array![
[1.0, 1.0],
[1.1, 1.1],
[0.9, 0.9],
[10.0, 10.0],
[10.1, 10.1],
[9.9, 9.9],
];
let result1 = kmeans2(
data.view(),
2,
Some(100), Some(1e-10), Some(MinitMethod::PlusPlus),
None,
None,
Some(42),
);
assert!(result1.is_ok());
let (centroids1, labels1) = result1.expect("Test: operation failed");
assert_eq!(centroids1.shape(), [2, 2]);
assert_eq!(labels1.len(), 6);
let result2 = kmeans2(
data.view(),
2,
Some(100),
Some(1e-1), Some(MinitMethod::PlusPlus),
None,
None,
Some(42),
);
assert!(result2.is_ok());
let (centroids2, labels2) = result2.expect("Test: operation failed");
assert_eq!(centroids2.shape(), [2, 2]);
assert_eq!(labels2.len(), 6);
}
#[test]
fn test_kmeans2_check_finite() {
let data = array![[1.0, 2.0], [1.5, 1.5], [8.0, 8.0],];
let result = kmeans2(
data.view(),
2,
Some(10),
None,
Some(MinitMethod::Random),
None,
Some(true), Some(42),
);
assert!(result.is_ok());
let result = kmeans2(
data.view(),
2,
Some(10),
None,
Some(MinitMethod::Random),
None,
Some(false), Some(42),
);
assert!(result.is_ok());
}
#[test]
fn test_kmeans2_large_dataset() {
let mut data = Array2::zeros((100, 3));
for i in 0..100 {
let cluster = i % 3;
match cluster {
0 => {
data[[i, 0]] = 1.0 + (i as f64) * 0.01;
data[[i, 1]] = 1.0 + (i as f64) * 0.01;
data[[i, 2]] = 1.0 + (i as f64) * 0.01;
}
1 => {
data[[i, 0]] = 5.0 + (i as f64) * 0.01;
data[[i, 1]] = 5.0 + (i as f64) * 0.01;
data[[i, 2]] = 5.0 + (i as f64) * 0.01;
}
2 => {
data[[i, 0]] = 10.0 + (i as f64) * 0.01;
data[[i, 1]] = 10.0 + (i as f64) * 0.01;
data[[i, 2]] = 10.0 + (i as f64) * 0.01;
}
_ => unreachable!(),
}
}
let (centroids, labels) = kmeans2(
data.view(),
3,
Some(50),
None,
Some(MinitMethod::PlusPlus),
None,
None,
Some(42),
)
.expect("Test: operation failed");
assert_eq!(centroids.shape(), [3, 3]);
assert_eq!(labels.len(), 100);
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 3);
}
use super::kmeans2_str;
#[test]
fn test_kmeans2_str_basic_functionality() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[0.8, 0.9],
[8.0, 8.0],
[8.2, 8.1],
[7.8, 7.9],
];
let (centroids, labels) = kmeans2_str(
data.view(),
2,
Some(50),
Some(1e-6),
Some("k-means++"),
Some("warn"),
Some(true),
Some(42),
)
.expect("Test: operation failed");
assert_eq!(centroids.shape(), [2, 2]);
assert_eq!(labels.len(), 6);
assert!(labels.iter().all(|&l| l == 0 || l == 1));
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 2);
}
#[test]
fn test_kmeans2_str_all_init_methods() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[0.8, 0.9],
[8.0, 8.0],
[8.2, 8.1],
[7.8, 7.9],
];
let methods = vec!["random", "points", "k-means++", "kmeans++", "plusplus"];
for method in methods {
let result = kmeans2_str(
data.view(),
2,
Some(10),
None,
Some(method),
Some("warn"),
None,
Some(42),
);
assert!(result.is_ok(), "Failed with method: '{}'", method);
let (centroids, labels) = result.expect("Test: operation failed");
assert_eq!(centroids.shape(), [2, 2]);
assert_eq!(labels.len(), 6);
}
}
#[test]
fn test_kmeans2_str_case_insensitive() {
let data = array![[1.0, 1.0], [2.0, 2.0], [8.0, 8.0], [9.0, 9.0],];
let methods = vec![
"RANDOM",
"Random",
"random",
"POINTS",
"Points",
"points",
"K-MEANS++",
"K-Means++",
"k-means++",
];
for method in methods {
let result = kmeans2_str(
data.view(),
2,
Some(10),
None,
Some(method),
Some("warn"),
None,
Some(42),
);
assert!(result.is_ok(), "Failed with method: '{}'", method);
}
}
#[test]
fn test_kmeans2_str_missing_methods() {
let data = array![[1.0, 1.0], [2.0, 2.0], [8.0, 8.0],];
let missing_methods = vec!["warn", "raise", "WARN", "RAISE"];
for missing_method in missing_methods {
let result = kmeans2_str(
data.view(),
2,
Some(5),
None,
Some("points"),
Some(missing_method),
None,
Some(42),
);
assert!(
result.is_ok(),
"Failed with missing method: '{}'",
missing_method
);
}
}
#[test]
fn test_kmeans2_str_invalid_method() {
let data = array![[1.0, 1.0], [2.0, 2.0]];
let result = kmeans2_str(
data.view(),
2,
Some(10),
None,
Some("invalid_method"),
Some("warn"),
None,
None,
);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unknown initialization method"));
}
#[test]
fn test_kmeans2_str_invalid_missing_method() {
let data = array![[1.0, 1.0], [2.0, 2.0]];
let result = kmeans2_str(
data.view(),
2,
Some(10),
None,
Some("points"),
Some("invalid_missing"),
None,
None,
);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unknown missing method"));
}
#[test]
fn test_kmeans2_str_defaults() {
let data = array![[1.0, 1.0], [1.5, 1.5], [8.0, 8.0], [8.5, 8.5],];
let result = kmeans2_str(
data.view(),
2,
Some(10),
None,
None, None, None,
Some(42),
);
assert!(result.is_ok());
let (centroids, labels) = result.expect("Test: operation failed");
assert_eq!(centroids.shape(), [2, 2]);
assert_eq!(labels.len(), 4);
}
#[test]
fn test_kmeans2_str_equivalence_with_enum() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[0.8, 0.9],
[8.0, 8.0],
[8.2, 8.1],
[7.8, 7.9],
];
let (centroids_enum, labels_enum) = kmeans2(
data.view(),
2,
Some(50),
Some(1e-6),
Some(MinitMethod::PlusPlus),
Some(MissingMethod::Warn),
Some(true),
Some(42),
)
.expect("Test: operation failed");
let (centroids_str, labels_str) = kmeans2_str(
data.view(),
2,
Some(50),
Some(1e-6),
Some("k-means++"),
Some("warn"),
Some(true),
Some(42),
)
.expect("Test: operation failed");
assert_eq!(labels_enum, labels_str);
for i in 0..centroids_enum.shape()[0] {
for j in 0..centroids_enum.shape()[1] {
assert_abs_diff_eq!(
centroids_enum[[i, j]],
centroids_str[[i, j]],
epsilon = 1e-10
);
}
}
}
#[test]
fn test_minit_method_from_str() {
assert_eq!(
MinitMethod::from_str("random").expect("Operation failed"),
MinitMethod::Random
);
assert_eq!(
MinitMethod::from_str("RANDOM").expect("Operation failed"),
MinitMethod::Random
);
assert_eq!(
MinitMethod::from_str("points").expect("Operation failed"),
MinitMethod::Points
);
assert_eq!(
MinitMethod::from_str("POINTS").expect("Operation failed"),
MinitMethod::Points
);
assert_eq!(
MinitMethod::from_str("k-means++").expect("Operation failed"),
MinitMethod::PlusPlus
);
assert_eq!(
MinitMethod::from_str("kmeans++").expect("Operation failed"),
MinitMethod::PlusPlus
);
assert_eq!(
MinitMethod::from_str("plusplus").expect("Operation failed"),
MinitMethod::PlusPlus
);
assert_eq!(
MinitMethod::from_str("K-MEANS++").expect("Operation failed"),
MinitMethod::PlusPlus
);
assert!(MinitMethod::from_str("invalid").is_err());
}
}