Skip to main content

rsomics_gradient_trajectory/
gradient.rs

1use std::collections::BTreeMap;
2
3use rsomics_common::{Result, RsomicsError};
4
5use crate::anova::f_oneway;
6use crate::io::{Coords, Metadata};
7use crate::natsort::realsorted_by;
8
9#[derive(Clone, Copy, PartialEq, Eq, Debug)]
10pub enum Algorithm {
11    /// RMS trajectory: norm of consecutive sample differences. skbio `trajectory`.
12    Trajectory,
13    /// RMS average: norm of each sample from the group centroid. skbio `avg`.
14    Average,
15    /// First difference of consecutive-difference norms. skbio `diff`.
16    FirstDifference,
17    /// Windowed first difference. skbio `wdiff`.
18    WindowDifference,
19}
20
21impl Algorithm {
22    pub fn name(self) -> &'static str {
23        match self {
24            Algorithm::Trajectory => "trajectory",
25            Algorithm::Average => "avg",
26            Algorithm::FirstDifference => "diff",
27            Algorithm::WindowDifference => "wdiff",
28        }
29    }
30}
31
32pub struct GroupResult {
33    pub name: String,
34    pub trajectory: Vec<f64>,
35    pub mean: f64,
36}
37
38pub struct CategoryResult {
39    pub category: String,
40    pub probability: Option<f64>,
41    pub groups: Option<Vec<GroupResult>>,
42    pub message: Option<String>,
43}
44
45pub struct GradientResult {
46    pub algorithm: &'static str,
47    pub weighted: bool,
48    pub categories: Vec<CategoryResult>,
49}
50
51pub struct Params<'a> {
52    pub algorithm: Algorithm,
53    pub trajectory_categories: &'a [String],
54    pub sort_category: Option<&'a str>,
55    pub axes: usize,
56    pub weighted: bool,
57    pub window_size: usize,
58}
59
60/// QIIME-style gradient/trajectory ANOVA over a precomputed ordination, the
61/// scikit-bio `GradientANOVA` family. For each requested metadata category it
62/// builds per-group trajectories through the ordination space and runs one-way
63/// ANOVA across the groups.
64///
65/// # Errors
66/// Unknown category or sort category, `axes` out of `1..=prop.len()`, weighting
67/// without a numeric sort category, or no samples shared by coords and metadata.
68pub fn gradient_anova(
69    coords: &Coords,
70    prop: &[f64],
71    meta: &Metadata,
72    params: &Params,
73) -> Result<GradientResult> {
74    if params.axes == 0 || params.axes > prop.len() {
75        return Err(RsomicsError::InvalidInput(format!(
76            "axes must be between 1 and the number of proportions ({}), got {}",
77            prop.len(),
78            params.axes
79        )));
80    }
81    if coords.naxes < params.axes {
82        return Err(RsomicsError::InvalidInput(format!(
83            "coordinates have {} axes, fewer than the requested {}",
84            coords.naxes, params.axes
85        )));
86    }
87    let sort_col = params
88        .sort_category
89        .map(|c| meta.col_index(c))
90        .transpose()?;
91    let cat_cols: Vec<(String, usize)> = params
92        .trajectory_categories
93        .iter()
94        .map(|c| Ok((c.clone(), meta.col_index(c)?)))
95        .collect::<Result<_>>()?;
96
97    let weight_vec = if params.weighted {
98        let col = sort_col.ok_or_else(|| {
99            RsomicsError::InvalidInput("weighting requires a sort category".into())
100        })?;
101        let mut w = std::collections::HashMap::new();
102        for sid in &coords.ids {
103            if let Some(row) = meta.rows.get(sid) {
104                let v: f64 = row[col].parse().map_err(|_| {
105                    RsomicsError::InvalidInput("the sort category must be numeric to weight".into())
106                })?;
107                w.insert(sid.clone(), v);
108            }
109        }
110        Some(w)
111    } else {
112        None
113    };
114
115    // _normalize_samples: only samples present in both coords and metadata.
116    let shared: Vec<usize> = coords
117        .ids
118        .iter()
119        .enumerate()
120        .filter(|(_, sid)| meta.rows.contains_key(*sid))
121        .map(|(i, _)| i)
122        .collect();
123    if shared.is_empty() {
124        return Err(RsomicsError::InvalidInput(
125            "coordinates and metadata have no samples in common".into(),
126        ));
127    }
128    let index_of: std::collections::HashMap<&str, usize> = shared
129        .iter()
130        .map(|&i| (coords.ids[i].as_str(), i))
131        .collect();
132
133    let mut categories = Vec::new();
134    for (cat_name, cat_col) in &cat_cols {
135        // _make_groups: bucket shared samples by category value, then realsort
136        // each bucket by the sort_category value (or sample id if none).
137        let mut groups: BTreeMap<String, Vec<String>> = BTreeMap::new();
138        for &i in &shared {
139            let sid = &coords.ids[i];
140            let val = meta.value(sid, *cat_col).to_string();
141            groups.entry(val).or_default().push(sid.clone());
142        }
143        for sids in groups.values_mut() {
144            match sort_col {
145                Some(col) => realsorted_by(sids, |sid| meta.value(sid, col).to_string()),
146                None => realsorted_by(sids, |sid| sid.clone()),
147            }
148        }
149
150        let mut res_by_group = Vec::with_capacity(groups.len());
151        for (gname, sids) in &groups {
152            let traj = group_trajectory(coords, prop, &index_of, sids, params, weight_vec.as_ref());
153            res_by_group.push(GroupResult {
154                name: gname.clone(),
155                trajectory: traj.clone(),
156                mean: mean(&traj),
157            });
158        }
159        categories.push(anova_category(cat_name.clone(), res_by_group));
160    }
161
162    Ok(GradientResult {
163        algorithm: params.algorithm.name(),
164        weighted: params.weighted,
165        categories,
166    })
167}
168
169fn group_trajectory(
170    coords: &Coords,
171    prop: &[f64],
172    index_of: &std::collections::HashMap<&str, usize>,
173    sids: &[String],
174    params: &Params,
175    weight_vec: Option<&std::collections::HashMap<String, f64>>,
176) -> Vec<f64> {
177    let a = params.axes;
178    // trajectories = coords[sids][:axes] * prop[:axes]
179    let mut rows: Vec<Vec<f64>> = sids
180        .iter()
181        .map(|sid| {
182            let r = coords.row(index_of[sid.as_str()]);
183            (0..a).map(|k| r[k] * prop[k]).collect()
184        })
185        .collect();
186
187    if params.weighted && sids.len() > 1 {
188        let w: Vec<f64> = sids.iter().map(|s| weight_vec.unwrap()[s]).collect();
189        if let Some(weighted) = weight_by_vector(&rows, &w) {
190            rows = weighted;
191        }
192    }
193
194    match params.algorithm {
195        Algorithm::Average => average(&rows),
196        Algorithm::Trajectory => trajectory(&rows),
197        Algorithm::FirstDifference => first_difference(&rows),
198        Algorithm::WindowDifference => window_difference(&rows, params.window_size),
199    }
200}
201
202/// skbio `_weight_by_vector`: scale each row i>0 by optimal_gradient/|w[i]-w[i-1]|.
203/// Returns None (leave unweighted) if `w` is not a gradient — a repeated value.
204fn weight_by_vector(rows: &[Vec<f64>], w: &[f64]) -> Option<Vec<Vec<f64>>> {
205    let n = w.len();
206    let mut seen: Vec<f64> = Vec::with_capacity(n);
207    for &v in w {
208        if seen.contains(&v) {
209            return None;
210        }
211        seen.push(v);
212    }
213    if n == 1 {
214        return Some(rows.to_vec());
215    }
216    let (mut lo, mut hi) = (w[0], w[0]);
217    for &v in w {
218        lo = lo.min(v);
219        hi = hi.max(v);
220    }
221    let optimal = (hi - lo) / (n - 1) as f64;
222    let mut out = rows.to_vec();
223    for i in 1..n {
224        let scale = optimal / (w[i] - w[i - 1]).abs();
225        for x in &mut out[i] {
226            *x *= scale;
227        }
228    }
229    Some(out)
230}
231
232fn norm(v: &[f64]) -> f64 {
233    v.iter().map(|x| x * x).sum::<f64>().sqrt()
234}
235
236fn diff_norms(rows: &[Vec<f64>]) -> Vec<f64> {
237    (0..rows.len() - 1)
238        .map(|i| {
239            let d: Vec<f64> = rows[i + 1]
240                .iter()
241                .zip(&rows[i])
242                .map(|(b, a)| b - a)
243                .collect();
244            norm(&d)
245        })
246        .collect()
247}
248
249fn average(rows: &[Vec<f64>]) -> Vec<f64> {
250    let a = rows[0].len();
251    let mut center = vec![0.0; a];
252    for r in rows {
253        for (c, x) in center.iter_mut().zip(r) {
254            *c += x;
255        }
256    }
257    for c in &mut center {
258        *c /= rows.len() as f64;
259    }
260    if rows.len() == 1 {
261        vec![norm(&center)]
262    } else {
263        rows.iter()
264            .map(|r| {
265                let d: Vec<f64> = r.iter().zip(&center).map(|(x, c)| x - c).collect();
266                norm(&d)
267            })
268            .collect()
269    }
270}
271
272fn trajectory(rows: &[Vec<f64>]) -> Vec<f64> {
273    if rows.len() == 1 {
274        vec![norm(&rows[0])]
275    } else {
276        diff_norms(rows)
277    }
278}
279
280fn first_difference(rows: &[Vec<f64>]) -> Vec<f64> {
281    match rows.len() {
282        1 => vec![norm(&rows[0])],
283        2 => {
284            let d: Vec<f64> = rows[1].iter().zip(&rows[0]).map(|(b, a)| b - a).collect();
285            vec![norm(&d)]
286        }
287        _ => {
288            let vn = diff_norms(rows);
289            (0..vn.len() - 1).map(|i| vn[i + 1] - vn[i]).collect()
290        }
291    }
292}
293
294fn window_difference(rows: &[Vec<f64>], window: usize) -> Vec<f64> {
295    match rows.len() {
296        1 => vec![norm(&rows[0])],
297        2 => {
298            let d: Vec<f64> = rows[1].iter().zip(&rows[0]).map(|(b, a)| b - a).collect();
299            vec![norm(&d)]
300        }
301        _ => {
302            let mut vn = diff_norms(rows);
303            if vn.len() <= window {
304                return vn;
305            }
306            let last = *vn.last().unwrap();
307            for _ in 0..window {
308                vn.push(last);
309            }
310            let n = vn.len() - window;
311            (0..n)
312                .map(|i| {
313                    let m: f64 = vn[i + 1..i + 1 + window].iter().sum::<f64>() / window as f64;
314                    m - vn[i]
315                })
316                .collect()
317        }
318    }
319}
320
321fn mean(v: &[f64]) -> f64 {
322    v.iter().sum::<f64>() / v.len() as f64
323}
324
325/// skbio `_ANOVA_trajectories`: run ANOVA across the group trajectories, or
326/// record why it could not (one group, or a group whose trajectory has length 1).
327fn anova_category(category: String, groups: Vec<GroupResult>) -> CategoryResult {
328    if groups.len() == 1 {
329        return CategoryResult {
330            category,
331            probability: None,
332            groups: None,
333            message: Some("Only one value in the group.".into()),
334        };
335    }
336    if groups.iter().any(|g| g.trajectory.len() == 1) {
337        return CategoryResult {
338            category,
339            probability: None,
340            groups: None,
341            message: Some(
342                "This group can not be used. All groups should have more than 1 element.".into(),
343            ),
344        };
345    }
346    let arrays: Vec<Vec<f64>> = groups.iter().map(|g| g.trajectory.clone()).collect();
347    let (_f, p) = f_oneway(&arrays);
348    CategoryResult {
349        category,
350        probability: Some(p),
351        groups: Some(groups),
352        message: None,
353    }
354}