ndarray_glm/link.rs
1//! Defines traits for link functions
2
3use crate::{glm::Glm, num::Float};
4use ndarray::Array1;
5
6/// Describes the link function $`g`$ that maps between the expected response $`\mu`$ and
7/// the linear predictor $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$:
8///
9/// ```math
10/// g(\mu) = \omega, \qquad \mu = g^{-1}(\omega)
11/// ```
12// TODO: The link function and its inverse are independent of the response
13// distribution. This could be refactored to separate the function itself from
14// the transformation that works with the distribution.
15pub trait Link<M: Glm>: Transform {
16 /// Maps the expectation value of the response variable to the linear
17 /// predictor. In general this is determined by a composition of the inverse
18 /// natural parameter transformation and the canonical link function.
19 fn func<F: Float>(y: F) -> F;
20 // fn func<F: Float>(y: Array1<F>) -> Array1<F>;
21 /// Maps the linear predictor to the expectation value of the response.
22 // TODO: There may not be a point in using Array versions of these functions
23 // since clones are necessary anyway. Perhaps we could simply define the
24 // scalar function and use mapv().
25 fn func_inv<F: Float>(lin_pred: F) -> F;
26 // fn func_inv<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
27}
28
29pub trait Transform {
30 /// The natural parameter of the response distribution as a function
31 /// of the linear predictor: $`\eta(\omega) = g_0(g^{-1}(\omega))`$ where $`g_0`$ is the
32 /// canonical link. For canonical links this is the identity.
33 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
34 /// The derivative $`\eta'(\omega)`$ of the transformation to the natural parameter.
35 /// If it is zero in a region that the IRLS is in, the algorithm may have difficulty
36 /// converging.
37 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F>;
38 /// Adjust the error/residual terms of the likelihood function based on the first derivative of
39 /// the transformation. The linear predictor must be un-transformed, i.e. it must be X*beta
40 /// without the transformation applied.
41 fn adjust_errors<F: Float>(
42 errors: Array1<F>,
43 lin_pred: &Array1<F>,
44 ) -> Array1<F> {
45 let eta_d = Self::d_nat_param(lin_pred);
46 eta_d * errors
47 }
48 /// Adjust the variance terms of the likelihood function based on the first and second
49 /// derivatives of the transformation. The linear predictor must be un-transformed, i.e. it
50 /// must be X*beta without the transformation applied.
51 fn adjust_variance<F: Float>(
52 variance: Array1<F>,
53 lin_pred: &Array1<F>,
54 ) -> Array1<F> {
55 let eta_d = Self::d_nat_param(lin_pred);
56 // The second-derivative term in the variance matrix can lead it to not
57 // be positive-definite. In fact, the second term should vanish when
58 // taking the expecation of Y to give the Fisher information.
59 // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
60 &eta_d * &variance * eta_d
61 }
62 /// Adjust the error and variance terms of the likelihood function based on
63 /// the first and second derivatives of the transformation. The adjustment
64 /// is performed simultaneously. The linear predictor must be
65 /// un-transformed, i.e. it must be X*beta without the transformation
66 /// applied.
67 fn adjust_errors_variance<F: Float>(
68 errors: Array1<F>,
69 variance: Array1<F>,
70 lin_pred: &Array1<F>,
71 ) -> (Array1<F>, Array1<F>) {
72 let eta_d = Self::d_nat_param(lin_pred);
73 let err_adj = &eta_d * &errors;
74 // The second-derivative term in the variance matrix can lead it to not
75 // be positive-definite. In fact, the second term should vanish when
76 // taking the expecation of Y to give the Fisher information.
77 // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
78 let var_adj = &eta_d * &variance * eta_d;
79 (err_adj, var_adj)
80 }
81}
82
83/// The canonical transformation by definition equates the linear predictor with
84/// the natural parameter of the response distribution. Implementing this trait
85/// for a link function automatically defines the trivial transformation
86/// functions.
87pub trait Canonical {}
88impl<T> Transform for T
89where
90 T: Canonical,
91{
92 /// By defintion this function is the identity function for canonical links.
93 #[inline]
94 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
95 lin_pred
96 }
97 #[inline]
98 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
99 Array1::<F>::ones(lin_pred.len())
100 }
101 /// The canonical link function requires no transformation of the error and variance terms.
102 #[inline]
103 fn adjust_errors<F: Float>(
104 errors: Array1<F>,
105 _lin_pred: &Array1<F>,
106 ) -> Array1<F> {
107 errors
108 }
109 #[inline]
110 fn adjust_variance<F: Float>(
111 variance: Array1<F>,
112 _lin_pred: &Array1<F>,
113 ) -> Array1<F> {
114 variance
115 }
116 #[inline]
117 fn adjust_errors_variance<F: Float>(
118 errors: Array1<F>,
119 variance: Array1<F>,
120 _lin_pred: &Array1<F>,
121 ) -> (Array1<F>, Array1<F>) {
122 (errors, variance)
123 }
124}