mlinrust/model/
kmeans.rs

1use crate::{utils::RandGenerator, model::utils::minkowski_distance};
2use super::Model;
3
4pub struct KMeansClustering {
5    pub k: usize,
6    pub clusters: Vec<Vec<f32>>,
7    p: f32,
8}
9
10impl KMeansClustering {
11
12    /// build K-Means model with unsupervised learning
13    /// * k: number of clusters
14    /// * max_iter: max iterations
15    /// * p: decide the distance, parameter of minkowski distance
16    ///     * default: 2, i.e., Euclidean distance
17    /// * init_clusters: you can give a init clusters
18    ///     * default: it will random choose k clusters from the samples
19    /// * early_stop: decide the tolerance error of the distance between the last two iterations
20    ///     * default: 1e-3
21    /// * seed: seed for randomly choosing the init clusters
22    ///     * default: 0
23    /// * features: the samples set
24    pub fn new(k: usize, max_iter: usize, p: Option<usize>, init_clusters: Option<Vec<Vec<f32>>>, early_stop: Option<f32>, seed: Option<usize>, features: &Vec<Vec<f32>>) -> Self {
25        assert!(k > 0 && k <= features.len());
26        let mut rng = RandGenerator::new(seed.unwrap_or(0));
27        let mut clusters = init_clusters.unwrap_or(
28            rng.choice(&features, k, false)
29        );
30        // check whether the dim of init clusters is matching the feature dim
31        assert!(clusters.len() == k && clusters[0].len() == features[0].len());
32        let p = p.unwrap_or(2) as f32;
33        let early_stop = early_stop.unwrap_or(1e-3);
34        
35        // starting iteration
36        let feature_dim = features[0].len();
37        for _ in 0..max_iter {
38            let mut new_clusters = vec![vec![0.0; feature_dim]; k];
39            let mut cnts = vec![0; k];
40            features.iter().for_each(|item| {
41
42                // find the center that is closest to the sample(item)
43                let idx = clusters.iter().enumerate().fold((0, f32::MAX), |s, (i, center)| {
44                    let d = minkowski_distance(center, item, p);
45                    if s.1 < d {
46                        s
47                    } else {
48                        (i, d)
49                    }
50                }).0;
51                
52                // accumulate to the cluster center
53                cnts[idx] += 1;
54                new_clusters[idx].iter_mut().zip(item.iter()).for_each(|(nc, i)| {                   
55                     *nc += i;
56                });
57
58            });
59
60            // take average
61            new_clusters.iter_mut().zip(cnts.into_iter()).for_each(|(cl, c)| {
62                cl.iter_mut().for_each(|i| *i /= f32::max(c as f32, 1e-6));
63            });
64
65            // see if the new clusters are as same as the old_clusters, then we can decide early stop
66            let err = new_clusters.iter().zip(clusters.iter()).fold(0.0, |err, (ni, oi)| {
67                err + minkowski_distance(ni, oi, p)
68            });
69
70            clusters = new_clusters;
71
72            if err < early_stop {
73                if cfg!(test) {
74                    println!("early stop with err {err}");
75                }
76                break;
77            }
78
79        }
80
81        Self { k: k, p: p, clusters: clusters }
82    }
83}
84
85impl Model<usize> for KMeansClustering {
86    /// return the nearest cluster idx, note that it is NOT the classification prediction
87    fn predict(&self, feature: &Vec<f32>) -> usize {
88        self.clusters.iter().enumerate().fold((0, f32::MAX), |s, (i, center)| {
89            let d = minkowski_distance(center, feature, self.p);
90            if s.1 < d {
91                s
92            } else {
93                (i, d)
94            }
95        }).0
96    }
97}
98
99#[cfg(test)]
100mod test {
101    use crate::model::Model;
102
103    use super::KMeansClustering;
104
105    #[test]
106    fn test_kmeans() {
107        let datas = vec![
108            vec![1.0, 3.0],
109            vec![2.0, 3.0],
110            vec![1.0, 2.0],
111            vec![4.0, 0.0],
112            vec![3.0, 0.0],
113            vec![3.0, -1.0],
114            vec![3.0, 0.5],
115        ];
116        let model = KMeansClustering::new(3, 100, Some(2), None, None, None, &datas);
117        for item in datas.iter() {
118            println!("label: {}", model.predict(item));
119        }
120    }
121}