pub fn apply_inverse_link_vec(eta: &[f64], family_kind: &str) -> Result<Vec<f64>, String> {
let kind = family_kind.trim().to_ascii_lowercase();
let mut out = Vec::with_capacity(eta.len());
match kind.as_str() {
"" | "identity" => out.extend_from_slice(eta),
"logit" => {
for &e in eta {
out.push(if e >= 0.0 {
1.0 / (1.0 + (-e).exp())
} else {
let ex = e.exp();
ex / (1.0 + ex)
});
}
}
"probit" => {
let inv_sqrt2 = 1.0 / std::f64::consts::SQRT_2;
for &e in eta {
out.push(0.5 * statrs::function::erf::erfc(-e * inv_sqrt2));
}
}
"cloglog" => {
for &e in eta {
out.push(-(-e.exp()).exp_m1());
}
}
"log" => {
for &e in eta {
out.push(e.exp());
}
}
other => {
return Err(format!(
"posterior fitted-mean draws on response scale are not wired for \
family_kind={other:?} from the bare string tag; the parameterized \
links (sas, mixture, latent-cloglog, beta-logistic) carry per-fit \
state and must be routed through the serialized `link_spec` (see \
`apply_inverse_link_spec_vec`). access posterior.predict_draws(...).eta \
for link-scale draws or use model.predict(new_data, interval=...) \
for class-specific bands."
));
}
}
Ok(out)
}
pub fn apply_inverse_link_spec_vec(
eta: &[f64],
link: &crate::types::InverseLink,
) -> Result<Vec<f64>, String> {
use crate::types::{InverseLink, StandardLink};
if let InverseLink::Standard(std_link) = link {
let tag = match std_link {
StandardLink::Identity => "identity",
StandardLink::Log => "log",
StandardLink::Logit => "logit",
StandardLink::Probit => "probit",
StandardLink::CLogLog => "cloglog",
};
return apply_inverse_link_vec(eta, tag);
}
let mut out = Vec::with_capacity(eta.len());
for &e in eta {
let (mu, _d1) = crate::solver::mixture_link::inverse_link_mu_d1_for_inverse_link(link, e)
.map_err(|err| {
format!("failed to evaluate parameterized inverse link at eta={e}: {err}")
})?;
out.push(mu);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::{apply_inverse_link_spec_vec, apply_inverse_link_vec};
use crate::solver::mixture_link::inverse_link_mu_d1_for_inverse_link;
use crate::types::{
InverseLink, LatentCLogLogState, LinkComponent, MixtureLinkSpec, StandardLink,
};
#[test]
fn public_log_inverse_link_is_exact_exp_not_solver_clamp() {
let out = apply_inverse_link_vec(&[705.0], "log").expect("log inverse link");
assert_eq!(out.len(), 1);
let exact = 705.0_f64.exp();
assert!(exact.is_finite(), "exp(705) must be representable in f64");
assert_eq!(
out[0], exact,
"public log inverse link must be exact exp(705), not the solver clamp"
);
let clamped = 700.0_f64.exp();
assert!(
out[0] > clamped * 100.0,
"exact exp(705) must exceed the clamped exp(700) by ~exp(5); got {} vs {}",
out[0],
clamped
);
let out = apply_inverse_link_vec(&[-720.0], "log").expect("log inverse link");
let exact = (-720.0_f64).exp();
assert_eq!(
out[0], exact,
"public log inverse link must be exact exp(-720), not the solver clamp"
);
let clamped = (-700.0_f64).exp();
assert!(
out[0] < clamped,
"exact exp(-720) must be strictly below the clamped exp(-700); got {} vs {}",
out[0],
clamped
);
let over = apply_inverse_link_vec(&[710.0], "log").expect("log inverse link");
assert!(
over[0].is_infinite() && over[0] > 0.0,
"exp(710) overflows to +inf under the exact public transform"
);
let under = apply_inverse_link_vec(&[-746.0], "log").expect("log inverse link");
assert_eq!(
under[0], 0.0,
"exp(-746) underflows to exactly 0.0 under the exact public transform"
);
}
#[test]
fn string_tag_refuses_parameterized_links_and_points_at_link_spec() {
for tag in ["sas", "mixture", "latent-cloglog", "beta-logistic"] {
let err = apply_inverse_link_vec(&[0.0, 0.5], tag)
.expect_err("parameterized link must not be evaluable from the bare tag");
assert!(
err.contains("link_spec"),
"refusal for {tag:?} must mention the typed link_spec seam; got {err}"
);
}
}
#[test]
fn spec_path_evaluates_sas_link_bit_identical_to_solver_jet() {
let state = crate::solver::mixture_link::sas_link_state_from_raw(0.7, -0.4)
.expect("valid SAS link state");
let link = InverseLink::Sas(state);
let eta = [-2.0_f64, -0.5, 0.0, 0.5, 2.0, 4.0];
assert!(apply_inverse_link_vec(&eta, "sas").is_err());
let out = apply_inverse_link_spec_vec(&eta, &link).expect("sas spec inverse link");
assert_eq!(out.len(), eta.len());
for (i, &e) in eta.iter().enumerate() {
let (mu, _d1) = inverse_link_mu_d1_for_inverse_link(&link, e).expect("solver jet eval");
assert_eq!(
out[i], mu,
"SAS spec inverse link row {i} must equal the canonical solver mean"
);
assert!(
out[i] > 0.0 && out[i] < 1.0,
"SAS is a binomial inverse link; mu must lie in (0, 1), got {}",
out[i]
);
}
}
#[test]
fn spec_path_evaluates_mixture_link_bit_identical_to_solver_jet() {
let spec = MixtureLinkSpec {
components: vec![LinkComponent::Logit, LinkComponent::Probit],
initial_rho: ndarray::array![0.3],
};
let state = crate::solver::mixture_link::state_fromspec(&spec).expect("mixture state");
let link = InverseLink::Mixture(state);
let eta = [-3.0_f64, -1.0, 0.0, 1.0, 3.0];
assert!(apply_inverse_link_vec(&eta, "mixture").is_err());
let out = apply_inverse_link_spec_vec(&eta, &link).expect("mixture spec inverse link");
for (i, &e) in eta.iter().enumerate() {
let (mu, _d1) = inverse_link_mu_d1_for_inverse_link(&link, e).expect("solver jet eval");
assert_eq!(out[i], mu, "mixture spec inverse link row {i} mismatch");
}
for w in out.windows(2) {
assert!(
w[1] > w[0],
"mixture inverse link must be strictly increasing"
);
}
}
#[test]
fn spec_path_evaluates_latent_cloglog_with_fitted_latent_sd() {
let eta = [-1.0_f64, 0.0, 1.0];
let link_a =
InverseLink::LatentCLogLog(LatentCLogLogState::new(0.5).expect("valid latent SD"));
let link_b =
InverseLink::LatentCLogLog(LatentCLogLogState::new(1.5).expect("valid latent SD"));
assert!(apply_inverse_link_vec(&eta, "latent-cloglog").is_err());
let out_a = apply_inverse_link_spec_vec(&eta, &link_a).expect("latent-cloglog a");
let out_b = apply_inverse_link_spec_vec(&eta, &link_b).expect("latent-cloglog b");
for (i, &e) in eta.iter().enumerate() {
let (mu_a, _) = inverse_link_mu_d1_for_inverse_link(&link_a, e).expect("jet a");
assert_eq!(
out_a[i], mu_a,
"latent-cloglog row {i} must match solver jet"
);
assert!(out_a[i] > 0.0 && out_a[i] < 1.0, "mu in (0,1)");
}
assert!(
out_a
.iter()
.zip(out_b.iter())
.any(|(a, b)| (a - b).abs() > 1e-6),
"different latent SDs must yield different response-scale means"
);
}
#[test]
fn spec_path_matches_string_path_for_standard_links_incl_exact_log() {
let eta = [-1.5_f64, 0.0, 0.8];
for (link, tag) in [
(InverseLink::Standard(StandardLink::Identity), "identity"),
(InverseLink::Standard(StandardLink::Logit), "logit"),
(InverseLink::Standard(StandardLink::Probit), "probit"),
(InverseLink::Standard(StandardLink::CLogLog), "cloglog"),
(InverseLink::Standard(StandardLink::Log), "log"),
] {
let via_spec = apply_inverse_link_spec_vec(&eta, &link).expect("spec");
let via_tag = apply_inverse_link_vec(&eta, tag).expect("tag");
assert_eq!(via_spec, via_tag, "spec vs tag mismatch for {tag}");
}
let log_link = InverseLink::Standard(StandardLink::Log);
let out = apply_inverse_link_spec_vec(&[705.0], &log_link).expect("log spec");
assert_eq!(
out[0],
705.0_f64.exp(),
"Standard(Log) spec path must report exact exp(705), not the solver clamp"
);
}
}