easy_ml/
k_means.rs

1/*!
2K-means example
3
4[Overview](https://en.wikipedia.org/wiki/K-means_clustering).
5
6# K means
7
8The following code creates two 2-dimensional gaussian distributions and then draws samples
9from them to create some data which is then assigned to clusters
10
11## Matrix APIs
12
13```
14use easy_ml::matrices::Matrix;
15use easy_ml::distributions::MultivariateGaussian;
16
17use rand::{Rng, SeedableRng};
18use rand::distr::{Iter, StandardUniform};
19use rand_chacha::ChaCha8Rng;
20
21use rgb::RGB8;
22use textplots::{Chart, ColorPlot, Plot, Shape};
23
24// use a fixed seed random generator from the rand crate
25let mut random_generator = ChaCha8Rng::seed_from_u64(11);
26
27// define two cluster centres using two 2d gaussians, making sure they overlap a bit
28let cluster1 = MultivariateGaussian::new(
29    Matrix::column(vec![ 2.0, 3.0 ]),
30    Matrix::from(vec![
31        vec![ 1.0, 0.1 ],
32        vec![ 0.1, 1.0 ]]));
33
34// make the second cluster more spread out so there will be a bit of overlap with the first
35// in the (0,0) to (1, 1) area
36let cluster2 = MultivariateGaussian::new(
37    Matrix::column(vec![ -2.0, -1.0 ]),
38    Matrix::from(vec![
39        vec![ 2.5, 1.2 ],
40        vec![ 1.2, 2.5 ]]));
41
42// Generate 200 points for each cluster
43let points = 200;
44let mut random_numbers: Iter<StandardUniform, &mut ChaCha8Rng, f64> =
45    (&mut random_generator).sample_iter(StandardUniform);
46// we can unwrap here because we deliberately constructed a positive definite covariance matrix
47// and supplied enough random numbers
48let cluster1_points = cluster1.draw(&mut random_numbers, points).unwrap();
49let cluster2_points = cluster2.draw(&mut random_numbers, points).unwrap();
50
51// Plot the generated data into a scatter plot
52// There are two clear clusters around the means (of cluster1 and cluster2) but
53// many points in the middle are ambiguous, this was deliberate in the choice of
54// parameters to generate the data with, as if our data was linearly seperable we
55// wouldn't need to perform clustering on it in the first place. Note that, as an unsupervised
56// learning method, k-means does not find or try to find a 'right' clustering for arbitary data
57println!("Generated data points");
58// textplots expects a Vec<(f32, f32)> where each tuple is a (x,y) point to plot,
59// so we must transform the data from the cluster points slightly to plot
60let scatter_points = cluster1_points.column_iter(0)
61    // zip is used to merge the x and y columns in the cluster points into a single tuple
62    .zip(cluster1_points.column_iter(1))
63    // chain then links the two iterators together so after all of cluster1_points
64    // are consumed we use all of cluster2_points
65    .chain(cluster2_points.column_iter(0).zip(cluster2_points.column_iter(1)))
66    // finally we map the tuples of (f64, f64) into (f32, f32) for handing to the library
67    .map(|(x, y)| (x as f32, y as f32))
68    .collect::<Vec<(f32, f32)>>();
69Chart::new(180, 60, -8.0, 8.0)
70    .lineplot(&Shape::Points(&scatter_points))
71    .display();
72
73
74// pick seeds to start each cluster at, in this case we start the seeds at a fixed position
75// of (1, 0) and (0, 1) which is deliberately where the two clusters overlap
76let mut clusters = Matrix::from(vec![
77    vec![ 1.0, 0.0 ],
78    vec![ 0.0, 1.0 ]]);
79
80// construct a matrix of rows in the format [x, y, cluster] to contain all the points
81let mut points = {
82    let mut points = cluster1_points;
83    // copy each row of cluster2_points into points
84    for row in 0..cluster2_points.rows() {
85        // insert each row from cluster2_points to the end of points
86        points.insert_row_with(points.rows(), cluster2_points.row_iter(row));
87    }
88    // extend points from rows of [x, y] to [x, y, cluster] for use in the update loop
89    points.insert_column(2, -1.0);
90    points
91};
92
93// give a name for the meaning of each column in the points matrix
94const X: usize = 0;
95const Y: usize = 1;
96const CLUSTER: usize = 2;
97
98// set a threshold at which we consider the cluster centres to have converged
99const CHANGE_THRESHOLD: f64 = 0.001;
100
101// track how much the means have changed each update
102let mut absolute_changes = -1.0;
103
104// track where the clusters move over time for plotting
105let mut cluster_center_1_history = Vec::with_capacity(7);
106let mut cluster_center_2_history = Vec::with_capacity(7);
107
108// loop until we go under the CHANGE_THRESHOLD, reassigning points to the nearest
109// cluster then cluster centres to their mean of points
110while absolute_changes == -1.0 || absolute_changes > CHANGE_THRESHOLD {
111    println!("Cluster centres: ({},{}), ({},{})",
112        clusters.get(0, X), clusters.get(0, Y),
113        clusters.get(1, X), clusters.get(1, Y));
114    cluster_center_1_history.push((clusters.get(0, X) as f32, clusters.get(0, Y) as f32));
115    cluster_center_2_history.push((clusters.get(1, X) as f32, clusters.get(1, Y) as f32));
116
117    // assign each point to the nearest cluster centre by euclidean distance
118    for point in 0..points.rows() {
119        let x = points.get(point, X);
120        let y = points.get(point, Y);
121        let mut closest_cluster = -1.0;
122        let mut least_squared_distance = std::f64::MAX;
123        for cluster in 0..clusters.rows() {
124            let cx = clusters.get(cluster, X);
125            let cy = clusters.get(cluster, Y);
126            // we don't actually need to square root the distances for finding
127            // which is least because least squared distance is the same as
128            // least distance
129            let squared_distance = (x - cx).powi(2) + (y - cy).powi(2);
130            if squared_distance < least_squared_distance {
131                closest_cluster = cluster as f64;
132                least_squared_distance = squared_distance;
133            }
134        }
135        // save the cluster that is closest to each point
136        points.set(point, CLUSTER, closest_cluster);
137    }
138    // update cluster centres to the mean of their points
139    absolute_changes = 0.0;
140    for cluster in 0..clusters.rows() {
141        // construct a list of the points this cluster owns
142        let owned = points.column_iter(CLUSTER)
143            // zip together the cluster id in each point with their X, Y points
144            .zip(points.column_reference_iter(X).zip(points.column_reference_iter(Y)))
145            // exclude the points that aren't assigned to this cluster
146            .filter(|(id, (x, y))| (*id as usize) == cluster)
147            // drop the cluster ids from each item
148            .map(|(id, (x, y))| (x, y))
149            // collect into a vector of tuples
150            .collect::<Vec<(&f64, &f64)>>();
151        let total = owned.len() as f64;
152        let mean_x = owned.iter().map(|&(&x, _)| x).sum::<f64>() / total;
153        let mean_y = owned.iter().map(|&(_, &y)| y).sum::<f64>() / total;
154        // track the absolute difference between the new mean and the old one
155        // so we know when to stop updating the clusters
156        absolute_changes += (clusters.get(cluster, X) - mean_x).abs();
157        absolute_changes += (clusters.get(cluster, Y) - mean_y).abs();
158        // set the new mean x and y for this cluster
159        clusters.set(cluster, X, mean_x);
160        clusters.set(cluster, Y, mean_y);
161    }
162}
163println!("Cluster centres: ({},{}), ({},{})",
164    clusters.get(0, X), clusters.get(0, Y),
165    clusters.get(1, X), clusters.get(1, Y));
166cluster_center_1_history.push((clusters.get(0, X) as f32, clusters.get(0, Y) as f32));
167cluster_center_2_history.push((clusters.get(1, X) as f32, clusters.get(1, Y) as f32));
168
169println!("Cluster centre movements");
170Chart::new(180, 60, -8.0, 8.0)
171    .lineplot(&Shape::Points(&scatter_points))
172    .linecolorplot(&Shape::Lines(&cluster_center_1_history), RGB8::new(255, 100, 100))
173    .linecolorplot(&Shape::Lines(&cluster_center_2_history), RGB8::new(100, 100, 255))
174    .display();
175```
176
177## Tensor APIs
178
179```
180use easy_ml::tensors::Tensor;
181use easy_ml::tensors::views::TensorStack;
182use easy_ml::distributions::MultivariateGaussianTensor;
183
184use rand::{Rng, SeedableRng};
185use rand::distr::{Iter, StandardUniform};
186use rand_chacha::ChaCha8Rng;
187
188use rgb::RGB8;
189use textplots::{Chart, ColorPlot, Plot, Shape};
190
191// use a fixed seed random generator from the rand crate
192let mut random_generator = ChaCha8Rng::seed_from_u64(11);
193
194// define two cluster centres using two 2d gaussians, making sure they overlap a bit
195let cluster1 = MultivariateGaussianTensor::new(
196    Tensor::from([("means", 2)], vec![ 2.0, 3.0 ]),
197    Tensor::from(
198        [("rows", 2), ("columns", 2)],
199        vec![
200            1.0, 0.1,
201            0.1, 1.0
202        ]
203    )
204).unwrap(); // we can unwrap here because we know we supplied valid inputs to the Gaussian
205
206// make the second cluster more spread out so there will be a bit of overlap with the first
207// in the (0,0) to (1, 1) area
208let cluster2 = MultivariateGaussianTensor::new(
209    Tensor::from([("means", 2)], vec![ -2.0, -1.0 ]),
210    Tensor::from(
211        [("rows", 2), ("columns", 2)],
212        vec![
213            2.5, 1.2,
214            1.2, 2.5
215        ]
216    )
217).unwrap(); // we can unwrap here because we know we supplied valid inputs to the Gaussian
218
219// Generate 200 points for each cluster
220let points = 200;
221let mut random_numbers: Iter<StandardUniform, &mut ChaCha8Rng, f64> =
222    (&mut random_generator).sample_iter(StandardUniform);
223// we can unwrap here because we deliberately constructed a positive definite covariance matrix
224// and supplied enough random numbers
225let cluster1_points = cluster1.draw(&mut random_numbers, points, "data", "feature").unwrap();
226let cluster2_points = cluster2.draw(&mut random_numbers, points, "data", "feature").unwrap();
227
228// Plot the generated data into a scatter plot
229// There are two clear clusters around the means (of cluster1 and cluster2) but
230// many points in the middle are ambiguous, this was deliberate in the choice of
231// parameters to generate the data with, as if our data was linearly seperable we
232// wouldn't need to perform clustering on it in the first place. Note that, as an unsupervised
233// learning method, k-means does not find or try to find a 'right' clustering for arbitary data
234println!("Generated data points");
235// textplots expects a Vec<(f32, f32)> where each tuple is a (x,y) point to plot,
236// so we must transform the data from the cluster points slightly to plot
237let scatter_points = cluster1_points
238    .select([("feature", 0)])
239    .iter()
240    // zip is used to merge the x and y columns in the cluster points into a single tuple
241    .zip(cluster1_points.select([("feature", 1)]).iter())
242    // chain then links the two iterators together so after all of cluster1_points
243    // are consumed we use all of cluster2_points
244    .chain(
245        cluster2_points
246            .select([("feature", 0)])
247            .iter()
248            .zip(cluster2_points.select([("feature", 1)]).iter())
249    )
250    // finally we map the tuples of (f64, f64) into (f32, f32) for handing to the library
251    .map(|(x, y)| (x as f32, y as f32))
252    .collect::<Vec<(f32, f32)>>();
253
254Chart::new(180, 60, -8.0, 8.0)
255    .lineplot(&Shape::Points(&scatter_points))
256    .display();
257
258
259// pick seeds to start each cluster at, in this case we start the seeds at a fixed position
260// of (1, 0) and (0, 1) which is deliberately where the two clusters overlap
261let mut clusters = Tensor::from(
262    [("cluster", 2), ("xy", 2)],
263    vec![
264        1.0, 0.0,
265        0.0, 1.0
266    ]
267);
268
269// construct a matrix of rows in the format [x, y, cluster] to contain all the points
270let mut points = {
271    let mut points = Tensor::empty(
272        [("data", 400), ("feature", 3)],
273        -1.0
274    );
275    // copy in the rows of cluster1_points and cluster2_points
276    let mut data = cluster1_points.iter().chain(cluster2_points.iter());
277    for ([_row, feature], x) in points.iter_reference_mut().with_index() {
278        *x = match feature {
279            // x and y come from cluster points
280            0 | 1 => data.next().unwrap(),
281            _ => -1.0,
282        };
283    }
284    points
285};
286
287// give a name for the meaning of each feature in the points matrix
288const X: usize = 0;
289const Y: usize = 1;
290const CLUSTER: usize = 2;
291
292// set a threshold at which we consider the cluster centres to have converged
293const CHANGE_THRESHOLD: f64 = 0.001;
294
295// track how much the means have changed each update
296let mut absolute_changes = -1.0;
297
298// track where the clusters move over time for plotting
299let mut cluster_center_1_history = Vec::with_capacity(7);
300let mut cluster_center_2_history = Vec::with_capacity(7);
301
302// loop until we go under the CHANGE_THRESHOLD, reassigning points to the nearest
303// cluster then cluster centres to their mean of points
304while absolute_changes == -1.0 || absolute_changes > CHANGE_THRESHOLD {
305    let mut clusters = clusters.index_by_mut(["cluster", "xy"]);
306    println!("Cluster centres: ({},{}), ({},{})",
307        clusters.get([0, X]), clusters.get([0, Y]),
308        clusters.get([1, X]), clusters.get([1, Y])
309    );
310    cluster_center_1_history.push((clusters.get([0, X]) as f32, clusters.get([0, Y]) as f32));
311    cluster_center_2_history.push((clusters.get([1, X]) as f32, clusters.get([1, Y]) as f32));
312
313    let number_of_points = points.shape()[0].1;
314    let number_of_clusters = clusters.shape()[0].1;
315    // assign each point to the nearest cluster centre by euclidean distance
316    {
317        let mut points = points.index_by_mut(["data", "feature"]);
318        for point in 0..number_of_points {
319            let x = points.get([point, X]);
320            let y = points.get([point, Y]);
321            let mut closest_cluster = -1.0;
322            let mut least_squared_distance = std::f64::MAX;
323            for cluster in 0..number_of_clusters {
324                let cx = clusters.get([cluster, X]);
325                let cy = clusters.get([cluster, Y]);
326                // we don't actually need to square root the distances for finding
327                // which is least because least squared distance is the same as
328                // least distance
329                let squared_distance = (x - cx).powi(2) + (y - cy).powi(2);
330                if squared_distance < least_squared_distance {
331                    closest_cluster = cluster as f64;
332                    least_squared_distance = squared_distance;
333                }
334            }
335            // save the cluster that is closest to each point
336            *points.get_ref_mut([point, CLUSTER]) = closest_cluster;
337        }
338    } // drop the TensorAccess wrapper on points
339
340    // update cluster centres to the mean of their points
341    absolute_changes = 0.0;
342    for cluster in 0..number_of_clusters {
343        // construct a list of the points this cluster owns
344        let owned = points.select([("feature", CLUSTER)]).iter()
345            // zip together the cluster id in each point with their X, Y points
346            .zip(
347                points.select([("feature", X)]).iter()
348                    .zip(points.select([("feature", Y)]).iter())
349            )
350            // exclude the points that aren't assigned to this cluster
351            .filter(|(id, (x, y))| (*id as usize) == cluster)
352            // drop the cluster ids from each item
353            .map(|(id, (x, y))| (x, y))
354            // collect into a vector of tuples
355            .collect::<Vec<(f64, f64)>>();
356        let total = owned.len() as f64;
357        let mean_x = owned.iter().map(|(x, _)| x).sum::<f64>() / total;
358        let mean_y = owned.iter().map(|(_, y)| y).sum::<f64>() / total;
359        // track the absolute difference between the new mean and the old one
360        // so we know when to stop updating the clusters
361        absolute_changes += (clusters.get([cluster, X]) - mean_x).abs();
362        absolute_changes += (clusters.get([cluster, Y]) - mean_y).abs();
363        // set the new mean x and y for this cluster
364        *clusters.get_ref_mut([cluster, X]) = mean_x;
365        *clusters.get_ref_mut([cluster, Y]) = mean_y;
366    }
367}
368let clusters = clusters.index_by(["cluster", "xy"]);
369println!("Cluster centres: ({},{}), ({},{})",
370    clusters.get([0, X]), clusters.get([0, Y]),
371    clusters.get([1, X]), clusters.get([1, Y]));
372cluster_center_1_history.push((clusters.get([0, X]) as f32, clusters.get([0, Y]) as f32));
373cluster_center_2_history.push((clusters.get([1, X]) as f32, clusters.get([1, Y]) as f32));
374
375println!("Cluster centre movements");
376Chart::new(180, 60, -8.0, 8.0)
377    .lineplot(&Shape::Points(&scatter_points))
378    .linecolorplot(&Shape::Lines(&cluster_center_1_history), RGB8::new(255, 100, 100))
379    .linecolorplot(&Shape::Lines(&cluster_center_2_history), RGB8::new(100, 100, 255))
380    .display();
381```
382
383# 5 Dimensional K-means
384
385See [naive_bayes](super::naive_bayes::three_class)
386*/