ngboost_rs/dist/mod.rs
1pub mod categorical;
2pub mod cauchy;
3pub mod exponential;
4pub mod gamma;
5pub mod halfnormal;
6pub mod laplace;
7pub mod lognormal;
8pub mod multivariate_normal;
9pub mod normal;
10pub mod poisson;
11pub mod studentt;
12pub mod weibull;
13
14// Re-export all distributions for convenience
15pub use categorical::{
16 Bernoulli, Categorical, Categorical10, Categorical3, Categorical4, Categorical5,
17};
18pub use cauchy::{Cauchy, CauchyFixedVar};
19pub use exponential::Exponential;
20pub use gamma::Gamma;
21pub use halfnormal::HalfNormal;
22pub use laplace::Laplace;
23pub use lognormal::LogNormal;
24pub use multivariate_normal::{
25 MultivariateNormal, MultivariateNormal2, MultivariateNormal3, MultivariateNormal4,
26};
27pub use normal::{Normal, NormalFixedMean, NormalFixedVar};
28pub use poisson::Poisson;
29pub use studentt::{StudentT, TFixedDf, TFixedDfFixedVar};
30pub use weibull::Weibull;
31
32use crate::scores::{Scorable, Score};
33use ndarray::{Array1, Array2};
34use std::fmt::Debug;
35
36/// A trait for probability distributions.
37pub trait Distribution: Sized + Clone + Debug {
38 /// Creates a new distribution from a set of parameters.
39 fn from_params(params: &Array2<f64>) -> Self;
40
41 /// Fits the distribution to the data `y` and returns the initial parameters.
42 fn fit(y: &Array1<f64>) -> Array1<f64>;
43
44 /// Returns the number of parameters for this distribution.
45 fn n_params(&self) -> usize;
46
47 /// Returns a point prediction.
48 fn predict(&self) -> Array1<f64>;
49
50 /// Returns the parameters of the distribution.
51 fn params(&self) -> &Array2<f64>;
52
53 /// Calculates the gradient of the score with respect to the distribution's parameters.
54 fn grad<S: Score>(&self, y: &Array1<f64>, _score: S, natural: bool) -> Array2<f64>
55 where
56 Self: Scorable<S>,
57 {
58 Scorable::grad(self, y, natural)
59 }
60
61 fn total_score<S: Score>(&self, y: &Array1<f64>, _score: S) -> f64
62 where
63 Self: Scorable<S>,
64 {
65 Scorable::total_score(self, y, None)
66 }
67}
68
69/// A sub-trait for distributions used in regression.
70pub trait RegressionDistn: Distribution {}
71
72/// A sub-trait for distributions used in classification.
73pub trait ClassificationDistn: Distribution {
74 fn class_probs(&self) -> Array2<f64>;
75}
76
77/// A trait providing common distribution helper methods.
78///
79/// This trait provides scipy-like methods for distributions:
80/// - `mean()`, `std()`, `var()` - moments
81/// - `pdf()`, `logpdf()` - probability density functions
82/// - `cdf()` - cumulative distribution function
83/// - `ppf()` - percent point function (inverse CDF / quantile function)
84/// - `sample()` - random sampling
85/// - `interval()` - confidence intervals
86pub trait DistributionMethods: Distribution {
87 /// Returns the mean of the distribution for each observation.
88 fn mean(&self) -> Array1<f64>;
89
90 /// Returns the variance of the distribution for each observation.
91 fn variance(&self) -> Array1<f64>;
92
93 /// Returns the standard deviation of the distribution for each observation.
94 fn std(&self) -> Array1<f64> {
95 self.variance().mapv(f64::sqrt)
96 }
97
98 /// Evaluates the probability density function at point y for each observation.
99 fn pdf(&self, y: &Array1<f64>) -> Array1<f64>;
100
101 /// Evaluates the log probability density function at point y for each observation.
102 fn logpdf(&self, y: &Array1<f64>) -> Array1<f64> {
103 self.pdf(y).mapv(|p| p.ln())
104 }
105
106 /// Evaluates the cumulative distribution function at point y for each observation.
107 fn cdf(&self, y: &Array1<f64>) -> Array1<f64>;
108
109 /// Evaluates the percent point function (inverse CDF / quantile function).
110 /// Returns the value y such that P(Y <= y) = q.
111 fn ppf(&self, q: &Array1<f64>) -> Array1<f64>;
112
113 /// Generates random samples from the distribution.
114 ///
115 /// # Arguments
116 /// * `n_samples` - Number of samples to generate per observation
117 ///
118 /// # Returns
119 /// Array of shape (n_samples, n_observations)
120 fn sample(&self, n_samples: usize) -> Array2<f64>;
121
122 /// Returns the confidence interval for each observation.
123 ///
124 /// # Arguments
125 /// * `alpha` - Significance level (e.g., 0.05 for 95% CI)
126 ///
127 /// # Returns
128 /// Tuple of (lower bounds, upper bounds)
129 fn interval(&self, alpha: f64) -> (Array1<f64>, Array1<f64>) {
130 let lower_q = Array1::from_elem(self.mean().len(), alpha / 2.0);
131 let upper_q = Array1::from_elem(self.mean().len(), 1.0 - alpha / 2.0);
132 (self.ppf(&lower_q), self.ppf(&upper_q))
133 }
134
135 /// Returns the survival function (1 - CDF) at point y for each observation.
136 fn sf(&self, y: &Array1<f64>) -> Array1<f64> {
137 1.0 - self.cdf(y)
138 }
139
140 /// Returns the median of the distribution for each observation.
141 fn median(&self) -> Array1<f64> {
142 let q = Array1::from_elem(self.mean().len(), 0.5);
143 self.ppf(&q)
144 }
145
146 /// Returns the mode of the distribution for each observation (if well-defined).
147 /// Default implementation returns the mean; override for distributions where mode != mean.
148 fn mode(&self) -> Array1<f64> {
149 self.mean()
150 }
151}
152
153/// A trait for multivariate distributions with additional methods.
154pub trait MultivariateDistributionMethods: Distribution {
155 /// Returns the mean vector of the distribution for each observation.
156 /// Shape: (n_observations, n_dimensions)
157 fn mean(&self) -> Array2<f64>;
158
159 /// Returns the covariance matrix of the distribution for each observation.
160 /// Shape: (n_observations, n_dimensions, n_dimensions)
161 fn covariance(&self) -> ndarray::Array3<f64>;
162
163 /// Generates random samples from the multivariate distribution.
164 ///
165 /// # Arguments
166 /// * `n_samples` - Number of samples to generate per observation
167 ///
168 /// # Returns
169 /// Array of shape (n_samples, n_observations, n_dimensions)
170 fn sample(&self, n_samples: usize) -> ndarray::Array3<f64>;
171}