use crate::io::GenoData;
use anyhow::Result;
use itertools::Itertools;
use rand::prelude::IndexedRandom;
use rayon::prelude::*;
use scirs2_integrate::gaussian::gauss_kronrod21;
use scirs2_integrate::quad::{quad, QuadOptions};
use statistical::mean;
use statrs::distribution::{Binomial, Discrete};
use std::f64::consts::PI;
type AlleleDataTuple = (Vec<u64>, Vec<u64>, Vec<f64>, Vec<f64>);
type RangeTuple<'a> = (Vec<&'a Vec<i8>>, Vec<f64>, Vec<u64>, Vec<u64>, Vec<f64>);
struct AlleleFreqs {
pub total_counts1: Vec<u64>,
pub alt_counts1: Vec<u64>,
pub alt_freqs1: Vec<f64>,
pub alt_freqs2: Vec<f64>,
}
pub struct Bisector<'a, T> {
data: &'a [T],
}
impl<'a, T: Ord> Bisector<'a, T> {
pub fn new(data: &'a [T]) -> Self {
Bisector { data }
}
pub fn bisect_left(&self, x: &T) -> usize {
let mut low = 0;
let mut high = self.data.len();
while low < high {
let mid = (low + high) / 2;
if &self.data[mid] < x {
low = mid + 1;
} else {
high = mid;
}
}
low
}
pub fn bisect_right(&self, x: &T) -> usize {
let mut low = 0;
let mut high = self.data.len();
while low < high {
let mid = (low + high) / 2;
if &self.data[mid] <= x {
low = mid + 1;
} else {
high = mid;
}
}
low
}
}
pub struct PartialBisector<'a, T> {
data: &'a [T],
}
impl<'a, T: PartialOrd> PartialBisector<'a, T> {
pub fn new(data: &'a [T]) -> Self {
PartialBisector { data }
}
pub fn bisect_left(&self, x: &T) -> usize {
let mut low = 0;
let mut high = self.data.len();
while low < high {
let mid = (low + high) / 2;
if self.data[mid].partial_cmp(x) == Some(std::cmp::Ordering::Less) {
low = mid + 1;
} else {
high = mid;
}
}
low
}
pub fn bisect_right(&self, x: &T) -> usize {
let mut low = 0;
let mut high = self.data.len();
while low < high {
let mid = (low + high) / 2;
if self.data[mid].partial_cmp(x) != Some(std::cmp::Ordering::Greater) {
low = mid + 1;
} else {
high = mid;
}
}
low
}
}
pub fn est_omega(q1: &[f64], q2: &[f64]) -> Result<f64> {
if q2.contains(&0.0) || q2.contains(&1.0) {
eprintln!("No SNPs in p2 can be fixed.");
std::process::exit(1);
};
let w = mean(
&q1.iter()
.zip(q2)
.map(|(p, q)| ((p - q).powi(2)) / (q * (1f64 - q)))
.collect::<Vec<f64>>(),
);
Ok(w)
}
pub fn var_estimate(w: f64, q2: f64) -> Result<f64> {
Ok(w * (q2 * (1.0_f64 - q2)))
}
fn round_to(x: f64, digits: u32) -> f64 {
let factor = 10_f64.powi(digits as i32);
(x * factor).round() / factor
}
pub fn compute_c(
r: f64,
s: f64,
ne: Option<u64>,
minrd: Option<f64>,
sf: Option<u32>,
) -> Result<f64> {
let ne = ne.unwrap_or(20000) as f64;
let minrd = minrd.unwrap_or(1e-7);
let sf = sf.unwrap_or(5);
if s <= 0.0 {
Ok(1.0)
} else {
let x = -((2.0 * ne).ln()) * (r.max(minrd)) / s;
Ok(round_to(1.0 - (x.exp()), sf))
}
}
fn pdf_scalar(p1: f64, c: f64, p2: f64, var: f64) -> f64 {
let p1 = vec![p1];
let a_term: f64 = (2f64 * PI * var).sqrt().powf(-1.0);
let mut r: Vec<f64> = vec![0f64; p1.len()];
let bisector = PartialBisector::new(&p1);
let left = bisector.bisect_left(&c);
let right = bisector.bisect_right(&(1f64 - c));
let b_term_l = &p1[0..left]
.iter()
.map(|i| (c - i) / (c.powf(2f64)))
.collect::<Vec<f64>>();
let c_term_l = &p1[0..left]
.iter()
.map(|i| (i - (c * p2)).powf(2f64) / (2f64 * c.powf(2f64) * var))
.collect::<Vec<f64>>();
let l_slice = &mut r[..left];
for ((l_i, &b), &c) in l_slice.iter_mut().zip(b_term_l).zip(c_term_l) {
*l_i += a_term * b * (-c).exp();
}
let b_term_r = &p1[right..]
.iter()
.map(|i| (i + c - 1.0_f64) / (c.powf(2f64)))
.collect::<Vec<f64>>();
let c_term_r = &p1[right..]
.iter()
.map(|i| (i + c - 1.0_f64 - (c * p2)).powf(2f64) / (2f64 * c.powf(2f64) * var))
.collect::<Vec<f64>>();
let r_slice = &mut r[right..];
for ((r_i, &b), &c) in r_slice.iter_mut().zip(b_term_r).zip(c_term_r) {
*r_i += a_term * b * (-c).exp();
}
r[0]
}
fn pdf_integral_scalar(p1: f64, xj: u64, nj: u64, c: f64, p2: f64, var: f64) -> f64 {
let dens = pdf_scalar(p1, c, p2, var);
let binom = Binomial::new(p1, nj).expect("Cannot generate binomial distr.");
let logpmf = binom.ln_pmf((xj as i64).try_into().unwrap());
let pmf = logpmf.exp();
dens * pmf
}
fn _integrate_qags_gk_scirs2<F>(
f: F,
a: f64,
b: f64,
epsabs: f64,
epsrel: f64,
limit: usize,
fast: Option<bool>,
) -> (f64, f64)
where
F: Fn(f64) -> f64,
{
let (value, abs_error) = if fast == Some(true) {
let (value, abs_error, _est) = gauss_kronrod21(|x: f64| f(x), a, b);
(value, abs_error)
} else {
let options = QuadOptions {
use_simpson: true, abs_tol: epsabs,
rel_tol: epsrel,
max_evals: limit,
..Default::default()
};
let quadresult = quad(|x: f64| f(x), a, b, Some(options)).expect("Operation failed");
(quadresult.value, quadresult.abs_error)
};
(value, abs_error)
}
fn _compute_chen_likelihood(
xj: u64,
nj: u64,
c: f64,
p2: f64,
var: f64,
fast: Option<bool>,
) -> Result<f64> {
let a = 0.001;
let b = 0.999;
let epsabs = 0.0; let epsrel = 0.001; let limit = 50;
let (like_i, _err_i) = _integrate_qags_gk_scirs2(
|p1| pdf_integral_scalar(p1, xj, nj, c, p2, var),
a,
b,
epsabs,
epsrel,
limit,
fast,
);
let (like_b, _err_b) = _integrate_qags_gk_scirs2(
|p1| pdf_scalar(p1, c, p2, var),
a,
b,
epsabs,
epsrel,
limit,
fast,
);
let ratio = if like_i > 0.0 && like_b > 0.0 {
like_i.ln() - like_b.ln()
} else {
-1800.0
};
Ok(ratio)
}
fn compute_complikelihood(
sc: f64,
xs: &[u64],
ns: &[u64],
(rds, p2freqs, weights, omegas): (&[f64], &[f64], &[f64], &[f64]),
fast: Option<bool>,
) -> Result<f64> {
if !(0.0..1.0).contains(&sc) {
Ok(f64::INFINITY)
} else {
let marginall = itertools::izip!(xs, ns, rds, p2freqs, weights, omegas)
.map(|(xj, nj, r, p2, weight, omega)| {
let var = var_estimate(*omega, *p2).expect("Cannot compute variance");
let c = compute_c(*r, sc, None, None, Some(5_u32)).expect("Cannot compute C");
let cl = _compute_chen_likelihood(*xj, *nj, c, *p2, var, fast)
.expect("Cannot compute the likelihood");
cl * *weight
})
.collect::<Vec<f64>>();
let ml: f64 = marginall.iter().sum();
Ok(-ml)
}
}
fn compute_xpclr(
counts: (&[u64], &[u64]),
rds: &[f64],
p2freqs: &[f64],
weights: &[f64],
omegas: &[f64],
sel_coeffs: &[f64],
fast: Option<bool>,
) -> Result<(f64, f64, f64)> {
let xs = counts.0;
let ns = counts.1;
let mut maximum_li = f64::INFINITY;
let mut maxli_sc = 0.0f64;
let mut null_model_li = f64::INFINITY;
for (counter, sc) in sel_coeffs.iter().enumerate() {
let ll = compute_complikelihood(*sc, xs, ns, (rds, p2freqs, weights, omegas), fast)
.expect("Cannot infer composite likelihood");
if counter == 0 {
null_model_li = ll;
}
log::debug!("compute_xpclr_iter {counter} {sc} {ll} {}", ll < maximum_li);
if ll < maximum_li {
maximum_li = ll;
maxli_sc = *sc;
} else {
break;
}
}
log::debug!("compute_xpclr_final {maximum_li} {null_model_li} {maxli_sc}\n\n");
Ok((-maximum_li, -null_model_li, maxli_sc))
}
fn get_window(
pos: &[usize],
start: usize,
stop: usize,
max_pos_size: usize,
) -> Result<(Vec<usize>, usize)> {
let start_ix = pos.binary_search(&start).unwrap_or_else(|i| i);
let stop_ix = pos.binary_search(&stop).unwrap_or_else(|i| i);
if (stop_ix - start_ix) > max_pos_size {
let mut ix: Vec<usize> = (start_ix..stop_ix)
.collect::<Vec<usize>>()
.choose_multiple(&mut rand::rng(), max_pos_size)
.cloned()
.collect::<Vec<usize>>();
ix.sort();
Ok((ix, stop_ix - start_ix))
} else {
let ix = (start_ix..stop_ix).step_by(1).collect::<Vec<usize>>();
Ok((ix, stop_ix - start_ix))
}
}
fn pair_gt_to_af(
gt1_m: &[Vec<i8>],
gt2_m: &[Vec<i8>],
phased: Option<bool>,
) -> Result<AlleleFreqs> {
let vals: Vec<(u64, u64, f64, f64)> = gt1_m
.iter()
.zip(gt2_m)
.map(|(gts1, gts2)| {
let non_missing1 = gts1.iter().filter(|v| **v >= 0).count() as u64;
let tot_counts1 = 2 * non_missing1; let is_phased = phased.unwrap_or(false);
let tot_counts2 = if is_phased {
gts2.iter().filter(|v| **v >= 0).count() as u64
} else {
2 * (gts2.iter().filter(|v| **v >= 0).count() as u64)
};
let alt_counts1 = gts1
.iter()
.filter(|v| **v >= 0)
.map(|&v| v as u64)
.sum::<u64>() as f64;
let alt_counts2 = if is_phased {
gts2.iter()
.filter(|v| **v >= 0)
.map(|&v| v as u64)
.sum::<u64>() as f64
} else {
gts2.iter()
.filter(|v| **v >= 0)
.map(|&v| v as u64)
.sum::<u64>() as f64
};
(
tot_counts1,
alt_counts1 as u64,
alt_counts1 / (tot_counts1 as f64),
alt_counts2 / (tot_counts2 as f64),
)
})
.collect();
let (total_counts1, alt_counts1, alt_freqs1, alt_freqs2): AlleleDataTuple =
Itertools::multiunzip(vals.into_iter());
Ok(AlleleFreqs {
total_counts1,
alt_counts1,
alt_freqs1,
alt_freqs2,
})
}
fn gn_pairwise_corrcoef_int8(gn: &[Vec<i8>]) -> Result<Vec<Vec<f64>>> {
let n = gn.len();
let gn_sq: Vec<Vec<i8>> = gn
.iter()
.map(|row| row.iter().map(|&v| v * v).collect())
.collect();
let mut out = vec![vec![0.0_f64; n]; n];
for i in 0..(n - 1) {
for j in (i + 1)..n {
let gn0 = &gn[i];
let gn1 = &gn[j];
let gn0_sq = &gn_sq[i];
let gn1_sq = &gn_sq[j];
let r = gn_corrcoef_int8(gn0, gn1, gn0_sq, gn1_sq);
out[i][j] = r.powi(2);
out[j][i] = r.powi(2);
}
}
Ok(out)
}
fn gn_corrcoef_int8(a: &[i8], b: &[i8], a_sq: &[i8], b_sq: &[i8]) -> f64 {
let mut m0: f64 = 0.0;
let mut m1: f64 = 0.0;
let mut v0: f64 = 0.0;
let mut v1: f64 = 0.0;
let mut cov: f64 = 0.0;
let mut n: f64 = 0.0;
for i in 0..a.len() {
let x = a[i];
let y = b[i];
if x >= 0 && y >= 0 {
n += 1.0f64;
m0 += x as f64;
m1 += y as f64;
v0 += a_sq[i] as f64;
v1 += b_sq[i] as f64;
cov += (x * y) as f64;
}
}
if n == 0.0 || v0 == 0.0 || v1 == 0.0 {
return f64::NAN;
}
m0 /= n;
m1 /= n;
v0 /= n;
v1 /= n;
cov /= n;
cov -= m0 * m1;
v0 -= m0 * m0;
v1 -= m1 * m1;
cov / (v0 * v1).sqrt()
}
fn apply_cutoff(matrix: &[Vec<f64>], cutoff: f64) -> Vec<Vec<bool>> {
matrix
.iter()
.map(|row| {
row.iter()
.map(|&val| {
let cond1 = val > cutoff; let cond2 = val.is_nan(); cond1 || cond2 })
.collect()
})
.collect()
}
fn compute_weights(gt_m: Vec<&Vec<i8>>, ldcutoff: f64) -> Result<Vec<f64>> {
let d: Vec<Vec<i8>> = gt_m.iter().map(|row| row.to_vec()).collect();
let ld = gn_pairwise_corrcoef_int8(&d).expect("Cannot compute LD");
let above_cut = apply_cutoff(&ld, ldcutoff);
let weights = above_cut
.iter()
.map(|v| {
let summa: i32 = v
.iter()
.map(|b| match b {
true => 1,
false => 0,
})
.sum();
1_f64 / ((summa + 1) as f64)
})
.collect::<Vec<f64>>();
Ok(weights)
}
pub struct XPCLRResult {
pub window: (usize, usize, usize, usize, usize, usize), pub ll_sel: f64,
pub ll_neut: f64,
pub sel_coeff: f64,
pub xpclr: f64,
}
pub fn xpclr(
g_data: GenoData,
windows: Vec<(usize, usize)>, ldcutoff: Option<f64>,
maxsnps: usize,
minsnps: usize, phased: Option<bool>,
fast: Option<bool>,
) -> Result<Vec<(usize, XPCLRResult)>> {
let sel_coeffs = vec![
0.0, 0.00001, 0.00005, 0.0001, 0.0002, 0.0004, 0.0006, 0.0008, 0.001, 0.003, 0.005, 0.01,
0.05, 0.08, 0.1, 0.15,
];
let ldcutoff = ldcutoff.unwrap_or(0.95f64);
let af_data: AlleleFreqs = pair_gt_to_af(&g_data.gt1, &g_data.gt2, phased)
.expect("Failed to copmute the AF for pop 1");
let w = est_omega(&af_data.alt_freqs1, &af_data.alt_freqs2).expect("Cannot compute omega");
log::info!("Omega: {w}");
let mut results: Vec<(usize, XPCLRResult)> = windows
.par_iter()
.enumerate()
.map(|(n, (start, stop))| {
let (ix, n_avail) = get_window(&g_data.positions, *start, *stop, maxsnps).expect("Cannot find the window");
let n_snps = ix.len();
let max_ix = ix.iter().last().unwrap_or(&0_usize).to_owned();
log::debug!("xpclr Window idx: {n}; Window BP interval: {start}-{stop}; N SNPs selected: {n_snps}; N SNP available: {n_avail}");
if n_snps < minsnps {
let xpclr_win_res = XPCLRResult{
window: (*start, *stop, *start, *stop, n_snps, n_avail),
ll_sel: f64::NAN,
ll_neut: f64::NAN,
sel_coeff: f64::NAN,
xpclr: f64::NAN,
};
(n, xpclr_win_res)
} else {
let bpi = g_data.positions[ix[0]] + 1;
let bpe = g_data.positions[max_ix] + 1;
let (gt_range, gd_range, a1_range, t1_range, p2freqs): RangeTuple =
Itertools::multiunzip(ix.iter().map(|&i| (&g_data.gt2[i], &g_data.gdistances[i], &af_data.alt_counts1[i], &af_data.total_counts1[i], &af_data.alt_freqs2[i])));
let mdist = mean(&gd_range);
let rds = gd_range.iter().map(|d| (d - mdist).abs()).collect::<Vec<f64>>();
let weights = compute_weights(gt_range, ldcutoff).expect("Failed to compute the weights");
let omegas = vec![w; rds.len()];
log::debug!("P2freqs {start} {stop} {} {p2freqs:?}", p2freqs.len());
let xpclr_res = compute_xpclr(
(&a1_range, &t1_range),
&rds,
&p2freqs,
&weights,
&omegas,
&sel_coeffs,
fast
).expect("Failed computing XP-CLR for window");
let xpclr_v = 2.0_f64 * (xpclr_res.0 - xpclr_res.1);
let xpclr_win_res = XPCLRResult{
window: (*start, *stop, bpi, bpe, n_snps, n_avail),
ll_sel: xpclr_res.0,
ll_neut: xpclr_res.1,
sel_coeff: xpclr_res.2,
xpclr: xpclr_v,
};
(n, xpclr_win_res)
}
})
.collect();
results.sort_by_key(|item| item.0);
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol
}
#[test]
fn bisector_basic_indices() {
let data = vec![1, 2, 2, 3, 5];
let b = Bisector::new(&data);
assert_eq!(b.bisect_left(&2), 1);
assert_eq!(b.bisect_right(&2), 3);
assert_eq!(b.bisect_left(&4), 4);
assert_eq!(b.bisect_right(&0), 0);
}
#[test]
fn partial_bisector_with_floats() {
let data = vec![0.1_f64, 0.2, 0.2, 0.5];
let b = PartialBisector::new(&data);
assert_eq!(b.bisect_left(&0.2), 1);
assert_eq!(b.bisect_right(&0.2), 3);
assert_eq!(b.bisect_left(&0.3), 3);
assert_eq!(b.bisect_right(&0.3), 3);
}
#[test]
fn omega_estimation_matches_formula() {
let q1 = vec![0.2_f64, 0.3];
let q2 = vec![0.4_f64, 0.5];
let expected = mean(
&q1.iter()
.zip(&q2)
.map(|(p, q)| ((p - q).powi(2)) / (q * (1.0_f64 - q)))
.collect::<Vec<f64>>(),
);
let w = est_omega(&q1, &q2).expect("omega");
assert!(approx_eq(w, expected, 1e-12));
}
#[test]
fn variance_estimate_simple() {
let w = 2.0_f64;
let q2 = 0.25_f64;
let v = var_estimate(w, q2).expect("variance");
assert!(approx_eq(v, 0.375_f64, 1e-12));
}
#[test]
fn compute_c_bounds_and_rounding() {
let c0 = compute_c(0.01, 0.0, Some(20000), Some(1e-7), Some(5)).expect("compute_c");
assert!(approx_eq(c0, 1.0_f64, 1e-12));
let c = compute_c(0.01, 0.1, Some(20000), Some(1e-7), Some(5)).expect("compute_c");
assert!((0.0_f64..=1.0_f64).contains(&c));
let x = -((2.0_f64 * 20000.0_f64).ln()) * (0.01_f64.max(1e-7)) / 0.1;
let expected = round_to(1.0 - x.exp(), 5);
assert!(approx_eq(c, expected, 1e-12));
}
#[test]
fn pdf_scalar_interval_behavior() {
let dens_left = pdf_scalar(0.05, 0.1, 0.4, 0.02);
let dens_mid = pdf_scalar(0.5, 0.1, 0.4, 0.02);
let dens_right = pdf_scalar(0.95, 0.1, 0.4, 0.02);
assert!(dens_left > 0.0_f64);
assert!(dens_right > 0.0_f64);
assert!(approx_eq(dens_mid, 0.0_f64, 1e-12));
}
}