use std::collections::BTreeMap;
use rsomics_common::{Result, RsomicsError};
use crate::anova::f_oneway;
use crate::io::{Coords, Metadata};
use crate::natsort::realsorted_by;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Algorithm {
Trajectory,
Average,
FirstDifference,
WindowDifference,
}
impl Algorithm {
pub fn name(self) -> &'static str {
match self {
Algorithm::Trajectory => "trajectory",
Algorithm::Average => "avg",
Algorithm::FirstDifference => "diff",
Algorithm::WindowDifference => "wdiff",
}
}
}
pub struct GroupResult {
pub name: String,
pub trajectory: Vec<f64>,
pub mean: f64,
}
pub struct CategoryResult {
pub category: String,
pub probability: Option<f64>,
pub groups: Option<Vec<GroupResult>>,
pub message: Option<String>,
}
pub struct GradientResult {
pub algorithm: &'static str,
pub weighted: bool,
pub categories: Vec<CategoryResult>,
}
pub struct Params<'a> {
pub algorithm: Algorithm,
pub trajectory_categories: &'a [String],
pub sort_category: Option<&'a str>,
pub axes: usize,
pub weighted: bool,
pub window_size: usize,
}
pub fn gradient_anova(
coords: &Coords,
prop: &[f64],
meta: &Metadata,
params: &Params,
) -> Result<GradientResult> {
if params.axes == 0 || params.axes > prop.len() {
return Err(RsomicsError::InvalidInput(format!(
"axes must be between 1 and the number of proportions ({}), got {}",
prop.len(),
params.axes
)));
}
if coords.naxes < params.axes {
return Err(RsomicsError::InvalidInput(format!(
"coordinates have {} axes, fewer than the requested {}",
coords.naxes, params.axes
)));
}
let sort_col = params
.sort_category
.map(|c| meta.col_index(c))
.transpose()?;
let cat_cols: Vec<(String, usize)> = params
.trajectory_categories
.iter()
.map(|c| Ok((c.clone(), meta.col_index(c)?)))
.collect::<Result<_>>()?;
let weight_vec = if params.weighted {
let col = sort_col.ok_or_else(|| {
RsomicsError::InvalidInput("weighting requires a sort category".into())
})?;
let mut w = std::collections::HashMap::new();
for sid in &coords.ids {
if let Some(row) = meta.rows.get(sid) {
let v: f64 = row[col].parse().map_err(|_| {
RsomicsError::InvalidInput("the sort category must be numeric to weight".into())
})?;
w.insert(sid.clone(), v);
}
}
Some(w)
} else {
None
};
let shared: Vec<usize> = coords
.ids
.iter()
.enumerate()
.filter(|(_, sid)| meta.rows.contains_key(*sid))
.map(|(i, _)| i)
.collect();
if shared.is_empty() {
return Err(RsomicsError::InvalidInput(
"coordinates and metadata have no samples in common".into(),
));
}
let index_of: std::collections::HashMap<&str, usize> = shared
.iter()
.map(|&i| (coords.ids[i].as_str(), i))
.collect();
let mut categories = Vec::new();
for (cat_name, cat_col) in &cat_cols {
let mut groups: BTreeMap<String, Vec<String>> = BTreeMap::new();
for &i in &shared {
let sid = &coords.ids[i];
let val = meta.value(sid, *cat_col).to_string();
groups.entry(val).or_default().push(sid.clone());
}
for sids in groups.values_mut() {
match sort_col {
Some(col) => realsorted_by(sids, |sid| meta.value(sid, col).to_string()),
None => realsorted_by(sids, |sid| sid.clone()),
}
}
let mut res_by_group = Vec::with_capacity(groups.len());
for (gname, sids) in &groups {
let traj = group_trajectory(coords, prop, &index_of, sids, params, weight_vec.as_ref());
res_by_group.push(GroupResult {
name: gname.clone(),
trajectory: traj.clone(),
mean: mean(&traj),
});
}
categories.push(anova_category(cat_name.clone(), res_by_group));
}
Ok(GradientResult {
algorithm: params.algorithm.name(),
weighted: params.weighted,
categories,
})
}
fn group_trajectory(
coords: &Coords,
prop: &[f64],
index_of: &std::collections::HashMap<&str, usize>,
sids: &[String],
params: &Params,
weight_vec: Option<&std::collections::HashMap<String, f64>>,
) -> Vec<f64> {
let a = params.axes;
let mut rows: Vec<Vec<f64>> = sids
.iter()
.map(|sid| {
let r = coords.row(index_of[sid.as_str()]);
(0..a).map(|k| r[k] * prop[k]).collect()
})
.collect();
if params.weighted && sids.len() > 1 {
let w: Vec<f64> = sids.iter().map(|s| weight_vec.unwrap()[s]).collect();
if let Some(weighted) = weight_by_vector(&rows, &w) {
rows = weighted;
}
}
match params.algorithm {
Algorithm::Average => average(&rows),
Algorithm::Trajectory => trajectory(&rows),
Algorithm::FirstDifference => first_difference(&rows),
Algorithm::WindowDifference => window_difference(&rows, params.window_size),
}
}
fn weight_by_vector(rows: &[Vec<f64>], w: &[f64]) -> Option<Vec<Vec<f64>>> {
let n = w.len();
let mut seen: Vec<f64> = Vec::with_capacity(n);
for &v in w {
if seen.contains(&v) {
return None;
}
seen.push(v);
}
if n == 1 {
return Some(rows.to_vec());
}
let (mut lo, mut hi) = (w[0], w[0]);
for &v in w {
lo = lo.min(v);
hi = hi.max(v);
}
let optimal = (hi - lo) / (n - 1) as f64;
let mut out = rows.to_vec();
for i in 1..n {
let scale = optimal / (w[i] - w[i - 1]).abs();
for x in &mut out[i] {
*x *= scale;
}
}
Some(out)
}
fn norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn diff_norms(rows: &[Vec<f64>]) -> Vec<f64> {
(0..rows.len() - 1)
.map(|i| {
let d: Vec<f64> = rows[i + 1]
.iter()
.zip(&rows[i])
.map(|(b, a)| b - a)
.collect();
norm(&d)
})
.collect()
}
fn average(rows: &[Vec<f64>]) -> Vec<f64> {
let a = rows[0].len();
let mut center = vec![0.0; a];
for r in rows {
for (c, x) in center.iter_mut().zip(r) {
*c += x;
}
}
for c in &mut center {
*c /= rows.len() as f64;
}
if rows.len() == 1 {
vec![norm(¢er)]
} else {
rows.iter()
.map(|r| {
let d: Vec<f64> = r.iter().zip(¢er).map(|(x, c)| x - c).collect();
norm(&d)
})
.collect()
}
}
fn trajectory(rows: &[Vec<f64>]) -> Vec<f64> {
if rows.len() == 1 {
vec![norm(&rows[0])]
} else {
diff_norms(rows)
}
}
fn first_difference(rows: &[Vec<f64>]) -> Vec<f64> {
match rows.len() {
1 => vec![norm(&rows[0])],
2 => {
let d: Vec<f64> = rows[1].iter().zip(&rows[0]).map(|(b, a)| b - a).collect();
vec![norm(&d)]
}
_ => {
let vn = diff_norms(rows);
(0..vn.len() - 1).map(|i| vn[i + 1] - vn[i]).collect()
}
}
}
fn window_difference(rows: &[Vec<f64>], window: usize) -> Vec<f64> {
match rows.len() {
1 => vec![norm(&rows[0])],
2 => {
let d: Vec<f64> = rows[1].iter().zip(&rows[0]).map(|(b, a)| b - a).collect();
vec![norm(&d)]
}
_ => {
let mut vn = diff_norms(rows);
if vn.len() <= window {
return vn;
}
let last = *vn.last().unwrap();
for _ in 0..window {
vn.push(last);
}
let n = vn.len() - window;
(0..n)
.map(|i| {
let m: f64 = vn[i + 1..i + 1 + window].iter().sum::<f64>() / window as f64;
m - vn[i]
})
.collect()
}
}
}
fn mean(v: &[f64]) -> f64 {
v.iter().sum::<f64>() / v.len() as f64
}
fn anova_category(category: String, groups: Vec<GroupResult>) -> CategoryResult {
if groups.len() == 1 {
return CategoryResult {
category,
probability: None,
groups: None,
message: Some("Only one value in the group.".into()),
};
}
if groups.iter().any(|g| g.trajectory.len() == 1) {
return CategoryResult {
category,
probability: None,
groups: None,
message: Some(
"This group can not be used. All groups should have more than 1 element.".into(),
),
};
}
let arrays: Vec<Vec<f64>> = groups.iter().map(|g| g.trajectory.clone()).collect();
let (_f, p) = f_oneway(&arrays);
CategoryResult {
category,
probability: Some(p),
groups: Some(groups),
message: None,
}
}