use super::{
utils::{filter_zeros, num_unique, set_alpha_threshold},
AggregationResult, GeneAggregation,
};
use crate::{
enrich::EnrichmentResult,
utils::{agg::aggregate_fold_changes, logging::Logger},
};
use adjustp::Procedure;
use alpha_rra::AlphaRRA;
use intc::{fdr::Direction, Inc};
use ndarray::Array1;
struct InternalAggregationResult {
genes: Vec<String>,
logfc: Array1<f64>,
scores_low: Array1<f64>,
pvalues_low: Array1<f64>,
correction_low: Array1<f64>,
scores_high: Array1<f64>,
pvalues_high: Array1<f64>,
correction_high: Array1<f64>,
threshold_low: Option<f64>,
threshold_high: Option<f64>,
}
impl InternalAggregationResult {
pub fn new(
genes: Vec<String>,
logfc: Array1<f64>,
scores_low: Array1<f64>,
pvalues_low: Array1<f64>,
correction_low: Array1<f64>,
scores_high: Array1<f64>,
pvalues_high: Array1<f64>,
correction_high: Array1<f64>,
threshold_low: Option<f64>,
threshold_high: Option<f64>,
) -> Self {
Self {
genes,
logfc,
scores_low,
pvalues_low,
correction_low,
scores_high,
pvalues_high,
correction_high,
threshold_low,
threshold_high,
}
}
}
fn run_rra(
pvalue_low: &Array1<f64>,
pvalue_high: &Array1<f64>,
logfc: &Array1<f64>,
gene_names: &[String],
alpha: f64,
adjust_alpha: bool,
npermutations: usize,
correction: Procedure,
logger: &Logger,
) -> InternalAggregationResult {
let (alpha_low, alpha_high) = set_alpha_threshold(pvalue_low, pvalue_high, alpha, adjust_alpha);
logger.report_rra_alpha(alpha_low, alpha_high);
let alpha_rra_low = AlphaRRA::new(gene_names, alpha_low, npermutations, correction);
let permutation_sizes_low = alpha_rra_low
.permutation_vectors()
.keys()
.copied()
.collect::<Vec<usize>>();
logger.permutation_sizes(&permutation_sizes_low);
let result_low = alpha_rra_low
.run(pvalue_low)
.expect("Error in RRA fit for depleted pvalues");
let alpha_rra_high = AlphaRRA::new(gene_names, alpha_high, npermutations, correction);
let permutation_sizes_high = alpha_rra_high
.permutation_vectors()
.keys()
.copied()
.collect::<Vec<usize>>();
logger.permutation_sizes(&permutation_sizes_high);
let result_high = alpha_rra_high
.run(pvalue_high)
.expect("Error in RRA fit for enriched pvalues");
let gene_fc_hashmap = aggregate_fold_changes(gene_names, logfc);
let gene_fc = result_low
.names()
.iter()
.map(|gene| gene_fc_hashmap.get(gene).unwrap_or(&0.0))
.copied()
.collect();
InternalAggregationResult::new(
result_low.names().to_vec(),
gene_fc,
result_low.scores().to_owned(),
result_low.pvalues().to_owned(),
result_low.adj_pvalues().to_owned(),
result_high.scores().to_owned(),
result_high.pvalues().to_owned(),
result_high.adj_pvalues().to_owned(),
None,
None,
)
}
fn run_inc(
pvalue_low: &Array1<f64>,
pvalue_high: &Array1<f64>,
log2_fold_change: &Array1<f64>,
gene_names: &[String],
token: &str,
fdr: f64,
group_size: usize,
num_genes: usize,
use_product: bool,
seed: u64,
logger: &Logger,
) -> InternalAggregationResult {
logger.report_inc_params(token, num_genes, fdr, group_size);
let (dir_low, dir_high) = if use_product {
(Some(Direction::Less), Some(Direction::Greater))
} else {
(None, None)
};
let result_low = Inc::new(
pvalue_low,
log2_fold_change,
gene_names,
token,
num_genes,
group_size,
fdr,
intc::mwu::Alternative::Less,
true,
dir_low,
Some(seed),
)
.fit()
.expect("Error calculating INC on low pvalues");
logger.report_inc_low_threshold(result_low.threshold(), use_product);
let result_high = Inc::new(
pvalue_high,
log2_fold_change,
gene_names,
token,
num_genes,
group_size,
fdr,
intc::mwu::Alternative::Less,
true,
dir_high,
Some(seed),
)
.fit()
.expect("Error calculating INC on high pvalues");
logger.report_inc_high_threshold(result_high.threshold(), use_product);
InternalAggregationResult::new(
result_low.genes().to_vec(),
result_low.logfc().to_owned(),
result_low.u_scores().to_owned(),
result_low.u_pvalues().to_owned(),
result_low.fdr().to_owned(),
result_high.u_scores().to_owned(),
result_high.u_pvalues().to_owned(),
result_high.fdr().to_owned(),
Some(result_low.threshold()),
Some(result_high.threshold()),
)
}
pub fn compute_aggregation(
agg: &GeneAggregation,
sgrna_results: &EnrichmentResult,
gene_names: &[String],
logger: &Logger,
correction: Procedure,
seed: u64,
) -> AggregationResult {
logger.start_gene_aggregation();
let num_genes = num_unique(gene_names);
let (
passing_gene_names,
passing_sgrna_pvalues_low,
passing_sgrna_pvalues_high,
passing_sgrna_logfc,
) = filter_zeros(
sgrna_results.base_means(),
gene_names,
sgrna_results.pvalues_low(),
sgrna_results.pvalues_high(),
sgrna_results.log_fold_change(),
logger,
);
let agg_result = match agg {
GeneAggregation::AlpaRRA {
alpha,
npermutations,
adjust_alpha,
fdr: _,
} => run_rra(
&passing_sgrna_pvalues_low,
&passing_sgrna_pvalues_high,
&passing_sgrna_logfc,
&passing_gene_names,
*alpha,
*adjust_alpha,
*npermutations,
correction,
logger,
),
GeneAggregation::Inc {
token,
fdr,
group_size,
use_product,
} => run_inc(
&passing_sgrna_pvalues_low,
&passing_sgrna_pvalues_high,
&passing_sgrna_logfc,
&passing_gene_names,
token,
*fdr,
*group_size,
num_genes,
*use_product,
seed,
logger,
),
};
let fold_change = agg_result
.logfc
.iter()
.map(|x| x.exp2())
.collect::<Array1<f64>>();
AggregationResult::new(
agg_result.genes,
fold_change,
agg_result.pvalues_low,
agg_result.pvalues_high,
agg_result.correction_low,
agg_result.correction_high,
agg_result.scores_low,
agg_result.scores_high,
agg_result.threshold_low,
agg_result.threshold_high,
)
}