use super::fpca::fpca_tolerance_band;
use super::helpers::valid_band_params;
use super::{BandType, ExponentialFamily, ToleranceBand};
use crate::error::FdarError;
use crate::matrix::FdMatrix;
fn apply_link(value: f64, family: ExponentialFamily) -> f64 {
match family {
ExponentialFamily::Gaussian => value,
ExponentialFamily::Binomial => {
let p = value.clamp(1e-10, 1.0 - 1e-10);
(p / (1.0 - p)).ln()
}
ExponentialFamily::Poisson => {
value.max(1e-10).ln()
}
}
}
fn apply_inverse_link(value: f64, family: ExponentialFamily) -> f64 {
match family {
ExponentialFamily::Gaussian => value,
ExponentialFamily::Binomial => {
1.0 / (1.0 + (-value).exp())
}
ExponentialFamily::Poisson => {
value.exp()
}
}
}
fn transform_data(data: &FdMatrix, family: ExponentialFamily) -> FdMatrix {
let (n, m) = data.shape();
let mut out = FdMatrix::zeros(n, m);
for j in 0..m {
for i in 0..n {
out[(i, j)] = apply_link(data[(i, j)], family);
}
}
out
}
fn inverse_link_band(band: &ToleranceBand, family: ExponentialFamily) -> ToleranceBand {
let lower: Vec<f64> = band
.lower
.iter()
.map(|&v| apply_inverse_link(v, family))
.collect();
let upper: Vec<f64> = band
.upper
.iter()
.map(|&v| apply_inverse_link(v, family))
.collect();
let center: Vec<f64> = band
.center
.iter()
.map(|&v| apply_inverse_link(v, family))
.collect();
let half_width: Vec<f64> = upper
.iter()
.zip(lower.iter())
.map(|(&u, &l)| (u - l) / 2.0)
.collect();
ToleranceBand {
lower,
upper,
center,
half_width,
}
}
pub fn exponential_family_tolerance_band(
data: &FdMatrix,
family: ExponentialFamily,
ncomp: usize,
nb: usize,
coverage: f64,
seed: u64,
) -> Result<ToleranceBand, FdarError> {
let (n, m) = data.shape();
if n < 3 || m == 0 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "at least 3 rows and 1 column".to_string(),
actual: format!("{n} x {m}"),
});
}
if !valid_band_params(n, m, ncomp, nb, coverage) {
return Err(FdarError::InvalidParameter {
parameter: "ncomp/nb/coverage",
message: "ncomp and nb must be >= 1, coverage must be in (0, 1)".to_string(),
});
}
let transformed = transform_data(data, family);
let band = fpca_tolerance_band(
&transformed,
ncomp,
nb,
coverage,
BandType::Simultaneous,
seed,
)?;
Ok(inverse_link_band(&band, family))
}