fast_nnt/ordering/
ordering_matrix.rs

1use ndarray::{Array1, Array2, Axis, s};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone)]
5pub struct NeighbourNetResult {
6    pub ordering: Vec<usize>, // circular order in terms of original taxa indices (0-based)
7}
8
9type Matrix = Array2<f64>;
10
11fn mean_between_sets(d: &Matrix, a: &[usize], b: &[usize]) -> f64 {
12    let mut sum = 0.0;
13    let mut cnt = 0usize;
14    for &i in a {
15        for &j in b {
16            sum += unsafe { *d.uget((i, j)) };
17            cnt += 1;
18        }
19    }
20    if cnt == 0 { 0.0 } else { sum / cnt as f64 }
21}
22
23/// Update cluster distance matrix `dm` for row/col `j` given full `d` and cluster list `cl`.
24fn update_dm(dm: &mut Matrix, d: &Matrix, cl: &Vec<Vec<usize>>, j: usize) {
25    let l = cl.len();
26    // Compute all means to cluster j in parallel, then assign
27    let col_vals: Vec<f64> = (0..l)
28        .into_par_iter()
29        .map(|i| mean_between_sets(d, &cl[i], &cl[j]))
30        .collect();
31
32    // Assign row j
33    for i in 0..l {
34        dm[[i, j]] = col_vals[i];
35    }
36    // Mirror to column j
37    for i in 0..l {
38        dm[[j, i]] = col_vals[i];
39    }
40    dm[[j, j]] = 0.0;
41}
42
43/// Rx helper from the R code.
44fn rx(d: &Matrix, x: &[usize], cl: &Vec<Vec<usize>>) -> Vec<f64> {
45    let lx = x.len();
46    let mut res = vec![0.0; lx];
47    for (i, &xi) in x.iter().enumerate() {
48        let mut tmp = 0.0;
49        // sum to other x's
50        for (j, &xj) in x.iter().enumerate() {
51            if j != i {
52                tmp += unsafe { *d.uget((xi, xj)) };
53            }
54        }
55        // plus mean to each other cluster
56        for c in cl.iter() {
57            tmp += mean_between_sets(d, std::slice::from_ref(&xi), c);
58        }
59        res[i] = tmp;
60    }
61    res
62}
63
64/// Formula (1) reduction step; mutates `d` in place.
65fn reduc(d: &mut Matrix, x: usize, y: usize, z: usize) {
66    let n = d.nrows();
67    // capture rows before overwriting
68    let row_x = d.row(x).to_owned();
69    let row_y = d.row(y).to_owned();
70    let row_z = d.row(z).to_owned();
71
72    // u = 2/3 * row_x + 1/3 * row_y
73    // v = 2/3 * row_z + 1/3 * row_y
74    let u: Array1<f64> = &(&row_x * (2.0 / 3.0)) + &(&row_y * (1.0 / 3.0));
75    let v: Array1<f64> = &(&row_z * (2.0 / 3.0)) + &(&row_y * (1.0 / 3.0));
76
77    let uv = (row_x[y] + row_x[z] + row_y[z]) / 3.0;
78
79    // write back rows
80    d.row_mut(x).assign(&u);
81    d.row_mut(z).assign(&v);
82    d.row_mut(y).fill(0.0);
83
84    // symmetric columns
85    for j in 0..n {
86        d[[j, x]] = u[j];
87        d[[j, z]] = v[j];
88        d[[j, y]] = 0.0;
89    }
90
91    d[[x, z]] = uv;
92    d[[z, x]] = uv;
93    d[[x, x]] = 0.0;
94    d[[z, z]] = 0.0;
95}
96
97/// Remove row/col `idx` from a square matrix, returning a new (n-1)x(n-1) array.
98fn remove_row_col(m: &Matrix, idx: usize) -> Matrix {
99    let n = m.nrows();
100    debug_assert_eq!(n, m.ncols());
101    if n == 1 {
102        return Array2::zeros((0, 0));
103    }
104    let mut out = Array2::<f64>::zeros((n - 1, n - 1));
105    // Top-left block
106    if idx > 0 {
107        out.slice_mut(s![0..idx, 0..idx])
108            .assign(&m.slice(s![0..idx, 0..idx]));
109    }
110    // Top-right block
111    if idx + 1 < n {
112        out.slice_mut(s![0..idx, idx..])
113            .assign(&m.slice(s![0..idx, (idx + 1)..]));
114    }
115    // Bottom-left block
116    if idx + 1 < n {
117        out.slice_mut(s![idx.., 0..idx])
118            .assign(&m.slice(s![(idx + 1).., 0..idx]));
119    }
120    // Bottom-right block
121    if idx + 1 < n {
122        out.slice_mut(s![idx.., idx..])
123            .assign(&m.slice(s![(idx + 1).., (idx + 1)..]));
124    }
125    out
126}
127
128/// Choose (e1,e2) minimizing DM[i,j] - r[i] - r[j] (i<j), parallelized.
129fn choose_pair(dm: &Matrix, r: &[f64]) -> (usize, usize) {
130    let l = dm.nrows();
131    if l <= 1 {
132        return (0, 0);
133    }
134    // Each worker scans a band of i and returns (best_val, i, j),
135    // then we reduce to the global best.
136    let per_i = (0..l).into_par_iter().map(|i| {
137        let mut best = f64::INFINITY;
138        let mut best_j = i + 1;
139        for j in (i + 1)..l {
140            let q = unsafe { *dm.uget((i, j)) } - r[i] - r[j];
141            if q < best {
142                best = q;
143                best_j = j;
144            }
145        }
146        (best, i, best_j)
147    });
148
149    let (.., bi, bj) = per_i.reduce(
150        || (f64::INFINITY, 0usize, 1usize),
151        |a, b| if a.0 <= b.0 { a } else { b },
152    );
153    (bi, bj)
154}
155
156fn remove_e2(cl: &mut Vec<Vec<usize>>, ord: &mut Vec<Vec<usize>>, dm: &mut Matrix, e2: usize) {
157    cl.remove(e2);
158    ord.remove(e2);
159    *dm = remove_row_col(dm, e2);
160}
161
162/// The main ordering routine.
163pub fn get_ordering_nn(x: &Matrix) -> Vec<usize> {
164    assert_eq!(x.nrows(), x.ncols(), "Distance matrix must be square");
165    let n = x.nrows();
166
167    // Mutable working copy of D (modified by `reduc`)
168    let mut d = x.clone();
169
170    // Clusters & per-cluster linear orders
171    let mut cl: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
172    let mut ord: Vec<Vec<usize>> = cl.clone();
173
174    // Cluster distance matrix (start with singleton distances)
175    let mut dm = d.clone();
176
177    while cl.len() > 1 {
178        let l = dm.nrows();
179        let (e1, e2) = if l > 2 {
180            // r = rowSums(DM) / (l - 2), parallel row sums
181            let denom = (l as f64) - 2.0;
182            let r: Vec<f64> = dm
183                .axis_iter(Axis(0))
184                .into_par_iter()
185                .map(|row| row.sum() / denom)
186                .collect();
187
188            choose_pair(&dm, &r)
189        } else {
190            (0, 1)
191        };
192
193        let n1 = cl[e1].len();
194        let n2 = cl[e2].len();
195
196        if n1 == 1 && n2 == 1 {
197            // Simple merge of two leaves
198            let mut new_cl = Vec::with_capacity(2);
199            new_cl.extend_from_slice(&cl[e1]);
200            new_cl.extend_from_slice(&cl[e2]);
201            let new_ord = new_cl.clone();
202
203            cl[e1] = new_cl;
204            ord[e1] = new_ord;
205
206            update_dm(&mut dm, &d, &cl, e1);
207            remove_e2(&mut cl, &mut ord, &mut dm, e2);
208        } else {
209            // Build "others" (all clusters except e1,e2)
210            let mut others: Vec<Vec<usize>> = Vec::with_capacity(cl.len().saturating_sub(2));
211            for (idx, c) in cl.iter().enumerate() {
212                if idx != e1 && idx != e2 {
213                    others.push(c.clone());
214                }
215            }
216
217            // cltmp2 = elements of CL[e1] followed by CL[e2]
218            let mut cltmp2: Vec<usize> = Vec::with_capacity(n1 + n2);
219            cltmp2.extend_from_slice(&cl[e1]);
220            cltmp2.extend_from_slice(&cl[e2]);
221
222            let mut rtmp2 = rx(&d, &cltmp2, &others);
223            let ltmp = cl[e1].len() + cl[e2].len() + others.len();
224            if ltmp > 2 {
225                let scale = 1.0 / ((ltmp as f64) - 2.0);
226                for v in rtmp2.iter_mut() {
227                    *v *= scale;
228                }
229            }
230
231            // DM3 = d[cltmp2, cltmp2] - (rtmp2[i] + rtmp2[j])
232            // We only need the cross block rows 0..n1-1, cols n1..n1+n2-1
233            let mut best_val = f64::INFINITY;
234            let mut best_row = 0usize; // 0..n1-1
235            let mut best_col = 0usize; // 0..n2-1
236            for col in 0..n2 {
237                for row in 0..n1 {
238                    let i = cltmp2[row];
239                    let j = cltmp2[n1 + col];
240                    let v = unsafe { *d.uget((i, j)) } - (rtmp2[row] + rtmp2[n1 + col]);
241                    if v < best_val {
242                        best_val = v;
243                        best_row = row;
244                        best_col = col;
245                    }
246                }
247            }
248
249            // Cases with cluster sizes from {1,2}
250            let (new_cl, new_ord) = match (n1, n2) {
251                (2, 1) => {
252                    if best_row == 1 {
253                        // blub == 2
254                        reduc(&mut d, cl[e1][0], cl[e1][1], cl[e2][0]);
255                        let nc = vec![cl[e1][0], cl[e2][0]];
256                        let mut no = ord[e1].clone();
257                        no.extend_from_slice(&ord[e2]);
258                        (nc, no)
259                    } else {
260                        // else
261                        reduc(&mut d, cl[e2][0], cl[e1][0], cl[e1][1]);
262                        let nc = vec![cl[e2][0], cl[e1][1]];
263                        let mut no = ord[e2].clone();
264                        no.extend_from_slice(&ord[e1]);
265                        (nc, no)
266                    }
267                }
268                (1, 2) => {
269                    if best_col == 0 {
270                        // blub == 1
271                        reduc(&mut d, cl[e1][0], cl[e2][0], cl[e2][1]);
272                        let nc = vec![cl[e1][0], cl[e2][1]];
273                        let mut no = ord[e1].clone();
274                        no.extend_from_slice(&ord[e2]);
275                        (nc, no)
276                    } else {
277                        // else
278                        reduc(&mut d, cl[e2][0], cl[e2][1], cl[e1][0]);
279                        let nc = vec![cl[e2][0], cl[e1][0]];
280                        let mut no = ord[e2].clone();
281                        no.extend_from_slice(&ord[e1]);
282                        (nc, no)
283                    }
284                }
285                (2, 2) => match (best_row, best_col) {
286                    (0, 0) => {
287                        // blub == 1
288                        reduc(&mut d, cl[e1][1], cl[e1][0], cl[e2][0]);
289                        reduc(&mut d, cl[e1][1], cl[e2][0], cl[e2][1]);
290                        let nc = vec![cl[e1][1], cl[e2][1]];
291                        let mut no = ord[e1].clone();
292                        no.reverse();
293                        no.extend_from_slice(&ord[e2]);
294                        (nc, no)
295                    }
296                    (1, 0) => {
297                        // blub == 2
298                        reduc(&mut d, cl[e1][0], cl[e1][1], cl[e2][0]);
299                        reduc(&mut d, cl[e1][0], cl[e2][0], cl[e2][1]);
300                        let nc = vec![cl[e1][0], cl[e2][1]];
301                        let mut no = ord[e1].clone();
302                        no.extend_from_slice(&ord[e2]);
303                        (nc, no)
304                    }
305                    (0, 1) => {
306                        // blub == 3
307                        reduc(&mut d, cl[e1][1], cl[e1][0], cl[e2][1]);
308                        reduc(&mut d, cl[e1][1], cl[e2][1], cl[e2][0]);
309                        let nc = vec![cl[e1][1], cl[e2][0]];
310                        let mut no = ord[e1].clone();
311                        no.reverse();
312                        let mut oe2 = ord[e2].clone();
313                        oe2.reverse();
314                        no.extend_from_slice(&oe2);
315                        (nc, no)
316                    }
317                    (1, 1) => {
318                        reduc(&mut d, cl[e1][0], cl[e1][1], cl[e2][1]);
319                        reduc(&mut d, cl[e1][0], cl[e2][1], cl[e2][0]);
320                        let nc = vec![cl[e1][0], cl[e2][0]];
321                        let mut no = ord[e1].clone();
322                        let mut oe2 = ord[e2].clone();
323                        oe2.reverse();
324                        no.extend_from_slice(&oe2);
325                        (nc, no)
326                    }
327                    _ => unreachable!(),
328                },
329                _ => panic!(
330                    "Unhandled cluster sizes in NeighborNet step: n1={}, n2={}",
331                    n1, n2
332                ),
333            };
334
335            ord[e1] = new_ord;
336            cl[e1] = new_cl;
337
338            update_dm(&mut dm, &d, &cl, e1);
339            remove_e2(&mut cl, &mut ord, &mut dm, e2);
340        }
341    }
342
343    ord.into_iter().next().unwrap_or_default()
344}
345
346pub fn neighbor_net_ordering(x: &Matrix) -> NeighbourNetResult {
347    NeighbourNetResult {
348        ordering: get_ordering_nn(x),
349    }
350}
351
352// --- Example usage & quick test ---
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    use ndarray::array;
358
359    #[test]
360    fn smoke() {
361        let d = array![
362            [0.0, 5.0, 9.0, 9.0, 8.0],
363            [5.0, 0.0, 10.0, 10.0, 9.0],
364            [9.0, 10.0, 0.0, 8.0, 7.0],
365            [9.0, 10.0, 8.0, 0.0, 3.0],
366            [8.0, 9.0, 7.0, 3.0, 0.0],
367        ];
368        let res = neighbor_net_ordering(&d);
369        assert_eq!(res.ordering.len(), 5);
370    }
371}