1use ndarray::{Array1, Array2};
9
10pub trait Density: Clone + Send + Sync {
15 fn log_lik(&self, y: &Array1<f64>) -> Array1<f64>;
17
18 fn score_and_der(&self, y: &Array2<f64>) -> (Array2<f64>, Array2<f64>);
22}
23
24#[derive(Clone, Debug)]
31pub struct Tanh {
32 pub alpha: f64,
34}
35
36impl Default for Tanh {
37 fn default() -> Self {
38 Self { alpha: 1.0 }
39 }
40}
41
42impl Tanh {
43 pub fn new(alpha: f64) -> Self {
45 Self { alpha }
46 }
47}
48
49impl Density for Tanh {
50 fn log_lik(&self, y: &Array1<f64>) -> Array1<f64> {
51 let alpha = self.alpha;
52 y.mapv(|v| {
53 let abs_y = v.abs();
54 abs_y + (1.0 + (-2.0 * alpha * abs_y).exp()).ln() / alpha
55 })
56 }
57
58 fn score_and_der(&self, y: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
59 let alpha = self.alpha;
60 let score = y.mapv(|v| (alpha * v).tanh());
61 let score_der = score.mapv(|s| alpha * (1.0 - s * s));
62 (score, score_der)
63 }
64}
65
66#[derive(Clone, Debug)]
72pub struct Exp {
73 pub alpha: f64,
75}
76
77impl Default for Exp {
78 fn default() -> Self {
79 Self { alpha: 1.0 }
80 }
81}
82
83impl Exp {
84 pub fn new(alpha: f64) -> Self {
86 Self { alpha }
87 }
88}
89
90impl Density for Exp {
91 fn log_lik(&self, y: &Array1<f64>) -> Array1<f64> {
92 let a = self.alpha;
93 y.mapv(|v| -(-a * v * v / 2.0).exp() / a)
94 }
95
96 fn score_and_der(&self, y: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
97 let a = self.alpha;
98 let y_sq = y.mapv(|v| v * v);
99 let k = y_sq.mapv(|v| (-a / 2.0 * v).exp());
100 let score = y * &k;
101 let score_der = (1.0 - a * &y_sq) * k;
102 (score, score_der)
103 }
104}
105
106#[derive(Clone, Debug, Default)]
112pub struct Cube;
113
114impl Cube {
115 pub fn new() -> Self {
117 Self
118 }
119}
120
121impl Density for Cube {
122 fn log_lik(&self, y: &Array1<f64>) -> Array1<f64> {
123 y.mapv(|v| v.powi(4) / 4.0)
124 }
125
126 fn score_and_der(&self, y: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
127 let score = y.mapv(|v| v.powi(3));
128 let score_der = y.mapv(|v| 3.0 * v * v);
129 (score, score_der)
130 }
131}
132
133#[derive(Clone, Debug)]
137pub enum DensityType {
138 Tanh(Tanh),
140 Exp(Exp),
142 Cube(Cube),
144}
145
146impl Default for DensityType {
147 fn default() -> Self {
148 DensityType::Tanh(Tanh::default())
149 }
150}
151
152impl DensityType {
153 pub fn tanh() -> Self {
155 DensityType::Tanh(Tanh::default())
156 }
157
158 pub fn tanh_with_alpha(alpha: f64) -> Self {
160 DensityType::Tanh(Tanh::new(alpha))
161 }
162
163 pub fn exp() -> Self {
165 DensityType::Exp(Exp::default())
166 }
167
168 pub fn exp_with_alpha(alpha: f64) -> Self {
170 DensityType::Exp(Exp::new(alpha))
171 }
172
173 pub fn cube() -> Self {
175 DensityType::Cube(Cube::new())
176 }
177
178 pub fn log_lik(&self, y: &Array1<f64>) -> Array1<f64> {
180 match self {
181 DensityType::Tanh(d) => d.log_lik(y),
182 DensityType::Exp(d) => d.log_lik(y),
183 DensityType::Cube(d) => d.log_lik(y),
184 }
185 }
186
187 pub fn score_and_der(&self, y: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
189 match self {
190 DensityType::Tanh(d) => d.score_and_der(y),
191 DensityType::Exp(d) => d.score_and_der(y),
192 DensityType::Cube(d) => d.score_and_der(y),
193 }
194 }
195}