1pub trait Link<S> {
6 fn inverse(eta: S) -> S;
8 fn derivative_inverse(eta: S) -> S;
10}
11
12pub trait PositiveLink<S>: Link<S> {}
14
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
17pub struct Identity;
18
19impl Link<f64> for Identity {
20 #[inline(always)]
21 fn inverse(eta: f64) -> f64 {
22 eta
23 }
24
25 #[inline(always)]
26 fn derivative_inverse(_: f64) -> f64 {
27 1.0
28 }
29}
30
31#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
33pub struct Log;
34
35impl Link<f64> for Log {
36 #[inline(always)]
37 fn inverse(eta: f64) -> f64 {
38 eta.exp()
39 }
40
41 #[inline(always)]
42 fn derivative_inverse(eta: f64) -> f64 {
43 eta.exp()
44 }
45}
46
47impl PositiveLink<f64> for Log {}
48
49#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
51pub struct Softplus;
52
53impl Link<f64> for Softplus {
54 #[inline(always)]
55 fn inverse(eta: f64) -> f64 {
56 if eta > 30.0 {
57 eta
58 } else if eta < -30.0 {
59 eta.exp()
60 } else {
61 eta.exp().ln_1p()
62 }
63 }
64
65 #[inline(always)]
66 fn derivative_inverse(eta: f64) -> f64 {
67 if eta >= 0.0 {
68 1.0 / (1.0 + (-eta).exp())
69 } else {
70 let exp_eta = eta.exp();
71 exp_eta / (1.0 + exp_eta)
72 }
73 }
74}
75
76impl PositiveLink<f64> for Softplus {}
77
78#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
80pub struct Logit;
81
82impl Link<f64> for Logit {
83 #[inline(always)]
84 fn inverse(eta: f64) -> f64 {
85 if eta >= 0.0 {
86 let z = (-eta).exp();
87 1.0 / (1.0 + z)
88 } else {
89 let z = eta.exp();
90 z / (1.0 + z)
91 }
92 }
93
94 #[inline(always)]
95 fn derivative_inverse(eta: f64) -> f64 {
96 let p = Self::inverse(eta);
97 p * (1.0 - p)
98 }
99}
100
101impl PositiveLink<f64> for Logit {}
102
103#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
105pub struct LogPlus<const OFFSET: i64>;
106
107impl<const OFFSET: i64> Link<f64> for LogPlus<OFFSET> {
108 #[inline(always)]
109 fn inverse(eta: f64) -> f64 {
110 OFFSET as f64 + eta.exp()
111 }
112
113 #[inline(always)]
114 fn derivative_inverse(eta: f64) -> f64 {
115 eta.exp()
116 }
117}
118
119impl<const OFFSET: i64> PositiveLink<f64> for LogPlus<OFFSET> {}
120
121#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
126pub struct ClampedLog<const MIN: i64, const MAX: i64>;
127
128impl<const MIN: i64, const MAX: i64> Link<f64> for ClampedLog<MIN, MAX> {
129 #[inline(always)]
130 fn inverse(eta: f64) -> f64 {
131 let min = MIN as f64;
132 let max = MAX as f64;
133 debug_assert!(min <= max);
134
135 if eta < min {
136 min.exp()
137 } else if eta > max {
138 max.exp()
139 } else {
140 eta.exp()
141 }
142 }
143
144 #[inline(always)]
145 fn derivative_inverse(eta: f64) -> f64 {
146 let min = MIN as f64;
147 let max = MAX as f64;
148 debug_assert!(min <= max);
149
150 if (min..=max).contains(&eta) {
151 eta.exp()
152 } else {
153 0.0
154 }
155 }
156}
157
158impl<const MIN: i64, const MAX: i64> PositiveLink<f64> for ClampedLog<MIN, MAX> {}
159
160#[cfg(test)]
161mod tests {
162 use approx::assert_relative_eq;
163
164 use crate::{ClampedLog, Link};
165
166 #[test]
167 fn clamped_log_clamps_value_and_derivative() {
168 type LinkUnderTest = ClampedLog<-2, 2>;
169
170 assert_relative_eq!(LinkUnderTest::inverse(-3.0), (-2.0_f64).exp());
171 assert_relative_eq!(LinkUnderTest::inverse(1.0), 1.0_f64.exp());
172 assert_relative_eq!(LinkUnderTest::inverse(3.0), 2.0_f64.exp());
173
174 assert_eq!(LinkUnderTest::derivative_inverse(-3.0), 0.0);
175 assert_relative_eq!(LinkUnderTest::derivative_inverse(1.0), 1.0_f64.exp());
176 assert_eq!(LinkUnderTest::derivative_inverse(3.0), 0.0);
177 }
178}