rsomics_gradient_trajectory/
gradient.rs1use std::collections::BTreeMap;
2
3use rsomics_common::{Result, RsomicsError};
4
5use crate::anova::f_oneway;
6use crate::io::{Coords, Metadata};
7use crate::natsort::realsorted_by;
8
9#[derive(Clone, Copy, PartialEq, Eq, Debug)]
10pub enum Algorithm {
11 Trajectory,
13 Average,
15 FirstDifference,
17 WindowDifference,
19}
20
21impl Algorithm {
22 pub fn name(self) -> &'static str {
23 match self {
24 Algorithm::Trajectory => "trajectory",
25 Algorithm::Average => "avg",
26 Algorithm::FirstDifference => "diff",
27 Algorithm::WindowDifference => "wdiff",
28 }
29 }
30}
31
32pub struct GroupResult {
33 pub name: String,
34 pub trajectory: Vec<f64>,
35 pub mean: f64,
36}
37
38pub struct CategoryResult {
39 pub category: String,
40 pub probability: Option<f64>,
41 pub groups: Option<Vec<GroupResult>>,
42 pub message: Option<String>,
43}
44
45pub struct GradientResult {
46 pub algorithm: &'static str,
47 pub weighted: bool,
48 pub categories: Vec<CategoryResult>,
49}
50
51pub struct Params<'a> {
52 pub algorithm: Algorithm,
53 pub trajectory_categories: &'a [String],
54 pub sort_category: Option<&'a str>,
55 pub axes: usize,
56 pub weighted: bool,
57 pub window_size: usize,
58}
59
60pub fn gradient_anova(
69 coords: &Coords,
70 prop: &[f64],
71 meta: &Metadata,
72 params: &Params,
73) -> Result<GradientResult> {
74 if params.axes == 0 || params.axes > prop.len() {
75 return Err(RsomicsError::InvalidInput(format!(
76 "axes must be between 1 and the number of proportions ({}), got {}",
77 prop.len(),
78 params.axes
79 )));
80 }
81 if coords.naxes < params.axes {
82 return Err(RsomicsError::InvalidInput(format!(
83 "coordinates have {} axes, fewer than the requested {}",
84 coords.naxes, params.axes
85 )));
86 }
87 let sort_col = params
88 .sort_category
89 .map(|c| meta.col_index(c))
90 .transpose()?;
91 let cat_cols: Vec<(String, usize)> = params
92 .trajectory_categories
93 .iter()
94 .map(|c| Ok((c.clone(), meta.col_index(c)?)))
95 .collect::<Result<_>>()?;
96
97 let weight_vec = if params.weighted {
98 let col = sort_col.ok_or_else(|| {
99 RsomicsError::InvalidInput("weighting requires a sort category".into())
100 })?;
101 let mut w = std::collections::HashMap::new();
102 for sid in &coords.ids {
103 if let Some(row) = meta.rows.get(sid) {
104 let v: f64 = row[col].parse().map_err(|_| {
105 RsomicsError::InvalidInput("the sort category must be numeric to weight".into())
106 })?;
107 w.insert(sid.clone(), v);
108 }
109 }
110 Some(w)
111 } else {
112 None
113 };
114
115 let shared: Vec<usize> = coords
117 .ids
118 .iter()
119 .enumerate()
120 .filter(|(_, sid)| meta.rows.contains_key(*sid))
121 .map(|(i, _)| i)
122 .collect();
123 if shared.is_empty() {
124 return Err(RsomicsError::InvalidInput(
125 "coordinates and metadata have no samples in common".into(),
126 ));
127 }
128 let index_of: std::collections::HashMap<&str, usize> = shared
129 .iter()
130 .map(|&i| (coords.ids[i].as_str(), i))
131 .collect();
132
133 let mut categories = Vec::new();
134 for (cat_name, cat_col) in &cat_cols {
135 let mut groups: BTreeMap<String, Vec<String>> = BTreeMap::new();
138 for &i in &shared {
139 let sid = &coords.ids[i];
140 let val = meta.value(sid, *cat_col).to_string();
141 groups.entry(val).or_default().push(sid.clone());
142 }
143 for sids in groups.values_mut() {
144 match sort_col {
145 Some(col) => realsorted_by(sids, |sid| meta.value(sid, col).to_string()),
146 None => realsorted_by(sids, |sid| sid.clone()),
147 }
148 }
149
150 let mut res_by_group = Vec::with_capacity(groups.len());
151 for (gname, sids) in &groups {
152 let traj = group_trajectory(coords, prop, &index_of, sids, params, weight_vec.as_ref());
153 res_by_group.push(GroupResult {
154 name: gname.clone(),
155 trajectory: traj.clone(),
156 mean: mean(&traj),
157 });
158 }
159 categories.push(anova_category(cat_name.clone(), res_by_group));
160 }
161
162 Ok(GradientResult {
163 algorithm: params.algorithm.name(),
164 weighted: params.weighted,
165 categories,
166 })
167}
168
169fn group_trajectory(
170 coords: &Coords,
171 prop: &[f64],
172 index_of: &std::collections::HashMap<&str, usize>,
173 sids: &[String],
174 params: &Params,
175 weight_vec: Option<&std::collections::HashMap<String, f64>>,
176) -> Vec<f64> {
177 let a = params.axes;
178 let mut rows: Vec<Vec<f64>> = sids
180 .iter()
181 .map(|sid| {
182 let r = coords.row(index_of[sid.as_str()]);
183 (0..a).map(|k| r[k] * prop[k]).collect()
184 })
185 .collect();
186
187 if params.weighted && sids.len() > 1 {
188 let w: Vec<f64> = sids.iter().map(|s| weight_vec.unwrap()[s]).collect();
189 if let Some(weighted) = weight_by_vector(&rows, &w) {
190 rows = weighted;
191 }
192 }
193
194 match params.algorithm {
195 Algorithm::Average => average(&rows),
196 Algorithm::Trajectory => trajectory(&rows),
197 Algorithm::FirstDifference => first_difference(&rows),
198 Algorithm::WindowDifference => window_difference(&rows, params.window_size),
199 }
200}
201
202fn weight_by_vector(rows: &[Vec<f64>], w: &[f64]) -> Option<Vec<Vec<f64>>> {
205 let n = w.len();
206 let mut seen: Vec<f64> = Vec::with_capacity(n);
207 for &v in w {
208 if seen.contains(&v) {
209 return None;
210 }
211 seen.push(v);
212 }
213 if n == 1 {
214 return Some(rows.to_vec());
215 }
216 let (mut lo, mut hi) = (w[0], w[0]);
217 for &v in w {
218 lo = lo.min(v);
219 hi = hi.max(v);
220 }
221 let optimal = (hi - lo) / (n - 1) as f64;
222 let mut out = rows.to_vec();
223 for i in 1..n {
224 let scale = optimal / (w[i] - w[i - 1]).abs();
225 for x in &mut out[i] {
226 *x *= scale;
227 }
228 }
229 Some(out)
230}
231
232fn norm(v: &[f64]) -> f64 {
233 v.iter().map(|x| x * x).sum::<f64>().sqrt()
234}
235
236fn diff_norms(rows: &[Vec<f64>]) -> Vec<f64> {
237 (0..rows.len() - 1)
238 .map(|i| {
239 let d: Vec<f64> = rows[i + 1]
240 .iter()
241 .zip(&rows[i])
242 .map(|(b, a)| b - a)
243 .collect();
244 norm(&d)
245 })
246 .collect()
247}
248
249fn average(rows: &[Vec<f64>]) -> Vec<f64> {
250 let a = rows[0].len();
251 let mut center = vec![0.0; a];
252 for r in rows {
253 for (c, x) in center.iter_mut().zip(r) {
254 *c += x;
255 }
256 }
257 for c in &mut center {
258 *c /= rows.len() as f64;
259 }
260 if rows.len() == 1 {
261 vec![norm(¢er)]
262 } else {
263 rows.iter()
264 .map(|r| {
265 let d: Vec<f64> = r.iter().zip(¢er).map(|(x, c)| x - c).collect();
266 norm(&d)
267 })
268 .collect()
269 }
270}
271
272fn trajectory(rows: &[Vec<f64>]) -> Vec<f64> {
273 if rows.len() == 1 {
274 vec![norm(&rows[0])]
275 } else {
276 diff_norms(rows)
277 }
278}
279
280fn first_difference(rows: &[Vec<f64>]) -> Vec<f64> {
281 match rows.len() {
282 1 => vec![norm(&rows[0])],
283 2 => {
284 let d: Vec<f64> = rows[1].iter().zip(&rows[0]).map(|(b, a)| b - a).collect();
285 vec![norm(&d)]
286 }
287 _ => {
288 let vn = diff_norms(rows);
289 (0..vn.len() - 1).map(|i| vn[i + 1] - vn[i]).collect()
290 }
291 }
292}
293
294fn window_difference(rows: &[Vec<f64>], window: usize) -> Vec<f64> {
295 match rows.len() {
296 1 => vec![norm(&rows[0])],
297 2 => {
298 let d: Vec<f64> = rows[1].iter().zip(&rows[0]).map(|(b, a)| b - a).collect();
299 vec![norm(&d)]
300 }
301 _ => {
302 let mut vn = diff_norms(rows);
303 if vn.len() <= window {
304 return vn;
305 }
306 let last = *vn.last().unwrap();
307 for _ in 0..window {
308 vn.push(last);
309 }
310 let n = vn.len() - window;
311 (0..n)
312 .map(|i| {
313 let m: f64 = vn[i + 1..i + 1 + window].iter().sum::<f64>() / window as f64;
314 m - vn[i]
315 })
316 .collect()
317 }
318 }
319}
320
321fn mean(v: &[f64]) -> f64 {
322 v.iter().sum::<f64>() / v.len() as f64
323}
324
325fn anova_category(category: String, groups: Vec<GroupResult>) -> CategoryResult {
328 if groups.len() == 1 {
329 return CategoryResult {
330 category,
331 probability: None,
332 groups: None,
333 message: Some("Only one value in the group.".into()),
334 };
335 }
336 if groups.iter().any(|g| g.trajectory.len() == 1) {
337 return CategoryResult {
338 category,
339 probability: None,
340 groups: None,
341 message: Some(
342 "This group can not be used. All groups should have more than 1 element.".into(),
343 ),
344 };
345 }
346 let arrays: Vec<Vec<f64>> = groups.iter().map(|g| g.trajectory.clone()).collect();
347 let (_f, p) = f_oneway(&arrays);
348 CategoryResult {
349 category,
350 probability: Some(p),
351 groups: Some(groups),
352 message: None,
353 }
354}