Skip to main content

ndarray_glm/
link.rs

1//! Traits and utilities for link functions.
2//!
3//! The link function $`g`$ maps the expected response $`\mu`$ to the linear predictor
4//! $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$. Each family defaults to its canonical
5//! link, but an alternative can be selected via the family's type parameter.
6//!
7//! # Using a provided non-canonical link
8//!
9//! Alternative links are re-exported for convenience: [`exp_link`](crate::exp_link) for
10//! exponential regression and [`logistic_link`](crate::logistic_link) for logistic regression.
11//! Provide a link as the family's type parameter:
12//!
13//! ```
14//! use ndarray_glm::{Exponential, ModelBuilder, array, exp_link::Log};
15//!
16//! fn main() -> ndarray_glm::error::RegressionResult<(), f64> {
17//!     let data_y = array![1.0, 2.5, 0.8, 3.1];
18//!     let data_x = array![[0.0], [1.0], [0.5], [1.5]];
19//!     // Use the log link instead of the default negative-reciprocal canonical link.
20//!     let model = ModelBuilder::<Exponential<Log>>::data(&data_y, &data_x).build()?;
21//!     let fit = model.fit()?;
22//!     Ok(())
23//! }
24//! ```
25//!
26//! # Implementing a custom non-canonical link
27//!
28//! A non-canonical link requires two trait implementations:
29//!
30//! 1. [`Link<M>`] — the forward map $`g(\mu) = \omega`$ ([`Link::func`]) and its inverse
31//!    $`g^{-1}(\omega) = \mu`$ ([`Link::func_inv`]).
32//! 2. [`Transform`] — the natural-parameter transformation
33//!    $`\eta(\omega) = g_0(g^{-1}(\omega))`$ ([`Transform::nat_param`]) and its derivative
34//!    ([`Transform::d_nat_param`]), where $`g_0`$ is the family's canonical link. The derivative
35//!    satisfies $`\eta'(\omega) = \frac{1}{g'(\mu)\,V(\mu)}`$ where $`V(\mu)`$ is the family's
36//!    variance function evaluated at $`\mu = g^{-1}(\omega)`$.
37//!
38//! Example: a square-root link $`g(\mu) = \sqrt{\mu}`$ for Poisson regression. The canonical
39//! link is $`\log`$ and $`V(\mu) = \mu`$, so
40//! $`\eta(\omega) = \log(\omega^2) = 2\log\omega`$ and $`\eta'(\omega) = 2/\omega`$:
41//!
42//! ```
43//! use ndarray_glm::{Poisson, link::{Link, Transform}, num::Float};
44//! use ndarray::Array1;
45//!
46//! pub struct Sqrt;
47//!
48//! impl Link<Poisson<Sqrt>> for Sqrt {
49//!     fn func<F: Float>(mu: F) -> F { num_traits::Float::sqrt(mu) }
50//!     fn func_inv<F: Float>(omega: F) -> F { omega * omega }
51//! }
52//!
53//! impl Transform for Sqrt {
54//!     fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
55//!         lin_pred.mapv(|w| F::two() * num_traits::Float::ln(w))
56//!     }
57//!     fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
58//!         lin_pred.mapv(|w| F::two() / w)
59//!     }
60//! }
61//! ```
62//!
63//! # Consistency tests with `TestLink`
64//!
65//! The `TestLink` trait (available only in `#[cfg(test)]` builds) provides canned assertions
66//! that every correct link implementation should satisfy. Call them from your test module:
67//!
68//! ```no_run
69//! #[cfg(test)]
70//! mod tests {
71//!     use super::*;
72//!     use ndarray_glm::link::TestLink;
73//!     use ndarray::array;
74//!
75//!     #[test]
76//!     fn sqrt_link_checks() {
77//!         // Linear-predictor values; must lie in the domain of ω (ω > 0 for sqrt).
78//!         let lin_vals = array![0.25, 1.0, 2.0, 4.0, 9.0];
79//!
80//!         // Verify g(g⁻¹(ω)) ≈ ω.
81//!         Sqrt::check_closure(&lin_vals);
82//!
83//!         // Verify g⁻¹(g(μ)) ≈ μ; values must lie in the response domain (μ > 0 for Poisson).
84//!         Sqrt::check_closure_y(&array![0.5, 1.0, 3.0, 10.0]);
85//!
86//!         // For non-canonical links: verify nat_param(ω) = g₀(g⁻¹(ω)).
87//!         // Pass the *canonical* model variant as `Mc`. `Poisson` without a type parameter
88//!         // defaults to the canonical log link.
89//!         use ndarray_glm::Poisson;
90//!         Sqrt::check_nat_par::<Poisson>(&lin_vals);
91//!
92//!         // Verify d_nat_param matches the numerical derivative.
93//!         Sqrt::check_nat_par_d(&lin_vals);
94//!     }
95//! }
96//! ```
97
98use crate::{glm::Glm, num::Float};
99use ndarray::Array1;
100
101/// Describes the link function $`g`$ that maps between the expected response $`\mu`$ and
102/// the linear predictor $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$:
103///
104/// ```math
105/// g(\mu) = \omega, \qquad \mu = g^{-1}(\omega)
106/// ```
107pub trait Link<M: Glm>: Transform {
108    /// Maps the expectation value of the response variable to the linear
109    /// predictor. In general this is determined by a composition of the inverse
110    /// natural parameter transformation and the canonical link function.
111    fn func<F: Float>(y: F) -> F;
112    /// Maps the linear predictor to the expectation value of the response.
113    fn func_inv<F: Float>(lin_pred: F) -> F;
114}
115
116/// Establishes the relationship between the linear predictor $`\omega =
117/// \mathbf{x}\cdot\boldsymbol\beta`$ and the natural parameter $`\eta`$.
118pub trait Transform {
119    /// The natural parameter of the response distribution as a function
120    /// of the linear predictor: $`\eta(\omega) = g_0(g^{-1}(\omega))`$ where $`g_0`$ is the
121    /// canonical link. For canonical links this is the identity.
122    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
123    /// The derivative $`\eta'(\omega)`$ of the transformation to the natural parameter.
124    /// If it is zero in a region that the IRLS is in, the algorithm may have difficulty
125    /// converging.
126    /// It is given in terms of the link and variance functions as $`\eta'(\omega_i) =
127    /// \frac{1}{g'(\mu_i) V(\mu_i)}`$.
128    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F>;
129    /// Adjust the error/residual terms of the likelihood function based on the first derivative of
130    /// the transformation. The linear predictor must be un-transformed, i.e. it must be X*beta
131    /// without the transformation applied.
132    fn adjust_errors<F: Float>(errors: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
133        let eta_d = Self::d_nat_param(lin_pred);
134        eta_d * errors
135    }
136    /// Adjust the variance terms of the likelihood function based on the first and second
137    /// derivatives of the transformation. The linear predictor must be un-transformed, i.e. it
138    /// must be X*beta without the transformation applied.
139    fn adjust_variance<F: Float>(variance: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
140        let eta_d = Self::d_nat_param(lin_pred);
141        // The second-derivative term in the variance matrix can lead it to not
142        // be positive-definite. In fact, the second term should vanish when
143        // taking the expecation of Y to give the Fisher information.
144        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
145        &eta_d * &variance * eta_d
146    }
147    /// Adjust the error and variance terms of the likelihood function based on
148    /// the first and second derivatives of the transformation. The adjustment
149    /// is performed simultaneously. The linear predictor must be
150    /// un-transformed, i.e. it must be X*beta without the transformation
151    /// applied.
152    fn adjust_errors_variance<F: Float>(
153        errors: Array1<F>,
154        variance: Array1<F>,
155        lin_pred: &Array1<F>,
156    ) -> (Array1<F>, Array1<F>) {
157        let eta_d = Self::d_nat_param(lin_pred);
158        let err_adj = &eta_d * &errors;
159        // The second-derivative term in the variance matrix can lead it to not
160        // be positive-definite. In fact, the second term should vanish when
161        // taking the expecation of Y to give the Fisher information.
162        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
163        let var_adj = &eta_d * &variance * eta_d;
164        (err_adj, var_adj)
165    }
166}
167
168/// The canonical transformation by definition equates the linear predictor with
169/// the natural parameter of the response distribution. Implementing this trait
170/// for a link function automatically defines the trivial transformation
171/// functions.
172pub trait Canonical {}
173impl<T> Transform for T
174where
175    T: Canonical,
176{
177    /// By defintion this function is the identity function for canonical links.
178    #[inline]
179    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
180        lin_pred
181    }
182    #[inline]
183    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
184        Array1::<F>::ones(lin_pred.len())
185    }
186    /// The canonical link function requires no transformation of the error and variance terms.
187    #[inline]
188    fn adjust_errors<F: Float>(errors: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
189        errors
190    }
191    #[inline]
192    fn adjust_variance<F: Float>(variance: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
193        variance
194    }
195    #[inline]
196    fn adjust_errors_variance<F: Float>(
197        errors: Array1<F>,
198        variance: Array1<F>,
199        _lin_pred: &Array1<F>,
200    ) -> (Array1<F>, Array1<F>) {
201        (errors, variance)
202    }
203}
204
205/// Implement some common testing methods that every link function should satisfy.
206#[cfg(test)]
207pub trait TestLink<M: Glm> {
208    /// Assert that $`g(g^{-1}(\omega)) = 1`$ for the entire input array. Since the input domain is
209    /// that of the linear predictor, this should hold for all normal inputs.
210    fn check_closure(xs: &Array1<f64>);
211
212    /// Assert that $`g^{-1}(g(y)) = 1`$ for the entire input array. Since the input domain is
213    /// that of the response variable, the input array must be in the domain of y.
214    fn check_closure_y(ys: &Array1<f64>);
215
216    /// Check that $`\eta(\omega) = g_0(g^{-1}(\omega))`$ on the input domain, where $`g_0`$ is
217    /// the canonical link for `M` (supplied as the type parameter `Mc`) and $`g^{-1}`$ is the
218    /// inverse of the link under test. This is the defining property of the `nat_param`
219    /// transformation and should hold for all normal linear predictor inputs.
220    fn check_nat_par<Mc: Glm>(xs: &Array1<f64>);
221
222    /// Check the derivative of the natural parameter function with numerical difference.
223    /// In particular it compares the ratio of the numerical derivative to the analytical one, so
224    /// that it can be evaluated with a constant epsilon.
225    fn check_nat_par_d(xs: &Array1<f64>);
226}
227
228#[cfg(test)]
229impl<L, M> TestLink<M> for L
230where
231    M: Glm,
232    L: Link<M>,
233{
234    fn check_closure(x: &Array1<f64>) {
235        let x_closed = x.clone().mapv_into(|w| L::func(L::func_inv(w)));
236        // We need a relatively generous epsilon since some of these back-and-forths do lose
237        // precision
238        approx::assert_abs_diff_eq!(*x, x_closed, epsilon = 1e-6);
239    }
240
241    fn check_closure_y(y: &Array1<f64>) {
242        let y_closed = y.clone().mapv_into(|y| L::func_inv(L::func(y)));
243        approx::assert_abs_diff_eq!(*y, y_closed, epsilon = f32::EPSILON as f64);
244    }
245
246    fn check_nat_par<Mc: Glm>(xs: &Array1<f64>) {
247        // nat_param(ω) is defined as g_0(g^{-1}(ω)), so verify:
248        //   L::nat_param(xs)  ==  Mc::Link::func(L::func_inv(xs_i))
249        let nat_par_direct = xs.mapv(|w| Mc::Link::func::<f64>(L::func_inv(w)));
250        let nat_par_transform = L::nat_param(xs.clone());
251        approx::assert_abs_diff_eq!(nat_par_direct, nat_par_transform, epsilon = 1e-6);
252    }
253
254    fn check_nat_par_d(xs: &Array1<f64>) {
255        let delta = f32::EPSILON as f64;
256        let d_eta = L::d_nat_param(xs);
257        let x_plus = xs.clone().mapv_into(|x| x + delta / 2.);
258        let x_minus = xs.clone().mapv_into(|x| x - delta / 2.);
259        let eta_diff = L::nat_param(x_plus) - L::nat_param(x_minus);
260        // Note that this requires d_eta != 0, but that should be the case for a good eta
261        // function anyway.
262        // The scaling is necessary because for some link functions a modest range of inputs can be
263        // mapped over many orders of magnitude (e.g. 10^5 to 10^{-5}) and we want a consistent
264        // epsilon over all of them.
265        approx::assert_abs_diff_eq!(
266            eta_diff / (delta * d_eta),
267            Array1::<f64>::ones(xs.len()),
268            epsilon = delta
269        );
270    }
271}