1use crate::helpers::{l2_distance, simpsons_weights};
10use crate::iter_maybe_parallel;
11use crate::matrix::FdMatrix;
12#[cfg(feature = "parallel")]
13use rayon::iter::ParallelIterator;
14
15#[must_use]
31pub fn pairwise_distance_matrix<F>(n: usize, dist_fn: F) -> FdMatrix
32where
33 F: Fn(usize, usize) -> f64 + Sync,
34{
35 let pairs: Vec<(usize, usize)> = (0..n)
36 .flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
37 .collect();
38
39 let pair_dists: Vec<(usize, usize, f64)> = iter_maybe_parallel!(pairs)
40 .map(|(i, j)| (i, j, dist_fn(i, j)))
41 .collect();
42
43 let mut mat = FdMatrix::zeros(n, n);
44 for (i, j, d) in pair_dists {
45 mat[(i, j)] = d;
46 mat[(j, i)] = d;
47 }
48 mat
49}
50
51#[must_use]
67pub fn l2_distance_matrix(data: &FdMatrix, argvals: &[f64]) -> FdMatrix {
68 let weights = simpsons_weights(argvals);
69 let n = data.nrows();
70 pairwise_distance_matrix(n, |i, j| l2_distance(&data.row(i), &data.row(j), &weights))
71}
72
73#[must_use]
90pub fn euclidean_distance_matrix(data: &FdMatrix) -> FdMatrix {
91 let n = data.nrows();
92 let p = data.ncols();
93 pairwise_distance_matrix(n, |i, j| {
94 let mut d2 = 0.0;
95 for k in 0..p {
96 let diff = data[(i, k)] - data[(j, k)];
97 d2 += diff * diff;
98 }
99 d2.sqrt()
100 })
101}
102
103#[must_use]
119pub fn cross_distance_matrix<F>(n_new: usize, n_train: usize, dist_fn: F) -> FdMatrix
120where
121 F: Fn(usize, usize) -> f64 + Sync,
122{
123 let pairs: Vec<(usize, usize)> = (0..n_new)
124 .flat_map(|i| (0..n_train).map(move |j| (i, j)))
125 .collect();
126
127 let pair_dists: Vec<(usize, usize, f64)> = iter_maybe_parallel!(pairs)
128 .map(|(i, j)| (i, j, dist_fn(i, j)))
129 .collect();
130
131 let mut mat = FdMatrix::zeros(n_new, n_train);
132 for (i, j, d) in pair_dists {
133 mat[(i, j)] = d;
134 }
135 mat
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn pairwise_symmetric() {
144 let mat = pairwise_distance_matrix(4, |i, j| (i as f64 - j as f64).abs());
145 assert_eq!(mat.shape(), (4, 4));
146 for i in 0..4 {
147 assert!(mat[(i, i)].abs() < 1e-15);
148 for j in 0..4 {
149 assert!((mat[(i, j)] - mat[(j, i)]).abs() < 1e-15);
150 }
151 }
152 assert!((mat[(0, 3)] - 3.0).abs() < 1e-15);
153 }
154
155 #[test]
156 fn l2_matrix_smoke() {
157 let data = FdMatrix::zeros(5, 10);
158 let t: Vec<f64> = (0..10).map(|i| i as f64 / 9.0).collect();
159 let mat = l2_distance_matrix(&data, &t);
160 assert_eq!(mat.shape(), (5, 5));
161 for i in 0..5 {
163 for j in 0..5 {
164 assert!(mat[(i, j)].abs() < 1e-15);
165 }
166 }
167 }
168
169 #[test]
170 fn l2_matrix_nonzero() {
171 let mut data = FdMatrix::zeros(2, 10);
173 for j in 0..10 {
174 data[(1, j)] = 1.0;
175 }
176 let t: Vec<f64> = (0..10).map(|i| i as f64 / 9.0).collect();
177 let mat = l2_distance_matrix(&data, &t);
178 assert!((mat[(0, 1)] - 1.0).abs() < 0.1);
180 assert!((mat[(1, 0)] - mat[(0, 1)]).abs() < 1e-15);
181 }
182
183 #[test]
184 fn euclidean_matrix_smoke() {
185 let mut data = FdMatrix::zeros(3, 2);
186 data[(0, 0)] = 0.0;
187 data[(0, 1)] = 0.0;
188 data[(1, 0)] = 3.0;
189 data[(1, 1)] = 4.0;
190 data[(2, 0)] = 0.0;
191 data[(2, 1)] = 0.0;
192 let mat = euclidean_distance_matrix(&data);
193 assert!((mat[(0, 1)] - 5.0).abs() < 1e-12);
194 assert!((mat[(0, 2)]).abs() < 1e-12);
195 assert!((mat[(1, 2)] - 5.0).abs() < 1e-12);
196 }
197
198 #[test]
199 fn cross_distance_dims() {
200 let mat = cross_distance_matrix(3, 5, |i, j| (i + j) as f64);
201 assert_eq!(mat.shape(), (3, 5));
202 assert!((mat[(0, 0)]).abs() < 1e-15);
203 assert!((mat[(2, 4)] - 6.0).abs() < 1e-15);
204 }
205
206 #[test]
207 fn pairwise_n_zero() {
208 let mat = pairwise_distance_matrix(0, |_i, _j| 1.0);
209 assert_eq!(mat.shape(), (0, 0));
210 }
211
212 #[test]
213 fn pairwise_n_one() {
214 let mat = pairwise_distance_matrix(1, |_i, _j| 1.0);
215 assert_eq!(mat.shape(), (1, 1));
216 assert!(mat[(0, 0)].abs() < 1e-15);
217 }
218}