use hashbrown::HashMap;
use ndarray::{Array1, Axis};
use std::hash::Hash;
use super::math::weighted_mean;
pub fn unique_indices<T: Eq + Hash + Clone>(vec: &[T]) -> HashMap<T, Vec<usize>> {
let mut map = HashMap::new();
for (i, x) in vec.iter().enumerate() {
map.entry(x.clone()).or_insert(Vec::new()).push(i);
}
map
}
pub fn weighted_fold_change<T: Eq + Hash + Clone>(
fold_change: &Array1<f64>,
pvalues: &Array1<f64>,
map: &HashMap<T, Vec<usize>>,
) -> HashMap<T, f64> {
map.iter()
.map(|(k, v)| {
let fc = fold_change.select(Axis(0), v);
let weights = 1.0 - pvalues.select(Axis(0), v);
let weighted_fc = weighted_mean(&fc, &weights);
(k.clone(), weighted_fc)
})
.collect()
}
pub fn aggregate_fold_changes(
gene_names: &[String],
fold_changes: &Array1<f64>,
pvalues: &Array1<f64>,
) -> HashMap<String, f64> {
let map = unique_indices(gene_names);
weighted_fold_change(fold_changes, pvalues, &map)
}
#[cfg(test)]
mod testing {
use hashbrown::HashMap;
use ndarray::Array1;
#[test]
fn test_unique_indices() {
let vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut map = HashMap::new();
map.insert(1, vec![0, 10]);
map.insert(2, vec![1, 11]);
map.insert(3, vec![2, 12]);
map.insert(4, vec![3, 13]);
map.insert(5, vec![4, 14]);
map.insert(6, vec![5, 15]);
map.insert(7, vec![6, 16]);
map.insert(8, vec![7, 17]);
map.insert(9, vec![8, 18]);
map.insert(10, vec![9, 19]);
assert_eq!(map, super::unique_indices(&vec));
}
#[test]
fn test_weighted_fold_change() {
use hashbrown::HashMap;
use ndarray::Array1;
let fc = Array1::from(vec![1., 2., 3., 4.]);
let pvalues = Array1::from(vec![0.1, 0.2, 0.3, 0.4]);
let mut map = HashMap::new();
map.insert(1, vec![0, 2]);
map.insert(2, vec![1, 3]);
let mut expected = HashMap::new();
expected.insert(1, ((1.0 * 0.9) + (3.0 * 0.7)) / 1.6);
expected.insert(2, ((2.0 * 0.8) + (4.0 * 0.6)) / 1.4);
assert_eq!(expected, super::weighted_fold_change(&fc, &pvalues, &map));
}
#[test]
fn test_aggregate_fold_changes() {
let gene_names = vec![
"A".to_string(),
"B".to_string(),
"C".to_string(),
"D".to_string(),
"A".to_string(),
"B".to_string(),
"C".to_string(),
"D".to_string(),
];
let fc = Array1::from(vec![1., 2., 3., 4., 5., 6., 7., 8.]);
let pvalues = Array1::from(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]);
let mut expected = HashMap::new();
expected.insert("A".to_string(), ((1.0 * 0.9) + (5.0 * 0.5)) / 1.4);
expected.insert("B".to_string(), ((2.0 * 0.8) + (6.0 * 0.4)) / 1.2);
expected.insert("C".to_string(), ((3.0 * 0.7) + (7.0 * 0.3)) / 1.0);
expected.insert("D".to_string(), ((4.0 * 0.6) + (8.0 * 0.2)) / 0.8);
let result = super::aggregate_fold_changes(&gene_names, &fc, &pvalues);
for (k, v) in expected.iter() {
let v_hat = result.get(k).unwrap();
assert!((v - v_hat).abs() < 1e-8);
}
}
}