Skip to main content

ndarray_glm/
link.rs

1//! Defines traits for link functions
2
3use crate::{glm::Glm, num::Float};
4use ndarray::Array1;
5
6/// Describes the link function $`g`$ that maps between the expected response $`\mu`$ and
7/// the linear predictor $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$:
8///
9/// ```math
10/// g(\mu) = \omega, \qquad \mu = g^{-1}(\omega)
11/// ```
12pub trait Link<M: Glm>: Transform {
13    /// Maps the expectation value of the response variable to the linear
14    /// predictor. In general this is determined by a composition of the inverse
15    /// natural parameter transformation and the canonical link function.
16    fn func<F: Float>(y: F) -> F;
17    /// Maps the linear predictor to the expectation value of the response.
18    fn func_inv<F: Float>(lin_pred: F) -> F;
19}
20
21pub trait Transform {
22    /// The natural parameter of the response distribution as a function
23    /// of the linear predictor: $`\eta(\omega) = g_0(g^{-1}(\omega))`$ where $`g_0`$ is the
24    /// canonical link. For canonical links this is the identity.
25    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
26    /// The derivative $`\eta'(\omega)`$ of the transformation to the natural parameter.
27    /// If it is zero in a region that the IRLS is in, the algorithm may have difficulty
28    /// converging.
29    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F>;
30    /// Adjust the error/residual terms of the likelihood function based on the first derivative of
31    /// the transformation. The linear predictor must be un-transformed, i.e. it must be X*beta
32    /// without the transformation applied.
33    fn adjust_errors<F: Float>(errors: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
34        let eta_d = Self::d_nat_param(lin_pred);
35        eta_d * errors
36    }
37    /// Adjust the variance terms of the likelihood function based on the first and second
38    /// derivatives of the transformation. The linear predictor must be un-transformed, i.e. it
39    /// must be X*beta without the transformation applied.
40    fn adjust_variance<F: Float>(variance: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
41        let eta_d = Self::d_nat_param(lin_pred);
42        // The second-derivative term in the variance matrix can lead it to not
43        // be positive-definite. In fact, the second term should vanish when
44        // taking the expecation of Y to give the Fisher information.
45        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
46        &eta_d * &variance * eta_d
47    }
48    /// Adjust the error and variance terms of the likelihood function based on
49    /// the first and second derivatives of the transformation. The adjustment
50    /// is performed simultaneously. The linear predictor must be
51    /// un-transformed, i.e. it must be X*beta without the transformation
52    /// applied.
53    fn adjust_errors_variance<F: Float>(
54        errors: Array1<F>,
55        variance: Array1<F>,
56        lin_pred: &Array1<F>,
57    ) -> (Array1<F>, Array1<F>) {
58        let eta_d = Self::d_nat_param(lin_pred);
59        let err_adj = &eta_d * &errors;
60        // The second-derivative term in the variance matrix can lead it to not
61        // be positive-definite. In fact, the second term should vanish when
62        // taking the expecation of Y to give the Fisher information.
63        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
64        let var_adj = &eta_d * &variance * eta_d;
65        (err_adj, var_adj)
66    }
67}
68
69/// The canonical transformation by definition equates the linear predictor with
70/// the natural parameter of the response distribution. Implementing this trait
71/// for a link function automatically defines the trivial transformation
72/// functions.
73pub trait Canonical {}
74impl<T> Transform for T
75where
76    T: Canonical,
77{
78    /// By defintion this function is the identity function for canonical links.
79    #[inline]
80    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
81        lin_pred
82    }
83    #[inline]
84    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
85        Array1::<F>::ones(lin_pred.len())
86    }
87    /// The canonical link function requires no transformation of the error and variance terms.
88    #[inline]
89    fn adjust_errors<F: Float>(errors: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
90        errors
91    }
92    #[inline]
93    fn adjust_variance<F: Float>(variance: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
94        variance
95    }
96    #[inline]
97    fn adjust_errors_variance<F: Float>(
98        errors: Array1<F>,
99        variance: Array1<F>,
100        _lin_pred: &Array1<F>,
101    ) -> (Array1<F>, Array1<F>) {
102        (errors, variance)
103    }
104}