1use std::io::{BufRead, Write};
2
3use rayon::prelude::*;
4use rsomics_common::{Result, RsomicsError};
5
6mod dm;
7mod env;
8
9pub use dm::DistanceMatrix;
10pub use env::{EnvTable, Standardized};
11
12pub struct SizeBest {
13 pub size: usize,
14 pub correlation: f64,
15 pub vars: Vec<String>,
16}
17
18pub fn bioenv(dm_flat: &[f64], env: &Standardized) -> Vec<SizeBest> {
22 let p = env.vars.len();
23 let n = env.n;
24 let m = dm_flat.len();
25
26 let dm_ranks = rankdata(dm_flat);
30 let dm_centered = centered_unit(&dm_ranks);
31
32 (1..=p)
33 .map(|size| {
34 let best = (0..count_combinations(p, size))
35 .into_par_iter()
36 .map(|combo_idx| {
37 let subset = nth_combination(p, size, combo_idx);
38 let dists = euclidean_condensed(&env.cols, n, m, &subset);
39 let var_ranks = rankdata(&dists);
40 let rho = spearman_against(&var_ranks, &dm_centered);
41 Candidate { rho, combo_idx }
42 })
43 .reduce(
44 || Candidate {
45 rho: f64::NEG_INFINITY,
46 combo_idx: usize::MAX,
47 },
48 Candidate::better,
49 );
50 let subset = nth_combination(p, size, best.combo_idx);
51 SizeBest {
52 size,
53 correlation: best.rho,
54 vars: subset.iter().map(|&i| env.vars[i].clone()).collect(),
55 }
56 })
57 .collect()
58}
59
60#[derive(Clone, Copy)]
61struct Candidate {
62 rho: f64,
63 combo_idx: usize,
64}
65
66impl Candidate {
67 fn better(a: Candidate, b: Candidate) -> Candidate {
70 if b.rho > a.rho || (b.rho == a.rho && b.combo_idx < a.combo_idx) {
71 b
72 } else {
73 a
74 }
75 }
76}
77
78fn euclidean_condensed(cols: &[f64], n: usize, m: usize, subset: &[usize]) -> Vec<f64> {
81 let mut out = Vec::with_capacity(m);
82 for i in 0..n {
83 for j in (i + 1)..n {
84 let mut s = 0.0;
85 for &c in subset {
86 let base = c * n;
87 let d = cols[base + i] - cols[base + j];
88 s += d * d;
89 }
90 out.push(s.sqrt());
91 }
92 }
93 out
94}
95
96fn spearman_against(var_ranks: &[f64], dm_centered: &[f64]) -> f64 {
100 let vmean = var_ranks.iter().sum::<f64>() / var_ranks.len() as f64;
101 let mut vnorm_sq = 0.0;
102 let mut dot = 0.0;
103 for (&vr, &dc) in var_ranks.iter().zip(dm_centered) {
104 let cv = vr - vmean;
105 vnorm_sq += cv * cv;
106 dot += cv * dc;
107 }
108 if vnorm_sq == 0.0 {
109 return f64::NAN;
110 }
111 dot / vnorm_sq.sqrt()
112}
113
114fn centered_unit(v: &[f64]) -> Vec<f64> {
117 let mean = v.iter().sum::<f64>() / v.len() as f64;
118 let mut out: Vec<f64> = v.iter().map(|&x| x - mean).collect();
119 let norm = out.iter().map(|&x| x * x).sum::<f64>().sqrt();
120 for x in &mut out {
121 *x /= norm;
122 }
123 out
124}
125
126fn rankdata(v: &[f64]) -> Vec<f64> {
128 let mut order: Vec<usize> = (0..v.len()).collect();
129 order.sort_by(|&a, &b| v[a].partial_cmp(&v[b]).unwrap());
130 let mut ranks = vec![0.0f64; v.len()];
131 let mut i = 0;
132 while i < order.len() {
133 let mut j = i + 1;
134 while j < order.len() && v[order[j]] == v[order[i]] {
135 j += 1;
136 }
137 let avg = ((i + 1 + j) as f64) / 2.0;
138 for &idx in &order[i..j] {
139 ranks[idx] = avg;
140 }
141 i = j;
142 }
143 ranks
144}
145
146fn count_combinations(n: usize, k: usize) -> usize {
147 if k > n {
148 return 0;
149 }
150 let k = k.min(n - k);
151 let mut c = 1usize;
152 for i in 0..k {
153 c = c * (n - i) / (i + 1);
154 }
155 c
156}
157
158fn nth_combination(n: usize, k: usize, mut idx: usize) -> Vec<usize> {
160 let mut out = Vec::with_capacity(k);
161 let mut c = 0usize;
162 let mut remaining = k;
163 while remaining > 0 {
164 let block = count_combinations(n - c - 1, remaining - 1);
165 if idx < block {
166 out.push(c);
167 remaining -= 1;
168 } else {
169 idx -= block;
170 }
171 c += 1;
172 }
173 out
174}
175
176pub fn write_result<W: Write>(out: &mut W, best: &[SizeBest]) -> Result<()> {
177 writeln!(out, "size\tcorrelation\tvars").map_err(RsomicsError::Io)?;
178 for b in best {
179 writeln!(
180 out,
181 "{}\t{:.12}\t{}",
182 b.size,
183 b.correlation,
184 b.vars.join(", ")
185 )
186 .map_err(RsomicsError::Io)?;
187 }
188 Ok(())
189}
190
191pub fn read_matrix<R: BufRead>(reader: R, source: &str) -> Result<DistanceMatrix> {
192 DistanceMatrix::read(reader, source)
193}
194
195pub fn read_env<R: BufRead>(reader: R, source: &str) -> Result<EnvTable> {
196 EnvTable::read(reader, source)
197}
198
199pub fn run_bioenv(
201 dm: &DistanceMatrix,
202 env: &EnvTable,
203 columns: Option<&[String]>,
204 env_source: &str,
205) -> Result<Vec<SizeBest>> {
206 let selected = env.select(columns, env_source)?;
207 let std = selected.standardized_for(&dm.ids, env_source)?;
208 Ok(bioenv(&dm.condensed(), &std))
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use std::io::Cursor;
215
216 const DM: &str = "\tA\tB\tC\tD\n\
217 A\t0.0\t0.5\t0.25\t0.75\n\
218 B\t0.5\t0.0\t0.1\t0.42\n\
219 C\t0.25\t0.1\t0.0\t0.33\n\
220 D\t0.75\t0.42\t0.33\t0.0\n";
221
222 const ENV: &str = "id\tpH\tElevation\n\
223 A\t7.0\t400\n\
224 B\t8.0\t530\n\
225 C\t7.5\t450\n\
226 D\t8.5\t810\n";
227
228 #[test]
229 fn skbio_doc_example() {
230 let dm = read_matrix(Cursor::new(DM), "dm").unwrap();
231 let env = read_env(Cursor::new(ENV), "env").unwrap();
232 let best = run_bioenv(&dm, &env, None, "env").unwrap();
233 assert_eq!(best.len(), 2);
234 assert_eq!(best[0].vars, vec!["pH"]);
235 assert!(
236 (best[0].correlation - 0.771517).abs() < 1e-6,
237 "{}",
238 best[0].correlation
239 );
240 assert_eq!(best[1].vars, vec!["pH", "Elevation"]);
241 assert!(
242 (best[1].correlation - 0.714286).abs() < 1e-6,
243 "{}",
244 best[1].correlation
245 );
246 }
247
248 #[test]
249 fn rankdata_ties_averaged() {
250 assert_eq!(rankdata(&[1.0, 2.0, 2.0, 3.0]), vec![1.0, 2.5, 2.5, 4.0]);
251 }
252
253 #[test]
254 fn nth_combination_is_lexicographic() {
255 let all: Vec<_> = (0..count_combinations(4, 2))
256 .map(|i| nth_combination(4, 2, i))
257 .collect();
258 assert_eq!(
259 all,
260 vec![
261 vec![0, 1],
262 vec![0, 2],
263 vec![0, 3],
264 vec![1, 2],
265 vec![1, 3],
266 vec![2, 3],
267 ]
268 );
269 }
270
271 #[test]
272 fn condensed_upper_triangle() {
273 let dm = read_matrix(Cursor::new(DM), "dm").unwrap();
274 assert_eq!(dm.condensed(), vec![0.5, 0.25, 0.75, 0.1, 0.42, 0.33]);
275 }
276}