1use crate::Tensor;
4use ndarray::Array1;
5
6use super::LossFn;
7
8pub struct CrossEntropyLoss;
26
27impl CrossEntropyLoss {
28 pub(crate) fn softmax(x: &Array1<f32>) -> Array1<f32> {
30 let max = x.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
31 let exp_x: Array1<f32> = x.mapv(|v| (v - max).exp());
32 let sum: f32 = exp_x.sum();
33 exp_x / sum
34 }
35}
36
37impl LossFn for CrossEntropyLoss {
38 fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
39 assert_eq!(
40 predictions.len(),
41 targets.len(),
42 "Predictions and targets must have same length"
43 );
44
45 let probs = Self::softmax(predictions.data());
47
48 let ce: f32 = targets
50 .data()
51 .iter()
52 .zip(probs.iter())
53 .map(|(&t, &p)| -t * (p + 1e-10).max(f32::MIN_POSITIVE).ln())
54 .sum();
55
56 let mut loss = Tensor::from_vec(vec![ce], true);
58
59 let grad = &probs - targets.data();
61
62 use crate::autograd::BackwardOp;
63 use std::rc::Rc;
64
65 struct CEBackward {
66 pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
67 grad: Array1<f32>,
68 }
69
70 impl BackwardOp for CEBackward {
71 fn backward(&self) {
72 let mut pred_grad = self.pred_grad_cell.borrow_mut();
73 if let Some(existing) = pred_grad.as_mut() {
74 *existing = &*existing + &self.grad;
75 } else {
76 *pred_grad = Some(self.grad.clone());
77 }
78 }
79 }
80
81 if predictions.requires_grad() {
82 loss.set_backward_op(Rc::new(CEBackward {
83 pred_grad_cell: predictions.grad_cell(),
84 grad,
85 }));
86 }
87
88 loss
89 }
90
91 fn name(&self) -> &'static str {
92 "CrossEntropy"
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use approx::assert_relative_eq;
100
101 fn reference_softmax_f64(logits: &[f32]) -> Vec<f64> {
103 let logits_f64: Vec<f64> = logits.iter().map(|&x| f64::from(x)).collect();
104 let max = logits_f64.iter().copied().fold(f64::NEG_INFINITY, f64::max);
105 let exp_vals: Vec<f64> = logits_f64.iter().map(|&x| (x - max).exp()).collect();
106 let sum: f64 = exp_vals.iter().sum();
107 exp_vals.iter().map(|&e| e / sum).collect()
108 }
109
110 fn reference_cross_entropy_f64(logits: &[f32], target_idx: usize) -> f64 {
112 let probs = reference_softmax_f64(logits);
113 -probs[target_idx].max(1e-30).ln()
114 }
115
116 #[test]
117 fn test_cross_entropy_accuracy_matches_reference() {
118 let logits = vec![2.0_f32, 1.0, 0.5];
119 let target_idx = 0;
120 let reference = reference_cross_entropy_f64(&logits, target_idx) as f32;
121 let ce = CrossEntropyLoss;
122 let pred = Tensor::from_vec(logits, false);
123 let mut one_hot = vec![0.0_f32; 3];
124 one_hot[target_idx] = 1.0;
125 let tgt = Tensor::from_vec(one_hot, false);
126 let loss = ce.forward(&pred, &tgt);
127 let actual = loss.data()[0];
128 let diff = (actual - reference).abs();
129 assert!(diff < 1e-5, "CE accuracy: actual={actual}, ref={reference}, diff={diff}");
130 }
131
132 #[test]
133 fn test_cross_entropy_accuracy_10class() {
134 let logits: Vec<f32> = (0..10).map(|i| (i as f32 - 5.0) * 0.5).collect();
135 for target_idx in 0..10 {
136 let reference = reference_cross_entropy_f64(&logits, target_idx) as f32;
137 let ce = CrossEntropyLoss;
138 let pred = Tensor::from_vec(logits.clone(), false);
139 let mut one_hot = vec![0.0_f32; 10];
140 one_hot[target_idx] = 1.0;
141 let tgt = Tensor::from_vec(one_hot, false);
142 let loss = ce.forward(&pred, &tgt);
143 let actual = loss.data()[0];
144 let diff = (actual - reference).abs();
145 assert!(diff < 1e-4, "CE accuracy 10-class[{target_idx}]: diff={diff}");
146 }
147 }
148
149 #[test]
150 fn test_cross_entropy_loss() {
151 let loss_fn = CrossEntropyLoss;
152 let logits = Tensor::from_vec(vec![2.0, 1.0, 0.5], true);
153 let targets = Tensor::from_vec(vec![1.0, 0.0, 0.0], false);
154
155 let loss = loss_fn.forward(&logits, &targets);
156
157 assert!(loss.data()[0] > 0.0);
159 assert!(loss.data()[0].is_finite());
160 }
161
162 #[test]
163 fn test_softmax() {
164 let x = Array1::from(vec![1.0, 2.0, 3.0]);
165 let probs = CrossEntropyLoss::softmax(&x);
166
167 let sum: f32 = probs.sum();
169 assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
170
171 for &p in &probs {
173 assert!((0.0..=1.0).contains(&p));
174 }
175 }
176
177 #[test]
178 fn test_cross_entropy_gradient() {
179 let loss_fn = CrossEntropyLoss;
180 let logits = Tensor::from_vec(vec![2.0, 1.0, 0.5], true);
181 let targets = Tensor::from_vec(vec![1.0, 0.0, 0.0], false);
182
183 let loss = loss_fn.forward(&logits, &targets);
184
185 if let Some(backward_op) = loss.backward_op() {
186 backward_op.backward();
187 }
188
189 let grad = logits.grad().expect("gradient should be available");
190 for g in &grad {
192 assert!(g.is_finite());
193 }
194 assert!(grad[0] < 0.0);
197 }
198
199 #[test]
200 #[should_panic(expected = "must have same length")]
201 fn test_cross_entropy_mismatched_lengths() {
202 let loss_fn = CrossEntropyLoss;
203 let pred = Tensor::from_vec(vec![1.0, 2.0], true);
204 let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
205 loss_fn.forward(&pred, &target);
206 }
207
208 #[test]
209 fn test_cross_entropy_no_grad() {
210 let loss_fn = CrossEntropyLoss;
211 let pred = Tensor::from_vec(vec![2.0, 1.0], false);
212 let target = Tensor::from_vec(vec![1.0, 0.0], false);
213 let loss = loss_fn.forward(&pred, &target);
214 assert!(loss.data()[0] > 0.0);
215 }
216
217 #[test]
218 fn test_softmax_numerical_stability() {
219 let x = Array1::from(vec![1000.0, 1001.0, 1002.0]);
221 let probs = CrossEntropyLoss::softmax(&x);
222
223 let sum: f32 = probs.sum();
225 assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
226
227 for &p in &probs {
229 assert!(p.is_finite());
230 assert!(p >= 0.0);
231 }
232 }
233
234 #[test]
235 fn test_gradient_accumulation_cross_entropy() {
236 let logits = Tensor::from_vec(vec![2.0, 1.0], true);
237 let targets = Tensor::from_vec(vec![1.0, 0.0], false);
238
239 let loss1 = CrossEntropyLoss.forward(&logits, &targets);
240 if let Some(op) = loss1.backward_op() {
241 op.backward();
242 }
243
244 let loss2 = CrossEntropyLoss.forward(&logits, &targets);
245 if let Some(op) = loss2.backward_op() {
246 op.backward();
247 }
248
249 let grad = logits.grad().expect("gradient should be available");
250 assert!(grad[0].is_finite());
251 assert!(grad[1].is_finite());
252 }
253}
254
255#[cfg(test)]
270mod ce_contract_tests {
271 use super::*;
272 use ndarray::Array1;
273
274 fn one_hot(idx: usize, len: usize) -> Vec<f32> {
276 let mut v = vec![0.0; len];
277 v[idx] = 1.0;
278 v
279 }
280
281 #[test]
285 fn falsify_ce_001_non_negativity() {
286 let ce = CrossEntropyLoss;
287
288 let cases: Vec<(Vec<f32>, Vec<f32>)> = vec![
289 (vec![2.0, 1.0, 0.5], one_hot(0, 3)),
290 (vec![0.0, 0.0, 0.0], one_hot(1, 3)),
291 (vec![-10.0, 10.0], one_hot(0, 2)),
292 (vec![100.0, -100.0, 0.0], one_hot(2, 3)),
293 (vec![0.1, 0.2, 0.3, 0.4], one_hot(3, 4)),
294 ];
295
296 for (i, (logits, targets)) in cases.iter().enumerate() {
297 let pred = Tensor::from_vec(logits.clone(), false);
298 let tgt = Tensor::from_vec(targets.clone(), false);
299 let loss = ce.forward(&pred, &tgt);
300 let val = loss.data()[0];
301 assert!(val >= -1e-6, "FALSIFIED CE-001 case {i}: CE = {val} < 0");
302 }
303 }
304
305 #[test]
309 fn falsify_ce_002_log_softmax_upper_bound() {
310 let cases: Vec<Vec<f32>> = vec![
311 vec![1.0, 2.0, 3.0],
312 vec![0.0, 0.0, 0.0],
313 vec![-100.0, 100.0],
314 vec![1000.0, 1001.0, 999.0],
315 vec![-500.0, -500.0, -500.0, -500.0],
316 ];
317
318 for (i, logits) in cases.iter().enumerate() {
319 let x = Array1::from(logits.clone());
320 let probs = CrossEntropyLoss::softmax(&x);
321 for (j, &p) in probs.iter().enumerate() {
322 let log_p = p.ln();
323 assert!(log_p <= 1e-6, "FALSIFIED CE-002 case {i}[{j}]: log_softmax = {log_p} > 0");
324 }
325 }
326 }
327
328 #[test]
332 fn falsify_ce_003_numerical_stability() {
333 let ce = CrossEntropyLoss;
334
335 let extreme_cases: Vec<(Vec<f32>, Vec<f32>)> = vec![
336 (vec![500.0, -500.0, 0.0], one_hot(0, 3)),
337 (vec![-1000.0, -1000.0, -1000.0], one_hot(1, 3)),
338 (vec![88.0, 88.0], one_hot(0, 2)), (vec![-88.0, -88.0, -88.0], one_hot(2, 3)), ];
341
342 for (i, (logits, targets)) in extreme_cases.iter().enumerate() {
343 let pred = Tensor::from_vec(logits.clone(), false);
344 let tgt = Tensor::from_vec(targets.clone(), false);
345 let loss = ce.forward(&pred, &tgt);
346 let val = loss.data()[0];
347 assert!(val.is_finite(), "FALSIFIED CE-003 case {i}: CE = {val} (not finite)");
348 }
349 }
350
351 #[test]
355 fn falsify_ce_006_perfect_prediction() {
356 let ce = CrossEntropyLoss;
357
358 for &target in &[0, 1, 2] {
359 let mut logits = vec![-50.0; 3];
360 logits[target] = 50.0;
361 let pred = Tensor::from_vec(logits, false);
362 let tgt = Tensor::from_vec(one_hot(target, 3), false);
363 let loss = ce.forward(&pred, &tgt);
364 let val = loss.data()[0];
365 assert!(
366 val < 1e-3,
367 "FALSIFIED CE-006: CE(one_hot({target}), dominant) = {val}, expected ≈ 0"
368 );
369 }
370 }
371
372 #[test]
377 fn falsify_ce_001b_uniform_logits() {
378 let ce = CrossEntropyLoss;
379
380 for &nc in &[2_usize, 3, 5, 10] {
381 let logits = vec![1.0; nc];
382 let targets = one_hot(0, nc);
383 let pred = Tensor::from_vec(logits, false);
384 let tgt = Tensor::from_vec(targets, false);
385 let loss = ce.forward(&pred, &tgt);
386 let val = loss.data()[0];
387 let expected = (nc as f32).ln();
388 let diff = (val - expected).abs();
389 assert!(
390 diff < 1e-4,
391 "FALSIFIED CE-001b: CE(uniform, C={nc}) = {val}, expected log({nc}) = {expected}"
392 );
393 }
394 }
395
396 mod ce_proptest_falsify {
397 use super::*;
398 use proptest::prelude::*;
399
400 proptest! {
402 #![proptest_config(ProptestConfig::with_cases(200))]
403
404 #[test]
405 fn falsify_ce_001_prop_non_negativity(
406 nc in 2..=10usize,
407 target in 0..10usize,
408 seed in 0..1000u32,
409 ) {
410 let target = target % nc;
411 let logits: Vec<f32> = (0..nc)
412 .map(|i| ((i as f32 + seed as f32) * 0.37).sin() * 10.0)
413 .collect();
414
415 let ce = CrossEntropyLoss;
416 let pred = Tensor::from_vec(logits, false);
417 let tgt = Tensor::from_vec(one_hot(target, nc), false);
418 let loss = ce.forward(&pred, &tgt);
419 let val = loss.data()[0];
420 prop_assert!(
421 val >= -1e-6,
422 "FALSIFIED CE-001-prop: CE = {} < 0 (nc={}, target={})",
423 val, nc, target
424 );
425 }
426 }
427
428 proptest! {
430 #![proptest_config(ProptestConfig::with_cases(200))]
431
432 #[test]
433 fn falsify_ce_003_prop_finite_output(
434 nc in 2..=10usize,
435 target in 0..10usize,
436 scale in 0.1f32..100.0,
437 seed in 0..1000u32,
438 ) {
439 let target = target % nc;
440 let logits: Vec<f32> = (0..nc)
441 .map(|i| ((i as f32 + seed as f32) * 0.73).cos() * scale)
442 .collect();
443
444 let ce = CrossEntropyLoss;
445 let pred = Tensor::from_vec(logits, false);
446 let tgt = Tensor::from_vec(one_hot(target, nc), false);
447 let loss = ce.forward(&pred, &tgt);
448 let val = loss.data()[0];
449 prop_assert!(
450 val.is_finite(),
451 "FALSIFIED CE-003-prop: CE = {} (not finite) for nc={}, scale={}",
452 val, nc, scale
453 );
454 }
455 }
456
457 proptest! {
459 #![proptest_config(ProptestConfig::with_cases(200))]
460
461 #[test]
462 fn falsify_ce_002_prop_log_softmax_bound(
463 nc in 2..=10usize,
464 scale in 0.1f32..100.0,
465 seed in 0..1000u32,
466 ) {
467 let logits: Vec<f32> = (0..nc)
468 .map(|i| ((i as f32 + seed as f32) * 0.37).sin() * scale)
469 .collect();
470 let x = Array1::from(logits);
471 let probs = CrossEntropyLoss::softmax(&x);
472 for (j, &p) in probs.iter().enumerate() {
473 prop_assert!(
474 (0.0..=1.0 + 1e-6).contains(&p),
475 "FALSIFIED CE-002-prop: softmax[{}] = {} outside [0,1]",
476 j, p
477 );
478 let log_p = p.ln();
479 prop_assert!(
480 log_p <= 1e-6,
481 "FALSIFIED CE-002-prop: log(softmax[{}]) = {} > 0",
482 j, log_p
483 );
484 }
485 }
486 }
487 }
488}