Skip to main content

pairwise_distance/
pairwise_distance.rs

1//! Pairwise Euclidean distances between rows using broadcasting and matmul.
2//!
3//! Run: cargo run --example pairwise_distance
4//!
5//! Uses the identity: ||a-b||² = ||a||² + ||b||² - 2·aᵀb
6
7use matten::Tensor;
8
9fn pairwise_euclidean(points: &Tensor) -> Tensor {
10    let n = points.shape()[0];
11
12    // Row-wise squared norms: shape [n]
13    let sq = points * points;
14    let row_sq_norms = sq.sum_axis(1); // [n]
15
16    // Gram matrix: G[i,j] = points[i] · points[j], shape [n,n]
17    let gram = points.matmul(&points.transpose());
18
19    // dist²[i,j] = sq_norm[i] + sq_norm[j] - 2·G[i,j]
20    // Broadcast [n] as column [n,1] + row [1,n]
21    let col = row_sq_norms.reshape(&[n, 1]);
22    let row = row_sq_norms.reshape(&[1, n]);
23    let dist_sq = &(&col + &row) - &(&gram * 2.0);
24
25    // Clamp small negatives from floating-point rounding, then sqrt
26    let dists: Vec<f64> = dist_sq
27        .as_slice()
28        .iter()
29        .map(|&v| if v < 0.0 { 0.0 } else { v.sqrt() })
30        .collect();
31    Tensor::new(dists, &[n, n])
32}
33
34fn main() {
35    let points = Tensor::new(vec![0.0, 0.0, 3.0, 4.0, 6.0, 0.0], &[3, 2]);
36
37    let dists = pairwise_euclidean(&points);
38    println!("pairwise distances:");
39    for i in 0..3 {
40        let row = dists.slice().index(i).all().build().unwrap();
41        println!("  row {i}: {:?}", row.as_slice());
42    }
43
44    // (0,0)→(3,4) = 5, (3,4)→(6,0) = 5, (0,0)→(6,0) = 6
45    assert!((dists.get(&[0, 1]).unwrap() - 5.0).abs() < 1e-9);
46    assert!((dists.get(&[1, 2]).unwrap() - 5.0).abs() < 1e-9);
47    assert!((dists.get(&[0, 2]).unwrap() - 6.0).abs() < 1e-9);
48    println!("Pairwise distances: OK");
49}