use super::*;
const SCALAR_FAMILY_NAMES_HELP: &str = "auto, gaussian, binomial/bernoulli, \
binomial-logit/bernoulli-logit, binomial-probit/bernoulli-probit, \
binomial-cloglog/bernoulli-cloglog, latent-cloglog-binomial, poisson, gamma, \
beta/beta-regression, tweedie/tw, negative-binomial/negbin/nb, \
royston-parmar, transformation-normal";
pub(crate) fn response_column_kind(data: &Dataset, y_col: usize) -> ResponseColumnKind {
match data.column_kinds.get(y_col) {
Some(ColumnKindTag::Categorical) => ResponseColumnKind::Categorical {
levels: data
.schema
.columns
.get(y_col)
.map(|sc| sc.levels.clone())
.unwrap_or_default(),
},
Some(ColumnKindTag::Binary) => ResponseColumnKind::Binary,
Some(ColumnKindTag::Continuous) | None => ResponseColumnKind::Numeric,
}
}
fn link_legal_for_family(response: &ResponseFamily, link: LinkFunction) -> bool {
match response {
ResponseFamily::Gaussian => matches!(link, LinkFunction::Identity),
ResponseFamily::Poisson
| ResponseFamily::Gamma
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => matches!(link, LinkFunction::Log),
ResponseFamily::Beta { .. } => matches!(link, LinkFunction::Logit),
ResponseFamily::Binomial => matches!(
link,
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic
),
ResponseFamily::RoystonParmar => false,
}
}
fn apply_paren_link(
base: (LikelihoodSpec, bool),
link_str: &str,
name: &str,
) -> Result<(LikelihoodSpec, bool), String> {
let (base_spec, base_pinned) = base;
let link = crate::inference::formula_dsl::parse_linkname(link_str).map_err(|_| {
let reason: String = WorkflowError::InvalidConfig {
reason: format!(
"family '{name}' names an unknown link '{link_str}'; \
use one of identity|log|logit|probit|cloglog|sas|beta-logistic"
),
}
.into();
reason
})?;
if !link_legal_for_family(&base_spec.response, link) {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"link '{}' is not supported for family '{}'",
link.name(),
base_spec.response.name()
),
}
.into());
}
if base_pinned && base_spec.link.link_function() != link {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"family '{}' pins link '{}', which conflicts with requested link '{}'",
base_spec.name(),
base_spec.link.link_function().name(),
link.name(),
),
}
.into());
}
let inverse_link = match link {
LinkFunction::Sas => {
let state = state_from_sasspec(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
})
.map_err(|err| format!("SAS link initial state: {err}"))?;
InverseLink::Sas(state)
}
LinkFunction::BetaLogistic => {
let state = state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
})
.map_err(|err| format!("Beta-Logistic link initial state: {err}"))?;
InverseLink::BetaLogistic(state)
}
standard => InverseLink::Standard(StandardLink::try_from(standard).map_err(|err| {
let reason: String = WorkflowError::InvalidConfig {
reason: format!(
"link '{}' has no state-less representation: {err}",
standard.name()
),
}
.into();
reason
})?),
};
Ok((LikelihoodSpec::new(base_spec.response, inverse_link), true))
}
pub fn resolve_family(
family: Option<&str>,
negative_binomial_theta: Option<f64>,
link_choice: Option<&LinkChoice>,
y: ArrayView1<'_, f64>,
y_kind: ResponseColumnKind,
response_name: &str,
) -> Result<LikelihoodSpec, String> {
let nb_theta = negative_binomial_theta.unwrap_or(1.0);
if !nb_theta.is_finite() || nb_theta <= 0.0 {
return Err(format!(
"negative-binomial theta must be finite and > 0; got {nb_theta}"
));
}
let explicit: Option<(LikelihoodSpec, bool)> = match family {
Some(name) => {
let lowered = name.to_ascii_lowercase().replace('_', "-");
let (head_name, paren_link): (&str, Option<&str>) = if let Some(open) =
lowered.find('(')
&& lowered.ends_with(')')
{
let head = lowered[..open].trim_end_matches('-').trim();
let inner = lowered[open + 1..lowered.len() - 1].trim();
if head.is_empty() || inner.is_empty() {
(lowered.as_str(), None)
} else {
(head, Some(inner))
}
} else {
(lowered.as_str(), None)
};
let resolved = match head_name {
"gaussian" => (
LikelihoodSpec::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
),
false,
),
"binomial" | "bernoulli" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
),
false,
),
"binomial-logit" | "bernoulli-logit" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
),
true,
),
"binomial-probit" | "bernoulli-probit" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Probit),
),
true,
),
"binomial-cloglog" | "bernoulli-cloglog" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::CLogLog),
),
true,
),
"latent-cloglog-binomial" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::LatentCLogLog(
LatentCLogLogState::new(1.0)
.map_err(|err| format!("latent cloglog default state: {err}"))?,
),
),
true,
),
"poisson" => (
LikelihoodSpec::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Log),
),
false,
),
"nb" | "negbin" | "negative-binomial" => (
LikelihoodSpec::new(
ResponseFamily::NegativeBinomial {
theta: nb_theta,
theta_fixed: negative_binomial_theta.is_some(),
},
InverseLink::Standard(StandardLink::Log),
),
false,
),
"negative-binomial-log" => (
LikelihoodSpec::new(
ResponseFamily::NegativeBinomial {
theta: nb_theta,
theta_fixed: negative_binomial_theta.is_some(),
},
InverseLink::Standard(StandardLink::Log),
),
true,
),
"beta" | "beta-regression" => (
LikelihoodSpec::new(
ResponseFamily::Beta { phi: 1.0 },
InverseLink::Standard(StandardLink::Logit),
),
false,
),
"beta-logit" | "beta-regression-logit" => (
LikelihoodSpec::new(
ResponseFamily::Beta { phi: 1.0 },
InverseLink::Standard(StandardLink::Logit),
),
true,
),
"gamma" => (
LikelihoodSpec::new(
ResponseFamily::Gamma,
InverseLink::Standard(StandardLink::Log),
),
false,
),
"royston-parmar" => (LikelihoodSpec::royston_parmar(), true),
"transformation-normal" => (
LikelihoodSpec::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
),
true,
),
"tweedie" | "tw" => (
LikelihoodSpec::new(
ResponseFamily::Tweedie { p: 1.5 },
InverseLink::Standard(StandardLink::Log),
),
false,
),
"tweedie-log" => (
LikelihoodSpec::new(
ResponseFamily::Tweedie { p: 1.5 },
InverseLink::Standard(StandardLink::Log),
),
true,
),
"multinomial" | "multinomial-logit" | "categorical" | "categorical-logit"
| "softmax" => {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"family '{name}' is a vector-response family; use \
the dedicated multinomial entry point \
(`crate::families::multinomial::fit_penalized_multinomial` \
in Rust, or `gamfit.fit_multinomial(...)` in Python) \
rather than the scalar `fit(family=...)` path"
),
}
.into());
}
_ => {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"unknown family '{name}'; expected one of: {SCALAR_FAMILY_NAMES_HELP}"
),
}
.into());
}
};
let resolved = match paren_link {
Some(link_str) => apply_paren_link(resolved, link_str, name)?,
None => resolved,
};
Some(resolved)
}
None => {
if negative_binomial_theta.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "negative_binomial_theta requires family='negative-binomial'"
.to_string(),
}
.into());
}
None
}
};
if let Some(choice) = link_choice {
let from_link: LikelihoodSpec = if let Some(components) = choice.mixture_components.as_ref()
{
let n = components.len();
let free = n.saturating_sub(1);
let mix_spec = MixtureLinkSpec {
components: components.clone(),
initial_rho: Array1::<f64>::zeros(free),
};
let state = state_fromspec(&mix_spec)
.map_err(|err| format!("mixture link initial state: {err}"))?;
LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Mixture(state))
} else {
match choice.link {
LinkFunction::Identity => LikelihoodSpec::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
),
LinkFunction::Log => {
if y.iter()
.all(|&yi| yi.is_finite() && yi >= 0.0 && (yi - yi.round()).abs() <= 1e-9)
{
LikelihoodSpec::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Log),
)
} else {
LikelihoodSpec::new(
ResponseFamily::Gamma,
InverseLink::Standard(StandardLink::Log),
)
}
}
LinkFunction::Logit => LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
),
LinkFunction::Probit => LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Probit),
),
LinkFunction::CLogLog => LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::CLogLog),
),
LinkFunction::Sas => {
let state = state_from_sasspec(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
})
.map_err(|err| format!("SAS link initial state: {err}"))?;
LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Sas(state))
}
LinkFunction::BetaLogistic => {
let state = state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
})
.map_err(|err| format!("Beta-Logistic link initial state: {err}"))?;
LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::BetaLogistic(state))
}
}
};
if let Some((explicit_spec, link_pinned)) = explicit.as_ref() {
if matches!(
choice.mode,
crate::inference::formula_dsl::LinkMode::Flexible
) && !matches!(explicit_spec.response, ResponseFamily::Binomial)
{
return Err(WorkflowError::InvalidConfig {
reason: format!(
"flexible(...) links (the jointly-fit anchored spline link offset) are \
implemented only for a binomial response; the resolved family is {} (a \
non-binomial family), for which the link offset has no solver and would \
otherwise be silently discarded. Use the plain base link, or fit a binomial \
response.",
explicit_spec.pretty_name()
),
}
.into());
}
let mixture_requested = choice.mixture_components.is_some();
let legal = if mixture_requested {
matches!(explicit_spec.response, ResponseFamily::Binomial)
} else {
link_legal_for_family(&explicit_spec.response, choice.link)
};
if !legal {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"link '{}' is not supported for family '{}'",
choice.link.name(),
explicit_spec.response.name()
),
}
.into());
}
if *link_pinned && explicit_spec.link.link_function() != from_link.link.link_function()
{
return Err(WorkflowError::InvalidConfig {
reason: format!(
"family '{}' pins link '{}', which conflicts with requested link '{}'",
explicit_spec.name(),
explicit_spec.link.link_function().name(),
choice.link.name(),
),
}
.into());
}
return Ok(LikelihoodSpec::new(
explicit_spec.response.clone(),
from_link.link,
));
}
return Ok(from_link);
}
if let Some((spec, _)) = explicit {
return Ok(spec);
}
let response = ResponseFamily::infer_from_response(y, y_kind).map_err(|refusal| {
let err: String = WorkflowError::InvalidConfig {
reason: refusal.message_for(response_name),
}
.into();
err
})?;
let link = match response {
ResponseFamily::Binomial => InverseLink::Standard(StandardLink::Logit),
ResponseFamily::Poisson => InverseLink::Standard(StandardLink::Log),
_ => InverseLink::Standard(StandardLink::Identity),
};
Ok(LikelihoodSpec::new(response, link))
}