Skip to main content

gamlss_core/
link.rs

1/// Link-функция, представленная через обратное преобразование.
2///
3/// В likelihood hot path используется `inverse(eta)`, где `eta` находится на
4/// линейной шкале predictor-а, и производная обратной функции для chain rule.
5pub trait Link<S> {
6    /// Переводит значение с predictor-шкалы на шкалу параметра.
7    fn inverse(eta: S) -> S;
8    /// Производная `inverse` по `eta`.
9    fn derivative_inverse(eta: S) -> S;
10}
11
12/// Маркер для link-функций, гарантирующих положительный результат.
13pub trait PositiveLink<S>: Link<S> {}
14
15/// Identity link: `theta = eta`.
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
17pub struct Identity;
18
19impl Link<f64> for Identity {
20    #[inline(always)]
21    fn inverse(eta: f64) -> f64 {
22        eta
23    }
24
25    #[inline(always)]
26    fn derivative_inverse(_: f64) -> f64 {
27        1.0
28    }
29}
30
31/// Log link: `theta = exp(eta)`.
32#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
33pub struct Log;
34
35impl Link<f64> for Log {
36    #[inline(always)]
37    fn inverse(eta: f64) -> f64 {
38        eta.exp()
39    }
40
41    #[inline(always)]
42    fn derivative_inverse(eta: f64) -> f64 {
43        eta.exp()
44    }
45}
46
47impl PositiveLink<f64> for Log {}
48
49/// Численно устойчивый positive link: `theta = ln(1 + exp(eta))`.
50#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
51pub struct Softplus;
52
53impl Link<f64> for Softplus {
54    #[inline(always)]
55    fn inverse(eta: f64) -> f64 {
56        if eta > 30.0 {
57            eta
58        } else if eta < -30.0 {
59            eta.exp()
60        } else {
61            eta.exp().ln_1p()
62        }
63    }
64
65    #[inline(always)]
66    fn derivative_inverse(eta: f64) -> f64 {
67        if eta >= 0.0 {
68            1.0 / (1.0 + (-eta).exp())
69        } else {
70            let exp_eta = eta.exp();
71            exp_eta / (1.0 + exp_eta)
72        }
73    }
74}
75
76impl PositiveLink<f64> for Softplus {}
77
78/// Обратный logit link: `theta` лежит в интервале `(0, 1)`.
79#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
80pub struct Logit;
81
82impl Link<f64> for Logit {
83    #[inline(always)]
84    fn inverse(eta: f64) -> f64 {
85        if eta >= 0.0 {
86            let z = (-eta).exp();
87            1.0 / (1.0 + z)
88        } else {
89            let z = eta.exp();
90            z / (1.0 + z)
91        }
92    }
93
94    #[inline(always)]
95    fn derivative_inverse(eta: f64) -> f64 {
96        let p = Self::inverse(eta);
97        p * (1.0 - p)
98    }
99}
100
101impl PositiveLink<f64> for Logit {}
102
103/// Сдвинутый log link: `theta = OFFSET + exp(eta)`.
104#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
105pub struct LogPlus<const OFFSET: i64>;
106
107impl<const OFFSET: i64> Link<f64> for LogPlus<OFFSET> {
108    #[inline(always)]
109    fn inverse(eta: f64) -> f64 {
110        OFFSET as f64 + eta.exp()
111    }
112
113    #[inline(always)]
114    fn derivative_inverse(eta: f64) -> f64 {
115        eta.exp()
116    }
117}
118
119impl<const OFFSET: i64> PositiveLink<f64> for LogPlus<OFFSET> {}
120
121/// Clamped log link: `theta = exp(clamp(eta, MIN, MAX))`.
122///
123/// The derivative is zero outside the active interval, matching a hard clamp
124/// on the predictor scale. Callers should choose `MIN <= MAX`.
125#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
126pub struct ClampedLog<const MIN: i64, const MAX: i64>;
127
128impl<const MIN: i64, const MAX: i64> Link<f64> for ClampedLog<MIN, MAX> {
129    #[inline(always)]
130    fn inverse(eta: f64) -> f64 {
131        let min = MIN as f64;
132        let max = MAX as f64;
133        debug_assert!(min <= max);
134
135        if eta < min {
136            min.exp()
137        } else if eta > max {
138            max.exp()
139        } else {
140            eta.exp()
141        }
142    }
143
144    #[inline(always)]
145    fn derivative_inverse(eta: f64) -> f64 {
146        let min = MIN as f64;
147        let max = MAX as f64;
148        debug_assert!(min <= max);
149
150        if (min..=max).contains(&eta) {
151            eta.exp()
152        } else {
153            0.0
154        }
155    }
156}
157
158impl<const MIN: i64, const MAX: i64> PositiveLink<f64> for ClampedLog<MIN, MAX> {}
159
160#[cfg(test)]
161mod tests {
162    use approx::assert_relative_eq;
163
164    use crate::{ClampedLog, Link};
165
166    #[test]
167    fn clamped_log_clamps_value_and_derivative() {
168        type LinkUnderTest = ClampedLog<-2, 2>;
169
170        assert_relative_eq!(LinkUnderTest::inverse(-3.0), (-2.0_f64).exp());
171        assert_relative_eq!(LinkUnderTest::inverse(1.0), 1.0_f64.exp());
172        assert_relative_eq!(LinkUnderTest::inverse(3.0), 2.0_f64.exp());
173
174        assert_eq!(LinkUnderTest::derivative_inverse(-3.0), 0.0);
175        assert_relative_eq!(LinkUnderTest::derivative_inverse(1.0), 1.0_f64.exp());
176        assert_eq!(LinkUnderTest::derivative_inverse(3.0), 0.0);
177    }
178}