use anyhow::Result;
#[derive(Debug, Clone)]
pub struct FitResult {
pub beta0: f64,
pub beta1: f64,
pub intercept: f64,
pub slope: f64,
}
impl FitResult {
pub fn predict(&self, log10_rpk: f64) -> f64 {
let eta = self.beta0 + self.beta1 * log10_rpk;
1.0 / (1.0 + (-eta).exp())
}
pub fn predict_rpk(&self, rpk: f64) -> f64 {
if rpk <= 0.0 {
return self.predict(f64::NEG_INFINITY);
}
self.predict(rpk.log10())
}
}
fn binomial_deviance_unit(y: f64, mu: f64) -> f64 {
let mut d = 0.0;
if y > 0.0 {
d += y * (y / mu).ln();
}
if y < 1.0 {
d += (1.0 - y) * ((1.0 - y) / (1.0 - mu)).ln();
}
2.0 * d
}
pub fn duprate_exp_fit(rpk: &[f64], dup_rate: &[f64]) -> Result<FitResult> {
anyhow::ensure!(
rpk.len() == dup_rate.len(),
"rpk and dup_rate must have the same length ({} vs {})",
rpk.len(),
dup_rate.len()
);
let mut x_vals: Vec<f64> = Vec::new();
let mut y_vals: Vec<f64> = Vec::new();
for (r, d) in rpk.iter().zip(dup_rate.iter()) {
if *r > 0.0 && d.is_finite() {
x_vals.push(r.log10());
let y_clamped = d.clamp(1e-10, 1.0 - 1e-10);
y_vals.push(y_clamped);
}
}
let n = x_vals.len();
anyhow::ensure!(n >= 2, "Need at least 2 valid data points for fitting");
let max_iter: usize = 25; let epsilon: f64 = 1e-8;
let mu_start: Vec<f64> = y_vals.iter().map(|&y| (y + 0.5) / 2.0).collect();
let mut sw = 0.0;
let mut swx = 0.0;
let mut swx2 = 0.0;
let mut swz = 0.0;
let mut swxz = 0.0;
for i in 0..n {
let xi = x_vals[i];
let yi = y_vals[i];
let mu = mu_start[i];
let eta = (mu / (1.0 - mu)).ln();
let v = mu * (1.0 - mu);
if v < 1e-20 {
continue;
}
let w = v;
let z = eta + (yi - mu) / v;
sw += w;
swx += w * xi;
swx2 += w * xi * xi;
swz += w * z;
swxz += w * xi * z;
}
let det = sw * swx2 - swx * swx;
anyhow::ensure!(
det.abs() > 1e-30,
"Singular matrix computing initial coefficients"
);
let mut beta0 = (swx2 * swz - swx * swxz) / det;
let mut beta1 = (sw * swxz - swx * swz) / det;
let mut dev = 0.0;
for i in 0..n {
let eta = beta0 + beta1 * x_vals[i];
let mu = (1.0 / (1.0 + (-eta).exp())).clamp(1e-10, 1.0 - 1e-10);
dev += binomial_deviance_unit(y_vals[i], mu);
}
let mut converged = false;
for _iter in 0..max_iter {
let mut sw = 0.0; let mut swx = 0.0; let mut swx2 = 0.0; let mut swz = 0.0; let mut swxz = 0.0;
for i in 0..n {
let xi = x_vals[i];
let yi = y_vals[i];
let eta = beta0 + beta1 * xi;
let mu = 1.0 / (1.0 + (-eta).exp());
let v = mu * (1.0 - mu);
if v < 1e-20 {
continue;
}
let w = v;
let z = eta + (yi - mu) / v;
sw += w;
swx += w * xi;
swx2 += w * xi * xi;
swz += w * z;
swxz += w * xi * z;
}
let det = sw * swx2 - swx * swx;
if det.abs() < 1e-30 {
anyhow::bail!("Singular matrix in IRLS iteration");
}
beta0 = (swx2 * swz - swx * swxz) / det;
beta1 = (sw * swxz - swx * swz) / det;
let dev_old = dev;
dev = 0.0;
for i in 0..n {
let eta = beta0 + beta1 * x_vals[i];
let mu = (1.0 / (1.0 + (-eta).exp())).clamp(1e-10, 1.0 - 1e-10);
dev += binomial_deviance_unit(y_vals[i], mu);
}
if (dev - dev_old).abs() / (0.1 + dev.abs()) < epsilon {
converged = true;
break;
}
}
if !converged {
log::warn!(
"IRLS logistic regression did not converge within {} iterations (epsilon: {:.0e})",
max_iter,
epsilon
);
}
Ok(FitResult {
beta0,
beta1,
intercept: beta0.exp(),
slope: beta1.exp(),
})
}
pub fn compute_rpkm_threshold_rpk(rpk: &[f64], rpkm: &[f64], threshold: f64) -> Option<f64> {
let mut rpk_gt: Option<f64> = None;
for (r, m) in rpk.iter().zip(rpkm.iter()) {
if *m >= threshold && *r > 0.0 {
match rpk_gt {
None => rpk_gt = Some(*r),
Some(current) => {
if *r < current {
rpk_gt = Some(*r);
}
}
}
}
}
rpk_gt
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_predict() {
let fit = FitResult {
beta0: -2.0,
beta1: 1.5,
intercept: (-2.0f64).exp(),
slope: 1.5f64.exp(),
};
let p = fit.predict(0.0);
assert!((p - 0.1192).abs() < 0.01);
let p = fit.predict(100.0);
assert!(p > 0.99);
}
#[test]
fn test_fit_basic() {
let n = 1000;
let mut rpk = Vec::with_capacity(n);
let mut dup_rate = Vec::with_capacity(n);
for i in 0..n {
let log_rpk = -1.0 + 5.0 * (i as f64 / n as f64);
let r = 10.0f64.powf(log_rpk);
let eta = -3.0 + 1.5 * log_rpk;
let p = 1.0 / (1.0 + (-eta).exp());
rpk.push(r);
dup_rate.push(p);
}
let fit = duprate_exp_fit(&rpk, &dup_rate).unwrap();
assert!(
(fit.beta0 - (-3.0)).abs() < 0.1,
"beta0={}, expected -3.0",
fit.beta0
);
assert!(
(fit.beta1 - 1.5).abs() < 0.1,
"beta1={}, expected 1.5",
fit.beta1
);
}
#[test]
fn test_fit_with_nans() {
let rpk = vec![10.0, 100.0, 1000.0, 0.0, 50.0];
let dup_rate = vec![0.1, 0.3, 0.7, f64::NAN, 0.2];
let fit = duprate_exp_fit(&rpk, &dup_rate);
assert!(fit.is_ok());
}
#[test]
fn test_rpkm_threshold() {
let rpk = vec![10.0, 50.0, 100.0, 500.0, 1000.0];
let rpkm = vec![0.1, 0.4, 0.6, 2.0, 5.0];
let thresh = compute_rpkm_threshold_rpk(&rpk, &rpkm, 0.5);
assert_eq!(thresh, Some(100.0));
}
}