use crate::{LocalFit, LocalRegressionConfig, LocfitError, PredictionMethod};
#[derive(Clone, Debug)]
pub struct Deseq2LocalDispersionTrend {
kind: TrendKind,
}
#[derive(Clone, Debug)]
enum TrendKind {
Constant { value: f64 },
Local { fit: LocalFit },
}
impl Deseq2LocalDispersionTrend {
pub fn fit(means: &[f64], disps: &[f64], min_disp: f64) -> Result<Self, LocfitError> {
if means.len() != disps.len() {
return Err(LocfitError::LengthMismatch {
x: means.len(),
y: disps.len(),
weights: None,
});
}
if means.is_empty() {
return Err(LocfitError::EmptyInput);
}
if !min_disp.is_finite() || min_disp <= 0.0 {
return Err(LocfitError::InvalidInput(
"min_disp must be finite and greater than zero".to_string(),
));
}
let mut x = Vec::new();
let mut y = Vec::new();
let mut weights = Vec::new();
let threshold = min_disp * 10.0;
for (index, (&mean, &disp)) in means.iter().zip(disps).enumerate() {
if !mean.is_finite() || mean <= 0.0 {
return Err(LocfitError::InvalidInput(format!(
"mean at index {index} must be finite and greater than zero"
)));
}
if !disp.is_finite() || disp <= 0.0 {
return Err(LocfitError::InvalidInput(format!(
"dispersion at index {index} must be finite and greater than zero"
)));
}
if disp >= threshold {
x.push(mean.ln());
y.push(disp.ln());
weights.push(mean);
}
}
Self::fit_validated_logs(&x, &y, &weights, min_disp)
}
pub fn fit_from_logs(
log_means: &[f64],
log_disps: &[f64],
means: &[f64],
min_disp: f64,
) -> Result<Self, LocfitError> {
if log_means.len() != log_disps.len() || means.len() != log_means.len() {
return Err(LocfitError::LengthMismatch {
x: log_means.len(),
y: log_disps.len(),
weights: Some(means.len()),
});
}
if log_means.is_empty() {
return Err(LocfitError::EmptyInput);
}
if !min_disp.is_finite() || min_disp <= 0.0 {
return Err(LocfitError::InvalidInput(
"min_disp must be finite and greater than zero".to_string(),
));
}
let threshold = (min_disp * 10.0).ln();
let mut x = Vec::new();
let mut y = Vec::new();
let mut weights = Vec::new();
for (index, ((&log_mean, &log_disp), &mean)) in
log_means.iter().zip(log_disps).zip(means).enumerate()
{
if !log_mean.is_finite() {
return Err(LocfitError::InvalidInput(format!(
"log mean at index {index} must be finite"
)));
}
if !log_disp.is_finite() {
return Err(LocfitError::InvalidInput(format!(
"log dispersion at index {index} must be finite"
)));
}
if !mean.is_finite() || mean <= 0.0 {
return Err(LocfitError::InvalidInput(format!(
"mean at index {index} must be finite and greater than zero"
)));
}
if log_disp >= threshold {
x.push(log_mean);
y.push(log_disp);
weights.push(mean);
}
}
Self::fit_validated_logs(&x, &y, &weights, min_disp)
}
pub fn predict_one(&self, mean: f64) -> Result<f64, LocfitError> {
validate_prediction_mean(mean)?;
match &self.kind {
TrendKind::Constant { value } => Ok(*value),
TrendKind::Local { .. } => self.predict_log_dispersion_one(mean.ln()).map(f64::exp),
}
}
pub fn predict_log_dispersion_one(&self, log_mean: f64) -> Result<f64, LocfitError> {
validate_prediction_log_mean(log_mean)?;
match &self.kind {
TrendKind::Constant { value } => Ok(value.ln()),
TrendKind::Local { fit } => fit.predict_one(log_mean),
}
}
pub fn predict(&self, means: &[f64]) -> Result<Vec<f64>, LocfitError> {
means.iter().map(|&mean| self.predict_one(mean)).collect()
}
pub fn predict_log_dispersion(&self, log_means: &[f64]) -> Result<Vec<f64>, LocfitError> {
log_means
.iter()
.map(|&log_mean| self.predict_log_dispersion_one(log_mean))
.collect()
}
fn fit_validated_logs(
log_means: &[f64],
log_disps: &[f64],
weights: &[f64],
min_disp: f64,
) -> Result<Self, LocfitError> {
if log_means.is_empty() {
return Ok(Self {
kind: TrendKind::Constant { value: min_disp },
});
}
let fit = LocalFit::fit(
log_means,
log_disps,
Some(weights),
LocalRegressionConfig {
prediction_method: PredictionMethod::LocfitHermiteApprox,
..LocalRegressionConfig::default()
},
)?;
Ok(Self {
kind: TrendKind::Local { fit },
})
}
}
pub fn fit_deseq2_local_dispersion_trend(
means: &[f64],
disps: &[f64],
min_disp: f64,
) -> Result<Deseq2LocalDispersionTrend, LocfitError> {
Deseq2LocalDispersionTrend::fit(means, disps, min_disp)
}
pub fn fit_deseq2_local_dispersion_trend_from_logs(
log_means: &[f64],
log_disps: &[f64],
means: &[f64],
min_disp: f64,
) -> Result<Deseq2LocalDispersionTrend, LocfitError> {
Deseq2LocalDispersionTrend::fit_from_logs(log_means, log_disps, means, min_disp)
}
fn validate_prediction_mean(mean: f64) -> Result<(), LocfitError> {
if !mean.is_finite() || mean <= 0.0 {
return Err(LocfitError::InvalidInput(
"prediction mean must be finite and greater than zero".to_string(),
));
}
Ok(())
}
fn validate_prediction_log_mean(log_mean: f64) -> Result<(), LocfitError> {
if !log_mean.is_finite() {
return Err(LocfitError::InvalidInput(
"prediction log mean must be finite".to_string(),
));
}
Ok(())
}