Skip to main content

ndarray_glm/
link.rs

1//! Defines traits for link functions
2
3use crate::{glm::Glm, num::Float};
4use ndarray::Array1;
5
6/// Describes the link function $`g`$ that maps between the expected response $`\mu`$ and
7/// the linear predictor $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$:
8///
9/// ```math
10/// g(\mu) = \omega, \qquad \mu = g^{-1}(\omega)
11/// ```
12// TODO: The link function and its inverse are independent of the response
13// distribution. This could be refactored to separate the function itself from
14// the transformation that works with the distribution.
15pub trait Link<M: Glm>: Transform {
16    /// Maps the expectation value of the response variable to the linear
17    /// predictor. In general this is determined by a composition of the inverse
18    /// natural parameter transformation and the canonical link function.
19    fn func<F: Float>(y: F) -> F;
20    // fn func<F: Float>(y: Array1<F>) -> Array1<F>;
21    /// Maps the linear predictor to the expectation value of the response.
22    // TODO: There may not be a point in using Array versions of these functions
23    // since clones are necessary anyway. Perhaps we could simply define the
24    // scalar function and use mapv().
25    fn func_inv<F: Float>(lin_pred: F) -> F;
26    // fn func_inv<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
27}
28
29pub trait Transform {
30    /// The natural parameter of the response distribution as a function
31    /// of the linear predictor: $`\eta(\omega) = g_0(g^{-1}(\omega))`$ where $`g_0`$ is the
32    /// canonical link. For canonical links this is the identity.
33    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
34    /// The derivative $`\eta'(\omega)`$ of the transformation to the natural parameter.
35    /// If it is zero in a region that the IRLS is in, the algorithm may have difficulty
36    /// converging.
37    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F>;
38    /// Adjust the error/residual terms of the likelihood function based on the first derivative of
39    /// the transformation. The linear predictor must be un-transformed, i.e. it must be X*beta
40    /// without the transformation applied.
41    fn adjust_errors<F: Float>(
42        errors: Array1<F>,
43        lin_pred: &Array1<F>,
44    ) -> Array1<F> {
45        let eta_d = Self::d_nat_param(lin_pred);
46        eta_d * errors
47    }
48    /// Adjust the variance terms of the likelihood function based on the first and second
49    /// derivatives of the transformation. The linear predictor must be un-transformed, i.e. it
50    /// must be X*beta without the transformation applied.
51    fn adjust_variance<F: Float>(
52        variance: Array1<F>,
53        lin_pred: &Array1<F>,
54    ) -> Array1<F> {
55        let eta_d = Self::d_nat_param(lin_pred);
56        // The second-derivative term in the variance matrix can lead it to not
57        // be positive-definite. In fact, the second term should vanish when
58        // taking the expecation of Y to give the Fisher information.
59        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
60        &eta_d * &variance * eta_d
61    }
62    /// Adjust the error and variance terms of the likelihood function based on
63    /// the first and second derivatives of the transformation. The adjustment
64    /// is performed simultaneously. The linear predictor must be
65    /// un-transformed, i.e. it must be X*beta without the transformation
66    /// applied.
67    fn adjust_errors_variance<F: Float>(
68        errors: Array1<F>,
69        variance: Array1<F>,
70        lin_pred: &Array1<F>,
71    ) -> (Array1<F>, Array1<F>) {
72        let eta_d = Self::d_nat_param(lin_pred);
73        let err_adj = &eta_d * &errors;
74        // The second-derivative term in the variance matrix can lead it to not
75        // be positive-definite. In fact, the second term should vanish when
76        // taking the expecation of Y to give the Fisher information.
77        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
78        let var_adj = &eta_d * &variance * eta_d;
79        (err_adj, var_adj)
80    }
81}
82
83/// The canonical transformation by definition equates the linear predictor with
84/// the natural parameter of the response distribution. Implementing this trait
85/// for a link function automatically defines the trivial transformation
86/// functions.
87pub trait Canonical {}
88impl<T> Transform for T
89where
90    T: Canonical,
91{
92    /// By defintion this function is the identity function for canonical links.
93    #[inline]
94    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
95        lin_pred
96    }
97    #[inline]
98    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
99        Array1::<F>::ones(lin_pred.len())
100    }
101    /// The canonical link function requires no transformation of the error and variance terms.
102    #[inline]
103    fn adjust_errors<F: Float>(
104        errors: Array1<F>,
105        _lin_pred: &Array1<F>,
106    ) -> Array1<F> {
107        errors
108    }
109    #[inline]
110    fn adjust_variance<F: Float>(
111        variance: Array1<F>,
112        _lin_pred: &Array1<F>,
113    ) -> Array1<F> {
114        variance
115    }
116    #[inline]
117    fn adjust_errors_variance<F: Float>(
118        errors: Array1<F>,
119        variance: Array1<F>,
120        _lin_pred: &Array1<F>,
121    ) -> (Array1<F>, Array1<F>) {
122        (errors, variance)
123    }
124}