Skip to main content

ndarray_glm/
link.rs

1//! Defines traits for link functions
2
3use crate::{glm::Glm, num::Float};
4use ndarray::Array1;
5
6/// Describes the functions to map to and from the linear predictors and the
7/// expectation of the response. It is constrained mathematically by the
8/// response distribution and the transformation of the linear predictor.
9// TODO: The link function and its inverse are independent of the response
10// distribution. This could be refactored to separate the function itself from
11// the transformation that works with the distribution.
12pub trait Link<M: Glm>: Transform {
13    /// Maps the expectation value of the response variable to the linear
14    /// predictor. In general this is determined by a composition of the inverse
15    /// natural parameter transformation and the canonical link function.
16    fn func<F: Float>(y: F) -> F;
17    // fn func<F: Float>(y: Array1<F>) -> Array1<F>;
18    /// Maps the linear predictor to the expectation value of the response.
19    // TODO: There may not be a point in using Array versions of these functions
20    // since clones are necessary anyway. Perhaps we could simply define the
21    // scalar function and use mapv().
22    fn func_inv<F: Float>(lin_pred: F) -> F;
23    // fn func_inv<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
24}
25
26pub trait Transform {
27    /// The natural parameter(s) of the response distribution as a function
28    /// of the linear predictor. For canonical link functions this is the
29    /// identity. It must be monotonic, invertible, and twice-differentiable.
30    /// For link function g and canonical link function g_0 it is equal to
31    /// g_0 ( g^{-1}(lin_pred) ) .
32    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
33    /// The derivative of the transformation to the natural parameter. If it is
34    /// zero in a region that the IRLS is in the algorithm may have difficulty
35    /// converging.
36    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F>;
37    /// Adjust the error/residual terms of the likelihood function based on the first derivative of
38    /// the transformation. The linear predictor must be un-transformed, i.e. it must be X*beta
39    /// without the transformation applied.
40    fn adjust_errors<F: Float>(
41        errors: Array1<F>,
42        lin_pred: &Array1<F>,
43    ) -> Array1<F> {
44        let eta_d = Self::d_nat_param(lin_pred);
45        eta_d * errors
46    }
47    /// Adjust the variance terms of the likelihood function based on the first and second
48    /// derivatives of the transformation. The linear predictor must be un-transformed, i.e. it
49    /// must be X*beta without the transformation applied.
50    fn adjust_variance<F: Float>(
51        variance: Array1<F>,
52        lin_pred: &Array1<F>,
53    ) -> Array1<F> {
54        let eta_d = Self::d_nat_param(lin_pred);
55        // The second-derivative term in the variance matrix can lead it to not
56        // be positive-definite. In fact, the second term should vanish when
57        // taking the expecation of Y to give the Fisher information.
58        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
59        &eta_d * &variance * eta_d
60    }
61    /// Adjust the error and variance terms of the likelihood function based on
62    /// the first and second derivatives of the transformation. The adjustment
63    /// is performed simultaneously. The linear predictor must be
64    /// un-transformed, i.e. it must be X*beta without the transformation
65    /// applied.
66    fn adjust_errors_variance<F: Float>(
67        errors: Array1<F>,
68        variance: Array1<F>,
69        lin_pred: &Array1<F>,
70    ) -> (Array1<F>, Array1<F>) {
71        let eta_d = Self::d_nat_param(lin_pred);
72        let err_adj = &eta_d * &errors;
73        // The second-derivative term in the variance matrix can lead it to not
74        // be positive-definite. In fact, the second term should vanish when
75        // taking the expecation of Y to give the Fisher information.
76        // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
77        let var_adj = &eta_d * &variance * eta_d;
78        (err_adj, var_adj)
79    }
80}
81
82/// The canonical transformation by definition equates the linear predictor with
83/// the natural parameter of the response distribution. Implementing this trait
84/// for a link function automatically defines the trivial transformation
85/// functions.
86pub trait Canonical {}
87impl<T> Transform for T
88where
89    T: Canonical,
90{
91    /// By defintion this function is the identity function for canonical links.
92    #[inline]
93    fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
94        lin_pred
95    }
96    #[inline]
97    fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
98        Array1::<F>::ones(lin_pred.len())
99    }
100    /// The canonical link function requires no transformation of the error and variance terms.
101    #[inline]
102    fn adjust_errors<F: Float>(
103        errors: Array1<F>,
104        _lin_pred: &Array1<F>,
105    ) -> Array1<F> {
106        errors
107    }
108    #[inline]
109    fn adjust_variance<F: Float>(
110        variance: Array1<F>,
111        _lin_pred: &Array1<F>,
112    ) -> Array1<F> {
113        variance
114    }
115    #[inline]
116    fn adjust_errors_variance<F: Float>(
117        errors: Array1<F>,
118        variance: Array1<F>,
119        _lin_pred: &Array1<F>,
120    ) -> (Array1<F>, Array1<F>) {
121        (errors, variance)
122    }
123}