ndarray_glm/link.rs
1//! Traits and utilities for link functions.
2//!
3//! The link function $`g`$ maps the expected response $`\mu`$ to the linear predictor
4//! $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$. Each family defaults to its canonical
5//! link, but an alternative can be selected via the family's type parameter.
6//!
7//! # Using a provided non-canonical link
8//!
9//! Alternative links are re-exported for convenience: [`exp_link`](crate::exp_link) for
10//! exponential regression and [`logistic_link`](crate::logistic_link) for logistic regression.
11//! Provide a link as the family's type parameter:
12//!
13//! ```
14//! use ndarray_glm::{Exponential, ModelBuilder, array, exp_link::Log};
15//!
16//! fn main() -> ndarray_glm::error::RegressionResult<(), f64> {
17//! let data_y = array![1.0, 2.5, 0.8, 3.1];
18//! let data_x = array![[0.0], [1.0], [0.5], [1.5]];
19//! // Use the log link instead of the default negative-reciprocal canonical link.
20//! let model = ModelBuilder::<Exponential<Log>>::data(&data_y, &data_x).build()?;
21//! let fit = model.fit()?;
22//! Ok(())
23//! }
24//! ```
25//!
26//! # Implementing a custom non-canonical link
27//!
28//! A non-canonical link requires two trait implementations:
29//!
30//! 1. [`Link<M>`] — the forward map $`g(\mu) = \omega`$ ([`Link::func`]) and its inverse
31//! $`g^{-1}(\omega) = \mu`$ ([`Link::func_inv`]).
32//! 2. [`Transform`] — the natural-parameter transformation
33//! $`\eta(\omega) = g_0(g^{-1}(\omega))`$ ([`Transform::nat_param`]) and its derivative
34//! ([`Transform::d_nat_param`]), where $`g_0`$ is the family's canonical link. The derivative
35//! satisfies $`\eta'(\omega) = \frac{1}{g'(\mu)\,V(\mu)}`$ where $`V(\mu)`$ is the family's
36//! variance function evaluated at $`\mu = g^{-1}(\omega)`$.
37//!
38//! Example: a square-root link $`g(\mu) = \sqrt{\mu}`$ for Poisson regression. The canonical
39//! link is $`\log`$ and $`V(\mu) = \mu`$, so
40//! $`\eta(\omega) = \log(\omega^2) = 2\log\omega`$ and $`\eta'(\omega) = 2/\omega`$:
41//!
42//! ```
43//! use ndarray_glm::{Poisson, link::{Link, Transform}, num::Float};
44//! use ndarray::Array1;
45//!
46//! pub struct Sqrt;
47//!
48//! impl Link<Poisson<Sqrt>> for Sqrt {
49//! fn func<F: Float>(mu: F) -> F { num_traits::Float::sqrt(mu) }
50//! fn func_inv<F: Float>(omega: F) -> F { omega * omega }
51//! }
52//!
53//! impl Transform for Sqrt {
54//! fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
55//! lin_pred.mapv(|w| F::two() * num_traits::Float::ln(w))
56//! }
57//! fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
58//! lin_pred.mapv(|w| F::two() / w)
59//! }
60//! }
61//! ```
62//!
63//! # Consistency tests with `TestLink`
64//!
65//! The `TestLink` trait (available only in `#[cfg(test)]` builds) provides canned assertions
66//! that every correct link implementation should satisfy. Call them from your test module:
67//!
68//! ```no_run
69//! #[cfg(test)]
70//! mod tests {
71//! use super::*;
72//! use ndarray_glm::link::TestLink;
73//! use ndarray::array;
74//!
75//! #[test]
76//! fn sqrt_link_checks() {
77//! // Linear-predictor values; must lie in the domain of ω (ω > 0 for sqrt).
78//! let lin_vals = array![0.25, 1.0, 2.0, 4.0, 9.0];
79//!
80//! // Verify g(g⁻¹(ω)) ≈ ω.
81//! Sqrt::check_closure(&lin_vals);
82//!
83//! // Verify g⁻¹(g(μ)) ≈ μ; values must lie in the response domain (μ > 0 for Poisson).
84//! Sqrt::check_closure_y(&array![0.5, 1.0, 3.0, 10.0]);
85//!
86//! // For non-canonical links: verify nat_param(ω) = g₀(g⁻¹(ω)).
87//! // Pass the *canonical* model variant as `Mc`. `Poisson` without a type parameter
88//! // defaults to the canonical log link.
89//! use ndarray_glm::Poisson;
90//! Sqrt::check_nat_par::<Poisson>(&lin_vals);
91//!
92//! // Verify d_nat_param matches the numerical derivative.
93//! Sqrt::check_nat_par_d(&lin_vals);
94//! }
95//! }
96//! ```
97
98use crate::{glm::Glm, num::Float};
99use ndarray::Array1;
100
101/// Describes the link function $`g`$ that maps between the expected response $`\mu`$ and
102/// the linear predictor $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$:
103///
104/// ```math
105/// g(\mu) = \omega, \qquad \mu = g^{-1}(\omega)
106/// ```
107pub trait Link<M: Glm>: Transform {
108 /// Maps the expectation value of the response variable to the linear
109 /// predictor. In general this is determined by a composition of the inverse
110 /// natural parameter transformation and the canonical link function.
111 fn func<F: Float>(y: F) -> F;
112 /// Maps the linear predictor to the expectation value of the response.
113 fn func_inv<F: Float>(lin_pred: F) -> F;
114}
115
116/// Establishes the relationship between the linear predictor $`\omega =
117/// \mathbf{x}\cdot\boldsymbol\beta`$ and the natural parameter $`\eta`$.
118pub trait Transform {
119 /// The natural parameter of the response distribution as a function
120 /// of the linear predictor: $`\eta(\omega) = g_0(g^{-1}(\omega))`$ where $`g_0`$ is the
121 /// canonical link. For canonical links this is the identity.
122 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
123 /// The derivative $`\eta'(\omega)`$ of the transformation to the natural parameter.
124 /// If it is zero in a region that the IRLS is in, the algorithm may have difficulty
125 /// converging.
126 /// It is given in terms of the link and variance functions as $`\eta'(\omega_i) =
127 /// \frac{1}{g'(\mu_i) V(\mu_i)}`$.
128 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F>;
129 /// Adjust the error/residual terms of the likelihood function based on the first derivative of
130 /// the transformation. The linear predictor must be un-transformed, i.e. it must be X*beta
131 /// without the transformation applied.
132 fn adjust_errors<F: Float>(errors: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
133 let eta_d = Self::d_nat_param(lin_pred);
134 eta_d * errors
135 }
136 /// Adjust the variance terms of the likelihood function based on the first and second
137 /// derivatives of the transformation. The linear predictor must be un-transformed, i.e. it
138 /// must be X*beta without the transformation applied.
139 fn adjust_variance<F: Float>(variance: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
140 let eta_d = Self::d_nat_param(lin_pred);
141 // The second-derivative term in the variance matrix can lead it to not
142 // be positive-definite. In fact, the second term should vanish when
143 // taking the expecation of Y to give the Fisher information.
144 // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
145 &eta_d * &variance * eta_d
146 }
147 /// Adjust the error and variance terms of the likelihood function based on
148 /// the first and second derivatives of the transformation. The adjustment
149 /// is performed simultaneously. The linear predictor must be
150 /// un-transformed, i.e. it must be X*beta without the transformation
151 /// applied.
152 fn adjust_errors_variance<F: Float>(
153 errors: Array1<F>,
154 variance: Array1<F>,
155 lin_pred: &Array1<F>,
156 ) -> (Array1<F>, Array1<F>) {
157 let eta_d = Self::d_nat_param(lin_pred);
158 let err_adj = &eta_d * &errors;
159 // The second-derivative term in the variance matrix can lead it to not
160 // be positive-definite. In fact, the second term should vanish when
161 // taking the expecation of Y to give the Fisher information.
162 // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
163 let var_adj = &eta_d * &variance * eta_d;
164 (err_adj, var_adj)
165 }
166}
167
168/// The canonical transformation by definition equates the linear predictor with
169/// the natural parameter of the response distribution. Implementing this trait
170/// for a link function automatically defines the trivial transformation
171/// functions.
172pub trait Canonical {}
173impl<T> Transform for T
174where
175 T: Canonical,
176{
177 /// By defintion this function is the identity function for canonical links.
178 #[inline]
179 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
180 lin_pred
181 }
182 #[inline]
183 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
184 Array1::<F>::ones(lin_pred.len())
185 }
186 /// The canonical link function requires no transformation of the error and variance terms.
187 #[inline]
188 fn adjust_errors<F: Float>(errors: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
189 errors
190 }
191 #[inline]
192 fn adjust_variance<F: Float>(variance: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
193 variance
194 }
195 #[inline]
196 fn adjust_errors_variance<F: Float>(
197 errors: Array1<F>,
198 variance: Array1<F>,
199 _lin_pred: &Array1<F>,
200 ) -> (Array1<F>, Array1<F>) {
201 (errors, variance)
202 }
203}
204
205/// Implement some common testing methods that every link function should satisfy.
206#[cfg(test)]
207pub trait TestLink<M: Glm> {
208 /// Assert that $`g(g^{-1}(\omega)) = 1`$ for the entire input array. Since the input domain is
209 /// that of the linear predictor, this should hold for all normal inputs.
210 fn check_closure(xs: &Array1<f64>);
211
212 /// Assert that $`g^{-1}(g(y)) = 1`$ for the entire input array. Since the input domain is
213 /// that of the response variable, the input array must be in the domain of y.
214 fn check_closure_y(ys: &Array1<f64>);
215
216 /// Check that $`\eta(\omega) = g_0(g^{-1}(\omega))`$ on the input domain, where $`g_0`$ is
217 /// the canonical link for `M` (supplied as the type parameter `Mc`) and $`g^{-1}`$ is the
218 /// inverse of the link under test. This is the defining property of the `nat_param`
219 /// transformation and should hold for all normal linear predictor inputs.
220 fn check_nat_par<Mc: Glm>(xs: &Array1<f64>);
221
222 /// Check the derivative of the natural parameter function with numerical difference.
223 /// In particular it compares the ratio of the numerical derivative to the analytical one, so
224 /// that it can be evaluated with a constant epsilon.
225 fn check_nat_par_d(xs: &Array1<f64>);
226}
227
228#[cfg(test)]
229impl<L, M> TestLink<M> for L
230where
231 M: Glm,
232 L: Link<M>,
233{
234 fn check_closure(x: &Array1<f64>) {
235 let x_closed = x.clone().mapv_into(|w| L::func(L::func_inv(w)));
236 // We need a relatively generous epsilon since some of these back-and-forths do lose
237 // precision
238 approx::assert_abs_diff_eq!(*x, x_closed, epsilon = 1e-6);
239 }
240
241 fn check_closure_y(y: &Array1<f64>) {
242 let y_closed = y.clone().mapv_into(|y| L::func_inv(L::func(y)));
243 approx::assert_abs_diff_eq!(*y, y_closed, epsilon = f32::EPSILON as f64);
244 }
245
246 fn check_nat_par<Mc: Glm>(xs: &Array1<f64>) {
247 // nat_param(ω) is defined as g_0(g^{-1}(ω)), so verify:
248 // L::nat_param(xs) == Mc::Link::func(L::func_inv(xs_i))
249 let nat_par_direct = xs.mapv(|w| Mc::Link::func::<f64>(L::func_inv(w)));
250 let nat_par_transform = L::nat_param(xs.clone());
251 approx::assert_abs_diff_eq!(nat_par_direct, nat_par_transform, epsilon = 1e-6);
252 }
253
254 fn check_nat_par_d(xs: &Array1<f64>) {
255 let delta = f32::EPSILON as f64;
256 let d_eta = L::d_nat_param(xs);
257 let x_plus = xs.clone().mapv_into(|x| x + delta / 2.);
258 let x_minus = xs.clone().mapv_into(|x| x - delta / 2.);
259 let eta_diff = L::nat_param(x_plus) - L::nat_param(x_minus);
260 // Note that this requires d_eta != 0, but that should be the case for a good eta
261 // function anyway.
262 // The scaling is necessary because for some link functions a modest range of inputs can be
263 // mapped over many orders of magnitude (e.g. 10^5 to 10^{-5}) and we want a consistent
264 // epsilon over all of them.
265 approx::assert_abs_diff_eq!(
266 eta_diff / (delta * d_eta),
267 Array1::<f64>::ones(xs.len()),
268 epsilon = delta
269 );
270 }
271}