38_nearest_neighbor_classification/38_nearest_neighbor_classification.rs
1//! # Example: Nearest-Neighbor Classification
2//!
3//! Run: cargo run --example 38_nearest_neighbor_classification
4//!
5//! ## Problem
6//! Given a small set of labeled training points, classify a new query point by the
7//! label of its single closest training point (1-nearest-neighbor).
8//!
9//! ## Math idea
10//! Compute the squared Euclidean distance from the query to every training point and
11//! take the label of the minimum (a 1-NN classifier — no training step, no fitted
12//! parameters).
13//!
14//! ## Tensor representation
15//! The training set is a `[samples, features]` `Tensor`; labels are a parallel slice.
16//! A query is a length-`features` slice.
17//!
18//! ## What this demonstrates
19//! - using a `Tensor` as a labeled `[samples, features]` data matrix;
20//! - a nearest-point search via `Tensor::argmin` (RFC-038);
21//! - composing `Tensor` row access with plain Rust arithmetic.
22//!
23//! ## Expected output
24//! ```text
25//! query [1.5, 1.5] -> class 0
26//! query [8.5, 8.5] -> class 1
27//! Nearest-neighbor classification: OK
28//! ```
29//!
30//! This is an algorithm demonstration, not an ML framework.
31
32use matten::Tensor;
33
34/// Squared Euclidean distance between two equal-length points.
35fn sq_dist(a: &[f64], b: &[f64]) -> f64 {
36 a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
37}
38
39/// Label of the single nearest training point (1-NN), via `Tensor::argmin` over
40/// the per-training-point distances.
41fn classify(query: &[f64], train: &Tensor, labels: &[u8]) -> u8 {
42 let dim = train.shape()[1];
43 let data = train.as_slice();
44 let dists: Vec<f64> = (0..train.shape()[0])
45 .map(|i| sq_dist(query, &data[i * dim..(i + 1) * dim]))
46 .collect();
47 labels[Tensor::from_vec(dists).argmin()]
48}
49
50fn main() {
51 // Labeled training set: 4 points (rows), 2 features (columns).
52 let train = Tensor::new(
53 vec![
54 1.0, 1.0, //
55 2.0, 2.0, //
56 8.0, 8.0, //
57 9.0, 9.0, //
58 ],
59 &[4, 2],
60 );
61 let labels = [0u8, 0, 1, 1];
62
63 let queries = [[1.5, 1.5], [8.5, 8.5]];
64 for q in &queries {
65 let label = classify(q, &train, &labels);
66 println!("query [{:.1}, {:.1}] -> class {label}", q[0], q[1]);
67 }
68
69 assert_eq!(classify(&[1.5, 1.5], &train, &labels), 0);
70 assert_eq!(classify(&[8.5, 8.5], &train, &labels), 1);
71 println!("Nearest-neighbor classification: OK");
72}