ngboost_rs/dist/
weibull.rs

1use crate::dist::{Distribution, DistributionMethods, RegressionDistn};
2use crate::scores::{
3    CRPScore, CRPScoreCensored, CensoredScorable, LogScore, LogScoreCensored, Scorable,
4    SurvivalData,
5};
6use ndarray::{array, Array1, Array2, Array3};
7use rand::prelude::*;
8use statrs::distribution::{Continuous, ContinuousCDF, Weibull as WeibullDist};
9use statrs::function::gamma::{digamma, gamma};
10
11/// The Weibull distribution.
12#[derive(Debug, Clone)]
13pub struct Weibull {
14    /// The shape parameter (k or c).
15    pub shape: Array1<f64>,
16    /// The scale parameter (lambda).
17    pub scale: Array1<f64>,
18    /// The parameters of the distribution, stored as a 2D array.
19    _params: Array2<f64>,
20}
21
22impl Distribution for Weibull {
23    fn from_params(params: &Array2<f64>) -> Self {
24        let shape = params.column(0).mapv(f64::exp);
25        let scale = params.column(1).mapv(f64::exp);
26        Weibull {
27            shape,
28            scale,
29            _params: params.clone(),
30        }
31    }
32
33    fn fit(y: &Array1<f64>) -> Array1<f64> {
34        // Simple method of moments estimation for Weibull
35        // This is approximate; proper MLE requires numerical optimization
36        let n = y.len();
37        if n == 0 {
38            return array![0.0, 0.0];
39        }
40
41        let mean = y.mean().unwrap_or(1.0);
42        let var = y.var(0.0);
43
44        // Coefficient of variation
45        let cv = (var.sqrt() / mean).clamp(0.1, 10.0);
46
47        // Approximate shape from CV (using approximation k ≈ 1.2 / CV)
48        let shape = (1.2 / cv).max(0.1);
49
50        // Scale from mean: mean = scale * Gamma(1 + 1/shape)
51        let gamma_val = gamma(1.0 + 1.0 / shape);
52        let scale = (mean / gamma_val).max(1e-6);
53
54        array![shape.ln(), scale.ln()]
55    }
56
57    fn n_params(&self) -> usize {
58        2
59    }
60
61    fn predict(&self) -> Array1<f64> {
62        // Mean of Weibull is scale * Gamma(1 + 1/shape)
63        let mut means = Array1::zeros(self.shape.len());
64        for i in 0..self.shape.len() {
65            let gamma_val = gamma(1.0 + 1.0 / self.shape[i]);
66            means[i] = self.scale[i] * gamma_val;
67        }
68        means
69    }
70
71    fn params(&self) -> &Array2<f64> {
72        &self._params
73    }
74}
75
76impl RegressionDistn for Weibull {}
77
78impl DistributionMethods for Weibull {
79    fn mean(&self) -> Array1<f64> {
80        // Mean of Weibull is scale * Gamma(1 + 1/shape)
81        let mut means = Array1::zeros(self.shape.len());
82        for i in 0..self.shape.len() {
83            let gamma_val = gamma(1.0 + 1.0 / self.shape[i]);
84            means[i] = self.scale[i] * gamma_val;
85        }
86        means
87    }
88
89    fn variance(&self) -> Array1<f64> {
90        // Var of Weibull is scale^2 * [Gamma(1 + 2/k) - Gamma(1 + 1/k)^2]
91        let mut vars = Array1::zeros(self.shape.len());
92        for i in 0..self.shape.len() {
93            let k = self.shape[i];
94            let lam = self.scale[i];
95            let gamma_1 = gamma(1.0 + 1.0 / k);
96            let gamma_2 = gamma(1.0 + 2.0 / k);
97            vars[i] = lam * lam * (gamma_2 - gamma_1 * gamma_1);
98        }
99        vars
100    }
101
102    fn pdf(&self, y: &Array1<f64>) -> Array1<f64> {
103        let mut result = Array1::zeros(y.len());
104        for i in 0..y.len() {
105            if y[i] >= 0.0 {
106                if let Ok(d) = WeibullDist::new(self.shape[i], self.scale[i]) {
107                    result[i] = d.pdf(y[i]);
108                }
109            }
110        }
111        result
112    }
113
114    fn logpdf(&self, y: &Array1<f64>) -> Array1<f64> {
115        let mut result = Array1::zeros(y.len());
116        for i in 0..y.len() {
117            if y[i] >= 0.0 {
118                if let Ok(d) = WeibullDist::new(self.shape[i], self.scale[i]) {
119                    result[i] = d.ln_pdf(y[i]);
120                }
121            } else {
122                result[i] = f64::NEG_INFINITY;
123            }
124        }
125        result
126    }
127
128    fn cdf(&self, y: &Array1<f64>) -> Array1<f64> {
129        let mut result = Array1::zeros(y.len());
130        for i in 0..y.len() {
131            if y[i] >= 0.0 {
132                if let Ok(d) = WeibullDist::new(self.shape[i], self.scale[i]) {
133                    result[i] = d.cdf(y[i]);
134                }
135            }
136        }
137        result
138    }
139
140    fn ppf(&self, q: &Array1<f64>) -> Array1<f64> {
141        // Inverse CDF for Weibull: scale * (-ln(1 - q))^(1/shape)
142        let mut result = Array1::zeros(q.len());
143        for i in 0..q.len() {
144            let q_clamped = q[i].clamp(1e-15, 1.0 - 1e-15);
145            result[i] = self.scale[i] * (-(1.0 - q_clamped).ln()).powf(1.0 / self.shape[i]);
146        }
147        result
148    }
149
150    fn sample(&self, n_samples: usize) -> Array2<f64> {
151        let n_obs = self.shape.len();
152        let mut samples = Array2::zeros((n_samples, n_obs));
153        let mut rng = rand::rng();
154
155        for i in 0..n_obs {
156            for s in 0..n_samples {
157                // Use inverse CDF method: scale * (-ln(1 - u))^(1/shape)
158                let u: f64 = rng.random();
159                samples[[s, i]] = self.scale[i] * (-(1.0 - u).ln()).powf(1.0 / self.shape[i]);
160            }
161        }
162        samples
163    }
164
165    fn median(&self) -> Array1<f64> {
166        // Median of Weibull is scale * (ln(2))^(1/shape)
167        let mut result = Array1::zeros(self.shape.len());
168        for i in 0..self.shape.len() {
169            result[i] = self.scale[i] * std::f64::consts::LN_2.powf(1.0 / self.shape[i]);
170        }
171        result
172    }
173
174    fn mode(&self) -> Array1<f64> {
175        // Mode of Weibull is scale * ((k-1)/k)^(1/k) for k > 1, else 0
176        let mut result = Array1::zeros(self.shape.len());
177        for i in 0..self.shape.len() {
178            let k = self.shape[i];
179            if k > 1.0 {
180                result[i] = self.scale[i] * ((k - 1.0) / k).powf(1.0 / k);
181            }
182            // For k <= 1, mode is 0 (already initialized)
183        }
184        result
185    }
186}
187
188impl Scorable<LogScore> for Weibull {
189    fn score(&self, y: &Array1<f64>) -> Array1<f64> {
190        let mut scores = Array1::zeros(y.len());
191        for (i, &y_i) in y.iter().enumerate() {
192            // statrs Weibull uses (shape, scale) parameterization
193            let d = WeibullDist::new(self.shape[i], self.scale[i]).unwrap();
194            scores[i] = -d.ln_pdf(y_i);
195        }
196        scores
197    }
198
199    fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
200        let n_obs = y.len();
201        let mut d_params = Array2::zeros((n_obs, 2));
202
203        for i in 0..n_obs {
204            let k = self.shape[i];
205            let lam = self.scale[i];
206            let y_i = y[i];
207
208            // Ratio y/scale
209            let ratio = y_i / lam;
210            let ratio_k = ratio.powf(k);
211
212            // shared_term = k * ((y/scale)^k - 1)
213            let shared_term = k * (ratio_k - 1.0);
214
215            // d/d(log(shape)) = shape * [shared_term * log(y/scale) - 1]
216            // But we parameterize as log(shape), so multiply by shape
217            d_params[[i, 0]] = shared_term * ratio.ln() - 1.0;
218
219            // d/d(log(scale)) = -shared_term
220            d_params[[i, 1]] = -shared_term;
221        }
222
223        d_params
224    }
225
226    fn metric(&self) -> Array3<f64> {
227        // Fisher Information Matrix for Weibull (from Python implementation)
228        // Uses Euler's constant gamma ≈ 0.5772156649
229        let euler_gamma = 0.5772156649;
230        let n_obs = self.shape.len();
231        let mut fi = Array3::zeros((n_obs, 2, 2));
232
233        for i in 0..n_obs {
234            let k = self.shape[i];
235
236            // FI[0, 0] = (pi^2 / 6) + (1 - gamma)^2
237            let pi = std::f64::consts::PI;
238            let one_minus_gamma = 1.0 - euler_gamma;
239            fi[[i, 0, 0]] = (pi * pi / 6.0) + (one_minus_gamma * one_minus_gamma);
240
241            // FI[1, 0] = FI[0, 1] = -k * (1 - gamma)
242            fi[[i, 0, 1]] = -k * (1.0 - euler_gamma);
243            fi[[i, 1, 0]] = fi[[i, 0, 1]];
244
245            // FI[1, 1] = k^2
246            fi[[i, 1, 1]] = k * k;
247        }
248
249        fi
250    }
251}
252
253// ============================================================================
254// CRPScore for Weibull (uncensored)
255// ============================================================================
256
257impl Scorable<CRPScore> for Weibull {
258    fn score(&self, y: &Array1<f64>) -> Array1<f64> {
259        // CRPS for Weibull distribution
260        // Reference: Gneiting, T., & Raftery, A. E. (2007). Strictly proper scoring rules,
261        // prediction, and estimation. JASA, 102(477), 359-378.
262        //
263        // CRPS = y * (2*F(y) - 1) - λ * Γ(1+1/k) * (2*(1-F(y)) - 1 + 2^(-1/k))
264        let mut scores = Array1::zeros(y.len());
265
266        for i in 0..y.len() {
267            let k = self.shape[i];
268            let lam = self.scale[i];
269            let y_i = y[i].max(1e-10); // Avoid issues with zero
270
271            let z = y_i / lam;
272            let z_k = z.powf(k);
273
274            // CDF at y: F(y) = 1 - exp(-(y/λ)^k)
275            let f_y = 1.0 - (-z_k).exp();
276
277            // Gamma(1 + 1/k)
278            let gamma_term = gamma(1.0 + 1.0 / k);
279
280            // 2^(-1/k)
281            let two_pow = 2.0_f64.powf(-1.0 / k);
282
283            // CRPS formula
284            scores[i] =
285                y_i * (2.0 * f_y - 1.0) - lam * gamma_term * (2.0 * (1.0 - f_y) - 1.0 + two_pow);
286        }
287        scores
288    }
289
290    fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
291        // Gradient of CRPS with respect to log(shape) and log(scale)
292        let n_obs = y.len();
293        let mut d_params = Array2::zeros((n_obs, 2));
294        let eps = 1e-10;
295
296        for i in 0..n_obs {
297            let k = self.shape[i];
298            let lam = self.scale[i];
299            let y_i = y[i].max(eps);
300
301            let z = y_i / lam;
302            let z_k = z.powf(k);
303            let ln_z = z.ln();
304
305            // F(y) = 1 - exp(-z^k)
306            let s_y = (-z_k).exp(); // Survival function S(y) = 1 - F(y)
307
308            // PDF: f(y) = (k/λ) * z^(k-1) * exp(-z^k)
309            let pdf_y = (k / lam) * z.powf(k - 1.0) * s_y;
310
311            // Gamma terms
312            let gamma_term = gamma(1.0 + 1.0 / k);
313            let dgamma_dk = -gamma_term * digamma(1.0 + 1.0 / k) / (k * k);
314
315            // 2^(-1/k) and its derivative
316            let two_pow = 2.0_f64.powf(-1.0 / k);
317            let dtwo_pow_dk = two_pow * 2.0_f64.ln() / (k * k);
318
319            // dF/dk = f(y) * z^k * ln(z)
320            let df_dk = pdf_y * z_k * ln_z / k;
321
322            // dF/dλ = -f(y) * k * z / λ = -pdf_y * k * z / lam
323            let df_dlam = -pdf_y * k * z / lam;
324
325            // dS/dk = -dF/dk, dS/dlam = -dF/dlam
326            let ds_dk = -df_dk;
327            let ds_dlam = -df_dlam;
328
329            // d(CRPS)/d(log k) = k * d(CRPS)/dk
330            // CRPS = y*(2F - 1) - λ*Γ*(2S - 1 + 2^(-1/k))
331            // d(CRPS)/dk = y*2*dF/dk - λ*dΓ/dk*(2S - 1 + 2^(-1/k)) - λ*Γ*(2*dS/dk + d(2^(-1/k))/dk)
332            let dcrps_dk = y_i * 2.0 * df_dk
333                - lam * dgamma_dk * (2.0 * s_y - 1.0 + two_pow)
334                - lam * gamma_term * (2.0 * ds_dk + dtwo_pow_dk);
335            d_params[[i, 0]] = k * dcrps_dk;
336
337            // d(CRPS)/d(log λ) = λ * d(CRPS)/dλ
338            // d(CRPS)/dλ = y*2*dF/dλ - Γ*(2S - 1 + 2^(-1/k)) - λ*Γ*2*dS/dλ
339            let dcrps_dlam = y_i * 2.0 * df_dlam
340                - gamma_term * (2.0 * s_y - 1.0 + two_pow)
341                - lam * gamma_term * 2.0 * ds_dlam;
342            d_params[[i, 1]] = lam * dcrps_dlam;
343        }
344        d_params
345    }
346
347    fn metric(&self) -> Array3<f64> {
348        // CRPS metric for Weibull
349        // Use an approximation based on the structure from Python
350        let n_obs = self.shape.len();
351        let mut fi = Array3::zeros((n_obs, 2, 2));
352        let eps = 1e-10;
353        let sqrt_pi = std::f64::consts::PI.sqrt();
354
355        for i in 0..n_obs {
356            let k = self.shape[i];
357            let lam = self.scale[i];
358            let gamma_term = gamma(1.0 + 1.0 / k);
359
360            // Approximate metric based on CRPS second moments (from Python)
361            fi[[i, 0, 0]] = (lam * lam * gamma_term * gamma_term / (k * k)) / (2.0 * sqrt_pi) + eps;
362            fi[[i, 1, 1]] = (lam * lam * gamma_term * gamma_term) / (2.0 * sqrt_pi) + eps;
363            fi[[i, 0, 1]] = 0.0;
364            fi[[i, 1, 0]] = 0.0;
365        }
366        fi
367    }
368}
369
370// ============================================================================
371// Censored LogScore for survival analysis
372// ============================================================================
373
374impl CensoredScorable<LogScoreCensored> for Weibull {
375    fn censored_score(&self, y: &SurvivalData) -> Array1<f64> {
376        // For right-censored data:
377        // - Uncensored (E=1): use log-likelihood = log(f(t)) → score = -log(f(t))
378        // - Censored (E=0): use log-survival = log(S(t)) → score = -log(S(t))
379        let mut scores = Array1::zeros(y.len());
380
381        for i in 0..y.len() {
382            let t = y.time[i].max(1e-10);
383            let e = y.event[i];
384            let k = self.shape[i];
385            let lam = self.scale[i];
386
387            let z = t / lam;
388            let z_k = z.powf(k);
389
390            if e {
391                // Uncensored: -log(f(t))
392                // f(t) = (k/λ) * (t/λ)^(k-1) * exp(-(t/λ)^k)
393                // log(f(t)) = log(k) - log(λ) + (k-1)*log(z) - z^k
394                let log_pdf = k.ln() - lam.ln() + (k - 1.0) * z.ln() - z_k;
395                scores[i] = -log_pdf;
396            } else {
397                // Censored: -log(S(t)) where S(t) = exp(-z^k)
398                // -log(S(t)) = z^k
399                scores[i] = z_k;
400            }
401        }
402        scores
403    }
404
405    fn censored_d_score(&self, y: &SurvivalData) -> Array2<f64> {
406        let n_obs = y.len();
407        let mut d_params = Array2::zeros((n_obs, 2));
408        let eps = 1e-10;
409
410        for i in 0..n_obs {
411            let t = y.time[i].max(eps);
412            let e = y.event[i];
413            let k = self.shape[i];
414            let lam = self.scale[i];
415
416            let z = t / lam;
417            let z_k = z.powf(k);
418            let ln_z = z.ln();
419
420            if e {
421                // Uncensored gradient (same as LogScore)
422                // shared_term = k * (z^k - 1)
423                let shared_term = k * (z_k - 1.0);
424                d_params[[i, 0]] = shared_term * ln_z - 1.0;
425                d_params[[i, 1]] = -shared_term;
426            } else {
427                // Censored gradient: d/dθ[z^k] where z = t/λ
428                // d(z^k)/d(log k) = k * z^k * ln(z)
429                // d(z^k)/d(log λ) = -k * z^k
430                d_params[[i, 0]] = k * z_k * ln_z;
431                d_params[[i, 1]] = -k * z_k;
432            }
433        }
434        d_params
435    }
436
437    fn censored_metric(&self) -> Array3<f64> {
438        // Use the uncensored Fisher information as approximation
439        Scorable::<LogScore>::metric(self)
440    }
441}
442
443// ============================================================================
444// Censored CRPScore for survival analysis
445// ============================================================================
446
447impl CensoredScorable<CRPScoreCensored> for Weibull {
448    fn censored_score(&self, y: &SurvivalData) -> Array1<f64> {
449        // CRPS for right-censored Weibull data
450        // For uncensored: standard CRPS
451        // For censored: modified CRPS that accounts for right-censoring
452        let mut scores = Array1::zeros(y.len());
453        let eps = 1e-10;
454
455        for i in 0..y.len() {
456            let t = y.time[i].max(eps);
457            let e = y.event[i];
458            let k = self.shape[i];
459            let lam = self.scale[i];
460
461            let z = t / lam;
462            let z_k = z.powf(k);
463
464            let f_t = 1.0 - (-z_k).exp(); // CDF
465            let s_t = (-z_k).exp(); // Survival function
466
467            let gamma_term = gamma(1.0 + 1.0 / k);
468            let two_pow = 2.0_f64.powf(-1.0 / k);
469
470            if e {
471                // Uncensored CRPS (same as regular CRPS)
472                scores[i] = t * (2.0 * f_t - 1.0) - lam * gamma_term * (2.0 * s_t - 1.0 + two_pow);
473            } else {
474                // Censored CRPS (adapted for right-censoring)
475                // From Python: crps_cens = t*F^2 + 2*λ*Γ*S*F - λ*Γ*2^(-1/k)*(1 - S^2)/2
476                scores[i] = t * f_t * f_t + 2.0 * lam * gamma_term * s_t * f_t
477                    - lam * gamma_term * two_pow * (1.0 - s_t * s_t) / 2.0;
478            }
479        }
480        scores
481    }
482
483    fn censored_d_score(&self, y: &SurvivalData) -> Array2<f64> {
484        let n_obs = y.len();
485        let mut d_params = Array2::zeros((n_obs, 2));
486        let eps = 1e-10;
487
488        for i in 0..n_obs {
489            let t = y.time[i].max(eps);
490            let e = y.event[i];
491            let k = self.shape[i];
492            let lam = self.scale[i];
493
494            let z = t / lam;
495            let z_k = z.powf(k);
496            let ln_z = z.ln();
497
498            let f_t = 1.0 - (-z_k).exp();
499            let s_t = (-z_k).exp();
500
501            // PDF: f(t) = (k/λ) * z^(k-1) * exp(-z^k)
502            let pdf_t = (k / lam) * z.powf(k - 1.0) * s_t;
503
504            let gamma_term = gamma(1.0 + 1.0 / k);
505            let dgamma_dk = -gamma_term * digamma(1.0 + 1.0 / k) / (k * k);
506
507            let two_pow = 2.0_f64.powf(-1.0 / k);
508            let dtwo_pow_dk = two_pow * 2.0_f64.ln() / (k * k);
509
510            // dF/dk and dS/dk
511            let df_dk = pdf_t * z_k * ln_z / k;
512            let ds_dk = -df_dk;
513
514            // dF/dλ and dS/dλ
515            let df_dlam = -pdf_t * k * z / lam;
516            let ds_dlam = -df_dlam;
517
518            if e {
519                // Uncensored derivatives (same as CRPScore)
520                let dcrps_dk = t * 2.0 * df_dk
521                    - lam * dgamma_dk * (2.0 * s_t - 1.0 + two_pow)
522                    - lam * gamma_term * (2.0 * ds_dk + dtwo_pow_dk);
523                d_params[[i, 0]] = k * dcrps_dk;
524
525                let dcrps_dlam = t * 2.0 * df_dlam
526                    - gamma_term * (2.0 * s_t - 1.0 + two_pow)
527                    - lam * gamma_term * 2.0 * ds_dlam;
528                d_params[[i, 1]] = lam * dcrps_dlam;
529            } else {
530                // Censored derivatives
531                // crps_cens = t*F^2 + 2*λ*Γ*S*F - λ*Γ*2^(-1/k)*(1 - S^2)/2
532
533                // d/dk[crps_cens]
534                let dcrps_cens_dk = t * 2.0 * f_t * df_dk
535                    + 2.0 * lam * dgamma_dk * s_t * f_t
536                    + 2.0 * lam * gamma_term * (ds_dk * f_t + s_t * df_dk)
537                    - lam * dgamma_dk * two_pow * (1.0 - s_t * s_t) / 2.0
538                    - lam * gamma_term * dtwo_pow_dk * (1.0 - s_t * s_t) / 2.0
539                    + lam * gamma_term * two_pow * s_t * ds_dk;
540                d_params[[i, 0]] = k * dcrps_cens_dk;
541
542                // d/dλ[crps_cens]
543                let dcrps_cens_dlam = t * 2.0 * f_t * df_dlam
544                    + 2.0 * gamma_term * s_t * f_t
545                    + 2.0 * lam * gamma_term * (ds_dlam * f_t + s_t * df_dlam)
546                    - gamma_term * two_pow * (1.0 - s_t * s_t) / 2.0
547                    + lam * gamma_term * two_pow * s_t * ds_dlam;
548                d_params[[i, 1]] = lam * dcrps_cens_dlam;
549            }
550        }
551        d_params
552    }
553
554    fn censored_metric(&self) -> Array3<f64> {
555        // Use the CRPS metric
556        Scorable::<CRPScore>::metric(self)
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use approx::assert_relative_eq;
564
565    #[test]
566    fn test_weibull_distribution_methods() {
567        // shape=2, scale=1 (Rayleigh distribution)
568        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
569        let dist = Weibull::from_params(&params);
570
571        // Test mean: scale * Gamma(1 + 1/shape) = 1 * Gamma(1.5) ≈ 0.886
572        let mean = dist.mean();
573        assert!(mean[0] > 0.8 && mean[0] < 1.0);
574
575        // Test variance
576        let var = dist.variance();
577        assert!(var[0] > 0.0);
578
579        // Test mode: scale * ((k-1)/k)^(1/k) = 1 * (0.5)^0.5 ≈ 0.707
580        let mode = dist.mode();
581        assert_relative_eq!(mode[0], 0.5_f64.sqrt(), epsilon = 1e-6);
582    }
583
584    #[test]
585    fn test_weibull_cdf_ppf() {
586        let params = Array2::from_shape_vec((1, 2), vec![1.0_f64.ln(), 0.0]).unwrap();
587        let dist = Weibull::from_params(&params);
588
589        // For shape=1, Weibull is Exponential(1)
590        // CDF at 1 should be 1 - exp(-1) ≈ 0.632
591        let y = Array1::from_vec(vec![1.0]);
592        let cdf = dist.cdf(&y);
593        assert_relative_eq!(cdf[0], 1.0 - (-1.0_f64).exp(), epsilon = 1e-6);
594
595        // PPF inverse test
596        let q = Array1::from_vec(vec![0.5]);
597        let ppf = dist.ppf(&q);
598        let cdf_of_ppf = dist.cdf(&ppf);
599        assert_relative_eq!(cdf_of_ppf[0], 0.5, epsilon = 1e-6);
600    }
601
602    #[test]
603    fn test_weibull_sample() {
604        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 1.0_f64.ln()]).unwrap();
605        let dist = Weibull::from_params(&params);
606
607        let samples = dist.sample(1000);
608        assert_eq!(samples.shape(), &[1000, 1]);
609
610        // All samples should be non-negative
611        assert!(samples.iter().all(|&x| x >= 0.0));
612
613        // Check that sample mean is close to theoretical mean
614        let sample_mean: f64 = samples.column(0).mean().unwrap();
615        let theoretical_mean = dist.mean()[0];
616        assert!((sample_mean - theoretical_mean).abs() / theoretical_mean < 0.15);
617    }
618
619    #[test]
620    fn test_weibull_median() {
621        let params = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).unwrap();
622        let dist = Weibull::from_params(&params);
623
624        // For shape=1, scale=1, median = ln(2) ≈ 0.693
625        let median = dist.median();
626        assert_relative_eq!(median[0], std::f64::consts::LN_2, epsilon = 1e-10);
627    }
628
629    #[test]
630    fn test_weibull_fit() {
631        let y = Array1::from_vec(vec![0.5, 1.0, 1.5, 2.0, 2.5]);
632        let params = Weibull::fit(&y);
633        assert_eq!(params.len(), 2);
634        // Should return log(shape) and log(scale)
635    }
636
637    #[test]
638    fn test_weibull_logscore() {
639        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
640        let dist = Weibull::from_params(&params);
641
642        let y = Array1::from_vec(vec![1.0]);
643        let score = Scorable::<LogScore>::score(&dist, &y);
644
645        // Score should be finite and positive
646        assert!(score[0].is_finite());
647        assert!(score[0] > 0.0);
648    }
649
650    #[test]
651    fn test_weibull_d_score() {
652        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
653        let dist = Weibull::from_params(&params);
654
655        let y = Array1::from_vec(vec![1.0]);
656        let d_score = Scorable::<LogScore>::d_score(&dist, &y);
657
658        // Gradients should be finite
659        assert!(d_score[[0, 0]].is_finite());
660        assert!(d_score[[0, 1]].is_finite());
661    }
662
663    #[test]
664    fn test_weibull_interval() {
665        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
666        let dist = Weibull::from_params(&params);
667
668        let (lower, upper) = dist.interval(0.1);
669        assert!(lower[0] > 0.0);
670        assert!(upper[0] > lower[0]);
671    }
672
673    #[test]
674    fn test_weibull_survival_function() {
675        let params = Array2::from_shape_vec((1, 2), vec![1.0_f64.ln(), 0.0]).unwrap();
676        let dist = Weibull::from_params(&params);
677
678        // For shape=1, Weibull is Exponential
679        // SF at 0 should be 1
680        let y = Array1::from_vec(vec![0.0]);
681        let sf = dist.sf(&y);
682        assert_relative_eq!(sf[0], 1.0, epsilon = 1e-10);
683
684        // SF + CDF should equal 1
685        let y = Array1::from_vec(vec![1.0]);
686        let sf = dist.sf(&y);
687        let cdf = dist.cdf(&y);
688        assert_relative_eq!(sf[0] + cdf[0], 1.0, epsilon = 1e-10);
689    }
690
691    // ========================================================================
692    // CRPScore tests
693    // ========================================================================
694
695    #[test]
696    fn test_weibull_crpscore() {
697        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
698        let dist = Weibull::from_params(&params);
699
700        let y = Array1::from_vec(vec![1.0]);
701        let score = Scorable::<CRPScore>::score(&dist, &y);
702
703        // CRPS should be finite and non-negative
704        assert!(score[0].is_finite());
705        // CRPS can be negative for some parameter/observation combinations
706    }
707
708    #[test]
709    fn test_weibull_crpscore_d_score() {
710        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
711        let dist = Weibull::from_params(&params);
712
713        let y = Array1::from_vec(vec![1.0]);
714        let d_score = Scorable::<CRPScore>::d_score(&dist, &y);
715
716        // Gradients should be finite
717        assert!(d_score[[0, 0]].is_finite());
718        assert!(d_score[[0, 1]].is_finite());
719    }
720
721    #[test]
722    fn test_weibull_crpscore_metric() {
723        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
724        let dist = Weibull::from_params(&params);
725
726        let metric = Scorable::<CRPScore>::metric(&dist);
727
728        // Metric should be 2x2 positive definite
729        assert_eq!(metric.shape(), &[1, 2, 2]);
730        assert!(metric[[0, 0, 0]] > 0.0);
731        assert!(metric[[0, 1, 1]] > 0.0);
732    }
733
734    // ========================================================================
735    // Censored LogScore tests
736    // ========================================================================
737
738    #[test]
739    fn test_weibull_censored_logscore() {
740        let params =
741            Array2::from_shape_vec((2, 2), vec![2.0_f64.ln(), 0.0, 2.0_f64.ln(), 0.0]).unwrap();
742        let dist = Weibull::from_params(&params);
743
744        let time = Array1::from_vec(vec![1.0, 2.0]);
745        let event = Array1::from_vec(vec![1.0, 0.0]); // First is uncensored, second is censored
746        let y = SurvivalData::from_arrays(&time, &event);
747
748        let scores = CensoredScorable::<LogScoreCensored>::censored_score(&dist, &y);
749
750        // Scores should be finite
751        assert!(scores[0].is_finite());
752        assert!(scores[1].is_finite());
753
754        // Uncensored score should match regular LogScore
755        let regular_y = Array1::from_vec(vec![1.0]);
756        let regular_params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
757        let regular_dist = Weibull::from_params(&regular_params);
758        let regular_score = Scorable::<LogScore>::score(&regular_dist, &regular_y);
759        assert_relative_eq!(scores[0], regular_score[0], epsilon = 1e-10);
760    }
761
762    #[test]
763    fn test_weibull_censored_logscore_d_score() {
764        let params =
765            Array2::from_shape_vec((2, 2), vec![2.0_f64.ln(), 0.0, 2.0_f64.ln(), 0.0]).unwrap();
766        let dist = Weibull::from_params(&params);
767
768        let time = Array1::from_vec(vec![1.0, 2.0]);
769        let event = Array1::from_vec(vec![1.0, 0.0]);
770        let y = SurvivalData::from_arrays(&time, &event);
771
772        let d_scores = CensoredScorable::<LogScoreCensored>::censored_d_score(&dist, &y);
773
774        // All gradients should be finite
775        assert!(d_scores.iter().all(|&x| x.is_finite()));
776    }
777
778    // ========================================================================
779    // Censored CRPScore tests
780    // ========================================================================
781
782    #[test]
783    fn test_weibull_censored_crpscore() {
784        let params =
785            Array2::from_shape_vec((2, 2), vec![2.0_f64.ln(), 0.0, 2.0_f64.ln(), 0.0]).unwrap();
786        let dist = Weibull::from_params(&params);
787
788        let time = Array1::from_vec(vec![1.0, 2.0]);
789        let event = Array1::from_vec(vec![1.0, 0.0]); // First is uncensored, second is censored
790        let y = SurvivalData::from_arrays(&time, &event);
791
792        let scores = CensoredScorable::<CRPScoreCensored>::censored_score(&dist, &y);
793
794        // Scores should be finite
795        assert!(scores[0].is_finite());
796        assert!(scores[1].is_finite());
797
798        // Uncensored score should match regular CRPScore
799        let regular_y = Array1::from_vec(vec![1.0]);
800        let regular_params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
801        let regular_dist = Weibull::from_params(&regular_params);
802        let regular_score = Scorable::<CRPScore>::score(&regular_dist, &regular_y);
803        assert_relative_eq!(scores[0], regular_score[0], epsilon = 1e-10);
804    }
805
806    #[test]
807    fn test_weibull_censored_crpscore_d_score() {
808        let params =
809            Array2::from_shape_vec((2, 2), vec![2.0_f64.ln(), 0.0, 2.0_f64.ln(), 0.0]).unwrap();
810        let dist = Weibull::from_params(&params);
811
812        let time = Array1::from_vec(vec![1.0, 2.0]);
813        let event = Array1::from_vec(vec![1.0, 0.0]);
814        let y = SurvivalData::from_arrays(&time, &event);
815
816        let d_scores = CensoredScorable::<CRPScoreCensored>::censored_d_score(&dist, &y);
817
818        // All gradients should be finite
819        assert!(d_scores.iter().all(|&x| x.is_finite()));
820    }
821}