1use crate::dist::{ClassificationDistn, Distribution};
2use crate::scores::{CRPScore, LogScore, Scorable};
3use ndarray::{Array1, Array2, Array3, Axis};
4
5const PROB_EPS: f64 = 1e-10;
7const PROB_MAX: f64 = 1.0 - PROB_EPS;
9
10fn softmax_axis0(logits: &Array2<f64>) -> Array2<f64> {
13 let max_vals = logits.fold_axis(Axis(0), f64::NEG_INFINITY, |&a, &b| a.max(b));
14 let shifted = logits - &max_vals;
15 let exp_vals = shifted.mapv(f64::exp);
16 let sum_exp = exp_vals.sum_axis(Axis(0));
17 let probs = exp_vals / &sum_exp;
18 probs.mapv(|p| p.clamp(PROB_EPS, PROB_MAX))
20}
21
22#[derive(Debug, Clone)]
27pub struct Categorical<const K: usize> {
28 pub logits: Array2<f64>,
30 pub probs: Array2<f64>,
32 n_obs: usize,
34 _params: Array2<f64>,
36}
37
38impl<const K: usize> Distribution for Categorical<K> {
39 fn from_params(params: &Array2<f64>) -> Self {
40 let n_obs = params.nrows();
42
43 let mut logits = Array2::zeros((K, n_obs));
45 for i in 0..n_obs {
46 for j in 0..(K - 1) {
47 logits[[j + 1, i]] = params[[i, j]];
48 }
49 }
50
51 let probs = softmax_axis0(&logits);
52
53 Categorical {
54 logits,
55 probs,
56 n_obs,
57 _params: params.clone(),
58 }
59 }
60
61 fn fit(y: &Array1<f64>) -> Array1<f64> {
62 let n = y.len();
64 let mut counts = vec![0usize; K];
65 for &y_i in y.iter() {
66 let class = y_i as usize;
67 if class < K {
68 counts[class] += 1;
69 }
70 }
71
72 let probs: Vec<f64> = counts
74 .iter()
75 .map(|&c| (c as f64 / n as f64).max(PROB_EPS))
76 .collect();
77
78 let log_p0 = probs[0].ln();
80 let mut init_params = Array1::zeros(K - 1);
81 for k in 1..K {
82 init_params[k - 1] = probs[k].ln() - log_p0;
83 }
84
85 init_params
86 }
87
88 fn n_params(&self) -> usize {
89 K - 1
90 }
91
92 fn predict(&self) -> Array1<f64> {
93 let mut predictions = Array1::zeros(self.n_obs);
95 for i in 0..self.n_obs {
96 let mut max_prob = f64::NEG_INFINITY;
97 let mut max_class = 0;
98 for k in 0..K {
99 if self.probs[[k, i]] > max_prob {
100 max_prob = self.probs[[k, i]];
101 max_class = k;
102 }
103 }
104 predictions[i] = max_class as f64;
105 }
106 predictions
107 }
108
109 fn params(&self) -> &Array2<f64> {
110 &self._params
111 }
112}
113
114impl<const K: usize> ClassificationDistn for Categorical<K> {
115 fn class_probs(&self) -> Array2<f64> {
116 self.probs.t().to_owned()
118 }
119}
120
121impl<const K: usize> Scorable<LogScore> for Categorical<K> {
122 fn score(&self, y: &Array1<f64>) -> Array1<f64> {
123 let mut scores = Array1::zeros(y.len());
125 for (i, &y_i) in y.iter().enumerate() {
126 let class = y_i as usize;
127 scores[i] = -self.probs[[class, i]].max(PROB_EPS).ln();
128 }
129 scores
130 }
131
132 fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
133 let n_obs = y.len();
135 let mut d_params = Array2::zeros((n_obs, K - 1));
136
137 for i in 0..n_obs {
138 let y_i = y[i] as usize;
139 for k in 1..K {
140 let indicator = if y_i == k { 1.0 } else { 0.0 };
142 d_params[[i, k - 1]] = self.probs[[k, i]] - indicator;
143 }
144 }
145
146 d_params
147 }
148
149 fn metric(&self) -> Array3<f64> {
150 let n_obs = self.n_obs;
154 let n_params = K - 1;
155 let mut fi = Array3::zeros((n_obs, n_params, n_params));
156
157 for i in 0..n_obs {
158 for j in 0..n_params {
159 let p_j = self.probs[[j + 1, i]];
160 for k in 0..n_params {
161 let p_k = self.probs[[k + 1, i]];
162 if j == k {
163 fi[[i, j, k]] = p_j * (1.0 - p_j);
164 } else {
165 fi[[i, j, k]] = -p_j * p_k;
166 }
167 }
168 }
169 }
170
171 fi
172 }
173}
174
175impl<const K: usize> Scorable<CRPScore> for Categorical<K> {
176 fn score(&self, y: &Array1<f64>) -> Array1<f64> {
177 let mut scores = Array1::zeros(y.len());
186 for (i, &y_i) in y.iter().enumerate() {
187 let true_class = y_i as usize;
188 let mut brier = 0.0;
189 for k in 0..K {
190 let p_k = self.probs[[k, i]];
191 let indicator = if k == true_class { 1.0 } else { 0.0 };
192 brier += (p_k - indicator).powi(2);
193 }
194 scores[i] = brier;
195 }
196 scores
197 }
198
199 fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
200 let n_obs = y.len();
215 let mut d_params = Array2::zeros((n_obs, K - 1));
216
217 for i in 0..n_obs {
218 let y_i = y[i] as usize;
219
220 let mut residuals = vec![0.0; K];
222 for k in 0..K {
223 let indicator = if k == y_i { 1.0 } else { 0.0 };
224 residuals[k] = self.probs[[k, i]] - indicator;
225 }
226
227 for j in 1..K {
229 let p_j = self.probs[[j, i]];
230
231 let mut grad = 0.0;
233 for k in 0..K {
234 let p_k = self.probs[[k, i]];
235 let delta_kj = if k == j { 1.0 } else { 0.0 };
236 grad += residuals[k] * p_k * (delta_kj - p_j);
237 }
238 d_params[[i, j - 1]] = 2.0 * grad;
239 }
240 }
241
242 d_params
243 }
244
245 fn metric(&self) -> Array3<f64> {
246 let n_obs = self.n_obs;
250 let n_params = K - 1;
251 let mut fi = Array3::zeros((n_obs, n_params, n_params));
252
253 for i in 0..n_obs {
254 for j in 0..n_params {
255 let p_j = self.probs[[j + 1, i]];
256 for k in 0..n_params {
257 let p_k = self.probs[[k + 1, i]];
258 if j == k {
259 fi[[i, j, k]] = 4.0 * p_j * (1.0 - p_j);
261 } else {
262 fi[[i, j, k]] = -4.0 * p_j * p_k;
264 }
265 }
266 }
267 }
268
269 fi
270 }
271}
272
273pub type Bernoulli = Categorical<2>;
275
276pub type Categorical3 = Categorical<3>;
278
279pub type Categorical4 = Categorical<4>;
281
282pub type Categorical5 = Categorical<5>;
284
285pub type Categorical10 = Categorical<10>;
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use approx::assert_relative_eq;
292
293 #[test]
294 fn test_categorical_crpscore_bernoulli() {
295 let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap(); let dist = Bernoulli::from_params(¶ms);
299
300 let y = Array1::from_vec(vec![0.0]);
303 let score = Scorable::<CRPScore>::score(&dist, &y);
304 assert_relative_eq!(score[0], 0.5, epsilon = 1e-6);
305
306 let y = Array1::from_vec(vec![1.0]);
308 let score = Scorable::<CRPScore>::score(&dist, &y);
309 assert_relative_eq!(score[0], 0.5, epsilon = 1e-6);
310 }
311
312 #[test]
313 fn test_categorical_crpscore_perfect_prediction() {
314 let params = Array2::from_shape_vec((1, 1), vec![10.0]).unwrap(); let dist = Bernoulli::from_params(¶ms);
318
319 let y = Array1::from_vec(vec![1.0]);
321 let score = Scorable::<CRPScore>::score(&dist, &y);
322 assert!(score[0] < 0.01); }
324
325 #[test]
326 fn test_categorical_crpscore_worst_prediction() {
327 let params = Array2::from_shape_vec((1, 1), vec![10.0]).unwrap(); let dist = Bernoulli::from_params(¶ms);
330
331 let y = Array1::from_vec(vec![0.0]);
333 let score = Scorable::<CRPScore>::score(&dist, &y);
334 assert!(score[0] > 1.9);
336 }
337
338 #[test]
339 fn test_categorical_crpscore_multiclass() {
340 let params = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).unwrap(); let dist = Categorical3::from_params(¶ms);
343
344 let y = Array1::from_vec(vec![0.0]);
345 let score = Scorable::<CRPScore>::score(&dist, &y);
346
347 assert_relative_eq!(score[0], 2.0 / 3.0, epsilon = 1e-6);
350 }
351
352 #[test]
353 fn test_categorical_crpscore_d_score() {
354 let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
355 let dist = Bernoulli::from_params(¶ms);
356
357 let y = Array1::from_vec(vec![1.0]);
358 let d_score = Scorable::<CRPScore>::d_score(&dist, &y);
359
360 assert!(d_score[[0, 0]].is_finite());
362 }
363
364 #[test]
365 fn test_categorical_crpscore_metric() {
366 let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap();
367 let dist = Bernoulli::from_params(¶ms);
368
369 let metric = Scorable::<CRPScore>::metric(&dist);
370
371 assert!(metric[[0, 0, 0]] > 0.0);
373 }
374
375 #[test]
376 fn test_categorical_logscore() {
377 let params = Array2::from_shape_vec((1, 1), vec![0.0]).unwrap(); let dist = Bernoulli::from_params(¶ms);
380
381 let y = Array1::from_vec(vec![0.0]);
382 let score = Scorable::<LogScore>::score(&dist, &y);
383
384 assert_relative_eq!(score[0], 0.5_f64.ln().abs(), epsilon = 1e-6);
386 }
387}