54_pairwise_distance/
54_pairwise_distance.rs1use matten::Tensor;
8
9fn pairwise_euclidean(points: &Tensor) -> Tensor {
10 let n = points.shape()[0];
11
12 let sq = points * points;
14 let row_sq_norms = sq.sum_axis(1); let gram = points.matmul(&points.transpose());
18
19 let col = row_sq_norms.reshape(&[n, 1]);
22 let row = row_sq_norms.reshape(&[1, n]);
23 let dist_sq = &(&col + &row) - &(&gram * 2.0);
24
25 let dists: Vec<f64> = dist_sq
27 .as_slice()
28 .iter()
29 .map(|&v| if v < 0.0 { 0.0 } else { v.sqrt() })
30 .collect();
31 Tensor::new(dists, &[n, n])
32}
33
34fn main() {
35 let points = Tensor::new(vec![0.0, 0.0, 3.0, 4.0, 6.0, 0.0], &[3, 2]);
36
37 let dists = pairwise_euclidean(&points);
38 println!("pairwise distances:");
39 for i in 0..3 {
40 let row = dists.slice().index(i).all().build().unwrap();
41 println!(" row {i}: {:?}", row.as_slice());
42 }
43
44 assert!((dists.get(&[0, 1]).unwrap() - 5.0).abs() < 1e-9);
46 assert!((dists.get(&[1, 2]).unwrap() - 5.0).abs() < 1e-9);
47 assert!((dists.get(&[0, 2]).unwrap() - 6.0).abs() < 1e-9);
48 println!("Pairwise distances: OK");
49}