ndarray_glm/link.rs
1//! Defines traits for link functions
2
3use crate::{glm::Glm, num::Float};
4use ndarray::Array1;
5
6/// Describes the link function $`g`$ that maps between the expected response $`\mu`$ and
7/// the linear predictor $`\omega = \mathbf{x}^\mathsf{T}\boldsymbol{\beta}`$:
8///
9/// ```math
10/// g(\mu) = \omega, \qquad \mu = g^{-1}(\omega)
11/// ```
12pub trait Link<M: Glm>: Transform {
13 /// Maps the expectation value of the response variable to the linear
14 /// predictor. In general this is determined by a composition of the inverse
15 /// natural parameter transformation and the canonical link function.
16 fn func<F: Float>(y: F) -> F;
17 /// Maps the linear predictor to the expectation value of the response.
18 fn func_inv<F: Float>(lin_pred: F) -> F;
19}
20
21pub trait Transform {
22 /// The natural parameter of the response distribution as a function
23 /// of the linear predictor: $`\eta(\omega) = g_0(g^{-1}(\omega))`$ where $`g_0`$ is the
24 /// canonical link. For canonical links this is the identity.
25 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F>;
26 /// The derivative $`\eta'(\omega)`$ of the transformation to the natural parameter.
27 /// If it is zero in a region that the IRLS is in, the algorithm may have difficulty
28 /// converging.
29 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F>;
30 /// Adjust the error/residual terms of the likelihood function based on the first derivative of
31 /// the transformation. The linear predictor must be un-transformed, i.e. it must be X*beta
32 /// without the transformation applied.
33 fn adjust_errors<F: Float>(errors: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
34 let eta_d = Self::d_nat_param(lin_pred);
35 eta_d * errors
36 }
37 /// Adjust the variance terms of the likelihood function based on the first and second
38 /// derivatives of the transformation. The linear predictor must be un-transformed, i.e. it
39 /// must be X*beta without the transformation applied.
40 fn adjust_variance<F: Float>(variance: Array1<F>, lin_pred: &Array1<F>) -> Array1<F> {
41 let eta_d = Self::d_nat_param(lin_pred);
42 // The second-derivative term in the variance matrix can lead it to not
43 // be positive-definite. In fact, the second term should vanish when
44 // taking the expecation of Y to give the Fisher information.
45 // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
46 &eta_d * &variance * eta_d
47 }
48 /// Adjust the error and variance terms of the likelihood function based on
49 /// the first and second derivatives of the transformation. The adjustment
50 /// is performed simultaneously. The linear predictor must be
51 /// un-transformed, i.e. it must be X*beta without the transformation
52 /// applied.
53 fn adjust_errors_variance<F: Float>(
54 errors: Array1<F>,
55 variance: Array1<F>,
56 lin_pred: &Array1<F>,
57 ) -> (Array1<F>, Array1<F>) {
58 let eta_d = Self::d_nat_param(lin_pred);
59 let err_adj = &eta_d * &errors;
60 // The second-derivative term in the variance matrix can lead it to not
61 // be positive-definite. In fact, the second term should vanish when
62 // taking the expecation of Y to give the Fisher information.
63 // let var_adj = &eta_d * &variance * eta_d - eta_dd * errors;
64 let var_adj = &eta_d * &variance * eta_d;
65 (err_adj, var_adj)
66 }
67}
68
69/// The canonical transformation by definition equates the linear predictor with
70/// the natural parameter of the response distribution. Implementing this trait
71/// for a link function automatically defines the trivial transformation
72/// functions.
73pub trait Canonical {}
74impl<T> Transform for T
75where
76 T: Canonical,
77{
78 /// By defintion this function is the identity function for canonical links.
79 #[inline]
80 fn nat_param<F: Float>(lin_pred: Array1<F>) -> Array1<F> {
81 lin_pred
82 }
83 #[inline]
84 fn d_nat_param<F: Float>(lin_pred: &Array1<F>) -> Array1<F> {
85 Array1::<F>::ones(lin_pred.len())
86 }
87 /// The canonical link function requires no transformation of the error and variance terms.
88 #[inline]
89 fn adjust_errors<F: Float>(errors: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
90 errors
91 }
92 #[inline]
93 fn adjust_variance<F: Float>(variance: Array1<F>, _lin_pred: &Array1<F>) -> Array1<F> {
94 variance
95 }
96 #[inline]
97 fn adjust_errors_variance<F: Float>(
98 errors: Array1<F>,
99 variance: Array1<F>,
100 _lin_pred: &Array1<F>,
101 ) -> (Array1<F>, Array1<F>) {
102 (errors, variance)
103 }
104}