ngboost_rs/dist/
mod.rs

1pub mod categorical;
2pub mod cauchy;
3pub mod exponential;
4pub mod gamma;
5pub mod halfnormal;
6pub mod laplace;
7pub mod lognormal;
8pub mod multivariate_normal;
9pub mod normal;
10pub mod poisson;
11pub mod studentt;
12pub mod weibull;
13
14// Re-export all distributions for convenience
15pub use categorical::{
16    Bernoulli, Categorical, Categorical10, Categorical3, Categorical4, Categorical5,
17};
18pub use cauchy::{Cauchy, CauchyFixedVar};
19pub use exponential::Exponential;
20pub use gamma::Gamma;
21pub use halfnormal::HalfNormal;
22pub use laplace::Laplace;
23pub use lognormal::LogNormal;
24pub use multivariate_normal::{
25    MultivariateNormal, MultivariateNormal2, MultivariateNormal3, MultivariateNormal4,
26};
27pub use normal::{Normal, NormalFixedMean, NormalFixedVar};
28pub use poisson::Poisson;
29pub use studentt::{StudentT, TFixedDf, TFixedDfFixedVar};
30pub use weibull::Weibull;
31
32use crate::scores::{Scorable, Score};
33use ndarray::{Array1, Array2};
34use std::fmt::Debug;
35
36/// A trait for probability distributions.
37pub trait Distribution: Sized + Clone + Debug {
38    /// Creates a new distribution from a set of parameters.
39    fn from_params(params: &Array2<f64>) -> Self;
40
41    /// Fits the distribution to the data `y` and returns the initial parameters.
42    fn fit(y: &Array1<f64>) -> Array1<f64>;
43
44    /// Returns the number of parameters for this distribution.
45    fn n_params(&self) -> usize;
46
47    /// Returns a point prediction.
48    fn predict(&self) -> Array1<f64>;
49
50    /// Returns the parameters of the distribution.
51    fn params(&self) -> &Array2<f64>;
52
53    /// Calculates the gradient of the score with respect to the distribution's parameters.
54    fn grad<S: Score>(&self, y: &Array1<f64>, _score: S, natural: bool) -> Array2<f64>
55    where
56        Self: Scorable<S>,
57    {
58        Scorable::grad(self, y, natural)
59    }
60
61    fn total_score<S: Score>(&self, y: &Array1<f64>, _score: S) -> f64
62    where
63        Self: Scorable<S>,
64    {
65        Scorable::total_score(self, y, None)
66    }
67}
68
69/// A sub-trait for distributions used in regression.
70pub trait RegressionDistn: Distribution {}
71
72/// A sub-trait for distributions used in classification.
73pub trait ClassificationDistn: Distribution {
74    fn class_probs(&self) -> Array2<f64>;
75}
76
77/// A trait providing common distribution helper methods.
78///
79/// This trait provides scipy-like methods for distributions:
80/// - `mean()`, `std()`, `var()` - moments
81/// - `pdf()`, `logpdf()` - probability density functions
82/// - `cdf()` - cumulative distribution function
83/// - `ppf()` - percent point function (inverse CDF / quantile function)
84/// - `sample()` - random sampling
85/// - `interval()` - confidence intervals
86pub trait DistributionMethods: Distribution {
87    /// Returns the mean of the distribution for each observation.
88    fn mean(&self) -> Array1<f64>;
89
90    /// Returns the variance of the distribution for each observation.
91    fn variance(&self) -> Array1<f64>;
92
93    /// Returns the standard deviation of the distribution for each observation.
94    fn std(&self) -> Array1<f64> {
95        self.variance().mapv(f64::sqrt)
96    }
97
98    /// Evaluates the probability density function at point y for each observation.
99    fn pdf(&self, y: &Array1<f64>) -> Array1<f64>;
100
101    /// Evaluates the log probability density function at point y for each observation.
102    fn logpdf(&self, y: &Array1<f64>) -> Array1<f64> {
103        self.pdf(y).mapv(|p| p.ln())
104    }
105
106    /// Evaluates the cumulative distribution function at point y for each observation.
107    fn cdf(&self, y: &Array1<f64>) -> Array1<f64>;
108
109    /// Evaluates the percent point function (inverse CDF / quantile function).
110    /// Returns the value y such that P(Y <= y) = q.
111    fn ppf(&self, q: &Array1<f64>) -> Array1<f64>;
112
113    /// Generates random samples from the distribution.
114    ///
115    /// # Arguments
116    /// * `n_samples` - Number of samples to generate per observation
117    ///
118    /// # Returns
119    /// Array of shape (n_samples, n_observations)
120    fn sample(&self, n_samples: usize) -> Array2<f64>;
121
122    /// Returns the confidence interval for each observation.
123    ///
124    /// # Arguments
125    /// * `alpha` - Significance level (e.g., 0.05 for 95% CI)
126    ///
127    /// # Returns
128    /// Tuple of (lower bounds, upper bounds)
129    fn interval(&self, alpha: f64) -> (Array1<f64>, Array1<f64>) {
130        let lower_q = Array1::from_elem(self.mean().len(), alpha / 2.0);
131        let upper_q = Array1::from_elem(self.mean().len(), 1.0 - alpha / 2.0);
132        (self.ppf(&lower_q), self.ppf(&upper_q))
133    }
134
135    /// Returns the survival function (1 - CDF) at point y for each observation.
136    fn sf(&self, y: &Array1<f64>) -> Array1<f64> {
137        1.0 - self.cdf(y)
138    }
139
140    /// Returns the median of the distribution for each observation.
141    fn median(&self) -> Array1<f64> {
142        let q = Array1::from_elem(self.mean().len(), 0.5);
143        self.ppf(&q)
144    }
145
146    /// Returns the mode of the distribution for each observation (if well-defined).
147    /// Default implementation returns the mean; override for distributions where mode != mean.
148    fn mode(&self) -> Array1<f64> {
149        self.mean()
150    }
151}
152
153/// A trait for multivariate distributions with additional methods.
154pub trait MultivariateDistributionMethods: Distribution {
155    /// Returns the mean vector of the distribution for each observation.
156    /// Shape: (n_observations, n_dimensions)
157    fn mean(&self) -> Array2<f64>;
158
159    /// Returns the covariance matrix of the distribution for each observation.
160    /// Shape: (n_observations, n_dimensions, n_dimensions)
161    fn covariance(&self) -> ndarray::Array3<f64>;
162
163    /// Generates random samples from the multivariate distribution.
164    ///
165    /// # Arguments
166    /// * `n_samples` - Number of samples to generate per observation
167    ///
168    /// # Returns
169    /// Array of shape (n_samples, n_observations, n_dimensions)
170    fn sample(&self, n_samples: usize) -> ndarray::Array3<f64>;
171}