1use crate::dist::{Distribution, DistributionMethods, RegressionDistn};
2use crate::scores::{
3 CRPScore, CRPScoreCensored, CensoredScorable, LogScore, LogScoreCensored, Scorable,
4 SurvivalData,
5};
6use ndarray::{array, Array1, Array2, Array3};
7use rand::prelude::*;
8use statrs::distribution::{ContinuousCDF, Exp};
9
10#[derive(Debug, Clone)]
12pub struct Exponential {
13 pub rate: Array1<f64>,
15 pub scale: Array1<f64>,
17 _params: Array2<f64>,
18}
19
20impl Distribution for Exponential {
21 fn from_params(params: &Array2<f64>) -> Self {
22 let scale = params.column(0).mapv(f64::exp);
24 let rate = 1.0 / &scale;
25 Exponential {
26 rate,
27 scale,
28 _params: params.clone(),
29 }
30 }
31
32 fn fit(y: &Array1<f64>) -> Array1<f64> {
33 let mean = y.mean().unwrap_or(1.0);
34 array![mean.ln()]
36 }
37
38 fn n_params(&self) -> usize {
39 1
40 }
41
42 fn predict(&self) -> Array1<f64> {
43 1.0 / &self.rate
45 }
46
47 fn params(&self) -> &Array2<f64> {
48 &self._params
49 }
50}
51
52impl RegressionDistn for Exponential {}
53
54impl DistributionMethods for Exponential {
55 fn mean(&self) -> Array1<f64> {
56 self.scale.clone()
58 }
59
60 fn variance(&self) -> Array1<f64> {
61 &self.scale * &self.scale
63 }
64
65 fn std(&self) -> Array1<f64> {
66 self.scale.clone()
68 }
69
70 fn pdf(&self, y: &Array1<f64>) -> Array1<f64> {
71 let mut result = Array1::zeros(y.len());
72 for i in 0..y.len() {
73 if y[i] >= 0.0 {
74 result[i] = self.rate[i] * (-self.rate[i] * y[i]).exp();
76 }
77 }
78 result
79 }
80
81 fn logpdf(&self, y: &Array1<f64>) -> Array1<f64> {
82 let mut result = Array1::zeros(y.len());
83 for i in 0..y.len() {
84 if y[i] >= 0.0 {
85 result[i] = self.rate[i].ln() - self.rate[i] * y[i];
87 } else {
88 result[i] = f64::NEG_INFINITY;
89 }
90 }
91 result
92 }
93
94 fn cdf(&self, y: &Array1<f64>) -> Array1<f64> {
95 let mut result = Array1::zeros(y.len());
96 for i in 0..y.len() {
97 if y[i] >= 0.0 {
98 result[i] = 1.0 - (-self.rate[i] * y[i]).exp();
100 }
101 }
102 result
103 }
104
105 fn ppf(&self, q: &Array1<f64>) -> Array1<f64> {
106 let mut result = Array1::zeros(q.len());
107 for i in 0..q.len() {
108 let q_clamped = q[i].clamp(1e-15, 1.0 - 1e-15);
110 result[i] = -(1.0 - q_clamped).ln() / self.rate[i];
111 }
112 result
113 }
114
115 fn sample(&self, n_samples: usize) -> Array2<f64> {
116 let n_obs = self.scale.len();
117 let mut samples = Array2::zeros((n_samples, n_obs));
118 let mut rng = rand::rng();
119
120 for i in 0..n_obs {
121 if let Ok(d) = Exp::new(self.rate[i]) {
122 for s in 0..n_samples {
123 let u: f64 = rng.random();
124 samples[[s, i]] = d.inverse_cdf(u);
125 }
126 }
127 }
128 samples
129 }
130
131 fn median(&self) -> Array1<f64> {
132 self.scale.mapv(|s| s * std::f64::consts::LN_2)
134 }
135
136 fn mode(&self) -> Array1<f64> {
137 Array1::zeros(self.scale.len())
139 }
140}
141
142impl Scorable<LogScore> for Exponential {
143 fn score(&self, y: &Array1<f64>) -> Array1<f64> {
144 let mut scores = Array1::zeros(y.len());
146 for (i, &y_i) in y.iter().enumerate() {
147 scores[i] = self.scale[i].ln() + y_i / self.scale[i];
148 }
149 scores
150 }
151
152 fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
153 let n_obs = y.len();
158 let mut d_params = Array2::zeros((n_obs, 1));
159
160 for i in 0..n_obs {
161 d_params[[i, 0]] = 1.0 - y[i] / self.scale[i];
162 }
163
164 d_params
165 }
166
167 fn metric(&self) -> Array3<f64> {
168 let n_obs = self.scale.len();
169 let mut fi = Array3::zeros((n_obs, 1, 1));
170
171 for i in 0..n_obs {
172 fi[[i, 0, 0]] = 1.0;
173 }
174
175 fi
176 }
177}
178
179impl Scorable<CRPScore> for Exponential {
180 fn score(&self, y: &Array1<f64>) -> Array1<f64> {
181 let mut scores = Array1::zeros(y.len());
185 for i in 0..y.len() {
186 let exp_term = (-y[i] / self.scale[i]).exp();
187 scores[i] = y[i] + self.scale[i] * (2.0 * exp_term - 1.5);
188 }
189 scores
190 }
191
192 fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
193 let n_obs = y.len();
200 let mut d_params = Array2::zeros((n_obs, 1));
201
202 for i in 0..n_obs {
203 let exp_term = (-y[i] / self.scale[i]).exp();
204 d_params[[i, 0]] = 2.0 * exp_term * (y[i] + self.scale[i]) - 1.5 * self.scale[i];
205 }
206
207 d_params
208 }
209
210 fn metric(&self) -> Array3<f64> {
211 let n_obs = self.scale.len();
213 let mut fi = Array3::zeros((n_obs, 1, 1));
214
215 for i in 0..n_obs {
216 fi[[i, 0, 0]] = 0.5 * self.scale[i];
217 }
218
219 fi
220 }
221}
222
223impl CensoredScorable<LogScoreCensored> for Exponential {
228 fn censored_score(&self, y: &SurvivalData) -> Array1<f64> {
229 let eps = 1e-10;
230 let mut scores = Array1::zeros(y.len());
231
232 for i in 0..y.len() {
233 let t = y.time[i];
234 let e = y.event[i];
235 let d = Exp::new(self.rate[i]).unwrap();
237
238 if e {
239 scores[i] = self.scale[i].ln() + t / self.scale[i];
241 } else {
242 let survival = 1.0 - d.cdf(t) + eps;
244 scores[i] = -survival.ln();
245 }
246 }
247 scores
248 }
249
250 fn censored_d_score(&self, y: &SurvivalData) -> Array2<f64> {
251 let n_obs = y.len();
252 let mut d_params = Array2::zeros((n_obs, 1));
253
254 for i in 0..n_obs {
255 let t = y.time[i];
256 let e = y.event[i];
257
258 if e {
259 d_params[[i, 0]] = 1.0 - t / self.scale[i];
261 } else {
262 d_params[[i, 0]] = t / self.scale[i];
265 }
266 d_params[[i, 0]] = -d_params[[i, 0]];
268 }
269 d_params
270 }
271
272 fn censored_metric(&self) -> Array3<f64> {
273 let n_obs = self.scale.len();
274 let mut fi = Array3::zeros((n_obs, 1, 1));
275
276 for i in 0..n_obs {
277 fi[[i, 0, 0]] = 1.0;
278 }
279
280 fi
281 }
282}
283
284impl CensoredScorable<CRPScoreCensored> for Exponential {
289 fn censored_score(&self, y: &SurvivalData) -> Array1<f64> {
290 let mut scores = Array1::zeros(y.len());
291
292 for i in 0..y.len() {
293 let t = y.time[i];
294 let e = y.event[i];
295 let exp_term = (-t / self.scale[i]).exp();
296
297 scores[i] = t + self.scale[i] * (2.0 * exp_term - 1.5);
299
300 if e {
301 let exp_2t = (-2.0 * t / self.scale[i]).exp();
303 scores[i] -= 0.5 * self.scale[i] * exp_2t;
304 }
305 }
306 scores
307 }
308
309 fn censored_d_score(&self, y: &SurvivalData) -> Array2<f64> {
310 let n_obs = y.len();
311 let mut d_params = Array2::zeros((n_obs, 1));
312
313 for i in 0..n_obs {
314 let t = y.time[i];
315 let e = y.event[i];
316 let exp_term = (-t / self.scale[i]).exp();
317
318 d_params[[i, 0]] = 2.0 * exp_term * (t + self.scale[i]) - 1.5 * self.scale[i];
320
321 if e {
322 let exp_2t = (-2.0 * t / self.scale[i]).exp();
324 d_params[[i, 0]] -= exp_2t * (0.5 * self.scale[i] - t);
325 }
326 }
327 d_params
328 }
329
330 fn censored_metric(&self) -> Array3<f64> {
331 let n_obs = self.scale.len();
332 let mut fi = Array3::zeros((n_obs, 1, 1));
333
334 for i in 0..n_obs {
335 fi[[i, 0, 0]] = 0.5 * self.scale[i];
336 }
337
338 fi
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use approx::assert_relative_eq;
346
347 #[test]
348 fn test_exponential_distribution_methods() {
349 let params = Array2::from_shape_vec((2, 1), vec![0.0, 1.0_f64.ln()]).unwrap();
350 let dist = Exponential::from_params(¶ms);
351
352 let mean = dist.mean();
354 assert_relative_eq!(mean[0], 1.0, epsilon = 1e-10);
355 assert_relative_eq!(mean[1], 1.0, epsilon = 1e-10);
356
357 let var = dist.variance();
359 assert_relative_eq!(var[0], 1.0, epsilon = 1e-10);
360 assert_relative_eq!(var[1], 1.0, epsilon = 1e-10);
361
362 let mode = dist.mode();
364 assert_relative_eq!(mode[0], 0.0, epsilon = 1e-10);
365 assert_relative_eq!(mode[1], 0.0, epsilon = 1e-10);
366 }
367
368 #[test]
369 fn test_exponential_cdf_ppf() {
370 let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
371 let dist = Exponential::from_params(¶ms);
372
373 let y = Array1::from_vec(vec![1.0]);
375 let cdf = dist.cdf(&y);
376 assert_relative_eq!(cdf[0], 1.0 - (-1.0_f64).exp(), epsilon = 1e-10);
377
378 let q = Array1::from_vec(vec![0.5]);
380 let ppf = dist.ppf(&q);
381 let cdf_of_ppf = dist.cdf(&ppf);
382 assert_relative_eq!(cdf_of_ppf[0], 0.5, epsilon = 1e-10);
383 }
384
385 #[test]
386 fn test_exponential_sample() {
387 let params = Array2::from_shape_vec((1, 1), vec![2.0_f64.ln()]).unwrap();
388 let dist = Exponential::from_params(¶ms);
389
390 let samples = dist.sample(1000);
391 assert_eq!(samples.shape(), &[1000, 1]);
392
393 assert!(samples.iter().all(|&x| x >= 0.0));
395
396 let sample_mean: f64 = samples.column(0).mean().unwrap();
398 assert!((sample_mean - 2.0).abs() < 0.3);
399 }
400
401 #[test]
402 fn test_exponential_median() {
403 let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
404 let dist = Exponential::from_params(¶ms);
405
406 let median = dist.median();
408 assert_relative_eq!(median[0], std::f64::consts::LN_2, epsilon = 1e-10);
409 }
410
411 #[test]
412 fn test_exponential_fit() {
413 let y = Array1::from_vec(vec![0.5, 1.0, 1.5, 2.0, 2.5]);
414 let params = Exponential::fit(&y);
415 assert_eq!(params.len(), 1);
416 assert_relative_eq!(params[0], 1.5_f64.ln(), epsilon = 1e-10);
418 }
419}