ngboost_rs/dist/
categorical.rs

1use crate::dist::{ClassificationDistn, Distribution};
2use crate::scores::{CRPScore, LogScore, Scorable};
3use ndarray::{Array1, Array2, Array3, Axis};
4
5/// Minimum probability value to avoid log(0) and division issues.
6const PROB_EPS: f64 = 1e-10;
7/// Maximum probability value (1 - PROB_EPS) to maintain numerical stability.
8const PROB_MAX: f64 = 1.0 - PROB_EPS;
9
10/// Softmax function applied along axis 0 with numerical stability improvements.
11/// Returns probabilities clamped to [PROB_EPS, PROB_MAX] to avoid log(0).
12fn softmax_axis0(logits: &Array2<f64>) -> Array2<f64> {
13    let max_vals = logits.fold_axis(Axis(0), f64::NEG_INFINITY, |&a, &b| a.max(b));
14    let shifted = logits - &max_vals;
15    let exp_vals = shifted.mapv(f64::exp);
16    let sum_exp = exp_vals.sum_axis(Axis(0));
17    let probs = exp_vals / &sum_exp;
18    // Clamp probabilities to avoid numerical issues in log computations
19    probs.mapv(|p| p.clamp(PROB_EPS, PROB_MAX))
20}
21
22/// A K-class Categorical distribution for classification.
23///
24/// This is a generic struct that can represent any K-class categorical.
25/// The number of parameters is K-1 (the 0th class logit is fixed at 0).
26#[derive(Debug, Clone)]
27pub struct Categorical<const K: usize> {
28    /// The logits (K x N), where K is the number of classes and N is the number of observations.
29    pub logits: Array2<f64>,
30    /// The probabilities (K x N), computed via softmax.
31    pub probs: Array2<f64>,
32    /// Number of observations.
33    n_obs: usize,
34    /// The parameters of the distribution (K-1 x N).
35    _params: Array2<f64>,
36}
37
38impl<const K: usize> Distribution for Categorical<K> {
39    fn from_params(params: &Array2<f64>) -> Self {
40        // params is (N, K-1) - each row is one observation's parameters
41        let n_obs = params.nrows();
42
43        // Build logits: (K, N) with first row as zeros
44        let mut logits = Array2::zeros((K, n_obs));
45        for i in 0..n_obs {
46            for j in 0..(K - 1) {
47                logits[[j + 1, i]] = params[[i, j]];
48            }
49        }
50
51        let probs = softmax_axis0(&logits);
52
53        Categorical {
54            logits,
55            probs,
56            n_obs,
57            _params: params.clone(),
58        }
59    }
60
61    fn fit(y: &Array1<f64>) -> Array1<f64> {
62        // Count occurrences of each class
63        let n = y.len();
64        let mut counts = vec![0usize; K];
65        for &y_i in y.iter() {
66            let class = y_i as usize;
67            if class < K {
68                counts[class] += 1;
69            }
70        }
71
72        // Convert to probabilities with smoothing to avoid log(0)
73        let probs: Vec<f64> = counts
74            .iter()
75            .map(|&c| (c as f64 / n as f64).max(PROB_EPS))
76            .collect();
77
78        // Return logits relative to class 0: log(p_k) - log(p_0)
79        let log_p0 = probs[0].ln();
80        let mut init_params = Array1::zeros(K - 1);
81        for k in 1..K {
82            init_params[k - 1] = probs[k].ln() - log_p0;
83        }
84
85        init_params
86    }
87
88    fn n_params(&self) -> usize {
89        K - 1
90    }
91
92    fn predict(&self) -> Array1<f64> {
93        // Return the most likely class for each observation
94        let mut predictions = Array1::zeros(self.n_obs);
95        for i in 0..self.n_obs {
96            let mut max_prob = f64::NEG_INFINITY;
97            let mut max_class = 0;
98            for k in 0..K {
99                if self.probs[[k, i]] > max_prob {
100                    max_prob = self.probs[[k, i]];
101                    max_class = k;
102                }
103            }
104            predictions[i] = max_class as f64;
105        }
106        predictions
107    }
108
109    fn params(&self) -> &Array2<f64> {
110        &self._params
111    }
112}
113
114impl<const K: usize> ClassificationDistn for Categorical<K> {
115    fn class_probs(&self) -> Array2<f64> {
116        // Return (N, K) probabilities
117        self.probs.t().to_owned()
118    }
119}
120
121impl<const K: usize> Scorable<LogScore> for Categorical<K> {
122    fn score(&self, y: &Array1<f64>) -> Array1<f64> {
123        // -log(p[y_i]) for each observation
124        let mut scores = Array1::zeros(y.len());
125        for (i, &y_i) in y.iter().enumerate() {
126            let class = y_i as usize;
127            scores[i] = -self.probs[[class, i]].max(PROB_EPS).ln();
128        }
129        scores
130    }
131
132    fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
133        // Gradient: probs - one_hot(y), but only for classes 1..K (not class 0)
134        let n_obs = y.len();
135        let mut d_params = Array2::zeros((n_obs, K - 1));
136
137        for i in 0..n_obs {
138            let y_i = y[i] as usize;
139            for k in 1..K {
140                // d/d(logit_k) = p_k - 1{y == k}
141                let indicator = if y_i == k { 1.0 } else { 0.0 };
142                d_params[[i, k - 1]] = self.probs[[k, i]] - indicator;
143            }
144        }
145
146        d_params
147    }
148
149    fn metric(&self) -> Array3<f64> {
150        // Fisher Information Matrix for categorical
151        // FI[j,k] = -p_j * p_k for j != k
152        // FI[j,j] = p_j * (1 - p_j) = p_j - p_j^2
153        let n_obs = self.n_obs;
154        let n_params = K - 1;
155        let mut fi = Array3::zeros((n_obs, n_params, n_params));
156
157        for i in 0..n_obs {
158            for j in 0..n_params {
159                let p_j = self.probs[[j + 1, i]];
160                for k in 0..n_params {
161                    let p_k = self.probs[[k + 1, i]];
162                    if j == k {
163                        fi[[i, j, k]] = p_j * (1.0 - p_j);
164                    } else {
165                        fi[[i, j, k]] = -p_j * p_k;
166                    }
167                }
168            }
169        }
170
171        fi
172    }
173}
174
175impl<const K: usize> Scorable<CRPScore> for Categorical<K> {
176    fn score(&self, y: &Array1<f64>) -> Array1<f64> {
177        // For categorical distributions, the CRPS equivalent is the Brier score:
178        // BS = Σ_k (p_k - 1{y == k})^2
179        // This is the sum of squared errors between predicted probs and one-hot encoding
180        //
181        // Note: For ordinal categories, one would use the Ranked Probability Score (RPS)
182        // which is CRPS applied to the cumulative distribution. Here we use Brier score
183        // since categorical typically implies unordered classes.
184
185        let mut scores = Array1::zeros(y.len());
186        for (i, &y_i) in y.iter().enumerate() {
187            let true_class = y_i as usize;
188            let mut brier = 0.0;
189            for k in 0..K {
190                let p_k = self.probs[[k, i]];
191                let indicator = if k == true_class { 1.0 } else { 0.0 };
192                brier += (p_k - indicator).powi(2);
193            }
194            scores[i] = brier;
195        }
196        scores
197    }
198
199    fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
200        // Gradient of Brier score w.r.t. logits (parameters)
201        // BS = Σ_k (p_k - 1{y == k})^2
202        // d(BS)/d(logit_j) = Σ_k 2*(p_k - 1{y==k}) * d(p_k)/d(logit_j)
203        //
204        // For softmax: d(p_k)/d(logit_j) = p_k * (1{k==j} - p_j)
205        //
206        // d(BS)/d(logit_j) = 2 * Σ_k (p_k - 1{y==k}) * p_k * (1{k==j} - p_j)
207        //                  = 2 * [(p_j - 1{y==j}) * p_j * (1 - p_j)
208        //                        - p_j * Σ_{k≠j} (p_k - 1{y==k}) * p_k]
209        //                  = 2 * p_j * [(p_j - 1{y==j}) * (1 - p_j)
210        //                              - Σ_{k≠j} (p_k - 1{y==k}) * p_k]
211        //
212        // Simplified: d(BS)/d(logit_j) = 2 * p_j * [Σ_k (p_k - 1{y==k}) * (1{k==j} - p_k)]
213
214        let n_obs = y.len();
215        let mut d_params = Array2::zeros((n_obs, K - 1));
216
217        for i in 0..n_obs {
218            let y_i = y[i] as usize;
219
220            // Compute residuals: r_k = p_k - 1{y == k}
221            let mut residuals = vec![0.0; K];
222            for k in 0..K {
223                let indicator = if k == y_i { 1.0 } else { 0.0 };
224                residuals[k] = self.probs[[k, i]] - indicator;
225            }
226
227            // For each parameter (logit_j for j = 1..K)
228            for j in 1..K {
229                let p_j = self.probs[[j, i]];
230
231                // d(BS)/d(logit_j) = 2 * Σ_k r_k * p_k * (1{k==j} - p_j)
232                let mut grad = 0.0;
233                for k in 0..K {
234                    let p_k = self.probs[[k, i]];
235                    let delta_kj = if k == j { 1.0 } else { 0.0 };
236                    grad += residuals[k] * p_k * (delta_kj - p_j);
237                }
238                d_params[[i, j - 1]] = 2.0 * grad;
239            }
240        }
241
242        d_params
243    }
244
245    fn metric(&self) -> Array3<f64> {
246        // Metric for Brier score
247        // We use the Fisher information matrix as an approximation
248        // This is similar to the LogScore metric but scaled
249        let n_obs = self.n_obs;
250        let n_params = K - 1;
251        let mut fi = Array3::zeros((n_obs, n_params, n_params));
252
253        for i in 0..n_obs {
254            for j in 0..n_params {
255                let p_j = self.probs[[j + 1, i]];
256                for k in 0..n_params {
257                    let p_k = self.probs[[k + 1, i]];
258                    if j == k {
259                        // Diagonal: scaled by 4 for Brier score
260                        fi[[i, j, k]] = 4.0 * p_j * (1.0 - p_j);
261                    } else {
262                        // Off-diagonal
263                        fi[[i, j, k]] = -4.0 * p_j * p_k;
264                    }
265                }
266            }
267        }
268
269        fi
270    }
271}
272
273/// Type alias for binary classification (Bernoulli distribution).
274pub type Bernoulli = Categorical<2>;
275
276/// Type alias for 3-class classification.
277pub type Categorical3 = Categorical<3>;
278
279/// Type alias for 4-class classification.
280pub type Categorical4 = Categorical<4>;
281
282/// Type alias for 5-class classification.
283pub type Categorical5 = Categorical<5>;
284
285/// Type alias for 10-class classification (e.g., digit recognition).
286pub type Categorical10 = Categorical<10>;
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use approx::assert_relative_eq;
292
293    #[test]
294    fn test_categorical_crpscore_bernoulli() {
295        // Binary classification with Bernoulli
296        // params: logit for class 1 (class 0 logit is fixed at 0)
297        let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap(); // p = [0.5, 0.5]
298        let dist = Bernoulli::from_params(&params);
299
300        // With equal probs [0.5, 0.5], Brier score for y=0:
301        // (0.5 - 1)^2 + (0.5 - 0)^2 = 0.25 + 0.25 = 0.5
302        let y = Array1::from_vec(vec![0.0]);
303        let score = Scorable::<CRPScore>::score(&dist, &y);
304        assert_relative_eq!(score[0], 0.5, epsilon = 1e-6);
305
306        // Same for y=1
307        let y = Array1::from_vec(vec![1.0]);
308        let score = Scorable::<CRPScore>::score(&dist, &y);
309        assert_relative_eq!(score[0], 0.5, epsilon = 1e-6);
310    }
311
312    #[test]
313    fn test_categorical_crpscore_perfect_prediction() {
314        // Perfect prediction should have Brier score of 0
315        // Use large logit to get probability close to 1
316        let params = Array2::from_shape_vec((1, 1), vec![10.0]).unwrap(); // p ≈ [0, 1]
317        let dist = Bernoulli::from_params(&params);
318
319        // Predicting class 1 when true class is 1
320        let y = Array1::from_vec(vec![1.0]);
321        let score = Scorable::<CRPScore>::score(&dist, &y);
322        assert!(score[0] < 0.01); // Should be very small
323    }
324
325    #[test]
326    fn test_categorical_crpscore_worst_prediction() {
327        // Worst prediction (confident wrong answer)
328        let params = Array2::from_shape_vec((1, 1), vec![10.0]).unwrap(); // p ≈ [0, 1]
329        let dist = Bernoulli::from_params(&params);
330
331        // Predicting class 1 when true class is 0
332        let y = Array1::from_vec(vec![0.0]);
333        let score = Scorable::<CRPScore>::score(&dist, &y);
334        // Brier score ≈ (0 - 1)^2 + (1 - 0)^2 = 2
335        assert!(score[0] > 1.9);
336    }
337
338    #[test]
339    fn test_categorical_crpscore_multiclass() {
340        // 3-class classification
341        let params = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).unwrap(); // equal probs
342        let dist = Categorical3::from_params(&params);
343
344        let y = Array1::from_vec(vec![0.0]);
345        let score = Scorable::<CRPScore>::score(&dist, &y);
346
347        // With equal probs [1/3, 1/3, 1/3], Brier for y=0:
348        // (1/3 - 1)^2 + (1/3 - 0)^2 + (1/3 - 0)^2 = 4/9 + 1/9 + 1/9 = 6/9 = 2/3
349        assert_relative_eq!(score[0], 2.0 / 3.0, epsilon = 1e-6);
350    }
351
352    #[test]
353    fn test_categorical_crpscore_d_score() {
354        let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
355        let dist = Bernoulli::from_params(&params);
356
357        let y = Array1::from_vec(vec![1.0]);
358        let d_score = Scorable::<CRPScore>::d_score(&dist, &y);
359
360        // Gradient should be finite
361        assert!(d_score[[0, 0]].is_finite());
362    }
363
364    #[test]
365    fn test_categorical_crpscore_metric() {
366        let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
367        let dist = Bernoulli::from_params(&params);
368
369        let metric = Scorable::<CRPScore>::metric(&dist);
370
371        // Metric should be positive (for diagonal)
372        assert!(metric[[0, 0, 0]] > 0.0);
373    }
374
375    #[test]
376    fn test_categorical_logscore() {
377        // Basic LogScore test
378        let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap(); // p = [0.5, 0.5]
379        let dist = Bernoulli::from_params(&params);
380
381        let y = Array1::from_vec(vec![0.0]);
382        let score = Scorable::<LogScore>::score(&dist, &y);
383
384        // -log(0.5) ≈ 0.693
385        assert_relative_eq!(score[0], 0.5_f64.ln().abs(), epsilon = 1e-6);
386    }
387}