Skip to main content

axonml_nn/
loss.rs

1//! Loss Functions - Training Objectives
2//!
3//! # File
4//! `crates/axonml-nn/src/loss.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::any::Any;
18
19use axonml_autograd::no_grad::is_grad_enabled;
20use axonml_autograd::{GradFn, GradientFunction, Variable};
21use axonml_tensor::Tensor;
22
23use crate::module::Module;
24
25// =============================================================================
26// Reduction Enum
27// =============================================================================
28
29/// Specifies how to reduce the loss over elements.
30#[derive(Debug, Clone, Copy, PartialEq, Default)]
31pub enum Reduction {
32    /// No reduction - return loss per element.
33    None,
34    /// Mean of all losses.
35    #[default]
36    Mean,
37    /// Sum of all losses.
38    Sum,
39}
40
41// =============================================================================
42// MSELoss
43// =============================================================================
44
45/// Mean Squared Error loss.
46///
47/// loss = mean((input - target)^2)
48#[derive(Debug, Clone, Copy)]
49pub struct MSELoss {
50    reduction: Reduction,
51}
52
53impl MSELoss {
54    /// Creates a new MSELoss with default reduction (Mean).
55    pub fn new() -> Self {
56        Self {
57            reduction: Reduction::Mean,
58        }
59    }
60
61    /// Creates MSELoss with specified reduction.
62    pub fn with_reduction(reduction: Reduction) -> Self {
63        Self { reduction }
64    }
65
66    /// Computes the loss.
67    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
68        let diff = input.sub_var(target);
69        let squared = diff.pow(2.0);
70
71        match self.reduction {
72            Reduction::None => squared,
73            Reduction::Mean => squared.mean(),
74            Reduction::Sum => squared.sum(),
75        }
76    }
77}
78
79impl Default for MSELoss {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl Module for MSELoss {
86    fn forward(&self, input: &Variable) -> Variable {
87        // For Module interface, we can't easily pass two inputs
88        // This is primarily used via compute() method
89        input.clone()
90    }
91
92    fn name(&self) -> &'static str {
93        "MSELoss"
94    }
95}
96
97// =============================================================================
98// L1Loss
99// =============================================================================
100
101/// Mean Absolute Error loss.
102///
103/// loss = mean(|input - target|)
104#[derive(Debug, Clone, Copy)]
105pub struct L1Loss {
106    reduction: Reduction,
107}
108
109impl L1Loss {
110    /// Creates a new L1Loss with default reduction (Mean).
111    pub fn new() -> Self {
112        Self {
113            reduction: Reduction::Mean,
114        }
115    }
116
117    /// Creates L1Loss with specified reduction.
118    pub fn with_reduction(reduction: Reduction) -> Self {
119        Self { reduction }
120    }
121
122    /// Computes the loss.
123    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
124        let input_data = input.data();
125        let target_data = target.data();
126        // diff = input - target (Tensor op, auto-dispatches to GPU)
127        let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
128        // |diff| = relu(diff) + relu(-diff), using Tensor ops that auto-dispatch to GPU
129        let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
130        let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
131        let abs_tensor = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
132
133        let requires_grad = (input.requires_grad() || target.requires_grad()) && is_grad_enabled();
134        let loss_var = if requires_grad {
135            let grad_fn = GradFn::new(L1LossBackward {
136                next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
137                diff_tensor,
138            });
139            Variable::from_operation(abs_tensor, grad_fn, true)
140        } else {
141            Variable::new(abs_tensor, false)
142        };
143
144        match self.reduction {
145            Reduction::None => loss_var,
146            Reduction::Mean => loss_var.mean(),
147            Reduction::Sum => loss_var.sum(),
148        }
149    }
150}
151
152impl Default for L1Loss {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158// =============================================================================
159// L1LossBackward
160// =============================================================================
161
162/// Gradient function for L1Loss.
163///
164/// d/d(input) = sign(input - target)
165/// d/d(target) = -sign(input - target)
166///
167/// Stores diff as Tensor<f32> so it stays on GPU when applicable.
168#[derive(Debug)]
169struct L1LossBackward {
170    next_fns: Vec<Option<GradFn>>,
171    diff_tensor: Tensor<f32>,
172}
173
174impl GradientFunction for L1LossBackward {
175    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
176        // sign(diff): +1 where diff > 0, -1 where diff < 0, 0 where diff == 0
177        // Compute as: diff / (|diff| + eps) which gives sign and handles GPU
178        let eps_tensor = Tensor::full(self.diff_tensor.shape(), 1e-12);
179        let eps_on_device = if self.diff_tensor.device().is_gpu() {
180            eps_tensor.to_device(self.diff_tensor.device()).unwrap()
181        } else {
182            eps_tensor
183        };
184        // |diff| approximated as sqrt(diff^2 + eps)  — but simpler: diff * diff then sqrt
185        let diff_sq = self
186            .diff_tensor
187            .mul(&self.diff_tensor)
188            .expect("tensor mul failed");
189        let diff_sq_eps = diff_sq.add(&eps_on_device).expect("tensor add failed");
190        // sqrt via exp(0.5 * ln(x))
191        let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
192        let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
193
194        // grad_input = sign(diff) * grad_output
195        let gi = sign_diff.mul(grad_output).unwrap();
196        // grad_target = -grad_input
197        let gt = gi.neg();
198        vec![Some(gi), Some(gt)]
199    }
200
201    fn name(&self) -> &'static str {
202        "L1LossBackward"
203    }
204
205    fn next_functions(&self) -> &[Option<GradFn>] {
206        &self.next_fns
207    }
208
209    fn as_any(&self) -> &dyn Any {
210        self
211    }
212}
213
214// =============================================================================
215// CrossEntropyBackward
216// =============================================================================
217
218/// Gradient function for CrossEntropyLoss.
219///
220/// The gradient of CE w.r.t. logits is: softmax(logits) - one_hot(target).
221/// For per-sample losses, each sample's gradient is scaled by the upstream
222/// gradient (from reduction).
223#[derive(Debug)]
224struct CrossEntropyBackward {
225    next_fns: Vec<Option<GradFn>>,
226    /// Softmax probabilities computed during forward pass, shape (N, C).
227    /// Stays on GPU if input was on GPU.
228    softmax_probs: Tensor<f32>,
229    /// Target class indices as f32, shape (N,). Stays on GPU if input was on GPU.
230    targets: Tensor<f32>,
231    batch_size: usize,
232    num_classes: usize,
233}
234
235impl GradientFunction for CrossEntropyBackward {
236    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
237        // Always use the exact CPU path for correctness, then transfer
238        // result to GPU if needed. The CUDA cross_entropy_bwd kernel has
239        // precision issues with approximate reciprocal/exp that cause
240        // training to stall. The CPU backward is fast enough since CE
241        // backward is O(N*C) — negligible vs the forward pass matmuls.
242        let softmax_vec = self.softmax_probs.to_vec();
243        let target_vec = self.targets.to_vec();
244        let grad_vec = grad_output.to_vec();
245        let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
246
247        // Handle both per-sample grad_output [N] and scalar [1] (from mean reduction)
248        let is_scalar_grad = grad_vec.len() == 1;
249
250        for b in 0..self.batch_size {
251            let grad_scale = if is_scalar_grad {
252                grad_vec[0]
253            } else if b < grad_vec.len() {
254                grad_vec[b]
255            } else {
256                1.0 / self.batch_size as f32
257            };
258            let offset = b * self.num_classes;
259            let tc = target_vec[b] as usize;
260            for c in 0..self.num_classes {
261                let mut g = softmax_vec[offset + c];
262                if c == tc {
263                    g -= 1.0;
264                }
265                grad_input[offset + c] = g * grad_scale;
266            }
267        }
268
269        let mut grad_tensor = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
270            .expect("tensor creation failed");
271        // Transfer to GPU if the forward was on GPU
272        if self.softmax_probs.device().is_gpu() {
273            grad_tensor = grad_tensor.to_device(self.softmax_probs.device()).unwrap();
274        }
275        vec![Some(grad_tensor)]
276    }
277
278    fn name(&self) -> &'static str {
279        "CrossEntropyBackward"
280    }
281
282    fn next_functions(&self) -> &[Option<GradFn>] {
283        &self.next_fns
284    }
285
286    fn as_any(&self) -> &dyn Any {
287        self
288    }
289}
290
291// =============================================================================
292// CrossEntropyLoss
293// =============================================================================
294
295/// Cross entropy loss with log softmax.
296///
297/// This combines LogSoftmax and NLLLoss in a single class.
298///
299/// # Shape
300/// - Input: (N, C) where C = number of classes
301/// - Target: (N,) with class indices
302#[derive(Debug, Clone, Copy)]
303pub struct CrossEntropyLoss {
304    reduction: Reduction,
305}
306
307impl CrossEntropyLoss {
308    /// Creates a new CrossEntropyLoss with default reduction (Mean).
309    pub fn new() -> Self {
310        Self {
311            reduction: Reduction::Mean,
312        }
313    }
314
315    /// Creates CrossEntropyLoss with specified reduction.
316    pub fn with_reduction(reduction: Reduction) -> Self {
317        Self { reduction }
318    }
319
320    /// Computes the loss.
321    ///
322    /// # Arguments
323    /// * `input` - Logits of shape (N, C)
324    /// * `target` - Class indices of shape (N,) as f32 (will be cast to usize)
325    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
326        let input_data = input.data();
327        let target_data = target.data();
328        let shape = input_data.shape().to_vec();
329        let batch_size = shape[0];
330        let num_classes = shape[1];
331
332        // GPU fast path: fused softmax + NLL loss kernel
333        #[cfg(feature = "cuda")]
334        if input_data.device().is_gpu() {
335            // Ensure targets are on GPU
336            let targets_gpu = if target_data.device().is_gpu() {
337                target_data.clone()
338            } else {
339                target_data.to_device(input_data.device()).unwrap()
340            };
341
342            let (loss_tensor, softmax_tensor) = input_data.cross_entropy_fwd_cuda(&targets_gpu);
343
344            let loss_var = if input.requires_grad() {
345                let grad_fn = GradFn::new(CrossEntropyBackward {
346                    next_fns: vec![input.grad_fn().cloned()],
347                    softmax_probs: softmax_tensor,
348                    targets: targets_gpu,
349                    batch_size,
350                    num_classes,
351                });
352                Variable::from_operation(loss_tensor, grad_fn, true)
353            } else {
354                Variable::new(loss_tensor, false)
355            };
356
357            return match self.reduction {
358                Reduction::None => loss_var,
359                Reduction::Mean => loss_var.mean(),
360                Reduction::Sum => loss_var.sum(),
361            };
362        }
363
364        // CPU path
365        let input_vec = input_data.to_vec();
366        let target_vec = target_data.to_vec();
367
368        let mut losses = vec![0.0f32; batch_size];
369        let mut softmax_probs_vec = vec![0.0f32; batch_size * num_classes];
370        let mut target_classes = vec![0usize; batch_size];
371
372        for b in 0..batch_size {
373            let offset = b * num_classes;
374
375            // Numerically stable log-softmax
376            let max_val = (0..num_classes)
377                .map(|c| input_vec[offset + c])
378                .fold(f32::NEG_INFINITY, f32::max);
379
380            let mut sum_exp = 0.0f32;
381            for c in 0..num_classes {
382                let exp_val = (input_vec[offset + c] - max_val).exp();
383                softmax_probs_vec[offset + c] = exp_val;
384                sum_exp += exp_val;
385            }
386
387            for c in 0..num_classes {
388                softmax_probs_vec[offset + c] /= sum_exp;
389            }
390
391            let log_sum_exp = max_val + sum_exp.ln();
392
393            let tc = target_vec[b] as usize;
394            target_classes[b] = tc;
395            losses[b] = log_sum_exp - input_vec[offset + tc];
396        }
397
398        let loss_tensor = Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
399        let softmax_tensor = Tensor::from_vec(softmax_probs_vec, &[batch_size, num_classes])
400            .expect("tensor creation failed");
401        let targets_f32: Vec<f32> = target_classes.iter().map(|&tc| tc as f32).collect();
402        let targets_tensor =
403            Tensor::from_vec(targets_f32, &[batch_size]).expect("tensor creation failed");
404
405        let loss_var = if input.requires_grad() {
406            let grad_fn = GradFn::new(CrossEntropyBackward {
407                next_fns: vec![input.grad_fn().cloned()],
408                softmax_probs: softmax_tensor,
409                targets: targets_tensor,
410                batch_size,
411                num_classes,
412            });
413            Variable::from_operation(loss_tensor, grad_fn, true)
414        } else {
415            Variable::new(loss_tensor, false)
416        };
417
418        match self.reduction {
419            Reduction::None => loss_var,
420            Reduction::Mean => loss_var.mean(),
421            Reduction::Sum => loss_var.sum(),
422        }
423    }
424}
425
426impl Default for CrossEntropyLoss {
427    fn default() -> Self {
428        Self::new()
429    }
430}
431
432// =============================================================================
433// NLLLoss
434// =============================================================================
435
436/// Negative Log Likelihood loss.
437///
438/// Expects input to be log-probabilities.
439#[derive(Debug, Clone, Copy)]
440pub struct NLLLoss {
441    reduction: Reduction,
442}
443
444impl NLLLoss {
445    /// Creates a new NLLLoss with default reduction (Mean).
446    pub fn new() -> Self {
447        Self {
448            reduction: Reduction::Mean,
449        }
450    }
451
452    /// Creates NLLLoss with specified reduction.
453    pub fn with_reduction(reduction: Reduction) -> Self {
454        Self { reduction }
455    }
456
457    /// Computes the loss.
458    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
459        let input_data = input.data();
460        let target_data = target.data();
461        let shape = input_data.shape().to_vec();
462        let batch_size = shape[0];
463        let num_classes = shape[1];
464
465        // NLL forward still needs per-sample gather (index into class dimension).
466        // We pull target indices to CPU for the gather but keep input on device.
467        let target_vec = target_data.to_vec();
468        let input_vec = input_data.to_vec();
469
470        let mut losses = vec![0.0f32; batch_size];
471        for b in 0..batch_size {
472            let tc = target_vec[b] as usize;
473            losses[b] = -input_vec[b * num_classes + tc];
474        }
475
476        let mut loss_tensor =
477            Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
478        if input_data.device().is_gpu() {
479            loss_tensor = loss_tensor.to_device(input_data.device()).unwrap();
480        }
481
482        let requires_grad = input.requires_grad() && is_grad_enabled();
483        let loss_var = if requires_grad {
484            let grad_fn = GradFn::new(NLLLossBackward {
485                next_fns: vec![input.grad_fn().cloned()],
486                target_tensor: target_data.clone(),
487                batch_size,
488                num_classes,
489            });
490            Variable::from_operation(loss_tensor, grad_fn, true)
491        } else {
492            Variable::new(loss_tensor, false)
493        };
494
495        match self.reduction {
496            Reduction::None => loss_var,
497            Reduction::Mean => loss_var.mean(),
498            Reduction::Sum => loss_var.sum(),
499        }
500    }
501}
502
503impl Default for NLLLoss {
504    fn default() -> Self {
505        Self::new()
506    }
507}
508
509// =============================================================================
510// NLLLossBackward
511// =============================================================================
512
513/// Gradient function for NLLLoss.
514///
515/// d/d(input)[b, c] = -1 if c == target[b], else 0
516///
517/// Stores targets as Tensor<f32> (GPU-resident when applicable).
518/// The scatter in backward still uses CPU indexing since it's a sparse write,
519/// but the result is moved to GPU if needed.
520#[derive(Debug)]
521struct NLLLossBackward {
522    next_fns: Vec<Option<GradFn>>,
523    target_tensor: Tensor<f32>,
524    batch_size: usize,
525    num_classes: usize,
526}
527
528impl GradientFunction for NLLLossBackward {
529    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
530        let grad_out_vec = grad_output.to_vec();
531        let target_vec = self.target_tensor.to_vec();
532        let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
533
534        for b in 0..self.batch_size {
535            let g = if grad_out_vec.len() == 1 {
536                grad_out_vec[0]
537            } else {
538                grad_out_vec[b]
539            };
540            let tc = target_vec[b] as usize;
541            grad_input[b * self.num_classes + tc] = -g;
542        }
543
544        let mut gi = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
545            .expect("tensor creation failed");
546        if grad_output.device().is_gpu() {
547            gi = gi.to_device(grad_output.device()).unwrap();
548        }
549        vec![Some(gi)]
550    }
551
552    fn name(&self) -> &'static str {
553        "NLLLossBackward"
554    }
555
556    fn next_functions(&self) -> &[Option<GradFn>] {
557        &self.next_fns
558    }
559
560    fn as_any(&self) -> &dyn Any {
561        self
562    }
563}
564
565// =============================================================================
566// BCELoss
567// =============================================================================
568
569/// Binary Cross Entropy loss.
570///
571/// Expects input to be probabilities in [0, 1].
572#[derive(Debug, Clone, Copy)]
573pub struct BCELoss {
574    reduction: Reduction,
575}
576
577impl BCELoss {
578    /// Creates a new BCELoss with default reduction (Mean).
579    pub fn new() -> Self {
580        Self {
581            reduction: Reduction::Mean,
582        }
583    }
584
585    /// Creates BCELoss with specified reduction.
586    pub fn with_reduction(reduction: Reduction) -> Self {
587        Self { reduction }
588    }
589
590    /// Computes the loss.
591    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
592        let input_data = input.data();
593        let target_data = target.data();
594
595        // Clamp predictions to [eps, 1-eps] using Tensor ops
596        let eps = 1e-7f32;
597        let p_clamped = axonml_tensor::ops::clamp(&input_data, eps, 1.0 - eps);
598
599        // loss = -(t * ln(p) + (1 - t) * ln(1 - p))
600        let ln_p = p_clamped.ln();
601        let one_minus_p = p_clamped.neg().add_scalar(1.0);
602        let ln_one_minus_p = one_minus_p.ln();
603        let one_minus_t = target_data.neg().add_scalar(1.0);
604
605        // t * ln(p)
606        let term1 = target_data.mul(&ln_p).expect("tensor mul failed");
607        // (1-t) * ln(1-p)
608        let term2 = one_minus_t.mul(&ln_one_minus_p).expect("tensor mul failed");
609        // -(term1 + term2)
610        let loss_tensor = term1.add(&term2).expect("tensor add failed").neg();
611
612        let requires_grad = input.requires_grad() && is_grad_enabled();
613        let loss_var = if requires_grad {
614            let grad_fn = GradFn::new(BCELossBackward {
615                next_fns: vec![input.grad_fn().cloned()],
616                input_tensor: input_data,
617                target_tensor: target_data,
618            });
619            Variable::from_operation(loss_tensor, grad_fn, true)
620        } else {
621            Variable::new(loss_tensor, false)
622        };
623
624        match self.reduction {
625            Reduction::None => loss_var,
626            Reduction::Mean => loss_var.mean(),
627            Reduction::Sum => loss_var.sum(),
628        }
629    }
630}
631
632impl Default for BCELoss {
633    fn default() -> Self {
634        Self::new()
635    }
636}
637
638// =============================================================================
639// BCELossBackward
640// =============================================================================
641
642/// Gradient function for BCELoss.
643///
644/// d/dp BCE = (p - y) / (p * (1 - p))
645///
646/// Stores input/target as Tensor<f32> (GPU-resident when applicable).
647#[derive(Debug)]
648struct BCELossBackward {
649    next_fns: Vec<Option<GradFn>>,
650    input_tensor: Tensor<f32>,
651    target_tensor: Tensor<f32>,
652}
653
654impl GradientFunction for BCELossBackward {
655    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
656        let eps = 1e-7f32;
657        // p_clamped = clamp(input, eps, 1-eps)
658        let p_clamped = axonml_tensor::ops::clamp(&self.input_tensor, eps, 1.0 - eps);
659        // (p - y)
660        let p_minus_y = p_clamped
661            .sub(&self.target_tensor)
662            .expect("tensor sub failed");
663        // p * (1 - p)
664        let one_minus_p = p_clamped.neg().add_scalar(1.0);
665        let denom = p_clamped.mul(&one_minus_p).expect("tensor mul failed");
666        // grad = grad_output * (p - y) / (p * (1 - p))
667        let ratio = p_minus_y.div(&denom).unwrap();
668        let grad_tensor = grad_output.mul(&ratio).expect("tensor mul failed");
669        vec![Some(grad_tensor)]
670    }
671
672    fn name(&self) -> &'static str {
673        "BCELossBackward"
674    }
675
676    fn next_functions(&self) -> &[Option<GradFn>] {
677        &self.next_fns
678    }
679
680    fn as_any(&self) -> &dyn Any {
681        self
682    }
683}
684
685// =============================================================================
686// BCEWithLogitsBackward
687// =============================================================================
688
689/// Gradient function for BCEWithLogitsLoss.
690///
691/// The gradient of BCE w.r.t. input logits is: sigmoid(input) - target.
692///
693/// Stores input/target as Tensor<f32> (GPU-resident when applicable).
694#[derive(Debug)]
695struct BCEWithLogitsBackward {
696    next_fns: Vec<Option<GradFn>>,
697    input_tensor: Tensor<f32>,
698    target_tensor: Tensor<f32>,
699}
700
701impl GradientFunction for BCEWithLogitsBackward {
702    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
703        // sigmoid(input) - target, all via Tensor ops (auto-dispatch to GPU)
704        let sig = self.input_tensor.sigmoid();
705        let sig_minus_t = sig.sub(&self.target_tensor).expect("tensor sub failed");
706        let grad_tensor = grad_output.mul(&sig_minus_t).expect("tensor mul failed");
707        vec![Some(grad_tensor)]
708    }
709
710    fn name(&self) -> &'static str {
711        "BCEWithLogitsBackward"
712    }
713
714    fn next_functions(&self) -> &[Option<GradFn>] {
715        &self.next_fns
716    }
717
718    fn as_any(&self) -> &dyn Any {
719        self
720    }
721}
722
723// =============================================================================
724// BCEWithLogitsLoss
725// =============================================================================
726
727/// Binary Cross Entropy with Logits.
728///
729/// Combines sigmoid and BCE in a numerically stable way.
730#[derive(Debug, Clone, Copy)]
731pub struct BCEWithLogitsLoss {
732    reduction: Reduction,
733}
734
735impl BCEWithLogitsLoss {
736    /// Creates a new BCEWithLogitsLoss with default reduction (Mean).
737    pub fn new() -> Self {
738        Self {
739            reduction: Reduction::Mean,
740        }
741    }
742
743    /// Creates BCEWithLogitsLoss with specified reduction.
744    pub fn with_reduction(reduction: Reduction) -> Self {
745        Self { reduction }
746    }
747
748    /// Computes the loss.
749    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
750        let input_data = input.data();
751        let target_data = target.data();
752
753        // Numerically stable: max(x, 0) - x*t + log(1 + exp(-|x|))
754        // max(x, 0) = relu(x) = clamp_min(x, 0)
755        let relu_x = axonml_tensor::ops::clamp_min(&input_data, 0.0);
756        // x * t
757        let x_times_t = input_data.mul(&target_data).expect("tensor mul failed");
758        // |x| via clamp trick: max(x, 0) + max(-x, 0) = relu(x) + relu(-x)
759        let neg_x = input_data.neg();
760        let relu_neg_x = axonml_tensor::ops::clamp_min(&neg_x, 0.0);
761        let abs_x = relu_x.add(&relu_neg_x).expect("tensor add failed");
762        // exp(-|x|)
763        let exp_neg_abs = abs_x.neg().exp();
764        // log(1 + exp(-|x|))
765        let log_term = exp_neg_abs.add_scalar(1.0).ln();
766        // loss = relu(x) - x*t + log(1 + exp(-|x|))
767        let loss_tensor = relu_x
768            .sub(&x_times_t)
769            .expect("tensor sub failed")
770            .add(&log_term)
771            .expect("tensor add failed");
772
773        let loss_var = if input.requires_grad() {
774            let grad_fn = GradFn::new(BCEWithLogitsBackward {
775                next_fns: vec![input.grad_fn().cloned()],
776                input_tensor: input_data,
777                target_tensor: target_data,
778            });
779            Variable::from_operation(loss_tensor, grad_fn, true)
780        } else {
781            Variable::new(loss_tensor, false)
782        };
783
784        match self.reduction {
785            Reduction::None => loss_var,
786            Reduction::Mean => loss_var.mean(),
787            Reduction::Sum => loss_var.sum(),
788        }
789    }
790}
791
792impl Default for BCEWithLogitsLoss {
793    fn default() -> Self {
794        Self::new()
795    }
796}
797
798// =============================================================================
799// SmoothL1Backward
800// =============================================================================
801
802/// Gradient function for SmoothL1Loss.
803///
804/// The gradient is: diff/beta if |diff| < beta, else sign(diff).
805/// Returns gradients for both input and target (negated for target).
806///
807/// Stores diff as Tensor<f32> (GPU-resident when applicable).
808#[derive(Debug)]
809struct SmoothL1Backward {
810    next_fns: Vec<Option<GradFn>>,
811    diff_tensor: Tensor<f32>,
812    beta: f32,
813    shape: Vec<usize>,
814}
815
816impl GradientFunction for SmoothL1Backward {
817    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
818        // Compute |diff| = sqrt(diff^2 + eps) to stay differentiable and device-agnostic
819        let eps = 1e-12f32;
820        let diff_sq = self
821            .diff_tensor
822            .mul(&self.diff_tensor)
823            .expect("tensor mul failed");
824        let diff_sq_eps = diff_sq.add_scalar(eps);
825        let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
826
827        // sign(diff) = diff / |diff|
828        let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
829
830        // For the L2 region (|diff| < beta): grad = diff / beta
831        // For the L1 region (|diff| >= beta): grad = sign(diff)
832        // Blend: mask = clamp(|diff| / beta, 0, 1), but we need a hard cutoff.
833        // Use a smooth approximation via where: if |diff| < beta -> diff/beta, else sign
834        // Since we need element-wise branching and don't have a GPU where_cond,
835        // we use: grad = (1 - mask) * (diff / beta) + mask * sign(diff)
836        // where mask = clamp((|diff| - beta) * large_value, 0, 1) approximates step function.
837        // Actually simpler: mask_l2 = clamp(1 - |diff|/beta, 0, 1) gives 1 in L2 region, 0 in L1
838        // BUT this gives a soft transition. For exact correctness, use CPU branching on the mask.
839        //
840        // Practical approach: compute both branches with Tensor ops, build mask on CPU, blend.
841        let grad_l2 = self.diff_tensor.mul_scalar(1.0 / self.beta); // diff / beta
842        let grad_l1 = sign_diff; // sign(diff)
843
844        // Build mask tensor: 1.0 where |d| < beta, 0.0 otherwise
845        let abs_vec = abs_diff.to_vec();
846        let beta = self.beta;
847        let mask_vec: Vec<f32> = abs_vec
848            .iter()
849            .map(|&a| if a < beta { 1.0 } else { 0.0 })
850            .collect();
851        let mut mask = Tensor::from_vec(mask_vec, &self.shape).expect("tensor creation failed");
852        if self.diff_tensor.device().is_gpu() {
853            mask = mask.to_device(self.diff_tensor.device()).unwrap();
854        }
855        let inv_mask = mask.neg().add_scalar(1.0);
856
857        // grad_per_elem = mask * grad_l2 + (1 - mask) * grad_l1
858        let blended = mask
859            .mul(&grad_l2)
860            .unwrap()
861            .add(&inv_mask.mul(&grad_l1).expect("tensor add failed"))
862            .unwrap();
863
864        // gi = blended * grad_output
865        let gi = blended.mul(grad_output).unwrap();
866        let gt = gi.neg();
867        vec![Some(gi), Some(gt)]
868    }
869
870    fn name(&self) -> &'static str {
871        "SmoothL1Backward"
872    }
873
874    fn next_functions(&self) -> &[Option<GradFn>] {
875        &self.next_fns
876    }
877
878    fn as_any(&self) -> &dyn Any {
879        self
880    }
881}
882
883// =============================================================================
884// SmoothL1Loss
885// =============================================================================
886
887/// Smooth L1 loss (Huber loss).
888///
889/// Uses L2 loss when |x| < beta, L1 loss otherwise.
890#[derive(Debug, Clone, Copy)]
891pub struct SmoothL1Loss {
892    reduction: Reduction,
893    beta: f32,
894}
895
896impl SmoothL1Loss {
897    /// Creates a new SmoothL1Loss with default beta (1.0).
898    pub fn new() -> Self {
899        Self {
900            reduction: Reduction::Mean,
901            beta: 1.0,
902        }
903    }
904
905    /// Creates SmoothL1Loss with specified beta.
906    pub fn with_beta(beta: f32) -> Self {
907        Self {
908            reduction: Reduction::Mean,
909            beta,
910        }
911    }
912
913    /// Computes the loss.
914    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
915        let input_data = input.data();
916        let target_data = target.data();
917        let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
918        let shape = diff_tensor.shape().to_vec();
919
920        // Compute |diff| via relu(diff) + relu(-diff)
921        let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
922        let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
923        let abs_diff = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
924
925        // L2 branch: 0.5 * diff^2 / beta
926        let diff_sq = diff_tensor.mul(&diff_tensor).expect("tensor mul failed");
927        let l2_loss = diff_sq.mul_scalar(0.5 / self.beta);
928
929        // L1 branch: |diff| - 0.5 * beta
930        let l1_loss = abs_diff.add_scalar(-0.5 * self.beta);
931
932        // Build mask: 1.0 where |diff| < beta, 0.0 otherwise
933        let abs_vec = abs_diff.to_vec();
934        let beta = self.beta;
935        let mask_vec: Vec<f32> = abs_vec
936            .iter()
937            .map(|&a| if a < beta { 1.0 } else { 0.0 })
938            .collect();
939        let mut mask = Tensor::from_vec(mask_vec, &shape).expect("tensor creation failed");
940        if diff_tensor.device().is_gpu() {
941            mask = mask.to_device(diff_tensor.device()).unwrap();
942        }
943        let inv_mask = mask.neg().add_scalar(1.0);
944
945        // loss = mask * l2_loss + (1-mask) * l1_loss
946        let loss_tensor = mask
947            .mul(&l2_loss)
948            .unwrap()
949            .add(&inv_mask.mul(&l1_loss).expect("tensor add failed"))
950            .unwrap();
951
952        let loss_var = if input.requires_grad() || target.requires_grad() {
953            let grad_fn = GradFn::new(SmoothL1Backward {
954                next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
955                diff_tensor,
956                beta: self.beta,
957                shape,
958            });
959            Variable::from_operation(loss_tensor, grad_fn, true)
960        } else {
961            Variable::new(loss_tensor, false)
962        };
963
964        match self.reduction {
965            Reduction::None => loss_var,
966            Reduction::Mean => loss_var.mean(),
967            Reduction::Sum => loss_var.sum(),
968        }
969    }
970}
971
972impl Default for SmoothL1Loss {
973    fn default() -> Self {
974        Self::new()
975    }
976}
977
978// =============================================================================
979// Tests
980// =============================================================================
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985
986    #[test]
987    fn test_mse_loss() {
988        let loss_fn = MSELoss::new();
989        let input = Variable::new(
990            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
991            false,
992        );
993        let target = Variable::new(
994            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
995            false,
996        );
997        let loss = loss_fn.compute(&input, &target);
998        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
999    }
1000
1001    #[test]
1002    fn test_mse_loss_nonzero() {
1003        let loss_fn = MSELoss::new();
1004        let input = Variable::new(
1005            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1006            false,
1007        );
1008        let target = Variable::new(
1009            Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).expect("tensor creation failed"),
1010            false,
1011        );
1012        let loss = loss_fn.compute(&input, &target);
1013        // Each diff is 1.0, squared is 1.0, mean is 1.0
1014        assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
1015    }
1016
1017    #[test]
1018    fn test_cross_entropy_loss() {
1019        let loss_fn = CrossEntropyLoss::new();
1020        let input = Variable::new(
1021            Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3])
1022                .expect("tensor creation failed"),
1023            false,
1024        );
1025        let target = Variable::new(
1026            Tensor::from_vec(vec![2.0, 0.0], &[2]).expect("tensor creation failed"),
1027            false,
1028        );
1029        let loss = loss_fn.compute(&input, &target);
1030        assert!(loss.data().to_vec()[0] > 0.0);
1031    }
1032
1033    #[test]
1034    fn test_bce_loss() {
1035        let loss_fn = BCELoss::new();
1036        let input = Variable::new(
1037            Tensor::from_vec(vec![0.5, 0.5], &[2]).expect("tensor creation failed"),
1038            false,
1039        );
1040        let target = Variable::new(
1041            Tensor::from_vec(vec![1.0, 0.0], &[2]).expect("tensor creation failed"),
1042            false,
1043        );
1044        let loss = loss_fn.compute(&input, &target);
1045        // -[1*ln(0.5) + 0*ln(0.5)] - [0*ln(0.5) + 1*ln(0.5)] = -2*ln(0.5) / 2 = -ln(0.5) = 0.693
1046        assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1047    }
1048
1049    #[test]
1050    fn test_cross_entropy_gradient_flow() {
1051        use axonml_autograd::backward;
1052
1053        // Create input logits with requires_grad=true
1054        let input = Variable::new(
1055            Tensor::from_vec(vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3], &[2, 3])
1056                .expect("tensor creation failed"),
1057            true,
1058        );
1059        let target = Variable::new(
1060            Tensor::from_vec(vec![0.0, 1.0], &[2]).expect("tensor creation failed"),
1061            false,
1062        );
1063
1064        let loss_fn = CrossEntropyLoss::new();
1065        let loss = loss_fn.compute(&input, &target);
1066
1067        // Loss should be positive
1068        let loss_val = loss.data().to_vec()[0];
1069        assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
1070
1071        // Backward pass
1072        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1073        backward(&loss, &ones);
1074
1075        // Input should have gradient
1076        let grad = input
1077            .grad()
1078            .expect("Input should have gradient after backward");
1079        let grad_vec = grad.to_vec();
1080
1081        // Gradient should be non-zero
1082        let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1083        assert!(
1084            grad_norm > 1e-10,
1085            "Gradient should be non-zero, got norm {}",
1086            grad_norm
1087        );
1088
1089        // Gradient shape should match input shape
1090        assert_eq!(grad.shape(), &[2, 3]);
1091
1092        // For the correct class, gradient should be negative (softmax - 1 < 0)
1093        // Sample 0, class 0 (target): grad should be (softmax[0,0] - 1) / 2
1094        assert!(
1095            grad_vec[0] < 0.0,
1096            "Gradient for correct class should be negative"
1097        );
1098        // Sample 1, class 1 (target): grad should be (softmax[1,1] - 1) / 2
1099        assert!(
1100            grad_vec[4] < 0.0,
1101            "Gradient for correct class should be negative"
1102        );
1103
1104        // For wrong classes, gradient should be positive (softmax > 0)
1105        assert!(
1106            grad_vec[1] > 0.0,
1107            "Gradient for wrong class should be positive"
1108        );
1109        assert!(
1110            grad_vec[2] > 0.0,
1111            "Gradient for wrong class should be positive"
1112        );
1113    }
1114
1115    #[test]
1116    fn test_cross_entropy_perfect_prediction() {
1117        // When logits strongly favor the correct class, loss should be near zero
1118        let loss_fn = CrossEntropyLoss::new();
1119        let input = Variable::new(
1120            Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).expect("tensor creation failed"),
1121            false,
1122        );
1123        let target = Variable::new(
1124            Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1125            false,
1126        );
1127        let loss = loss_fn.compute(&input, &target);
1128        assert!(
1129            loss.data().to_vec()[0] < 0.001,
1130            "Perfect prediction should have near-zero loss"
1131        );
1132    }
1133
1134    #[test]
1135    fn test_cross_entropy_uniform_prediction() {
1136        // When logits are all equal, loss should be ln(num_classes)
1137        let loss_fn = CrossEntropyLoss::new();
1138        let num_classes = 16;
1139        let input = Variable::new(
1140            Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes])
1141                .expect("tensor creation failed"),
1142            false,
1143        );
1144        let target = Variable::new(
1145            Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1146            false,
1147        );
1148        let loss = loss_fn.compute(&input, &target);
1149        let expected = (num_classes as f32).ln(); // ln(16) ≈ 2.77
1150        let actual = loss.data().to_vec()[0];
1151        assert!(
1152            (actual - expected).abs() < 0.01,
1153            "Uniform logits should give ln(C)={}, got {}",
1154            expected,
1155            actual,
1156        );
1157    }
1158
1159    #[test]
1160    fn test_bce_with_logits_gradient_flow() {
1161        use axonml_autograd::backward;
1162
1163        let input = Variable::new(
1164            Tensor::from_vec(vec![0.5, -0.5, 1.0, -1.0], &[4]).expect("tensor creation failed"),
1165            true,
1166        );
1167        let target = Variable::new(
1168            Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).expect("tensor creation failed"),
1169            false,
1170        );
1171
1172        let loss_fn = BCEWithLogitsLoss::new();
1173        let loss = loss_fn.compute(&input, &target);
1174        assert!(loss.data().to_vec()[0] > 0.0);
1175
1176        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1177        backward(&loss, &ones);
1178
1179        let grad = input.grad().expect("Input should have gradient");
1180        let grad_vec = grad.to_vec();
1181        assert_eq!(grad_vec.len(), 4);
1182
1183        // For target=1, grad = sigmoid(x) - 1 < 0
1184        assert!(grad_vec[0] < 0.0);
1185        // For target=0, grad = sigmoid(x) > 0
1186        assert!(grad_vec[1] > 0.0);
1187    }
1188
1189    #[test]
1190    fn test_smooth_l1_gradient_flow() {
1191        use axonml_autograd::backward;
1192
1193        let input = Variable::new(
1194            Tensor::from_vec(vec![1.0, 2.0, 5.0], &[3]).expect("tensor creation failed"),
1195            true,
1196        );
1197        let target = Variable::new(
1198            Tensor::from_vec(vec![1.5, 1.5, 1.5], &[3]).expect("tensor creation failed"),
1199            false,
1200        );
1201
1202        let loss_fn = SmoothL1Loss::new();
1203        let loss = loss_fn.compute(&input, &target);
1204        assert!(loss.data().to_vec()[0] > 0.0);
1205
1206        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1207        backward(&loss, &ones);
1208
1209        let grad = input.grad().expect("Input should have gradient");
1210        let grad_vec = grad.to_vec();
1211        assert_eq!(grad_vec.len(), 3);
1212        // Gradients should be non-zero
1213        let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1214        assert!(grad_norm > 1e-10);
1215    }
1216
1217    // =========================================================================
1218    // MSE Loss Comprehensive
1219    // =========================================================================
1220
1221    #[test]
1222    fn test_mse_loss_gradient_correctness() {
1223        use axonml_autograd::backward;
1224
1225        // MSE gradient = 2*(input - target) / N
1226        let input = Variable::new(Tensor::from_vec(vec![3.0, 1.0], &[2]).unwrap(), true);
1227        let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1228
1229        let loss = MSELoss::new().compute(&input, &target);
1230        // loss = ((3-1)^2 + (1-1)^2) / 2 = 4/2 = 2.0
1231        assert!(
1232            (loss.data().to_vec()[0] - 2.0).abs() < 1e-5,
1233            "MSE should be 2.0"
1234        );
1235
1236        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1237        backward(&loss, &ones);
1238
1239        let grad = input.grad().expect("Should have gradient");
1240        let gv = grad.to_vec();
1241        // dL/dx = 2*(x-t)/N = 2*(3-1)/2 = 2.0 for first, 0.0 for second
1242        assert!(
1243            (gv[0] - 2.0).abs() < 0.1,
1244            "Grad[0] should be ~2.0, got {}",
1245            gv[0]
1246        );
1247        assert!(gv[1].abs() < 0.1, "Grad[1] should be ~0.0, got {}", gv[1]);
1248    }
1249
1250    #[test]
1251    fn test_mse_loss_reduction_sum() {
1252        let input = Variable::new(Tensor::from_vec(vec![2.0, 4.0], &[2]).unwrap(), false);
1253        let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1254        let loss = MSELoss::with_reduction(Reduction::Sum).compute(&input, &target);
1255        // sum = (2-1)^2 + (4-1)^2 = 1 + 9 = 10
1256        assert!((loss.data().to_vec()[0] - 10.0).abs() < 1e-5);
1257    }
1258
1259    // =========================================================================
1260    // L1 Loss
1261    // =========================================================================
1262
1263    #[test]
1264    fn test_l1_loss_basic() {
1265        let input = Variable::new(Tensor::from_vec(vec![1.0, 5.0, 3.0], &[3]).unwrap(), false);
1266        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 4.0], &[3]).unwrap(), false);
1267        let loss = L1Loss::new().compute(&input, &target);
1268        // mean(|0| + |3| + |-1|) = 4/3 ≈ 1.333
1269        assert!((loss.data().to_vec()[0] - 4.0 / 3.0).abs() < 1e-4);
1270    }
1271
1272    #[test]
1273    fn test_l1_loss_zero() {
1274        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1275        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1276        let loss = L1Loss::new().compute(&input, &target);
1277        assert!(
1278            loss.data().to_vec()[0].abs() < 1e-6,
1279            "Identical inputs should give 0 loss"
1280        );
1281    }
1282
1283    // =========================================================================
1284    // BCE Loss Edge Cases
1285    // =========================================================================
1286
1287    #[test]
1288    fn test_bce_loss_perfect_prediction() {
1289        let loss_fn = BCELoss::new();
1290        // Near-perfect predictions
1291        let input = Variable::new(Tensor::from_vec(vec![0.999, 0.001], &[2]).unwrap(), false);
1292        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1293        let loss = loss_fn.compute(&input, &target);
1294        assert!(
1295            loss.data().to_vec()[0] < 0.01,
1296            "Perfect prediction should have near-zero loss"
1297        );
1298    }
1299
1300    #[test]
1301    fn test_bce_loss_worst_prediction() {
1302        let loss_fn = BCELoss::new();
1303        // Worst predictions (inverted)
1304        let input = Variable::new(Tensor::from_vec(vec![0.001, 0.999], &[2]).unwrap(), false);
1305        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1306        let loss = loss_fn.compute(&input, &target);
1307        // Should be high
1308        assert!(
1309            loss.data().to_vec()[0] > 3.0,
1310            "Worst prediction should have high loss"
1311        );
1312    }
1313
1314    // =========================================================================
1315    // BCEWithLogitsLoss Comprehensive
1316    // =========================================================================
1317
1318    #[test]
1319    fn test_bce_with_logits_numerical_stability() {
1320        let loss_fn = BCEWithLogitsLoss::new();
1321        // Very large logits should not produce NaN/Inf
1322        let input = Variable::new(
1323            Tensor::from_vec(vec![100.0, -100.0, 50.0, -50.0], &[4]).unwrap(),
1324            false,
1325        );
1326        let target = Variable::new(
1327            Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).unwrap(),
1328            false,
1329        );
1330        let loss = loss_fn.compute(&input, &target);
1331        let val = loss.data().to_vec()[0];
1332        assert!(
1333            val.is_finite(),
1334            "Loss should be finite for large logits, got {}",
1335            val
1336        );
1337        assert!(val >= 0.0, "BCE loss should be non-negative");
1338    }
1339
1340    #[test]
1341    fn test_bce_with_logits_zero_logits() {
1342        let loss_fn = BCEWithLogitsLoss::new();
1343        // Zero logits → sigmoid(0) = 0.5 → random prediction
1344        let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1345        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1346        let loss = loss_fn.compute(&input, &target);
1347        // Should be ln(2) ≈ 0.693
1348        assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1349    }
1350
1351    #[test]
1352    fn test_bce_with_logits_reduction_none() {
1353        let loss_fn = BCEWithLogitsLoss::with_reduction(Reduction::None);
1354        let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0, 0.0], &[3]).unwrap(), false);
1355        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0, 1.0], &[3]).unwrap(), false);
1356        let loss = loss_fn.compute(&input, &target);
1357        // Should return per-element losses, not reduced
1358        assert_eq!(loss.shape().len(), 1);
1359        assert_eq!(loss.shape()[0], 3);
1360    }
1361
1362    // =========================================================================
1363    // SmoothL1 Loss
1364    // =========================================================================
1365
1366    #[test]
1367    fn test_smooth_l1_small_error() {
1368        // For |diff| < beta=1.0: loss = 0.5 * diff^2 / beta
1369        let loss_fn = SmoothL1Loss::new();
1370        let input = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), false);
1371        let target = Variable::new(Tensor::from_vec(vec![1.3], &[1]).unwrap(), false);
1372        let loss = loss_fn.compute(&input, &target);
1373        // diff=0.3, |0.3| < 1.0, loss = 0.5 * 0.09 / 1.0 = 0.045
1374        assert!((loss.data().to_vec()[0] - 0.045).abs() < 0.01);
1375    }
1376
1377    #[test]
1378    fn test_smooth_l1_large_error() {
1379        // For |diff| >= beta=1.0: loss = |diff| - 0.5*beta
1380        let loss_fn = SmoothL1Loss::new();
1381        let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1382        let target = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), false);
1383        let loss = loss_fn.compute(&input, &target);
1384        // diff=5.0, |5| >= 1.0, loss = 5.0 - 0.5 = 4.5
1385        assert!((loss.data().to_vec()[0] - 4.5).abs() < 0.1);
1386    }
1387
1388    // =========================================================================
1389    // Cross-Entropy Batch Consistency
1390    // =========================================================================
1391
1392    #[test]
1393    fn test_cross_entropy_batch_independence() {
1394        let loss_fn = CrossEntropyLoss::new();
1395
1396        // Single sample
1397        let input1 = Variable::new(
1398            Tensor::from_vec(vec![2.0, 1.0, 0.1], &[1, 3]).unwrap(),
1399            false,
1400        );
1401        let target1 = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1402        let loss1 = loss_fn.compute(&input1, &target1).data().to_vec()[0];
1403
1404        // Same sample duplicated in batch
1405        let input2 = Variable::new(
1406            Tensor::from_vec(vec![2.0, 1.0, 0.1, 2.0, 1.0, 0.1], &[2, 3]).unwrap(),
1407            false,
1408        );
1409        let target2 = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1410        let loss2 = loss_fn.compute(&input2, &target2).data().to_vec()[0];
1411
1412        // Mean reduction should give same result for duplicated batch
1413        assert!(
1414            (loss1 - loss2).abs() < 1e-5,
1415            "Duplicated batch should give same loss: {} vs {}",
1416            loss1,
1417            loss2
1418        );
1419    }
1420
1421    #[test]
1422    fn test_cross_entropy_high_class_count() {
1423        // Test with many classes (like BirdCLEF 234 species)
1424        let n_classes = 100;
1425        let mut logits = vec![0.0f32; n_classes];
1426        logits[42] = 5.0; // Correct class has high logit
1427
1428        let loss_fn = CrossEntropyLoss::new();
1429        let input = Variable::new(Tensor::from_vec(logits, &[1, n_classes]).unwrap(), false);
1430        let target = Variable::new(Tensor::from_vec(vec![42.0], &[1]).unwrap(), false);
1431        let loss = loss_fn.compute(&input, &target);
1432        let val = loss.data().to_vec()[0];
1433        assert!(val.is_finite(), "Should handle 100 classes");
1434        assert!(val < 1.0, "Correct class should have low loss, got {}", val);
1435    }
1436}