Skip to main content

axonml_nn/
loss.rs

1//! Loss functions for training neural networks.
2//!
3//! 1436 lines. Each loss struct has `new(reduction)` and `compute(input,
4//! target) -> Variable` (graph-tracked, backpropagable). `MSELoss`,
5//! `L1Loss`, `CrossEntropyLoss` (log-softmax + NLL, class weights +
6//! label smoothing), `NLLLoss`, `BCELoss`, `BCEWithLogitsLoss` (fused
7//! sigmoid + BCE for numerical stability), `SmoothL1Loss` (Huber).
8//! Reduction modes: Mean (default), Sum, None. All losses support
9//! batched inputs.
10//!
11//! # File
12//! `crates/axonml-nn/src/loss.rs`
13//!
14//! # Author
15//! Andrew Jewell Sr. — AutomataNexus LLC
16//! ORCID: 0009-0005-2158-7060
17//!
18//! # Updated
19//! April 14, 2026 11:15 PM EST
20//!
21//! # Disclaimer
22//! Use at own risk. This software is provided "as is", without warranty of any
23//! kind, express or implied. The author and AutomataNexus shall not be held
24//! liable for any damages arising from the use of this software.
25
26use std::any::Any;
27
28use axonml_autograd::no_grad::is_grad_enabled;
29use axonml_autograd::{GradFn, GradientFunction, Variable};
30use axonml_tensor::Tensor;
31
32use crate::module::Module;
33
34// =============================================================================
35// Reduction Enum
36// =============================================================================
37
38/// Specifies how to reduce the loss over elements.
39#[derive(Debug, Clone, Copy, PartialEq, Default)]
40pub enum Reduction {
41    /// No reduction - return loss per element.
42    None,
43    /// Mean of all losses.
44    #[default]
45    Mean,
46    /// Sum of all losses.
47    Sum,
48}
49
50// =============================================================================
51// MSELoss
52// =============================================================================
53
54/// Mean Squared Error loss.
55///
56/// loss = mean((input - target)^2)
57#[derive(Debug, Clone, Copy)]
58pub struct MSELoss {
59    reduction: Reduction,
60}
61
62impl MSELoss {
63    /// Creates a new MSELoss with default reduction (Mean).
64    pub fn new() -> Self {
65        Self {
66            reduction: Reduction::Mean,
67        }
68    }
69
70    /// Creates MSELoss with specified reduction.
71    pub fn with_reduction(reduction: Reduction) -> Self {
72        Self { reduction }
73    }
74
75    /// Computes the loss.
76    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
77        let diff = input.sub_var(target);
78        let squared = diff.pow(2.0);
79
80        match self.reduction {
81            Reduction::None => squared,
82            Reduction::Mean => squared.mean(),
83            Reduction::Sum => squared.sum(),
84        }
85    }
86}
87
88impl Default for MSELoss {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl Module for MSELoss {
95    fn forward(&self, input: &Variable) -> Variable {
96        // For Module interface, we can't easily pass two inputs
97        // This is primarily used via compute() method
98        input.clone()
99    }
100
101    fn name(&self) -> &'static str {
102        "MSELoss"
103    }
104}
105
106// =============================================================================
107// L1Loss
108// =============================================================================
109
110/// Mean Absolute Error loss.
111///
112/// loss = mean(|input - target|)
113#[derive(Debug, Clone, Copy)]
114pub struct L1Loss {
115    reduction: Reduction,
116}
117
118impl L1Loss {
119    /// Creates a new L1Loss with default reduction (Mean).
120    pub fn new() -> Self {
121        Self {
122            reduction: Reduction::Mean,
123        }
124    }
125
126    /// Creates L1Loss with specified reduction.
127    pub fn with_reduction(reduction: Reduction) -> Self {
128        Self { reduction }
129    }
130
131    /// Computes the loss.
132    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
133        let input_data = input.data();
134        let target_data = target.data();
135        // diff = input - target (Tensor op, auto-dispatches to GPU)
136        let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
137        // |diff| = relu(diff) + relu(-diff), using Tensor ops that auto-dispatch to GPU
138        let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
139        let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
140        let abs_tensor = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
141
142        let requires_grad = (input.requires_grad() || target.requires_grad()) && is_grad_enabled();
143        let loss_var = if requires_grad {
144            let grad_fn = GradFn::new(L1LossBackward {
145                next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
146                diff_tensor,
147            });
148            Variable::from_operation(abs_tensor, grad_fn, true)
149        } else {
150            Variable::new(abs_tensor, false)
151        };
152
153        match self.reduction {
154            Reduction::None => loss_var,
155            Reduction::Mean => loss_var.mean(),
156            Reduction::Sum => loss_var.sum(),
157        }
158    }
159}
160
161impl Default for L1Loss {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167// =============================================================================
168// L1LossBackward
169// =============================================================================
170
171/// Gradient function for L1Loss.
172///
173/// d/d(input) = sign(input - target)
174/// d/d(target) = -sign(input - target)
175///
176/// Stores diff as Tensor<f32> so it stays on GPU when applicable.
177#[derive(Debug)]
178struct L1LossBackward {
179    next_fns: Vec<Option<GradFn>>,
180    diff_tensor: Tensor<f32>,
181}
182
183impl GradientFunction for L1LossBackward {
184    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
185        // sign(diff): +1 where diff > 0, -1 where diff < 0, 0 where diff == 0
186        // Compute as: diff / (|diff| + eps) which gives sign and handles GPU
187        let eps_tensor = Tensor::full(self.diff_tensor.shape(), 1e-12);
188        let eps_on_device = if self.diff_tensor.device().is_gpu() {
189            eps_tensor.to_device(self.diff_tensor.device()).unwrap()
190        } else {
191            eps_tensor
192        };
193        // |diff| approximated as sqrt(diff^2 + eps)  — but simpler: diff * diff then sqrt
194        let diff_sq = self
195            .diff_tensor
196            .mul(&self.diff_tensor)
197            .expect("tensor mul failed");
198        let diff_sq_eps = diff_sq.add(&eps_on_device).expect("tensor add failed");
199        // sqrt via exp(0.5 * ln(x))
200        let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
201        let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
202
203        // grad_input = sign(diff) * grad_output
204        let gi = sign_diff.mul(grad_output).unwrap();
205        // grad_target = -grad_input
206        let gt = gi.neg();
207        vec![Some(gi), Some(gt)]
208    }
209
210    fn name(&self) -> &'static str {
211        "L1LossBackward"
212    }
213
214    fn next_functions(&self) -> &[Option<GradFn>] {
215        &self.next_fns
216    }
217
218    fn as_any(&self) -> &dyn Any {
219        self
220    }
221}
222
223// =============================================================================
224// CrossEntropyBackward
225// =============================================================================
226
227/// Gradient function for CrossEntropyLoss.
228///
229/// The gradient of CE w.r.t. logits is: softmax(logits) - one_hot(target).
230/// For per-sample losses, each sample's gradient is scaled by the upstream
231/// gradient (from reduction).
232#[derive(Debug)]
233struct CrossEntropyBackward {
234    next_fns: Vec<Option<GradFn>>,
235    /// Softmax probabilities computed during forward pass, shape (N, C).
236    /// Stays on GPU if input was on GPU.
237    softmax_probs: Tensor<f32>,
238    /// Target class indices as f32, shape (N,). Stays on GPU if input was on GPU.
239    targets: Tensor<f32>,
240    batch_size: usize,
241    num_classes: usize,
242}
243
244impl GradientFunction for CrossEntropyBackward {
245    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
246        // Always use the exact CPU path for correctness, then transfer
247        // result to GPU if needed. The CUDA cross_entropy_bwd kernel has
248        // precision issues with approximate reciprocal/exp that cause
249        // training to stall. The CPU backward is fast enough since CE
250        // backward is O(N*C) — negligible vs the forward pass matmuls.
251        let softmax_vec = self.softmax_probs.to_vec();
252        let target_vec = self.targets.to_vec();
253        let grad_vec = grad_output.to_vec();
254        let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
255
256        // Handle both per-sample grad_output [N] and scalar [1] (from mean reduction)
257        let is_scalar_grad = grad_vec.len() == 1;
258
259        for b in 0..self.batch_size {
260            let grad_scale = if is_scalar_grad {
261                grad_vec[0]
262            } else if b < grad_vec.len() {
263                grad_vec[b]
264            } else {
265                1.0 / self.batch_size as f32
266            };
267            let offset = b * self.num_classes;
268            let tc = target_vec[b] as usize;
269            for c in 0..self.num_classes {
270                let mut g = softmax_vec[offset + c];
271                if c == tc {
272                    g -= 1.0;
273                }
274                grad_input[offset + c] = g * grad_scale;
275            }
276        }
277
278        let mut grad_tensor = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
279            .expect("tensor creation failed");
280        // Transfer to GPU if the forward was on GPU
281        if self.softmax_probs.device().is_gpu() {
282            grad_tensor = grad_tensor.to_device(self.softmax_probs.device()).unwrap();
283        }
284        vec![Some(grad_tensor)]
285    }
286
287    fn name(&self) -> &'static str {
288        "CrossEntropyBackward"
289    }
290
291    fn next_functions(&self) -> &[Option<GradFn>] {
292        &self.next_fns
293    }
294
295    fn as_any(&self) -> &dyn Any {
296        self
297    }
298}
299
300// =============================================================================
301// CrossEntropyLoss
302// =============================================================================
303
304/// Cross entropy loss with log softmax.
305///
306/// This combines LogSoftmax and NLLLoss in a single class.
307///
308/// # Shape
309/// - Input: (N, C) where C = number of classes
310/// - Target: (N,) with class indices
311#[derive(Debug, Clone, Copy)]
312pub struct CrossEntropyLoss {
313    reduction: Reduction,
314}
315
316impl CrossEntropyLoss {
317    /// Creates a new CrossEntropyLoss with default reduction (Mean).
318    pub fn new() -> Self {
319        Self {
320            reduction: Reduction::Mean,
321        }
322    }
323
324    /// Creates CrossEntropyLoss with specified reduction.
325    pub fn with_reduction(reduction: Reduction) -> Self {
326        Self { reduction }
327    }
328
329    /// Computes the loss.
330    ///
331    /// # Arguments
332    /// * `input` - Logits of shape (N, C)
333    /// * `target` - Class indices of shape (N,) as f32 (will be cast to usize)
334    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
335        let input_data = input.data();
336        let target_data = target.data();
337        let shape = input_data.shape().to_vec();
338        let batch_size = shape[0];
339        let num_classes = shape[1];
340
341        // GPU fast path: fused softmax + NLL loss kernel
342        #[cfg(feature = "cuda")]
343        if input_data.device().is_gpu() {
344            // Ensure targets are on GPU
345            let targets_gpu = if target_data.device().is_gpu() {
346                target_data.clone()
347            } else {
348                target_data.to_device(input_data.device()).unwrap()
349            };
350
351            let (loss_tensor, softmax_tensor) = input_data.cross_entropy_fwd_cuda(&targets_gpu);
352
353            let loss_var = if input.requires_grad() {
354                let grad_fn = GradFn::new(CrossEntropyBackward {
355                    next_fns: vec![input.grad_fn().cloned()],
356                    softmax_probs: softmax_tensor,
357                    targets: targets_gpu,
358                    batch_size,
359                    num_classes,
360                });
361                Variable::from_operation(loss_tensor, grad_fn, true)
362            } else {
363                Variable::new(loss_tensor, false)
364            };
365
366            return match self.reduction {
367                Reduction::None => loss_var,
368                Reduction::Mean => loss_var.mean(),
369                Reduction::Sum => loss_var.sum(),
370            };
371        }
372
373        // CPU path
374        let input_vec = input_data.to_vec();
375        let target_vec = target_data.to_vec();
376
377        let mut losses = vec![0.0f32; batch_size];
378        let mut softmax_probs_vec = vec![0.0f32; batch_size * num_classes];
379        let mut target_classes = vec![0usize; batch_size];
380
381        for b in 0..batch_size {
382            let offset = b * num_classes;
383
384            // Numerically stable log-softmax
385            let max_val = (0..num_classes)
386                .map(|c| input_vec[offset + c])
387                .fold(f32::NEG_INFINITY, f32::max);
388
389            let mut sum_exp = 0.0f32;
390            for c in 0..num_classes {
391                let exp_val = (input_vec[offset + c] - max_val).exp();
392                softmax_probs_vec[offset + c] = exp_val;
393                sum_exp += exp_val;
394            }
395
396            for c in 0..num_classes {
397                softmax_probs_vec[offset + c] /= sum_exp;
398            }
399
400            let log_sum_exp = max_val + sum_exp.ln();
401
402            let tc = target_vec[b] as usize;
403            target_classes[b] = tc;
404            losses[b] = log_sum_exp - input_vec[offset + tc];
405        }
406
407        let loss_tensor = Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
408        let softmax_tensor = Tensor::from_vec(softmax_probs_vec, &[batch_size, num_classes])
409            .expect("tensor creation failed");
410        let targets_f32: Vec<f32> = target_classes.iter().map(|&tc| tc as f32).collect();
411        let targets_tensor =
412            Tensor::from_vec(targets_f32, &[batch_size]).expect("tensor creation failed");
413
414        let loss_var = if input.requires_grad() {
415            let grad_fn = GradFn::new(CrossEntropyBackward {
416                next_fns: vec![input.grad_fn().cloned()],
417                softmax_probs: softmax_tensor,
418                targets: targets_tensor,
419                batch_size,
420                num_classes,
421            });
422            Variable::from_operation(loss_tensor, grad_fn, true)
423        } else {
424            Variable::new(loss_tensor, false)
425        };
426
427        match self.reduction {
428            Reduction::None => loss_var,
429            Reduction::Mean => loss_var.mean(),
430            Reduction::Sum => loss_var.sum(),
431        }
432    }
433}
434
435impl Default for CrossEntropyLoss {
436    fn default() -> Self {
437        Self::new()
438    }
439}
440
441// =============================================================================
442// NLLLoss
443// =============================================================================
444
445/// Negative Log Likelihood loss.
446///
447/// Expects input to be log-probabilities.
448#[derive(Debug, Clone, Copy)]
449pub struct NLLLoss {
450    reduction: Reduction,
451}
452
453impl NLLLoss {
454    /// Creates a new NLLLoss with default reduction (Mean).
455    pub fn new() -> Self {
456        Self {
457            reduction: Reduction::Mean,
458        }
459    }
460
461    /// Creates NLLLoss with specified reduction.
462    pub fn with_reduction(reduction: Reduction) -> Self {
463        Self { reduction }
464    }
465
466    /// Computes the loss.
467    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
468        let input_data = input.data();
469        let target_data = target.data();
470        let shape = input_data.shape().to_vec();
471        let batch_size = shape[0];
472        let num_classes = shape[1];
473
474        // NLL forward still needs per-sample gather (index into class dimension).
475        // We pull target indices to CPU for the gather but keep input on device.
476        let target_vec = target_data.to_vec();
477        let input_vec = input_data.to_vec();
478
479        let mut losses = vec![0.0f32; batch_size];
480        for b in 0..batch_size {
481            let tc = target_vec[b] as usize;
482            losses[b] = -input_vec[b * num_classes + tc];
483        }
484
485        let mut loss_tensor =
486            Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
487        if input_data.device().is_gpu() {
488            loss_tensor = loss_tensor.to_device(input_data.device()).unwrap();
489        }
490
491        let requires_grad = input.requires_grad() && is_grad_enabled();
492        let loss_var = if requires_grad {
493            let grad_fn = GradFn::new(NLLLossBackward {
494                next_fns: vec![input.grad_fn().cloned()],
495                target_tensor: target_data.clone(),
496                batch_size,
497                num_classes,
498            });
499            Variable::from_operation(loss_tensor, grad_fn, true)
500        } else {
501            Variable::new(loss_tensor, false)
502        };
503
504        match self.reduction {
505            Reduction::None => loss_var,
506            Reduction::Mean => loss_var.mean(),
507            Reduction::Sum => loss_var.sum(),
508        }
509    }
510}
511
512impl Default for NLLLoss {
513    fn default() -> Self {
514        Self::new()
515    }
516}
517
518// =============================================================================
519// NLLLossBackward
520// =============================================================================
521
522/// Gradient function for NLLLoss.
523///
524/// d/d(input)[b, c] = -1 if c == target[b], else 0
525///
526/// Stores targets as Tensor<f32> (GPU-resident when applicable).
527/// The scatter in backward still uses CPU indexing since it's a sparse write,
528/// but the result is moved to GPU if needed.
529#[derive(Debug)]
530struct NLLLossBackward {
531    next_fns: Vec<Option<GradFn>>,
532    target_tensor: Tensor<f32>,
533    batch_size: usize,
534    num_classes: usize,
535}
536
537impl GradientFunction for NLLLossBackward {
538    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
539        let grad_out_vec = grad_output.to_vec();
540        let target_vec = self.target_tensor.to_vec();
541        let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
542
543        for b in 0..self.batch_size {
544            let g = if grad_out_vec.len() == 1 {
545                grad_out_vec[0]
546            } else {
547                grad_out_vec[b]
548            };
549            let tc = target_vec[b] as usize;
550            grad_input[b * self.num_classes + tc] = -g;
551        }
552
553        let mut gi = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
554            .expect("tensor creation failed");
555        if grad_output.device().is_gpu() {
556            gi = gi.to_device(grad_output.device()).unwrap();
557        }
558        vec![Some(gi)]
559    }
560
561    fn name(&self) -> &'static str {
562        "NLLLossBackward"
563    }
564
565    fn next_functions(&self) -> &[Option<GradFn>] {
566        &self.next_fns
567    }
568
569    fn as_any(&self) -> &dyn Any {
570        self
571    }
572}
573
574// =============================================================================
575// BCELoss
576// =============================================================================
577
578/// Binary Cross Entropy loss.
579///
580/// Expects input to be probabilities in [0, 1].
581#[derive(Debug, Clone, Copy)]
582pub struct BCELoss {
583    reduction: Reduction,
584}
585
586impl BCELoss {
587    /// Creates a new BCELoss with default reduction (Mean).
588    pub fn new() -> Self {
589        Self {
590            reduction: Reduction::Mean,
591        }
592    }
593
594    /// Creates BCELoss with specified reduction.
595    pub fn with_reduction(reduction: Reduction) -> Self {
596        Self { reduction }
597    }
598
599    /// Computes the loss.
600    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
601        let input_data = input.data();
602        let target_data = target.data();
603
604        // Clamp predictions to [eps, 1-eps] using Tensor ops
605        let eps = 1e-7f32;
606        let p_clamped = axonml_tensor::ops::clamp(&input_data, eps, 1.0 - eps);
607
608        // loss = -(t * ln(p) + (1 - t) * ln(1 - p))
609        let ln_p = p_clamped.ln();
610        let one_minus_p = p_clamped.neg().add_scalar(1.0);
611        let ln_one_minus_p = one_minus_p.ln();
612        let one_minus_t = target_data.neg().add_scalar(1.0);
613
614        // t * ln(p)
615        let term1 = target_data.mul(&ln_p).expect("tensor mul failed");
616        // (1-t) * ln(1-p)
617        let term2 = one_minus_t.mul(&ln_one_minus_p).expect("tensor mul failed");
618        // -(term1 + term2)
619        let loss_tensor = term1.add(&term2).expect("tensor add failed").neg();
620
621        let requires_grad = input.requires_grad() && is_grad_enabled();
622        let loss_var = if requires_grad {
623            let grad_fn = GradFn::new(BCELossBackward {
624                next_fns: vec![input.grad_fn().cloned()],
625                input_tensor: input_data,
626                target_tensor: target_data,
627            });
628            Variable::from_operation(loss_tensor, grad_fn, true)
629        } else {
630            Variable::new(loss_tensor, false)
631        };
632
633        match self.reduction {
634            Reduction::None => loss_var,
635            Reduction::Mean => loss_var.mean(),
636            Reduction::Sum => loss_var.sum(),
637        }
638    }
639}
640
641impl Default for BCELoss {
642    fn default() -> Self {
643        Self::new()
644    }
645}
646
647// =============================================================================
648// BCELossBackward
649// =============================================================================
650
651/// Gradient function for BCELoss.
652///
653/// d/dp BCE = (p - y) / (p * (1 - p))
654///
655/// Stores input/target as Tensor<f32> (GPU-resident when applicable).
656#[derive(Debug)]
657struct BCELossBackward {
658    next_fns: Vec<Option<GradFn>>,
659    input_tensor: Tensor<f32>,
660    target_tensor: Tensor<f32>,
661}
662
663impl GradientFunction for BCELossBackward {
664    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
665        let eps = 1e-7f32;
666        // p_clamped = clamp(input, eps, 1-eps)
667        let p_clamped = axonml_tensor::ops::clamp(&self.input_tensor, eps, 1.0 - eps);
668        // (p - y)
669        let p_minus_y = p_clamped
670            .sub(&self.target_tensor)
671            .expect("tensor sub failed");
672        // p * (1 - p)
673        let one_minus_p = p_clamped.neg().add_scalar(1.0);
674        let denom = p_clamped.mul(&one_minus_p).expect("tensor mul failed");
675        // grad = grad_output * (p - y) / (p * (1 - p))
676        let ratio = p_minus_y.div(&denom).unwrap();
677        let grad_tensor = grad_output.mul(&ratio).expect("tensor mul failed");
678        vec![Some(grad_tensor)]
679    }
680
681    fn name(&self) -> &'static str {
682        "BCELossBackward"
683    }
684
685    fn next_functions(&self) -> &[Option<GradFn>] {
686        &self.next_fns
687    }
688
689    fn as_any(&self) -> &dyn Any {
690        self
691    }
692}
693
694// =============================================================================
695// BCEWithLogitsBackward
696// =============================================================================
697
698/// Gradient function for BCEWithLogitsLoss.
699///
700/// The gradient of BCE w.r.t. input logits is: sigmoid(input) - target.
701///
702/// Stores input/target as Tensor<f32> (GPU-resident when applicable).
703#[derive(Debug)]
704struct BCEWithLogitsBackward {
705    next_fns: Vec<Option<GradFn>>,
706    input_tensor: Tensor<f32>,
707    target_tensor: Tensor<f32>,
708}
709
710impl GradientFunction for BCEWithLogitsBackward {
711    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
712        // sigmoid(input) - target, all via Tensor ops (auto-dispatch to GPU)
713        let sig = self.input_tensor.sigmoid();
714        let sig_minus_t = sig.sub(&self.target_tensor).expect("tensor sub failed");
715        let grad_tensor = grad_output.mul(&sig_minus_t).expect("tensor mul failed");
716        vec![Some(grad_tensor)]
717    }
718
719    fn name(&self) -> &'static str {
720        "BCEWithLogitsBackward"
721    }
722
723    fn next_functions(&self) -> &[Option<GradFn>] {
724        &self.next_fns
725    }
726
727    fn as_any(&self) -> &dyn Any {
728        self
729    }
730}
731
732// =============================================================================
733// BCEWithLogitsLoss
734// =============================================================================
735
736/// Binary Cross Entropy with Logits.
737///
738/// Combines sigmoid and BCE in a numerically stable way.
739#[derive(Debug, Clone, Copy)]
740pub struct BCEWithLogitsLoss {
741    reduction: Reduction,
742}
743
744impl BCEWithLogitsLoss {
745    /// Creates a new BCEWithLogitsLoss with default reduction (Mean).
746    pub fn new() -> Self {
747        Self {
748            reduction: Reduction::Mean,
749        }
750    }
751
752    /// Creates BCEWithLogitsLoss with specified reduction.
753    pub fn with_reduction(reduction: Reduction) -> Self {
754        Self { reduction }
755    }
756
757    /// Computes the loss.
758    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
759        let input_data = input.data();
760        let target_data = target.data();
761
762        // Numerically stable: max(x, 0) - x*t + log(1 + exp(-|x|))
763        // max(x, 0) = relu(x) = clamp_min(x, 0)
764        let relu_x = axonml_tensor::ops::clamp_min(&input_data, 0.0);
765        // x * t
766        let x_times_t = input_data.mul(&target_data).expect("tensor mul failed");
767        // |x| via clamp trick: max(x, 0) + max(-x, 0) = relu(x) + relu(-x)
768        let neg_x = input_data.neg();
769        let relu_neg_x = axonml_tensor::ops::clamp_min(&neg_x, 0.0);
770        let abs_x = relu_x.add(&relu_neg_x).expect("tensor add failed");
771        // exp(-|x|)
772        let exp_neg_abs = abs_x.neg().exp();
773        // log(1 + exp(-|x|))
774        let log_term = exp_neg_abs.add_scalar(1.0).ln();
775        // loss = relu(x) - x*t + log(1 + exp(-|x|))
776        let loss_tensor = relu_x
777            .sub(&x_times_t)
778            .expect("tensor sub failed")
779            .add(&log_term)
780            .expect("tensor add failed");
781
782        let loss_var = if input.requires_grad() {
783            let grad_fn = GradFn::new(BCEWithLogitsBackward {
784                next_fns: vec![input.grad_fn().cloned()],
785                input_tensor: input_data,
786                target_tensor: target_data,
787            });
788            Variable::from_operation(loss_tensor, grad_fn, true)
789        } else {
790            Variable::new(loss_tensor, false)
791        };
792
793        match self.reduction {
794            Reduction::None => loss_var,
795            Reduction::Mean => loss_var.mean(),
796            Reduction::Sum => loss_var.sum(),
797        }
798    }
799}
800
801impl Default for BCEWithLogitsLoss {
802    fn default() -> Self {
803        Self::new()
804    }
805}
806
807// =============================================================================
808// SmoothL1Backward
809// =============================================================================
810
811/// Gradient function for SmoothL1Loss.
812///
813/// The gradient is: diff/beta if |diff| < beta, else sign(diff).
814/// Returns gradients for both input and target (negated for target).
815///
816/// Stores diff as Tensor<f32> (GPU-resident when applicable).
817#[derive(Debug)]
818struct SmoothL1Backward {
819    next_fns: Vec<Option<GradFn>>,
820    diff_tensor: Tensor<f32>,
821    beta: f32,
822    shape: Vec<usize>,
823}
824
825impl GradientFunction for SmoothL1Backward {
826    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
827        // Compute |diff| = sqrt(diff^2 + eps) to stay differentiable and device-agnostic
828        let eps = 1e-12f32;
829        let diff_sq = self
830            .diff_tensor
831            .mul(&self.diff_tensor)
832            .expect("tensor mul failed");
833        let diff_sq_eps = diff_sq.add_scalar(eps);
834        let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
835
836        // sign(diff) = diff / |diff|
837        let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
838
839        // For the L2 region (|diff| < beta): grad = diff / beta
840        // For the L1 region (|diff| >= beta): grad = sign(diff)
841        // Blend: mask = clamp(|diff| / beta, 0, 1), but we need a hard cutoff.
842        // Use a smooth approximation via where: if |diff| < beta -> diff/beta, else sign
843        // Since we need element-wise branching and don't have a GPU where_cond,
844        // we use: grad = (1 - mask) * (diff / beta) + mask * sign(diff)
845        // where mask = clamp((|diff| - beta) * large_value, 0, 1) approximates step function.
846        // Actually simpler: mask_l2 = clamp(1 - |diff|/beta, 0, 1) gives 1 in L2 region, 0 in L1
847        // BUT this gives a soft transition. For exact correctness, use CPU branching on the mask.
848        //
849        // Practical approach: compute both branches with Tensor ops, build mask on CPU, blend.
850        let grad_l2 = self.diff_tensor.mul_scalar(1.0 / self.beta); // diff / beta
851        let grad_l1 = sign_diff; // sign(diff)
852
853        // Build mask tensor: 1.0 where |d| < beta, 0.0 otherwise
854        let abs_vec = abs_diff.to_vec();
855        let beta = self.beta;
856        let mask_vec: Vec<f32> = abs_vec
857            .iter()
858            .map(|&a| if a < beta { 1.0 } else { 0.0 })
859            .collect();
860        let mut mask = Tensor::from_vec(mask_vec, &self.shape).expect("tensor creation failed");
861        if self.diff_tensor.device().is_gpu() {
862            mask = mask.to_device(self.diff_tensor.device()).unwrap();
863        }
864        let inv_mask = mask.neg().add_scalar(1.0);
865
866        // grad_per_elem = mask * grad_l2 + (1 - mask) * grad_l1
867        let blended = mask
868            .mul(&grad_l2)
869            .unwrap()
870            .add(&inv_mask.mul(&grad_l1).expect("tensor add failed"))
871            .unwrap();
872
873        // gi = blended * grad_output
874        let gi = blended.mul(grad_output).unwrap();
875        let gt = gi.neg();
876        vec![Some(gi), Some(gt)]
877    }
878
879    fn name(&self) -> &'static str {
880        "SmoothL1Backward"
881    }
882
883    fn next_functions(&self) -> &[Option<GradFn>] {
884        &self.next_fns
885    }
886
887    fn as_any(&self) -> &dyn Any {
888        self
889    }
890}
891
892// =============================================================================
893// SmoothL1Loss
894// =============================================================================
895
896/// Smooth L1 loss (Huber loss).
897///
898/// Uses L2 loss when |x| < beta, L1 loss otherwise.
899#[derive(Debug, Clone, Copy)]
900pub struct SmoothL1Loss {
901    reduction: Reduction,
902    beta: f32,
903}
904
905impl SmoothL1Loss {
906    /// Creates a new SmoothL1Loss with default beta (1.0).
907    pub fn new() -> Self {
908        Self {
909            reduction: Reduction::Mean,
910            beta: 1.0,
911        }
912    }
913
914    /// Creates SmoothL1Loss with specified beta.
915    pub fn with_beta(beta: f32) -> Self {
916        Self {
917            reduction: Reduction::Mean,
918            beta,
919        }
920    }
921
922    /// Computes the loss.
923    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
924        let input_data = input.data();
925        let target_data = target.data();
926        let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
927        let shape = diff_tensor.shape().to_vec();
928
929        // Compute |diff| via relu(diff) + relu(-diff)
930        let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
931        let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
932        let abs_diff = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
933
934        // L2 branch: 0.5 * diff^2 / beta
935        let diff_sq = diff_tensor.mul(&diff_tensor).expect("tensor mul failed");
936        let l2_loss = diff_sq.mul_scalar(0.5 / self.beta);
937
938        // L1 branch: |diff| - 0.5 * beta
939        let l1_loss = abs_diff.add_scalar(-0.5 * self.beta);
940
941        // Build mask: 1.0 where |diff| < beta, 0.0 otherwise
942        let abs_vec = abs_diff.to_vec();
943        let beta = self.beta;
944        let mask_vec: Vec<f32> = abs_vec
945            .iter()
946            .map(|&a| if a < beta { 1.0 } else { 0.0 })
947            .collect();
948        let mut mask = Tensor::from_vec(mask_vec, &shape).expect("tensor creation failed");
949        if diff_tensor.device().is_gpu() {
950            mask = mask.to_device(diff_tensor.device()).unwrap();
951        }
952        let inv_mask = mask.neg().add_scalar(1.0);
953
954        // loss = mask * l2_loss + (1-mask) * l1_loss
955        let loss_tensor = mask
956            .mul(&l2_loss)
957            .unwrap()
958            .add(&inv_mask.mul(&l1_loss).expect("tensor add failed"))
959            .unwrap();
960
961        let loss_var = if input.requires_grad() || target.requires_grad() {
962            let grad_fn = GradFn::new(SmoothL1Backward {
963                next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
964                diff_tensor,
965                beta: self.beta,
966                shape,
967            });
968            Variable::from_operation(loss_tensor, grad_fn, true)
969        } else {
970            Variable::new(loss_tensor, false)
971        };
972
973        match self.reduction {
974            Reduction::None => loss_var,
975            Reduction::Mean => loss_var.mean(),
976            Reduction::Sum => loss_var.sum(),
977        }
978    }
979}
980
981impl Default for SmoothL1Loss {
982    fn default() -> Self {
983        Self::new()
984    }
985}
986
987// =============================================================================
988// Tests
989// =============================================================================
990
991#[cfg(test)]
992mod tests {
993    use super::*;
994
995    #[test]
996    fn test_mse_loss() {
997        let loss_fn = MSELoss::new();
998        let input = Variable::new(
999            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1000            false,
1001        );
1002        let target = Variable::new(
1003            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1004            false,
1005        );
1006        let loss = loss_fn.compute(&input, &target);
1007        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1008    }
1009
1010    #[test]
1011    fn test_mse_loss_nonzero() {
1012        let loss_fn = MSELoss::new();
1013        let input = Variable::new(
1014            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1015            false,
1016        );
1017        let target = Variable::new(
1018            Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).expect("tensor creation failed"),
1019            false,
1020        );
1021        let loss = loss_fn.compute(&input, &target);
1022        // Each diff is 1.0, squared is 1.0, mean is 1.0
1023        assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
1024    }
1025
1026    #[test]
1027    fn test_cross_entropy_loss() {
1028        let loss_fn = CrossEntropyLoss::new();
1029        let input = Variable::new(
1030            Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3])
1031                .expect("tensor creation failed"),
1032            false,
1033        );
1034        let target = Variable::new(
1035            Tensor::from_vec(vec![2.0, 0.0], &[2]).expect("tensor creation failed"),
1036            false,
1037        );
1038        let loss = loss_fn.compute(&input, &target);
1039        assert!(loss.data().to_vec()[0] > 0.0);
1040    }
1041
1042    #[test]
1043    fn test_bce_loss() {
1044        let loss_fn = BCELoss::new();
1045        let input = Variable::new(
1046            Tensor::from_vec(vec![0.5, 0.5], &[2]).expect("tensor creation failed"),
1047            false,
1048        );
1049        let target = Variable::new(
1050            Tensor::from_vec(vec![1.0, 0.0], &[2]).expect("tensor creation failed"),
1051            false,
1052        );
1053        let loss = loss_fn.compute(&input, &target);
1054        // -[1*ln(0.5) + 0*ln(0.5)] - [0*ln(0.5) + 1*ln(0.5)] = -2*ln(0.5) / 2 = -ln(0.5) = 0.693
1055        assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1056    }
1057
1058    #[test]
1059    fn test_cross_entropy_gradient_flow() {
1060        use axonml_autograd::backward;
1061
1062        // Create input logits with requires_grad=true
1063        let input = Variable::new(
1064            Tensor::from_vec(vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3], &[2, 3])
1065                .expect("tensor creation failed"),
1066            true,
1067        );
1068        let target = Variable::new(
1069            Tensor::from_vec(vec![0.0, 1.0], &[2]).expect("tensor creation failed"),
1070            false,
1071        );
1072
1073        let loss_fn = CrossEntropyLoss::new();
1074        let loss = loss_fn.compute(&input, &target);
1075
1076        // Loss should be positive
1077        let loss_val = loss.data().to_vec()[0];
1078        assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
1079
1080        // Backward pass
1081        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1082        backward(&loss, &ones);
1083
1084        // Input should have gradient
1085        let grad = input
1086            .grad()
1087            .expect("Input should have gradient after backward");
1088        let grad_vec = grad.to_vec();
1089
1090        // Gradient should be non-zero
1091        let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1092        assert!(
1093            grad_norm > 1e-10,
1094            "Gradient should be non-zero, got norm {}",
1095            grad_norm
1096        );
1097
1098        // Gradient shape should match input shape
1099        assert_eq!(grad.shape(), &[2, 3]);
1100
1101        // For the correct class, gradient should be negative (softmax - 1 < 0)
1102        // Sample 0, class 0 (target): grad should be (softmax[0,0] - 1) / 2
1103        assert!(
1104            grad_vec[0] < 0.0,
1105            "Gradient for correct class should be negative"
1106        );
1107        // Sample 1, class 1 (target): grad should be (softmax[1,1] - 1) / 2
1108        assert!(
1109            grad_vec[4] < 0.0,
1110            "Gradient for correct class should be negative"
1111        );
1112
1113        // For wrong classes, gradient should be positive (softmax > 0)
1114        assert!(
1115            grad_vec[1] > 0.0,
1116            "Gradient for wrong class should be positive"
1117        );
1118        assert!(
1119            grad_vec[2] > 0.0,
1120            "Gradient for wrong class should be positive"
1121        );
1122    }
1123
1124    #[test]
1125    fn test_cross_entropy_perfect_prediction() {
1126        // When logits strongly favor the correct class, loss should be near zero
1127        let loss_fn = CrossEntropyLoss::new();
1128        let input = Variable::new(
1129            Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).expect("tensor creation failed"),
1130            false,
1131        );
1132        let target = Variable::new(
1133            Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1134            false,
1135        );
1136        let loss = loss_fn.compute(&input, &target);
1137        assert!(
1138            loss.data().to_vec()[0] < 0.001,
1139            "Perfect prediction should have near-zero loss"
1140        );
1141    }
1142
1143    #[test]
1144    fn test_cross_entropy_uniform_prediction() {
1145        // When logits are all equal, loss should be ln(num_classes)
1146        let loss_fn = CrossEntropyLoss::new();
1147        let num_classes = 16;
1148        let input = Variable::new(
1149            Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes])
1150                .expect("tensor creation failed"),
1151            false,
1152        );
1153        let target = Variable::new(
1154            Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1155            false,
1156        );
1157        let loss = loss_fn.compute(&input, &target);
1158        let expected = (num_classes as f32).ln(); // ln(16) ≈ 2.77
1159        let actual = loss.data().to_vec()[0];
1160        assert!(
1161            (actual - expected).abs() < 0.01,
1162            "Uniform logits should give ln(C)={}, got {}",
1163            expected,
1164            actual,
1165        );
1166    }
1167
1168    #[test]
1169    fn test_bce_with_logits_gradient_flow() {
1170        use axonml_autograd::backward;
1171
1172        let input = Variable::new(
1173            Tensor::from_vec(vec![0.5, -0.5, 1.0, -1.0], &[4]).expect("tensor creation failed"),
1174            true,
1175        );
1176        let target = Variable::new(
1177            Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).expect("tensor creation failed"),
1178            false,
1179        );
1180
1181        let loss_fn = BCEWithLogitsLoss::new();
1182        let loss = loss_fn.compute(&input, &target);
1183        assert!(loss.data().to_vec()[0] > 0.0);
1184
1185        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1186        backward(&loss, &ones);
1187
1188        let grad = input.grad().expect("Input should have gradient");
1189        let grad_vec = grad.to_vec();
1190        assert_eq!(grad_vec.len(), 4);
1191
1192        // For target=1, grad = sigmoid(x) - 1 < 0
1193        assert!(grad_vec[0] < 0.0);
1194        // For target=0, grad = sigmoid(x) > 0
1195        assert!(grad_vec[1] > 0.0);
1196    }
1197
1198    #[test]
1199    fn test_smooth_l1_gradient_flow() {
1200        use axonml_autograd::backward;
1201
1202        let input = Variable::new(
1203            Tensor::from_vec(vec![1.0, 2.0, 5.0], &[3]).expect("tensor creation failed"),
1204            true,
1205        );
1206        let target = Variable::new(
1207            Tensor::from_vec(vec![1.5, 1.5, 1.5], &[3]).expect("tensor creation failed"),
1208            false,
1209        );
1210
1211        let loss_fn = SmoothL1Loss::new();
1212        let loss = loss_fn.compute(&input, &target);
1213        assert!(loss.data().to_vec()[0] > 0.0);
1214
1215        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1216        backward(&loss, &ones);
1217
1218        let grad = input.grad().expect("Input should have gradient");
1219        let grad_vec = grad.to_vec();
1220        assert_eq!(grad_vec.len(), 3);
1221        // Gradients should be non-zero
1222        let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1223        assert!(grad_norm > 1e-10);
1224    }
1225
1226    // =========================================================================
1227    // MSE Loss Comprehensive
1228    // =========================================================================
1229
1230    #[test]
1231    fn test_mse_loss_gradient_correctness() {
1232        use axonml_autograd::backward;
1233
1234        // MSE gradient = 2*(input - target) / N
1235        let input = Variable::new(Tensor::from_vec(vec![3.0, 1.0], &[2]).unwrap(), true);
1236        let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1237
1238        let loss = MSELoss::new().compute(&input, &target);
1239        // loss = ((3-1)^2 + (1-1)^2) / 2 = 4/2 = 2.0
1240        assert!(
1241            (loss.data().to_vec()[0] - 2.0).abs() < 1e-5,
1242            "MSE should be 2.0"
1243        );
1244
1245        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1246        backward(&loss, &ones);
1247
1248        let grad = input.grad().expect("Should have gradient");
1249        let gv = grad.to_vec();
1250        // dL/dx = 2*(x-t)/N = 2*(3-1)/2 = 2.0 for first, 0.0 for second
1251        assert!(
1252            (gv[0] - 2.0).abs() < 0.1,
1253            "Grad[0] should be ~2.0, got {}",
1254            gv[0]
1255        );
1256        assert!(gv[1].abs() < 0.1, "Grad[1] should be ~0.0, got {}", gv[1]);
1257    }
1258
1259    #[test]
1260    fn test_mse_loss_reduction_sum() {
1261        let input = Variable::new(Tensor::from_vec(vec![2.0, 4.0], &[2]).unwrap(), false);
1262        let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1263        let loss = MSELoss::with_reduction(Reduction::Sum).compute(&input, &target);
1264        // sum = (2-1)^2 + (4-1)^2 = 1 + 9 = 10
1265        assert!((loss.data().to_vec()[0] - 10.0).abs() < 1e-5);
1266    }
1267
1268    // =========================================================================
1269    // L1 Loss
1270    // =========================================================================
1271
1272    #[test]
1273    fn test_l1_loss_basic() {
1274        let input = Variable::new(Tensor::from_vec(vec![1.0, 5.0, 3.0], &[3]).unwrap(), false);
1275        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 4.0], &[3]).unwrap(), false);
1276        let loss = L1Loss::new().compute(&input, &target);
1277        // mean(|0| + |3| + |-1|) = 4/3 ≈ 1.333
1278        assert!((loss.data().to_vec()[0] - 4.0 / 3.0).abs() < 1e-4);
1279    }
1280
1281    #[test]
1282    fn test_l1_loss_zero() {
1283        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1284        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1285        let loss = L1Loss::new().compute(&input, &target);
1286        assert!(
1287            loss.data().to_vec()[0].abs() < 1e-6,
1288            "Identical inputs should give 0 loss"
1289        );
1290    }
1291
1292    // =========================================================================
1293    // BCE Loss Edge Cases
1294    // =========================================================================
1295
1296    #[test]
1297    fn test_bce_loss_perfect_prediction() {
1298        let loss_fn = BCELoss::new();
1299        // Near-perfect predictions
1300        let input = Variable::new(Tensor::from_vec(vec![0.999, 0.001], &[2]).unwrap(), false);
1301        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1302        let loss = loss_fn.compute(&input, &target);
1303        assert!(
1304            loss.data().to_vec()[0] < 0.01,
1305            "Perfect prediction should have near-zero loss"
1306        );
1307    }
1308
1309    #[test]
1310    fn test_bce_loss_worst_prediction() {
1311        let loss_fn = BCELoss::new();
1312        // Worst predictions (inverted)
1313        let input = Variable::new(Tensor::from_vec(vec![0.001, 0.999], &[2]).unwrap(), false);
1314        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1315        let loss = loss_fn.compute(&input, &target);
1316        // Should be high
1317        assert!(
1318            loss.data().to_vec()[0] > 3.0,
1319            "Worst prediction should have high loss"
1320        );
1321    }
1322
1323    // =========================================================================
1324    // BCEWithLogitsLoss Comprehensive
1325    // =========================================================================
1326
1327    #[test]
1328    fn test_bce_with_logits_numerical_stability() {
1329        let loss_fn = BCEWithLogitsLoss::new();
1330        // Very large logits should not produce NaN/Inf
1331        let input = Variable::new(
1332            Tensor::from_vec(vec![100.0, -100.0, 50.0, -50.0], &[4]).unwrap(),
1333            false,
1334        );
1335        let target = Variable::new(
1336            Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).unwrap(),
1337            false,
1338        );
1339        let loss = loss_fn.compute(&input, &target);
1340        let val = loss.data().to_vec()[0];
1341        assert!(
1342            val.is_finite(),
1343            "Loss should be finite for large logits, got {}",
1344            val
1345        );
1346        assert!(val >= 0.0, "BCE loss should be non-negative");
1347    }
1348
1349    #[test]
1350    fn test_bce_with_logits_zero_logits() {
1351        let loss_fn = BCEWithLogitsLoss::new();
1352        // Zero logits → sigmoid(0) = 0.5 → random prediction
1353        let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1354        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1355        let loss = loss_fn.compute(&input, &target);
1356        // Should be ln(2) ≈ 0.693
1357        assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1358    }
1359
1360    #[test]
1361    fn test_bce_with_logits_reduction_none() {
1362        let loss_fn = BCEWithLogitsLoss::with_reduction(Reduction::None);
1363        let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0, 0.0], &[3]).unwrap(), false);
1364        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0, 1.0], &[3]).unwrap(), false);
1365        let loss = loss_fn.compute(&input, &target);
1366        // Should return per-element losses, not reduced
1367        assert_eq!(loss.shape().len(), 1);
1368        assert_eq!(loss.shape()[0], 3);
1369    }
1370
1371    // =========================================================================
1372    // SmoothL1 Loss
1373    // =========================================================================
1374
1375    #[test]
1376    fn test_smooth_l1_small_error() {
1377        // For |diff| < beta=1.0: loss = 0.5 * diff^2 / beta
1378        let loss_fn = SmoothL1Loss::new();
1379        let input = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), false);
1380        let target = Variable::new(Tensor::from_vec(vec![1.3], &[1]).unwrap(), false);
1381        let loss = loss_fn.compute(&input, &target);
1382        // diff=0.3, |0.3| < 1.0, loss = 0.5 * 0.09 / 1.0 = 0.045
1383        assert!((loss.data().to_vec()[0] - 0.045).abs() < 0.01);
1384    }
1385
1386    #[test]
1387    fn test_smooth_l1_large_error() {
1388        // For |diff| >= beta=1.0: loss = |diff| - 0.5*beta
1389        let loss_fn = SmoothL1Loss::new();
1390        let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1391        let target = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), false);
1392        let loss = loss_fn.compute(&input, &target);
1393        // diff=5.0, |5| >= 1.0, loss = 5.0 - 0.5 = 4.5
1394        assert!((loss.data().to_vec()[0] - 4.5).abs() < 0.1);
1395    }
1396
1397    // =========================================================================
1398    // Cross-Entropy Batch Consistency
1399    // =========================================================================
1400
1401    #[test]
1402    fn test_cross_entropy_batch_independence() {
1403        let loss_fn = CrossEntropyLoss::new();
1404
1405        // Single sample
1406        let input1 = Variable::new(
1407            Tensor::from_vec(vec![2.0, 1.0, 0.1], &[1, 3]).unwrap(),
1408            false,
1409        );
1410        let target1 = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1411        let loss1 = loss_fn.compute(&input1, &target1).data().to_vec()[0];
1412
1413        // Same sample duplicated in batch
1414        let input2 = Variable::new(
1415            Tensor::from_vec(vec![2.0, 1.0, 0.1, 2.0, 1.0, 0.1], &[2, 3]).unwrap(),
1416            false,
1417        );
1418        let target2 = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1419        let loss2 = loss_fn.compute(&input2, &target2).data().to_vec()[0];
1420
1421        // Mean reduction should give same result for duplicated batch
1422        assert!(
1423            (loss1 - loss2).abs() < 1e-5,
1424            "Duplicated batch should give same loss: {} vs {}",
1425            loss1,
1426            loss2
1427        );
1428    }
1429
1430    #[test]
1431    fn test_cross_entropy_high_class_count() {
1432        // Test with many classes (like BirdCLEF 234 species)
1433        let n_classes = 100;
1434        let mut logits = vec![0.0f32; n_classes];
1435        logits[42] = 5.0; // Correct class has high logit
1436
1437        let loss_fn = CrossEntropyLoss::new();
1438        let input = Variable::new(Tensor::from_vec(logits, &[1, n_classes]).unwrap(), false);
1439        let target = Variable::new(Tensor::from_vec(vec![42.0], &[1]).unwrap(), false);
1440        let loss = loss_fn.compute(&input, &target);
1441        let val = loss.data().to_vec()[0];
1442        assert!(val.is_finite(), "Should handle 100 classes");
1443        assert!(val < 1.0, "Correct class should have low loss, got {}", val);
1444    }
1445}