fast_nnt/
utils.rs

1use crate::splits::asplit::ASplit; // adjust path if needed
2use fixedbitset::FixedBitSet;
3use ndarray::Array2;
4use rayon::prelude::*;
5
6/// Compute the least-squares fit (%) of the given splits to the distances.
7/// - `distances`: symmetric n×n (0-based indices)
8/// - `splits`: list of ASplit with 1-based bitsets (bit 0 ignored)
9///
10/// Returns a percentage in [0, 100]. Matches Java behavior:
11///   100 * (1 - sum((s_ij - d_ij)^2) / sum(d_ij^2))  over i<j, 1-based.
12pub fn compute_least_squares_fit(distances: &Array2<f64>, splits: &[ASplit]) -> f32 {
13    let n = distances.nrows();
14    assert_eq!(n, distances.ncols(), "distances must be square");
15
16    // Build the split-induced distance matrix S (0-based) in parallel:
17    // S[i,j] = Σ_{splits} weight * [i∈A && j∈B or i∈B && j∈A]
18    let split_dist = splits
19        .par_iter()
20        .map(|s| {
21            let mut m = Array2::<f64>::zeros((n, n));
22            let w = s.get_weight();
23            let a: &FixedBitSet = s.get_a();
24            let b: &FixedBitSet = s.get_b();
25
26            for i1 in a.ones() {
27                if i1 == 0 || i1 > n {
28                    continue;
29                }
30                let ii = i1 - 1; // 0-based
31                for j1 in b.ones() {
32                    if j1 == 0 || j1 > n {
33                        continue;
34                    }
35                    let jj = j1 - 1;
36                    m[[ii, jj]] += w;
37                    m[[jj, ii]] += w; // symmetric
38                }
39            }
40            m
41        })
42        .reduce(
43            || Array2::<f64>::zeros((n, n)),
44            |mut acc, m| {
45                acc.zip_mut_with(&m, |a, b| *a += *b);
46                acc
47            },
48        );
49
50    // Sum over the upper triangle (i<j) in parallel
51    let (sum_diff_sq, sum_d_sq) = (0..n - 1)
52        .into_par_iter()
53        .map(|i| {
54            let mut diff_sum = 0.0;
55            let mut d_sum = 0.0;
56            for j in (i + 1)..n {
57                let sij = split_dist[[i, j]];
58                let dij = distances[[i, j]];
59                let diff = sij - dij;
60                diff_sum += diff * diff;
61                d_sum += dij * dij;
62            }
63            (diff_sum, d_sum)
64        })
65        .reduce(|| (0.0, 0.0), |a, b| (a.0 + b.0, a.1 + b.1));
66
67    let fit = if sum_d_sq > 0.0 {
68        100.0 * (1.0 - (sum_diff_sq / sum_d_sq))
69    } else {
70        0.0
71    };
72
73    fit as f32
74}
75
76#[cfg(test)]
77mod lsq_tests {
78    use crate::weights::active_set_weights::{NNLSParams, compute_asplits};
79
80    use super::*;
81    use ndarray::{Array2, arr2};
82
83    fn bs_from(indices: &[usize], len: usize) -> FixedBitSet {
84        let mut bs = FixedBitSet::with_capacity(len + 1);
85        bs.grow(len + 1);
86        for &i in indices {
87            bs.set(i, true); // 1-based
88        }
89        bs
90    }
91
92    /// Build a symmetric distance matrix from a set of splits (same logic as the evaluator).
93    fn distances_from_splits(n: usize, splits: &[ASplit]) -> Array2<f64> {
94        let mut d = Array2::<f64>::zeros((n, n));
95        for s in splits {
96            let w = s.get_weight();
97            let a = s.get_a();
98            let b = s.get_b();
99            for i1 in a.ones() {
100                if i1 == 0 || i1 > n {
101                    continue;
102                }
103                let ii = i1 - 1;
104                for j1 in b.ones() {
105                    if j1 == 0 || j1 > n {
106                        continue;
107                    }
108                    let jj = j1 - 1;
109                    d[[ii, jj]] += w;
110                    d[[jj, ii]] += w;
111                }
112            }
113        }
114        d
115    }
116
117    #[test]
118    fn lsq_perfect_fit() {
119        let n = 6;
120
121        // Construct a few splits on 1-based taxa
122        let s1 = ASplit::from_a_ntax_with_weight(bs_from(&[1, 2], n), n, 1.0);
123        let s2 = ASplit::from_a_ntax_with_weight(bs_from(&[2, 3, 4], n), n, 0.7);
124        let s3 = ASplit::from_a_ntax_with_weight(bs_from(&[5], n), n, 0.4); // trivial
125
126        let splits = vec![s1, s2, s3];
127        let d = distances_from_splits(n, &splits);
128
129        let fit = compute_least_squares_fit(&d, &splits);
130        assert!((fit - 100.0).abs() < 1e-6, "fit was {fit}");
131    }
132
133    #[test]
134    fn lsq_imperfect_fit_with_noise() {
135        let n = 6;
136        let s1 = ASplit::from_a_ntax_with_weight(bs_from(&[1, 2], n), n, 1.0);
137        let s2 = ASplit::from_a_ntax_with_weight(bs_from(&[2, 3, 4], n), n, 0.7);
138        let s3 = ASplit::from_a_ntax_with_weight(bs_from(&[5], n), n, 0.4);
139
140        let splits = vec![s1, s2, s3];
141        let mut d = distances_from_splits(n, &splits);
142
143        // Add small noise on upper triangle
144        for i in 0..n {
145            for j in (i + 1)..n {
146                d[[i, j]] += 0.01 * ((i + j) as f64);
147                d[[j, i]] = d[[i, j]];
148            }
149        }
150
151        let fit = compute_least_squares_fit(&d, &splits);
152        assert!(
153            fit < 100.0 && fit > 0.0,
154            "fit should drop below 100, got {fit}"
155        );
156    }
157
158    #[test]
159    fn smoke_10_1() {
160        let d = arr2(&[
161            [0.0, 5.0, 12.0, 7.0, 3.0, 9.0, 11.0, 6.0, 4.0, 10.0],
162            [5.0, 0.0, 8.0, 2.0, 14.0, 5.0, 13.0, 7.0, 12.0, 1.0],
163            [12.0, 8.0, 0.0, 4.0, 9.0, 3.0, 8.0, 2.0, 5.0, 6.0],
164            [7.0, 2.0, 4.0, 0.0, 11.0, 7.0, 10.0, 4.0, 6.0, 9.0],
165            [3.0, 14.0, 9.0, 11.0, 0.0, 8.0, 1.0, 13.0, 2.0, 7.0],
166            [9.0, 5.0, 3.0, 7.0, 8.0, 0.0, 12.0, 5.0, 3.0, 4.0],
167            [11.0, 13.0, 8.0, 10.0, 1.0, 12.0, 0.0, 6.0, 2.0, 8.0],
168            [6.0, 7.0, 2.0, 4.0, 13.0, 5.0, 6.0, 0.0, 9.0, 7.0],
169            [4.0, 12.0, 5.0, 6.0, 2.0, 3.0, 2.0, 9.0, 0.0, 5.0],
170            [10.0, 1.0, 6.0, 9.0, 7.0, 4.0, 8.0, 7.0, 5.0, 0.0],
171        ]);
172
173        let ord = vec![0, 1, 5, 7, 9, 3, 8, 4, 2, 10, 6];
174        let mut params = NNLSParams::default();
175
176        let splits = compute_asplits(&ord, &d, &mut params, None).expect("ASplits solved");
177        let fit = compute_least_squares_fit(&d, &splits);
178
179        println!("Fit: {}", fit);
180
181        assert_eq!(fit, 93.477936, "Expected perfect fit for this example");
182    }
183}