1use crate::splits::asplit::ASplit; use fixedbitset::FixedBitSet;
3use ndarray::Array2;
4use rayon::prelude::*;
5
6pub fn compute_least_squares_fit(distances: &Array2<f64>, splits: &[ASplit]) -> f32 {
13 let n = distances.nrows();
14 assert_eq!(n, distances.ncols(), "distances must be square");
15
16 let split_dist = splits
19 .par_iter()
20 .map(|s| {
21 let mut m = Array2::<f64>::zeros((n, n));
22 let w = s.get_weight();
23 let a: &FixedBitSet = s.get_a();
24 let b: &FixedBitSet = s.get_b();
25
26 for i1 in a.ones() {
27 if i1 == 0 || i1 > n {
28 continue;
29 }
30 let ii = i1 - 1; for j1 in b.ones() {
32 if j1 == 0 || j1 > n {
33 continue;
34 }
35 let jj = j1 - 1;
36 m[[ii, jj]] += w;
37 m[[jj, ii]] += w; }
39 }
40 m
41 })
42 .reduce(
43 || Array2::<f64>::zeros((n, n)),
44 |mut acc, m| {
45 acc.zip_mut_with(&m, |a, b| *a += *b);
46 acc
47 },
48 );
49
50 let (sum_diff_sq, sum_d_sq) = (0..n - 1)
52 .into_par_iter()
53 .map(|i| {
54 let mut diff_sum = 0.0;
55 let mut d_sum = 0.0;
56 for j in (i + 1)..n {
57 let sij = split_dist[[i, j]];
58 let dij = distances[[i, j]];
59 let diff = sij - dij;
60 diff_sum += diff * diff;
61 d_sum += dij * dij;
62 }
63 (diff_sum, d_sum)
64 })
65 .reduce(|| (0.0, 0.0), |a, b| (a.0 + b.0, a.1 + b.1));
66
67 let fit = if sum_d_sq > 0.0 {
68 100.0 * (1.0 - (sum_diff_sq / sum_d_sq))
69 } else {
70 0.0
71 };
72
73 fit as f32
74}
75
76#[cfg(test)]
77mod lsq_tests {
78 use crate::weights::active_set_weights::{NNLSParams, compute_asplits};
79
80 use super::*;
81 use ndarray::{Array2, arr2};
82
83 fn bs_from(indices: &[usize], len: usize) -> FixedBitSet {
84 let mut bs = FixedBitSet::with_capacity(len + 1);
85 bs.grow(len + 1);
86 for &i in indices {
87 bs.set(i, true); }
89 bs
90 }
91
92 fn distances_from_splits(n: usize, splits: &[ASplit]) -> Array2<f64> {
94 let mut d = Array2::<f64>::zeros((n, n));
95 for s in splits {
96 let w = s.get_weight();
97 let a = s.get_a();
98 let b = s.get_b();
99 for i1 in a.ones() {
100 if i1 == 0 || i1 > n {
101 continue;
102 }
103 let ii = i1 - 1;
104 for j1 in b.ones() {
105 if j1 == 0 || j1 > n {
106 continue;
107 }
108 let jj = j1 - 1;
109 d[[ii, jj]] += w;
110 d[[jj, ii]] += w;
111 }
112 }
113 }
114 d
115 }
116
117 #[test]
118 fn lsq_perfect_fit() {
119 let n = 6;
120
121 let s1 = ASplit::from_a_ntax_with_weight(bs_from(&[1, 2], n), n, 1.0);
123 let s2 = ASplit::from_a_ntax_with_weight(bs_from(&[2, 3, 4], n), n, 0.7);
124 let s3 = ASplit::from_a_ntax_with_weight(bs_from(&[5], n), n, 0.4); let splits = vec![s1, s2, s3];
127 let d = distances_from_splits(n, &splits);
128
129 let fit = compute_least_squares_fit(&d, &splits);
130 assert!((fit - 100.0).abs() < 1e-6, "fit was {fit}");
131 }
132
133 #[test]
134 fn lsq_imperfect_fit_with_noise() {
135 let n = 6;
136 let s1 = ASplit::from_a_ntax_with_weight(bs_from(&[1, 2], n), n, 1.0);
137 let s2 = ASplit::from_a_ntax_with_weight(bs_from(&[2, 3, 4], n), n, 0.7);
138 let s3 = ASplit::from_a_ntax_with_weight(bs_from(&[5], n), n, 0.4);
139
140 let splits = vec![s1, s2, s3];
141 let mut d = distances_from_splits(n, &splits);
142
143 for i in 0..n {
145 for j in (i + 1)..n {
146 d[[i, j]] += 0.01 * ((i + j) as f64);
147 d[[j, i]] = d[[i, j]];
148 }
149 }
150
151 let fit = compute_least_squares_fit(&d, &splits);
152 assert!(
153 fit < 100.0 && fit > 0.0,
154 "fit should drop below 100, got {fit}"
155 );
156 }
157
158 #[test]
159 fn smoke_10_1() {
160 let d = arr2(&[
161 [0.0, 5.0, 12.0, 7.0, 3.0, 9.0, 11.0, 6.0, 4.0, 10.0],
162 [5.0, 0.0, 8.0, 2.0, 14.0, 5.0, 13.0, 7.0, 12.0, 1.0],
163 [12.0, 8.0, 0.0, 4.0, 9.0, 3.0, 8.0, 2.0, 5.0, 6.0],
164 [7.0, 2.0, 4.0, 0.0, 11.0, 7.0, 10.0, 4.0, 6.0, 9.0],
165 [3.0, 14.0, 9.0, 11.0, 0.0, 8.0, 1.0, 13.0, 2.0, 7.0],
166 [9.0, 5.0, 3.0, 7.0, 8.0, 0.0, 12.0, 5.0, 3.0, 4.0],
167 [11.0, 13.0, 8.0, 10.0, 1.0, 12.0, 0.0, 6.0, 2.0, 8.0],
168 [6.0, 7.0, 2.0, 4.0, 13.0, 5.0, 6.0, 0.0, 9.0, 7.0],
169 [4.0, 12.0, 5.0, 6.0, 2.0, 3.0, 2.0, 9.0, 0.0, 5.0],
170 [10.0, 1.0, 6.0, 9.0, 7.0, 4.0, 8.0, 7.0, 5.0, 0.0],
171 ]);
172
173 let ord = vec![0, 1, 5, 7, 9, 3, 8, 4, 2, 10, 6];
174 let mut params = NNLSParams::default();
175
176 let splits = compute_asplits(&ord, &d, &mut params, None).expect("ASplits solved");
177 let fit = compute_least_squares_fit(&d, &splits);
178
179 println!("Fit: {}", fit);
180
181 assert_eq!(fit, 93.477936, "Expected perfect fit for this example");
182 }
183}