1use ndarray::{Array1, Array2, Axis, s};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone)]
5pub struct NeighbourNetResult {
6 pub ordering: Vec<usize>, }
8
9type Matrix = Array2<f64>;
10
11fn mean_between_sets(d: &Matrix, a: &[usize], b: &[usize]) -> f64 {
12 let mut sum = 0.0;
13 let mut cnt = 0usize;
14 for &i in a {
15 for &j in b {
16 sum += unsafe { *d.uget((i, j)) };
17 cnt += 1;
18 }
19 }
20 if cnt == 0 { 0.0 } else { sum / cnt as f64 }
21}
22
23fn update_dm(dm: &mut Matrix, d: &Matrix, cl: &Vec<Vec<usize>>, j: usize) {
25 let l = cl.len();
26 let col_vals: Vec<f64> = (0..l)
28 .into_par_iter()
29 .map(|i| mean_between_sets(d, &cl[i], &cl[j]))
30 .collect();
31
32 for i in 0..l {
34 dm[[i, j]] = col_vals[i];
35 }
36 for i in 0..l {
38 dm[[j, i]] = col_vals[i];
39 }
40 dm[[j, j]] = 0.0;
41}
42
43fn rx(d: &Matrix, x: &[usize], cl: &Vec<Vec<usize>>) -> Vec<f64> {
45 let lx = x.len();
46 let mut res = vec![0.0; lx];
47 for (i, &xi) in x.iter().enumerate() {
48 let mut tmp = 0.0;
49 for (j, &xj) in x.iter().enumerate() {
51 if j != i {
52 tmp += unsafe { *d.uget((xi, xj)) };
53 }
54 }
55 for c in cl.iter() {
57 tmp += mean_between_sets(d, std::slice::from_ref(&xi), c);
58 }
59 res[i] = tmp;
60 }
61 res
62}
63
64fn reduc(d: &mut Matrix, x: usize, y: usize, z: usize) {
66 let n = d.nrows();
67 let row_x = d.row(x).to_owned();
69 let row_y = d.row(y).to_owned();
70 let row_z = d.row(z).to_owned();
71
72 let u: Array1<f64> = &(&row_x * (2.0 / 3.0)) + &(&row_y * (1.0 / 3.0));
75 let v: Array1<f64> = &(&row_z * (2.0 / 3.0)) + &(&row_y * (1.0 / 3.0));
76
77 let uv = (row_x[y] + row_x[z] + row_y[z]) / 3.0;
78
79 d.row_mut(x).assign(&u);
81 d.row_mut(z).assign(&v);
82 d.row_mut(y).fill(0.0);
83
84 for j in 0..n {
86 d[[j, x]] = u[j];
87 d[[j, z]] = v[j];
88 d[[j, y]] = 0.0;
89 }
90
91 d[[x, z]] = uv;
92 d[[z, x]] = uv;
93 d[[x, x]] = 0.0;
94 d[[z, z]] = 0.0;
95}
96
97fn remove_row_col(m: &Matrix, idx: usize) -> Matrix {
99 let n = m.nrows();
100 debug_assert_eq!(n, m.ncols());
101 if n == 1 {
102 return Array2::zeros((0, 0));
103 }
104 let mut out = Array2::<f64>::zeros((n - 1, n - 1));
105 if idx > 0 {
107 out.slice_mut(s![0..idx, 0..idx])
108 .assign(&m.slice(s![0..idx, 0..idx]));
109 }
110 if idx + 1 < n {
112 out.slice_mut(s![0..idx, idx..])
113 .assign(&m.slice(s![0..idx, (idx + 1)..]));
114 }
115 if idx + 1 < n {
117 out.slice_mut(s![idx.., 0..idx])
118 .assign(&m.slice(s![(idx + 1).., 0..idx]));
119 }
120 if idx + 1 < n {
122 out.slice_mut(s![idx.., idx..])
123 .assign(&m.slice(s![(idx + 1).., (idx + 1)..]));
124 }
125 out
126}
127
128fn choose_pair(dm: &Matrix, r: &[f64]) -> (usize, usize) {
130 let l = dm.nrows();
131 if l <= 1 {
132 return (0, 0);
133 }
134 let per_i = (0..l).into_par_iter().map(|i| {
137 let mut best = f64::INFINITY;
138 let mut best_j = i + 1;
139 for j in (i + 1)..l {
140 let q = unsafe { *dm.uget((i, j)) } - r[i] - r[j];
141 if q < best {
142 best = q;
143 best_j = j;
144 }
145 }
146 (best, i, best_j)
147 });
148
149 let (.., bi, bj) = per_i.reduce(
150 || (f64::INFINITY, 0usize, 1usize),
151 |a, b| if a.0 <= b.0 { a } else { b },
152 );
153 (bi, bj)
154}
155
156fn remove_e2(cl: &mut Vec<Vec<usize>>, ord: &mut Vec<Vec<usize>>, dm: &mut Matrix, e2: usize) {
157 cl.remove(e2);
158 ord.remove(e2);
159 *dm = remove_row_col(dm, e2);
160}
161
162pub fn get_ordering_nn(x: &Matrix) -> Vec<usize> {
164 assert_eq!(x.nrows(), x.ncols(), "Distance matrix must be square");
165 let n = x.nrows();
166
167 let mut d = x.clone();
169
170 let mut cl: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
172 let mut ord: Vec<Vec<usize>> = cl.clone();
173
174 let mut dm = d.clone();
176
177 while cl.len() > 1 {
178 let l = dm.nrows();
179 let (e1, e2) = if l > 2 {
180 let denom = (l as f64) - 2.0;
182 let r: Vec<f64> = dm
183 .axis_iter(Axis(0))
184 .into_par_iter()
185 .map(|row| row.sum() / denom)
186 .collect();
187
188 choose_pair(&dm, &r)
189 } else {
190 (0, 1)
191 };
192
193 let n1 = cl[e1].len();
194 let n2 = cl[e2].len();
195
196 if n1 == 1 && n2 == 1 {
197 let mut new_cl = Vec::with_capacity(2);
199 new_cl.extend_from_slice(&cl[e1]);
200 new_cl.extend_from_slice(&cl[e2]);
201 let new_ord = new_cl.clone();
202
203 cl[e1] = new_cl;
204 ord[e1] = new_ord;
205
206 update_dm(&mut dm, &d, &cl, e1);
207 remove_e2(&mut cl, &mut ord, &mut dm, e2);
208 } else {
209 let mut others: Vec<Vec<usize>> = Vec::with_capacity(cl.len().saturating_sub(2));
211 for (idx, c) in cl.iter().enumerate() {
212 if idx != e1 && idx != e2 {
213 others.push(c.clone());
214 }
215 }
216
217 let mut cltmp2: Vec<usize> = Vec::with_capacity(n1 + n2);
219 cltmp2.extend_from_slice(&cl[e1]);
220 cltmp2.extend_from_slice(&cl[e2]);
221
222 let mut rtmp2 = rx(&d, &cltmp2, &others);
223 let ltmp = cl[e1].len() + cl[e2].len() + others.len();
224 if ltmp > 2 {
225 let scale = 1.0 / ((ltmp as f64) - 2.0);
226 for v in rtmp2.iter_mut() {
227 *v *= scale;
228 }
229 }
230
231 let mut best_val = f64::INFINITY;
234 let mut best_row = 0usize; let mut best_col = 0usize; for col in 0..n2 {
237 for row in 0..n1 {
238 let i = cltmp2[row];
239 let j = cltmp2[n1 + col];
240 let v = unsafe { *d.uget((i, j)) } - (rtmp2[row] + rtmp2[n1 + col]);
241 if v < best_val {
242 best_val = v;
243 best_row = row;
244 best_col = col;
245 }
246 }
247 }
248
249 let (new_cl, new_ord) = match (n1, n2) {
251 (2, 1) => {
252 if best_row == 1 {
253 reduc(&mut d, cl[e1][0], cl[e1][1], cl[e2][0]);
255 let nc = vec![cl[e1][0], cl[e2][0]];
256 let mut no = ord[e1].clone();
257 no.extend_from_slice(&ord[e2]);
258 (nc, no)
259 } else {
260 reduc(&mut d, cl[e2][0], cl[e1][0], cl[e1][1]);
262 let nc = vec![cl[e2][0], cl[e1][1]];
263 let mut no = ord[e2].clone();
264 no.extend_from_slice(&ord[e1]);
265 (nc, no)
266 }
267 }
268 (1, 2) => {
269 if best_col == 0 {
270 reduc(&mut d, cl[e1][0], cl[e2][0], cl[e2][1]);
272 let nc = vec![cl[e1][0], cl[e2][1]];
273 let mut no = ord[e1].clone();
274 no.extend_from_slice(&ord[e2]);
275 (nc, no)
276 } else {
277 reduc(&mut d, cl[e2][0], cl[e2][1], cl[e1][0]);
279 let nc = vec![cl[e2][0], cl[e1][0]];
280 let mut no = ord[e2].clone();
281 no.extend_from_slice(&ord[e1]);
282 (nc, no)
283 }
284 }
285 (2, 2) => match (best_row, best_col) {
286 (0, 0) => {
287 reduc(&mut d, cl[e1][1], cl[e1][0], cl[e2][0]);
289 reduc(&mut d, cl[e1][1], cl[e2][0], cl[e2][1]);
290 let nc = vec![cl[e1][1], cl[e2][1]];
291 let mut no = ord[e1].clone();
292 no.reverse();
293 no.extend_from_slice(&ord[e2]);
294 (nc, no)
295 }
296 (1, 0) => {
297 reduc(&mut d, cl[e1][0], cl[e1][1], cl[e2][0]);
299 reduc(&mut d, cl[e1][0], cl[e2][0], cl[e2][1]);
300 let nc = vec![cl[e1][0], cl[e2][1]];
301 let mut no = ord[e1].clone();
302 no.extend_from_slice(&ord[e2]);
303 (nc, no)
304 }
305 (0, 1) => {
306 reduc(&mut d, cl[e1][1], cl[e1][0], cl[e2][1]);
308 reduc(&mut d, cl[e1][1], cl[e2][1], cl[e2][0]);
309 let nc = vec![cl[e1][1], cl[e2][0]];
310 let mut no = ord[e1].clone();
311 no.reverse();
312 let mut oe2 = ord[e2].clone();
313 oe2.reverse();
314 no.extend_from_slice(&oe2);
315 (nc, no)
316 }
317 (1, 1) => {
318 reduc(&mut d, cl[e1][0], cl[e1][1], cl[e2][1]);
319 reduc(&mut d, cl[e1][0], cl[e2][1], cl[e2][0]);
320 let nc = vec![cl[e1][0], cl[e2][0]];
321 let mut no = ord[e1].clone();
322 let mut oe2 = ord[e2].clone();
323 oe2.reverse();
324 no.extend_from_slice(&oe2);
325 (nc, no)
326 }
327 _ => unreachable!(),
328 },
329 _ => panic!(
330 "Unhandled cluster sizes in NeighborNet step: n1={}, n2={}",
331 n1, n2
332 ),
333 };
334
335 ord[e1] = new_ord;
336 cl[e1] = new_cl;
337
338 update_dm(&mut dm, &d, &cl, e1);
339 remove_e2(&mut cl, &mut ord, &mut dm, e2);
340 }
341 }
342
343 ord.into_iter().next().unwrap_or_default()
344}
345
346pub fn neighbor_net_ordering(x: &Matrix) -> NeighbourNetResult {
347 NeighbourNetResult {
348 ordering: get_ordering_nn(x),
349 }
350}
351
352#[cfg(test)]
354mod tests {
355 use super::*;
356
357 use ndarray::array;
358
359 #[test]
360 fn smoke() {
361 let d = array![
362 [0.0, 5.0, 9.0, 9.0, 8.0],
363 [5.0, 0.0, 10.0, 10.0, 9.0],
364 [9.0, 10.0, 0.0, 8.0, 7.0],
365 [9.0, 10.0, 8.0, 0.0, 3.0],
366 [8.0, 9.0, 7.0, 3.0, 0.0],
367 ];
368 let res = neighbor_net_ordering(&d);
369 assert_eq!(res.ordering.len(), 5);
370 }
371}