Skip to main content

entrenar/train/loss/
cross_entropy.rs

1//! Cross Entropy Loss for classification
2
3use crate::Tensor;
4use ndarray::Array1;
5
6use super::LossFn;
7
8/// Cross Entropy Loss (for classification)
9///
10/// L = -sum(targets * log(softmax(predictions)))
11///
12/// # Example
13///
14/// ```
15/// use entrenar::train::{CrossEntropyLoss, LossFn};
16/// use entrenar::Tensor;
17///
18/// let loss_fn = CrossEntropyLoss;
19/// let logits = Tensor::from_vec(vec![2.0, 1.0, 0.5], true);
20/// let targets = Tensor::from_vec(vec![1.0, 0.0, 0.0], false); // one-hot
21///
22/// let loss = loss_fn.forward(&logits, &targets);
23/// assert!(loss.data()[0] > 0.0);
24/// ```
25pub struct CrossEntropyLoss;
26
27impl CrossEntropyLoss {
28    /// Compute softmax: exp(x_i) / sum(exp(x_j))
29    pub(crate) fn softmax(x: &Array1<f32>) -> Array1<f32> {
30        let max = x.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
31        let exp_x: Array1<f32> = x.mapv(|v| (v - max).exp());
32        let sum: f32 = exp_x.sum();
33        exp_x / sum
34    }
35}
36
37impl LossFn for CrossEntropyLoss {
38    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
39        assert_eq!(
40            predictions.len(),
41            targets.len(),
42            "Predictions and targets must have same length"
43        );
44
45        // Compute softmax
46        let probs = Self::softmax(predictions.data());
47
48        // Compute cross entropy: -sum(targets * log(probs))
49        let ce: f32 = targets
50            .data()
51            .iter()
52            .zip(probs.iter())
53            .map(|(&t, &p)| -t * (p + 1e-10).max(f32::MIN_POSITIVE).ln())
54            .sum();
55
56        // Create loss tensor
57        let mut loss = Tensor::from_vec(vec![ce], true);
58
59        // Set up gradient: d(CE)/d(logits) = probs - targets
60        let grad = &probs - targets.data();
61
62        use crate::autograd::BackwardOp;
63        use std::rc::Rc;
64
65        struct CEBackward {
66            pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
67            grad: Array1<f32>,
68        }
69
70        impl BackwardOp for CEBackward {
71            fn backward(&self) {
72                let mut pred_grad = self.pred_grad_cell.borrow_mut();
73                if let Some(existing) = pred_grad.as_mut() {
74                    *existing = &*existing + &self.grad;
75                } else {
76                    *pred_grad = Some(self.grad.clone());
77                }
78            }
79        }
80
81        if predictions.requires_grad() {
82            loss.set_backward_op(Rc::new(CEBackward {
83                pred_grad_cell: predictions.grad_cell(),
84                grad,
85            }));
86        }
87
88        loss
89    }
90
91    fn name(&self) -> &'static str {
92        "CrossEntropy"
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use approx::assert_relative_eq;
100
101    /// Reference softmax (f64 precision) for accuracy verification
102    fn reference_softmax_f64(logits: &[f32]) -> Vec<f64> {
103        let logits_f64: Vec<f64> = logits.iter().map(|&x| f64::from(x)).collect();
104        let max = logits_f64.iter().copied().fold(f64::NEG_INFINITY, f64::max);
105        let exp_vals: Vec<f64> = logits_f64.iter().map(|&x| (x - max).exp()).collect();
106        let sum: f64 = exp_vals.iter().sum();
107        exp_vals.iter().map(|&e| e / sum).collect()
108    }
109
110    /// Reference cross-entropy (f64 precision) for accuracy verification
111    fn reference_cross_entropy_f64(logits: &[f32], target_idx: usize) -> f64 {
112        let probs = reference_softmax_f64(logits);
113        -probs[target_idx].max(1e-30).ln()
114    }
115
116    #[test]
117    fn test_cross_entropy_accuracy_matches_reference() {
118        let logits = vec![2.0_f32, 1.0, 0.5];
119        let target_idx = 0;
120        let reference = reference_cross_entropy_f64(&logits, target_idx) as f32;
121        let ce = CrossEntropyLoss;
122        let pred = Tensor::from_vec(logits, false);
123        let mut one_hot = vec![0.0_f32; 3];
124        one_hot[target_idx] = 1.0;
125        let tgt = Tensor::from_vec(one_hot, false);
126        let loss = ce.forward(&pred, &tgt);
127        let actual = loss.data()[0];
128        let diff = (actual - reference).abs();
129        assert!(diff < 1e-5, "CE accuracy: actual={actual}, ref={reference}, diff={diff}");
130    }
131
132    #[test]
133    fn test_cross_entropy_accuracy_10class() {
134        let logits: Vec<f32> = (0..10).map(|i| (i as f32 - 5.0) * 0.5).collect();
135        for target_idx in 0..10 {
136            let reference = reference_cross_entropy_f64(&logits, target_idx) as f32;
137            let ce = CrossEntropyLoss;
138            let pred = Tensor::from_vec(logits.clone(), false);
139            let mut one_hot = vec![0.0_f32; 10];
140            one_hot[target_idx] = 1.0;
141            let tgt = Tensor::from_vec(one_hot, false);
142            let loss = ce.forward(&pred, &tgt);
143            let actual = loss.data()[0];
144            let diff = (actual - reference).abs();
145            assert!(diff < 1e-4, "CE accuracy 10-class[{target_idx}]: diff={diff}");
146        }
147    }
148
149    #[test]
150    fn test_cross_entropy_loss() {
151        let loss_fn = CrossEntropyLoss;
152        let logits = Tensor::from_vec(vec![2.0, 1.0, 0.5], true);
153        let targets = Tensor::from_vec(vec![1.0, 0.0, 0.0], false);
154
155        let loss = loss_fn.forward(&logits, &targets);
156
157        // Loss should be positive
158        assert!(loss.data()[0] > 0.0);
159        assert!(loss.data()[0].is_finite());
160    }
161
162    #[test]
163    fn test_softmax() {
164        let x = Array1::from(vec![1.0, 2.0, 3.0]);
165        let probs = CrossEntropyLoss::softmax(&x);
166
167        // Probabilities should sum to 1
168        let sum: f32 = probs.sum();
169        assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
170
171        // All probabilities should be in [0, 1]
172        for &p in &probs {
173            assert!((0.0..=1.0).contains(&p));
174        }
175    }
176
177    #[test]
178    fn test_cross_entropy_gradient() {
179        let loss_fn = CrossEntropyLoss;
180        let logits = Tensor::from_vec(vec![2.0, 1.0, 0.5], true);
181        let targets = Tensor::from_vec(vec![1.0, 0.0, 0.0], false);
182
183        let loss = loss_fn.forward(&logits, &targets);
184
185        if let Some(backward_op) = loss.backward_op() {
186            backward_op.backward();
187        }
188
189        let grad = logits.grad().expect("gradient should be available");
190        // Gradient should exist and be finite
191        for g in &grad {
192            assert!(g.is_finite());
193        }
194        // For CE with target at index 0, grad[0] should be negative
195        // (pred - target where target=1)
196        assert!(grad[0] < 0.0);
197    }
198
199    #[test]
200    #[should_panic(expected = "must have same length")]
201    fn test_cross_entropy_mismatched_lengths() {
202        let loss_fn = CrossEntropyLoss;
203        let pred = Tensor::from_vec(vec![1.0, 2.0], true);
204        let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
205        loss_fn.forward(&pred, &target);
206    }
207
208    #[test]
209    fn test_cross_entropy_no_grad() {
210        let loss_fn = CrossEntropyLoss;
211        let pred = Tensor::from_vec(vec![2.0, 1.0], false);
212        let target = Tensor::from_vec(vec![1.0, 0.0], false);
213        let loss = loss_fn.forward(&pred, &target);
214        assert!(loss.data()[0] > 0.0);
215    }
216
217    #[test]
218    fn test_softmax_numerical_stability() {
219        // Large values that could cause overflow without max subtraction
220        let x = Array1::from(vec![1000.0, 1001.0, 1002.0]);
221        let probs = CrossEntropyLoss::softmax(&x);
222
223        // Should still sum to 1.0
224        let sum: f32 = probs.sum();
225        assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
226
227        // All values should be valid
228        for &p in &probs {
229            assert!(p.is_finite());
230            assert!(p >= 0.0);
231        }
232    }
233
234    #[test]
235    fn test_gradient_accumulation_cross_entropy() {
236        let logits = Tensor::from_vec(vec![2.0, 1.0], true);
237        let targets = Tensor::from_vec(vec![1.0, 0.0], false);
238
239        let loss1 = CrossEntropyLoss.forward(&logits, &targets);
240        if let Some(op) = loss1.backward_op() {
241            op.backward();
242        }
243
244        let loss2 = CrossEntropyLoss.forward(&logits, &targets);
245        if let Some(op) = loss2.backward_op() {
246            op.backward();
247        }
248
249        let grad = logits.grad().expect("gradient should be available");
250        assert!(grad[0].is_finite());
251        assert!(grad[1].is_finite());
252    }
253}
254
255// =========================================================================
256// FALSIFY-CE: cross-entropy-kernel-v1.yaml contract (entrenar CrossEntropyLoss)
257//
258// Five-Whys (PMAT-354):
259//   Why 1: entrenar had 7 CE tests but zero FALSIFY-CE-* contract tests
260//   Why 2: existing tests verify API shape, not mathematical invariants
261//   Why 3: no mapping from cross-entropy-kernel-v1.yaml claims to test names
262//   Why 4: entrenar CE predates the provable-contracts YAML convention
263//   Why 5: CE was "obviously correct" (standard softmax + NLL)
264//
265// References:
266//   - provable-contracts/contracts/cross-entropy-kernel-v1.yaml
267//   - Shannon (1948) "A Mathematical Theory of Communication"
268// =========================================================================
269#[cfg(test)]
270mod ce_contract_tests {
271    use super::*;
272    use ndarray::Array1;
273
274    /// Helper: create one-hot targets
275    fn one_hot(idx: usize, len: usize) -> Vec<f32> {
276        let mut v = vec![0.0; len];
277        v[idx] = 1.0;
278        v
279    }
280
281    /// FALSIFY-CE-001: Non-negativity — CE(targets, logits) >= 0
282    ///
283    /// Contract: Cross-entropy of valid probability targets is always non-negative.
284    #[test]
285    fn falsify_ce_001_non_negativity() {
286        let ce = CrossEntropyLoss;
287
288        let cases: Vec<(Vec<f32>, Vec<f32>)> = vec![
289            (vec![2.0, 1.0, 0.5], one_hot(0, 3)),
290            (vec![0.0, 0.0, 0.0], one_hot(1, 3)),
291            (vec![-10.0, 10.0], one_hot(0, 2)),
292            (vec![100.0, -100.0, 0.0], one_hot(2, 3)),
293            (vec![0.1, 0.2, 0.3, 0.4], one_hot(3, 4)),
294        ];
295
296        for (i, (logits, targets)) in cases.iter().enumerate() {
297            let pred = Tensor::from_vec(logits.clone(), false);
298            let tgt = Tensor::from_vec(targets.clone(), false);
299            let loss = ce.forward(&pred, &tgt);
300            let val = loss.data()[0];
301            assert!(val >= -1e-6, "FALSIFIED CE-001 case {i}: CE = {val} < 0");
302        }
303    }
304
305    /// FALSIFY-CE-002: Log-softmax upper bound — log_softmax(x)_i <= 0
306    ///
307    /// Contract: All log-softmax values must be non-positive.
308    #[test]
309    fn falsify_ce_002_log_softmax_upper_bound() {
310        let cases: Vec<Vec<f32>> = vec![
311            vec![1.0, 2.0, 3.0],
312            vec![0.0, 0.0, 0.0],
313            vec![-100.0, 100.0],
314            vec![1000.0, 1001.0, 999.0],
315            vec![-500.0, -500.0, -500.0, -500.0],
316        ];
317
318        for (i, logits) in cases.iter().enumerate() {
319            let x = Array1::from(logits.clone());
320            let probs = CrossEntropyLoss::softmax(&x);
321            for (j, &p) in probs.iter().enumerate() {
322                let log_p = p.ln();
323                assert!(log_p <= 1e-6, "FALSIFIED CE-002 case {i}[{j}]: log_softmax = {log_p} > 0");
324            }
325        }
326    }
327
328    /// FALSIFY-CE-003: Numerical stability — no NaN/Inf for finite logits
329    ///
330    /// Contract: CE must produce finite output for all finite inputs.
331    #[test]
332    fn falsify_ce_003_numerical_stability() {
333        let ce = CrossEntropyLoss;
334
335        let extreme_cases: Vec<(Vec<f32>, Vec<f32>)> = vec![
336            (vec![500.0, -500.0, 0.0], one_hot(0, 3)),
337            (vec![-1000.0, -1000.0, -1000.0], one_hot(1, 3)),
338            (vec![88.0, 88.0], one_hot(0, 2)), // near f32 exp overflow
339            (vec![-88.0, -88.0, -88.0], one_hot(2, 3)), // near f32 exp underflow
340        ];
341
342        for (i, (logits, targets)) in extreme_cases.iter().enumerate() {
343            let pred = Tensor::from_vec(logits.clone(), false);
344            let tgt = Tensor::from_vec(targets.clone(), false);
345            let loss = ce.forward(&pred, &tgt);
346            let val = loss.data()[0];
347            assert!(val.is_finite(), "FALSIFIED CE-003 case {i}: CE = {val} (not finite)");
348        }
349    }
350
351    /// FALSIFY-CE-006: Perfect prediction — CE approaches 0 as dominant logit grows
352    ///
353    /// Contract: CE(one_hot(k), logits) → 0 when logits_k >> logits_j for j≠k
354    #[test]
355    fn falsify_ce_006_perfect_prediction() {
356        let ce = CrossEntropyLoss;
357
358        for &target in &[0, 1, 2] {
359            let mut logits = vec![-50.0; 3];
360            logits[target] = 50.0;
361            let pred = Tensor::from_vec(logits, false);
362            let tgt = Tensor::from_vec(one_hot(target, 3), false);
363            let loss = ce.forward(&pred, &tgt);
364            let val = loss.data()[0];
365            assert!(
366                val < 1e-3,
367                "FALSIFIED CE-006: CE(one_hot({target}), dominant) = {val}, expected ≈ 0"
368            );
369        }
370    }
371
372    /// FALSIFY-CE-001b: Uniform logits — CE = log(C)
373    ///
374    /// Contract: When all logits are equal, softmax is uniform 1/C,
375    /// so CE = -log(1/C) = log(C).
376    #[test]
377    fn falsify_ce_001b_uniform_logits() {
378        let ce = CrossEntropyLoss;
379
380        for &nc in &[2_usize, 3, 5, 10] {
381            let logits = vec![1.0; nc];
382            let targets = one_hot(0, nc);
383            let pred = Tensor::from_vec(logits, false);
384            let tgt = Tensor::from_vec(targets, false);
385            let loss = ce.forward(&pred, &tgt);
386            let val = loss.data()[0];
387            let expected = (nc as f32).ln();
388            let diff = (val - expected).abs();
389            assert!(
390                diff < 1e-4,
391                "FALSIFIED CE-001b: CE(uniform, C={nc}) = {val}, expected log({nc}) = {expected}"
392            );
393        }
394    }
395
396    mod ce_proptest_falsify {
397        use super::*;
398        use proptest::prelude::*;
399
400        // FALSIFY-CE-001-prop: Non-negativity for random one-hot targets
401        proptest! {
402            #![proptest_config(ProptestConfig::with_cases(200))]
403
404            #[test]
405            fn falsify_ce_001_prop_non_negativity(
406                nc in 2..=10usize,
407                target in 0..10usize,
408                seed in 0..1000u32,
409            ) {
410                let target = target % nc;
411                let logits: Vec<f32> = (0..nc)
412                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin() * 10.0)
413                    .collect();
414
415                let ce = CrossEntropyLoss;
416                let pred = Tensor::from_vec(logits, false);
417                let tgt = Tensor::from_vec(one_hot(target, nc), false);
418                let loss = ce.forward(&pred, &tgt);
419                let val = loss.data()[0];
420                prop_assert!(
421                    val >= -1e-6,
422                    "FALSIFIED CE-001-prop: CE = {} < 0 (nc={}, target={})",
423                    val, nc, target
424                );
425            }
426        }
427
428        // FALSIFY-CE-003-prop: Numerical stability for random logits
429        proptest! {
430            #![proptest_config(ProptestConfig::with_cases(200))]
431
432            #[test]
433            fn falsify_ce_003_prop_finite_output(
434                nc in 2..=10usize,
435                target in 0..10usize,
436                scale in 0.1f32..100.0,
437                seed in 0..1000u32,
438            ) {
439                let target = target % nc;
440                let logits: Vec<f32> = (0..nc)
441                    .map(|i| ((i as f32 + seed as f32) * 0.73).cos() * scale)
442                    .collect();
443
444                let ce = CrossEntropyLoss;
445                let pred = Tensor::from_vec(logits, false);
446                let tgt = Tensor::from_vec(one_hot(target, nc), false);
447                let loss = ce.forward(&pred, &tgt);
448                let val = loss.data()[0];
449                prop_assert!(
450                    val.is_finite(),
451                    "FALSIFIED CE-003-prop: CE = {} (not finite) for nc={}, scale={}",
452                    val, nc, scale
453                );
454            }
455        }
456
457        // FALSIFY-CE-002-prop: Log-softmax upper bound for random inputs
458        proptest! {
459            #![proptest_config(ProptestConfig::with_cases(200))]
460
461            #[test]
462            fn falsify_ce_002_prop_log_softmax_bound(
463                nc in 2..=10usize,
464                scale in 0.1f32..100.0,
465                seed in 0..1000u32,
466            ) {
467                let logits: Vec<f32> = (0..nc)
468                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin() * scale)
469                    .collect();
470                let x = Array1::from(logits);
471                let probs = CrossEntropyLoss::softmax(&x);
472                for (j, &p) in probs.iter().enumerate() {
473                    prop_assert!(
474                        (0.0..=1.0 + 1e-6).contains(&p),
475                        "FALSIFIED CE-002-prop: softmax[{}] = {} outside [0,1]",
476                        j, p
477                    );
478                    let log_p = p.ln();
479                    prop_assert!(
480                        log_p <= 1e-6,
481                        "FALSIFIED CE-002-prop: log(softmax[{}]) = {} > 0",
482                        j, log_p
483                    );
484                }
485            }
486        }
487    }
488}