1use crate::{utils::RandGenerator, model::utils::minkowski_distance};
2use super::Model;
3
4pub struct KMeansClustering {
5 pub k: usize,
6 pub clusters: Vec<Vec<f32>>,
7 p: f32,
8}
9
10impl KMeansClustering {
11
12 pub fn new(k: usize, max_iter: usize, p: Option<usize>, init_clusters: Option<Vec<Vec<f32>>>, early_stop: Option<f32>, seed: Option<usize>, features: &Vec<Vec<f32>>) -> Self {
25 assert!(k > 0 && k <= features.len());
26 let mut rng = RandGenerator::new(seed.unwrap_or(0));
27 let mut clusters = init_clusters.unwrap_or(
28 rng.choice(&features, k, false)
29 );
30 assert!(clusters.len() == k && clusters[0].len() == features[0].len());
32 let p = p.unwrap_or(2) as f32;
33 let early_stop = early_stop.unwrap_or(1e-3);
34
35 let feature_dim = features[0].len();
37 for _ in 0..max_iter {
38 let mut new_clusters = vec![vec![0.0; feature_dim]; k];
39 let mut cnts = vec![0; k];
40 features.iter().for_each(|item| {
41
42 let idx = clusters.iter().enumerate().fold((0, f32::MAX), |s, (i, center)| {
44 let d = minkowski_distance(center, item, p);
45 if s.1 < d {
46 s
47 } else {
48 (i, d)
49 }
50 }).0;
51
52 cnts[idx] += 1;
54 new_clusters[idx].iter_mut().zip(item.iter()).for_each(|(nc, i)| {
55 *nc += i;
56 });
57
58 });
59
60 new_clusters.iter_mut().zip(cnts.into_iter()).for_each(|(cl, c)| {
62 cl.iter_mut().for_each(|i| *i /= f32::max(c as f32, 1e-6));
63 });
64
65 let err = new_clusters.iter().zip(clusters.iter()).fold(0.0, |err, (ni, oi)| {
67 err + minkowski_distance(ni, oi, p)
68 });
69
70 clusters = new_clusters;
71
72 if err < early_stop {
73 if cfg!(test) {
74 println!("early stop with err {err}");
75 }
76 break;
77 }
78
79 }
80
81 Self { k: k, p: p, clusters: clusters }
82 }
83}
84
85impl Model<usize> for KMeansClustering {
86 fn predict(&self, feature: &Vec<f32>) -> usize {
88 self.clusters.iter().enumerate().fold((0, f32::MAX), |s, (i, center)| {
89 let d = minkowski_distance(center, feature, self.p);
90 if s.1 < d {
91 s
92 } else {
93 (i, d)
94 }
95 }).0
96 }
97}
98
99#[cfg(test)]
100mod test {
101 use crate::model::Model;
102
103 use super::KMeansClustering;
104
105 #[test]
106 fn test_kmeans() {
107 let datas = vec![
108 vec![1.0, 3.0],
109 vec![2.0, 3.0],
110 vec![1.0, 2.0],
111 vec![4.0, 0.0],
112 vec![3.0, 0.0],
113 vec![3.0, -1.0],
114 vec![3.0, 0.5],
115 ];
116 let model = KMeansClustering::new(3, 100, Some(2), None, None, None, &datas);
117 for item in datas.iter() {
118 println!("label: {}", model.predict(item));
119 }
120 }
121}