Skip to main content

axonml_nn/
loss.rs

1//! Loss Functions - Training Objectives
2//!
3//! # File
4//! `crates/axonml-nn/src/loss.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::any::Any;
18
19use axonml_autograd::no_grad::is_grad_enabled;
20use axonml_autograd::{GradFn, GradientFunction, Variable};
21use axonml_tensor::Tensor;
22
23use crate::module::Module;
24
25// =============================================================================
26// Reduction Enum
27// =============================================================================
28
29/// Specifies how to reduce the loss over elements.
30#[derive(Debug, Clone, Copy, PartialEq, Default)]
31pub enum Reduction {
32    /// No reduction - return loss per element.
33    None,
34    /// Mean of all losses.
35    #[default]
36    Mean,
37    /// Sum of all losses.
38    Sum,
39}
40
41// =============================================================================
42// MSELoss
43// =============================================================================
44
45/// Mean Squared Error loss.
46///
47/// loss = mean((input - target)^2)
48#[derive(Debug, Clone, Copy)]
49pub struct MSELoss {
50    reduction: Reduction,
51}
52
53impl MSELoss {
54    /// Creates a new MSELoss with default reduction (Mean).
55    pub fn new() -> Self {
56        Self {
57            reduction: Reduction::Mean,
58        }
59    }
60
61    /// Creates MSELoss with specified reduction.
62    pub fn with_reduction(reduction: Reduction) -> Self {
63        Self { reduction }
64    }
65
66    /// Computes the loss.
67    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
68        let diff = input.sub_var(target);
69        let squared = diff.pow(2.0);
70
71        match self.reduction {
72            Reduction::None => squared,
73            Reduction::Mean => squared.mean(),
74            Reduction::Sum => squared.sum(),
75        }
76    }
77}
78
79impl Default for MSELoss {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl Module for MSELoss {
86    fn forward(&self, input: &Variable) -> Variable {
87        // For Module interface, we can't easily pass two inputs
88        // This is primarily used via compute() method
89        input.clone()
90    }
91
92    fn name(&self) -> &'static str {
93        "MSELoss"
94    }
95}
96
97// =============================================================================
98// L1Loss
99// =============================================================================
100
101/// Mean Absolute Error loss.
102///
103/// loss = mean(|input - target|)
104#[derive(Debug, Clone, Copy)]
105pub struct L1Loss {
106    reduction: Reduction,
107}
108
109impl L1Loss {
110    /// Creates a new L1Loss with default reduction (Mean).
111    pub fn new() -> Self {
112        Self {
113            reduction: Reduction::Mean,
114        }
115    }
116
117    /// Creates L1Loss with specified reduction.
118    pub fn with_reduction(reduction: Reduction) -> Self {
119        Self { reduction }
120    }
121
122    /// Computes the loss.
123    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
124        let input_data = input.data();
125        let target_data = target.data();
126        // diff = input - target (Tensor op, auto-dispatches to GPU)
127        let diff_tensor = input_data.sub(&target_data).unwrap();
128        // |diff| = relu(diff) + relu(-diff), using Tensor ops that auto-dispatch to GPU
129        let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
130        let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
131        let abs_tensor = relu_diff.add(&relu_neg_diff).unwrap();
132
133        let requires_grad = (input.requires_grad() || target.requires_grad()) && is_grad_enabled();
134        let loss_var = if requires_grad {
135            let grad_fn = GradFn::new(L1LossBackward {
136                next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
137                diff_tensor,
138            });
139            Variable::from_operation(abs_tensor, grad_fn, true)
140        } else {
141            Variable::new(abs_tensor, false)
142        };
143
144        match self.reduction {
145            Reduction::None => loss_var,
146            Reduction::Mean => loss_var.mean(),
147            Reduction::Sum => loss_var.sum(),
148        }
149    }
150}
151
152impl Default for L1Loss {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158// =============================================================================
159// L1LossBackward
160// =============================================================================
161
162/// Gradient function for L1Loss.
163///
164/// d/d(input) = sign(input - target)
165/// d/d(target) = -sign(input - target)
166///
167/// Stores diff as Tensor<f32> so it stays on GPU when applicable.
168#[derive(Debug)]
169struct L1LossBackward {
170    next_fns: Vec<Option<GradFn>>,
171    diff_tensor: Tensor<f32>,
172}
173
174impl GradientFunction for L1LossBackward {
175    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
176        // sign(diff): +1 where diff > 0, -1 where diff < 0, 0 where diff == 0
177        // Compute as: diff / (|diff| + eps) which gives sign and handles GPU
178        let eps_tensor = Tensor::full(self.diff_tensor.shape(), 1e-12);
179        let eps_on_device = if self.diff_tensor.device().is_gpu() {
180            eps_tensor.to_device(self.diff_tensor.device()).unwrap()
181        } else {
182            eps_tensor
183        };
184        // |diff| approximated as sqrt(diff^2 + eps)  — but simpler: diff * diff then sqrt
185        let diff_sq = self.diff_tensor.mul(&self.diff_tensor).unwrap();
186        let diff_sq_eps = diff_sq.add(&eps_on_device).unwrap();
187        // sqrt via exp(0.5 * ln(x))
188        let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
189        let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
190
191        // grad_input = sign(diff) * grad_output
192        let gi = sign_diff.mul(grad_output).unwrap();
193        // grad_target = -grad_input
194        let gt = gi.neg();
195        vec![Some(gi), Some(gt)]
196    }
197
198    fn name(&self) -> &'static str {
199        "L1LossBackward"
200    }
201
202    fn next_functions(&self) -> &[Option<GradFn>] {
203        &self.next_fns
204    }
205
206    fn as_any(&self) -> &dyn Any {
207        self
208    }
209}
210
211// =============================================================================
212// CrossEntropyBackward
213// =============================================================================
214
215/// Gradient function for CrossEntropyLoss.
216///
217/// The gradient of CE w.r.t. logits is: softmax(logits) - one_hot(target).
218/// For per-sample losses, each sample's gradient is scaled by the upstream
219/// gradient (from reduction).
220#[derive(Debug)]
221struct CrossEntropyBackward {
222    next_fns: Vec<Option<GradFn>>,
223    /// Softmax probabilities computed during forward pass, shape (N, C).
224    /// Stays on GPU if input was on GPU.
225    softmax_probs: Tensor<f32>,
226    /// Target class indices as f32, shape (N,). Stays on GPU if input was on GPU.
227    targets: Tensor<f32>,
228    batch_size: usize,
229    num_classes: usize,
230}
231
232impl GradientFunction for CrossEntropyBackward {
233    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
234        // GPU fast path: use CUDA cross_entropy_bwd kernel
235        #[cfg(feature = "cuda")]
236        if self.softmax_probs.device().is_gpu() {
237            let grad_out_gpu = if grad_output.device().is_gpu() {
238                grad_output.clone()
239            } else {
240                grad_output.to_device(self.softmax_probs.device()).unwrap()
241            };
242            let grad_tensor = self
243                .softmax_probs
244                .cross_entropy_bwd_cuda(&self.targets, &grad_out_gpu);
245            return vec![Some(grad_tensor)];
246        }
247
248        // CPU path
249        let softmax_vec = self.softmax_probs.to_vec();
250        let target_vec = self.targets.to_vec();
251        let grad_vec = grad_output.to_vec();
252        let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
253
254        for b in 0..self.batch_size {
255            let grad_scale = grad_vec[b];
256            let offset = b * self.num_classes;
257            let tc = target_vec[b] as usize;
258            for c in 0..self.num_classes {
259                let mut g = softmax_vec[offset + c];
260                if c == tc {
261                    g -= 1.0;
262                }
263                grad_input[offset + c] = g * grad_scale;
264            }
265        }
266
267        let mut grad_tensor =
268            Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes]).unwrap();
269        if grad_output.device().is_gpu() {
270            grad_tensor = grad_tensor.to_device(grad_output.device()).unwrap();
271        }
272        vec![Some(grad_tensor)]
273    }
274
275    fn name(&self) -> &'static str {
276        "CrossEntropyBackward"
277    }
278
279    fn next_functions(&self) -> &[Option<GradFn>] {
280        &self.next_fns
281    }
282
283    fn as_any(&self) -> &dyn Any {
284        self
285    }
286}
287
288// =============================================================================
289// CrossEntropyLoss
290// =============================================================================
291
292/// Cross entropy loss with log softmax.
293///
294/// This combines LogSoftmax and NLLLoss in a single class.
295///
296/// # Shape
297/// - Input: (N, C) where C = number of classes
298/// - Target: (N,) with class indices
299#[derive(Debug, Clone, Copy)]
300pub struct CrossEntropyLoss {
301    reduction: Reduction,
302}
303
304impl CrossEntropyLoss {
305    /// Creates a new CrossEntropyLoss with default reduction (Mean).
306    pub fn new() -> Self {
307        Self {
308            reduction: Reduction::Mean,
309        }
310    }
311
312    /// Creates CrossEntropyLoss with specified reduction.
313    pub fn with_reduction(reduction: Reduction) -> Self {
314        Self { reduction }
315    }
316
317    /// Computes the loss.
318    ///
319    /// # Arguments
320    /// * `input` - Logits of shape (N, C)
321    /// * `target` - Class indices of shape (N,) as f32 (will be cast to usize)
322    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
323        let input_data = input.data();
324        let target_data = target.data();
325        let shape = input_data.shape().to_vec();
326        let batch_size = shape[0];
327        let num_classes = shape[1];
328
329        // GPU fast path: fused softmax + NLL loss kernel
330        #[cfg(feature = "cuda")]
331        if input_data.device().is_gpu() {
332            // Ensure targets are on GPU
333            let targets_gpu = if target_data.device().is_gpu() {
334                target_data.clone()
335            } else {
336                target_data.to_device(input_data.device()).unwrap()
337            };
338
339            let (loss_tensor, softmax_tensor) = input_data.cross_entropy_fwd_cuda(&targets_gpu);
340
341            let loss_var = if input.requires_grad() {
342                let grad_fn = GradFn::new(CrossEntropyBackward {
343                    next_fns: vec![input.grad_fn().cloned()],
344                    softmax_probs: softmax_tensor,
345                    targets: targets_gpu,
346                    batch_size,
347                    num_classes,
348                });
349                Variable::from_operation(loss_tensor, grad_fn, true)
350            } else {
351                Variable::new(loss_tensor, false)
352            };
353
354            return match self.reduction {
355                Reduction::None => loss_var,
356                Reduction::Mean => loss_var.mean(),
357                Reduction::Sum => loss_var.sum(),
358            };
359        }
360
361        // CPU path
362        let input_vec = input_data.to_vec();
363        let target_vec = target_data.to_vec();
364
365        let mut losses = vec![0.0f32; batch_size];
366        let mut softmax_probs_vec = vec![0.0f32; batch_size * num_classes];
367        let mut target_classes = vec![0usize; batch_size];
368
369        for b in 0..batch_size {
370            let offset = b * num_classes;
371
372            // Numerically stable log-softmax
373            let max_val = (0..num_classes)
374                .map(|c| input_vec[offset + c])
375                .fold(f32::NEG_INFINITY, f32::max);
376
377            let mut sum_exp = 0.0f32;
378            for c in 0..num_classes {
379                let exp_val = (input_vec[offset + c] - max_val).exp();
380                softmax_probs_vec[offset + c] = exp_val;
381                sum_exp += exp_val;
382            }
383
384            for c in 0..num_classes {
385                softmax_probs_vec[offset + c] /= sum_exp;
386            }
387
388            let log_sum_exp = max_val + sum_exp.ln();
389
390            let tc = target_vec[b] as usize;
391            target_classes[b] = tc;
392            losses[b] = log_sum_exp - input_vec[offset + tc];
393        }
394
395        let loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
396        let softmax_tensor =
397            Tensor::from_vec(softmax_probs_vec, &[batch_size, num_classes]).unwrap();
398        let targets_f32: Vec<f32> = target_classes.iter().map(|&tc| tc as f32).collect();
399        let targets_tensor = Tensor::from_vec(targets_f32, &[batch_size]).unwrap();
400
401        let loss_var = if input.requires_grad() {
402            let grad_fn = GradFn::new(CrossEntropyBackward {
403                next_fns: vec![input.grad_fn().cloned()],
404                softmax_probs: softmax_tensor,
405                targets: targets_tensor,
406                batch_size,
407                num_classes,
408            });
409            Variable::from_operation(loss_tensor, grad_fn, true)
410        } else {
411            Variable::new(loss_tensor, false)
412        };
413
414        match self.reduction {
415            Reduction::None => loss_var,
416            Reduction::Mean => loss_var.mean(),
417            Reduction::Sum => loss_var.sum(),
418        }
419    }
420}
421
422impl Default for CrossEntropyLoss {
423    fn default() -> Self {
424        Self::new()
425    }
426}
427
428// =============================================================================
429// NLLLoss
430// =============================================================================
431
432/// Negative Log Likelihood loss.
433///
434/// Expects input to be log-probabilities.
435#[derive(Debug, Clone, Copy)]
436pub struct NLLLoss {
437    reduction: Reduction,
438}
439
440impl NLLLoss {
441    /// Creates a new NLLLoss with default reduction (Mean).
442    pub fn new() -> Self {
443        Self {
444            reduction: Reduction::Mean,
445        }
446    }
447
448    /// Creates NLLLoss with specified reduction.
449    pub fn with_reduction(reduction: Reduction) -> Self {
450        Self { reduction }
451    }
452
453    /// Computes the loss.
454    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
455        let input_data = input.data();
456        let target_data = target.data();
457        let shape = input_data.shape().to_vec();
458        let batch_size = shape[0];
459        let num_classes = shape[1];
460
461        // NLL forward still needs per-sample gather (index into class dimension).
462        // We pull target indices to CPU for the gather but keep input on device.
463        let target_vec = target_data.to_vec();
464        let input_vec = input_data.to_vec();
465
466        let mut losses = vec![0.0f32; batch_size];
467        for b in 0..batch_size {
468            let tc = target_vec[b] as usize;
469            losses[b] = -input_vec[b * num_classes + tc];
470        }
471
472        let mut loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
473        if input_data.device().is_gpu() {
474            loss_tensor = loss_tensor.to_device(input_data.device()).unwrap();
475        }
476
477        let requires_grad = input.requires_grad() && is_grad_enabled();
478        let loss_var = if requires_grad {
479            let grad_fn = GradFn::new(NLLLossBackward {
480                next_fns: vec![input.grad_fn().cloned()],
481                target_tensor: target_data.clone(),
482                batch_size,
483                num_classes,
484            });
485            Variable::from_operation(loss_tensor, grad_fn, true)
486        } else {
487            Variable::new(loss_tensor, false)
488        };
489
490        match self.reduction {
491            Reduction::None => loss_var,
492            Reduction::Mean => loss_var.mean(),
493            Reduction::Sum => loss_var.sum(),
494        }
495    }
496}
497
498impl Default for NLLLoss {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504// =============================================================================
505// NLLLossBackward
506// =============================================================================
507
508/// Gradient function for NLLLoss.
509///
510/// d/d(input)[b, c] = -1 if c == target[b], else 0
511///
512/// Stores targets as Tensor<f32> (GPU-resident when applicable).
513/// The scatter in backward still uses CPU indexing since it's a sparse write,
514/// but the result is moved to GPU if needed.
515#[derive(Debug)]
516struct NLLLossBackward {
517    next_fns: Vec<Option<GradFn>>,
518    target_tensor: Tensor<f32>,
519    batch_size: usize,
520    num_classes: usize,
521}
522
523impl GradientFunction for NLLLossBackward {
524    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
525        let grad_out_vec = grad_output.to_vec();
526        let target_vec = self.target_tensor.to_vec();
527        let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
528
529        for b in 0..self.batch_size {
530            let g = if grad_out_vec.len() == 1 {
531                grad_out_vec[0]
532            } else {
533                grad_out_vec[b]
534            };
535            let tc = target_vec[b] as usize;
536            grad_input[b * self.num_classes + tc] = -g;
537        }
538
539        let mut gi = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes]).unwrap();
540        if grad_output.device().is_gpu() {
541            gi = gi.to_device(grad_output.device()).unwrap();
542        }
543        vec![Some(gi)]
544    }
545
546    fn name(&self) -> &'static str {
547        "NLLLossBackward"
548    }
549
550    fn next_functions(&self) -> &[Option<GradFn>] {
551        &self.next_fns
552    }
553
554    fn as_any(&self) -> &dyn Any {
555        self
556    }
557}
558
559// =============================================================================
560// BCELoss
561// =============================================================================
562
563/// Binary Cross Entropy loss.
564///
565/// Expects input to be probabilities in [0, 1].
566#[derive(Debug, Clone, Copy)]
567pub struct BCELoss {
568    reduction: Reduction,
569}
570
571impl BCELoss {
572    /// Creates a new BCELoss with default reduction (Mean).
573    pub fn new() -> Self {
574        Self {
575            reduction: Reduction::Mean,
576        }
577    }
578
579    /// Creates BCELoss with specified reduction.
580    pub fn with_reduction(reduction: Reduction) -> Self {
581        Self { reduction }
582    }
583
584    /// Computes the loss.
585    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
586        let input_data = input.data();
587        let target_data = target.data();
588
589        // Clamp predictions to [eps, 1-eps] using Tensor ops
590        let eps = 1e-7f32;
591        let p_clamped = axonml_tensor::ops::clamp(&input_data, eps, 1.0 - eps);
592
593        // loss = -(t * ln(p) + (1 - t) * ln(1 - p))
594        let ln_p = p_clamped.ln();
595        let one_minus_p = p_clamped.neg().add_scalar(1.0);
596        let ln_one_minus_p = one_minus_p.ln();
597        let one_minus_t = target_data.neg().add_scalar(1.0);
598
599        // t * ln(p)
600        let term1 = target_data.mul(&ln_p).unwrap();
601        // (1-t) * ln(1-p)
602        let term2 = one_minus_t.mul(&ln_one_minus_p).unwrap();
603        // -(term1 + term2)
604        let loss_tensor = term1.add(&term2).unwrap().neg();
605
606        let requires_grad = input.requires_grad() && is_grad_enabled();
607        let loss_var = if requires_grad {
608            let grad_fn = GradFn::new(BCELossBackward {
609                next_fns: vec![input.grad_fn().cloned()],
610                input_tensor: input_data,
611                target_tensor: target_data,
612            });
613            Variable::from_operation(loss_tensor, grad_fn, true)
614        } else {
615            Variable::new(loss_tensor, false)
616        };
617
618        match self.reduction {
619            Reduction::None => loss_var,
620            Reduction::Mean => loss_var.mean(),
621            Reduction::Sum => loss_var.sum(),
622        }
623    }
624}
625
626impl Default for BCELoss {
627    fn default() -> Self {
628        Self::new()
629    }
630}
631
632// =============================================================================
633// BCELossBackward
634// =============================================================================
635
636/// Gradient function for BCELoss.
637///
638/// d/dp BCE = (p - y) / (p * (1 - p))
639///
640/// Stores input/target as Tensor<f32> (GPU-resident when applicable).
641#[derive(Debug)]
642struct BCELossBackward {
643    next_fns: Vec<Option<GradFn>>,
644    input_tensor: Tensor<f32>,
645    target_tensor: Tensor<f32>,
646}
647
648impl GradientFunction for BCELossBackward {
649    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
650        let eps = 1e-7f32;
651        // p_clamped = clamp(input, eps, 1-eps)
652        let p_clamped = axonml_tensor::ops::clamp(&self.input_tensor, eps, 1.0 - eps);
653        // (p - y)
654        let p_minus_y = p_clamped.sub(&self.target_tensor).unwrap();
655        // p * (1 - p)
656        let one_minus_p = p_clamped.neg().add_scalar(1.0);
657        let denom = p_clamped.mul(&one_minus_p).unwrap();
658        // grad = grad_output * (p - y) / (p * (1 - p))
659        let ratio = p_minus_y.div(&denom).unwrap();
660        let grad_tensor = grad_output.mul(&ratio).unwrap();
661        vec![Some(grad_tensor)]
662    }
663
664    fn name(&self) -> &'static str {
665        "BCELossBackward"
666    }
667
668    fn next_functions(&self) -> &[Option<GradFn>] {
669        &self.next_fns
670    }
671
672    fn as_any(&self) -> &dyn Any {
673        self
674    }
675}
676
677// =============================================================================
678// BCEWithLogitsBackward
679// =============================================================================
680
681/// Gradient function for BCEWithLogitsLoss.
682///
683/// The gradient of BCE w.r.t. input logits is: sigmoid(input) - target.
684///
685/// Stores input/target as Tensor<f32> (GPU-resident when applicable).
686#[derive(Debug)]
687struct BCEWithLogitsBackward {
688    next_fns: Vec<Option<GradFn>>,
689    input_tensor: Tensor<f32>,
690    target_tensor: Tensor<f32>,
691}
692
693impl GradientFunction for BCEWithLogitsBackward {
694    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
695        // sigmoid(input) - target, all via Tensor ops (auto-dispatch to GPU)
696        let sig = self.input_tensor.sigmoid();
697        let sig_minus_t = sig.sub(&self.target_tensor).unwrap();
698        let grad_tensor = grad_output.mul(&sig_minus_t).unwrap();
699        vec![Some(grad_tensor)]
700    }
701
702    fn name(&self) -> &'static str {
703        "BCEWithLogitsBackward"
704    }
705
706    fn next_functions(&self) -> &[Option<GradFn>] {
707        &self.next_fns
708    }
709
710    fn as_any(&self) -> &dyn Any {
711        self
712    }
713}
714
715// =============================================================================
716// BCEWithLogitsLoss
717// =============================================================================
718
719/// Binary Cross Entropy with Logits.
720///
721/// Combines sigmoid and BCE in a numerically stable way.
722#[derive(Debug, Clone, Copy)]
723pub struct BCEWithLogitsLoss {
724    reduction: Reduction,
725}
726
727impl BCEWithLogitsLoss {
728    /// Creates a new BCEWithLogitsLoss with default reduction (Mean).
729    pub fn new() -> Self {
730        Self {
731            reduction: Reduction::Mean,
732        }
733    }
734
735    /// Creates BCEWithLogitsLoss with specified reduction.
736    pub fn with_reduction(reduction: Reduction) -> Self {
737        Self { reduction }
738    }
739
740    /// Computes the loss.
741    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
742        let input_data = input.data();
743        let target_data = target.data();
744
745        // Numerically stable: max(x, 0) - x*t + log(1 + exp(-|x|))
746        // max(x, 0) = relu(x) = clamp_min(x, 0)
747        let relu_x = axonml_tensor::ops::clamp_min(&input_data, 0.0);
748        // x * t
749        let x_times_t = input_data.mul(&target_data).unwrap();
750        // |x| via clamp trick: max(x, 0) + max(-x, 0) = relu(x) + relu(-x)
751        let neg_x = input_data.neg();
752        let relu_neg_x = axonml_tensor::ops::clamp_min(&neg_x, 0.0);
753        let abs_x = relu_x.add(&relu_neg_x).unwrap();
754        // exp(-|x|)
755        let exp_neg_abs = abs_x.neg().exp();
756        // log(1 + exp(-|x|))
757        let log_term = exp_neg_abs.add_scalar(1.0).ln();
758        // loss = relu(x) - x*t + log(1 + exp(-|x|))
759        let loss_tensor = relu_x.sub(&x_times_t).unwrap().add(&log_term).unwrap();
760
761        let loss_var = if input.requires_grad() {
762            let grad_fn = GradFn::new(BCEWithLogitsBackward {
763                next_fns: vec![input.grad_fn().cloned()],
764                input_tensor: input_data,
765                target_tensor: target_data,
766            });
767            Variable::from_operation(loss_tensor, grad_fn, true)
768        } else {
769            Variable::new(loss_tensor, false)
770        };
771
772        match self.reduction {
773            Reduction::None => loss_var,
774            Reduction::Mean => loss_var.mean(),
775            Reduction::Sum => loss_var.sum(),
776        }
777    }
778}
779
780impl Default for BCEWithLogitsLoss {
781    fn default() -> Self {
782        Self::new()
783    }
784}
785
786// =============================================================================
787// SmoothL1Backward
788// =============================================================================
789
790/// Gradient function for SmoothL1Loss.
791///
792/// The gradient is: diff/beta if |diff| < beta, else sign(diff).
793/// Returns gradients for both input and target (negated for target).
794///
795/// Stores diff as Tensor<f32> (GPU-resident when applicable).
796#[derive(Debug)]
797struct SmoothL1Backward {
798    next_fns: Vec<Option<GradFn>>,
799    diff_tensor: Tensor<f32>,
800    beta: f32,
801    shape: Vec<usize>,
802}
803
804impl GradientFunction for SmoothL1Backward {
805    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
806        // Compute |diff| = sqrt(diff^2 + eps) to stay differentiable and device-agnostic
807        let eps = 1e-12f32;
808        let diff_sq = self.diff_tensor.mul(&self.diff_tensor).unwrap();
809        let diff_sq_eps = diff_sq.add_scalar(eps);
810        let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
811
812        // sign(diff) = diff / |diff|
813        let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
814
815        // For the L2 region (|diff| < beta): grad = diff / beta
816        // For the L1 region (|diff| >= beta): grad = sign(diff)
817        // Blend: mask = clamp(|diff| / beta, 0, 1), but we need a hard cutoff.
818        // Use a smooth approximation via where: if |diff| < beta -> diff/beta, else sign
819        // Since we need element-wise branching and don't have a GPU where_cond,
820        // we use: grad = (1 - mask) * (diff / beta) + mask * sign(diff)
821        // where mask = clamp((|diff| - beta) * large_value, 0, 1) approximates step function.
822        // Actually simpler: mask_l2 = clamp(1 - |diff|/beta, 0, 1) gives 1 in L2 region, 0 in L1
823        // BUT this gives a soft transition. For exact correctness, use CPU branching on the mask.
824        //
825        // Practical approach: compute both branches with Tensor ops, build mask on CPU, blend.
826        let grad_l2 = self.diff_tensor.mul_scalar(1.0 / self.beta); // diff / beta
827        let grad_l1 = sign_diff; // sign(diff)
828
829        // Build mask tensor: 1.0 where |d| < beta, 0.0 otherwise
830        let abs_vec = abs_diff.to_vec();
831        let beta = self.beta;
832        let mask_vec: Vec<f32> = abs_vec
833            .iter()
834            .map(|&a| if a < beta { 1.0 } else { 0.0 })
835            .collect();
836        let mut mask = Tensor::from_vec(mask_vec, &self.shape).unwrap();
837        if self.diff_tensor.device().is_gpu() {
838            mask = mask.to_device(self.diff_tensor.device()).unwrap();
839        }
840        let inv_mask = mask.neg().add_scalar(1.0);
841
842        // grad_per_elem = mask * grad_l2 + (1 - mask) * grad_l1
843        let blended = mask
844            .mul(&grad_l2)
845            .unwrap()
846            .add(&inv_mask.mul(&grad_l1).unwrap())
847            .unwrap();
848
849        // gi = blended * grad_output
850        let gi = blended.mul(grad_output).unwrap();
851        let gt = gi.neg();
852        vec![Some(gi), Some(gt)]
853    }
854
855    fn name(&self) -> &'static str {
856        "SmoothL1Backward"
857    }
858
859    fn next_functions(&self) -> &[Option<GradFn>] {
860        &self.next_fns
861    }
862
863    fn as_any(&self) -> &dyn Any {
864        self
865    }
866}
867
868// =============================================================================
869// SmoothL1Loss
870// =============================================================================
871
872/// Smooth L1 loss (Huber loss).
873///
874/// Uses L2 loss when |x| < beta, L1 loss otherwise.
875#[derive(Debug, Clone, Copy)]
876pub struct SmoothL1Loss {
877    reduction: Reduction,
878    beta: f32,
879}
880
881impl SmoothL1Loss {
882    /// Creates a new SmoothL1Loss with default beta (1.0).
883    pub fn new() -> Self {
884        Self {
885            reduction: Reduction::Mean,
886            beta: 1.0,
887        }
888    }
889
890    /// Creates SmoothL1Loss with specified beta.
891    pub fn with_beta(beta: f32) -> Self {
892        Self {
893            reduction: Reduction::Mean,
894            beta,
895        }
896    }
897
898    /// Computes the loss.
899    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
900        let input_data = input.data();
901        let target_data = target.data();
902        let diff_tensor = input_data.sub(&target_data).unwrap();
903        let shape = diff_tensor.shape().to_vec();
904
905        // Compute |diff| via relu(diff) + relu(-diff)
906        let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
907        let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
908        let abs_diff = relu_diff.add(&relu_neg_diff).unwrap();
909
910        // L2 branch: 0.5 * diff^2 / beta
911        let diff_sq = diff_tensor.mul(&diff_tensor).unwrap();
912        let l2_loss = diff_sq.mul_scalar(0.5 / self.beta);
913
914        // L1 branch: |diff| - 0.5 * beta
915        let l1_loss = abs_diff.add_scalar(-0.5 * self.beta);
916
917        // Build mask: 1.0 where |diff| < beta, 0.0 otherwise
918        let abs_vec = abs_diff.to_vec();
919        let beta = self.beta;
920        let mask_vec: Vec<f32> = abs_vec
921            .iter()
922            .map(|&a| if a < beta { 1.0 } else { 0.0 })
923            .collect();
924        let mut mask = Tensor::from_vec(mask_vec, &shape).unwrap();
925        if diff_tensor.device().is_gpu() {
926            mask = mask.to_device(diff_tensor.device()).unwrap();
927        }
928        let inv_mask = mask.neg().add_scalar(1.0);
929
930        // loss = mask * l2_loss + (1-mask) * l1_loss
931        let loss_tensor = mask
932            .mul(&l2_loss)
933            .unwrap()
934            .add(&inv_mask.mul(&l1_loss).unwrap())
935            .unwrap();
936
937        let loss_var = if input.requires_grad() || target.requires_grad() {
938            let grad_fn = GradFn::new(SmoothL1Backward {
939                next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
940                diff_tensor,
941                beta: self.beta,
942                shape,
943            });
944            Variable::from_operation(loss_tensor, grad_fn, true)
945        } else {
946            Variable::new(loss_tensor, false)
947        };
948
949        match self.reduction {
950            Reduction::None => loss_var,
951            Reduction::Mean => loss_var.mean(),
952            Reduction::Sum => loss_var.sum(),
953        }
954    }
955}
956
957impl Default for SmoothL1Loss {
958    fn default() -> Self {
959        Self::new()
960    }
961}
962
963// =============================================================================
964// Tests
965// =============================================================================
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970
971    #[test]
972    fn test_mse_loss() {
973        let loss_fn = MSELoss::new();
974        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
975        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
976        let loss = loss_fn.compute(&input, &target);
977        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
978    }
979
980    #[test]
981    fn test_mse_loss_nonzero() {
982        let loss_fn = MSELoss::new();
983        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
984        let target = Variable::new(Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap(), false);
985        let loss = loss_fn.compute(&input, &target);
986        // Each diff is 1.0, squared is 1.0, mean is 1.0
987        assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
988    }
989
990    #[test]
991    fn test_cross_entropy_loss() {
992        let loss_fn = CrossEntropyLoss::new();
993        let input = Variable::new(
994            Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3]).unwrap(),
995            false,
996        );
997        let target = Variable::new(Tensor::from_vec(vec![2.0, 0.0], &[2]).unwrap(), false);
998        let loss = loss_fn.compute(&input, &target);
999        assert!(loss.data().to_vec()[0] > 0.0);
1000    }
1001
1002    #[test]
1003    fn test_bce_loss() {
1004        let loss_fn = BCELoss::new();
1005        let input = Variable::new(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap(), false);
1006        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1007        let loss = loss_fn.compute(&input, &target);
1008        // -[1*ln(0.5) + 0*ln(0.5)] - [0*ln(0.5) + 1*ln(0.5)] = -2*ln(0.5) / 2 = -ln(0.5) = 0.693
1009        assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1010    }
1011
1012    #[test]
1013    fn test_cross_entropy_gradient_flow() {
1014        use axonml_autograd::backward;
1015
1016        // Create input logits with requires_grad=true
1017        let input = Variable::new(
1018            Tensor::from_vec(vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3], &[2, 3]).unwrap(),
1019            true,
1020        );
1021        let target = Variable::new(Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap(), false);
1022
1023        let loss_fn = CrossEntropyLoss::new();
1024        let loss = loss_fn.compute(&input, &target);
1025
1026        // Loss should be positive
1027        let loss_val = loss.data().to_vec()[0];
1028        assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
1029
1030        // Backward pass
1031        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1032        backward(&loss, &ones);
1033
1034        // Input should have gradient
1035        let grad = input
1036            .grad()
1037            .expect("Input should have gradient after backward");
1038        let grad_vec = grad.to_vec();
1039
1040        // Gradient should be non-zero
1041        let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1042        assert!(
1043            grad_norm > 1e-10,
1044            "Gradient should be non-zero, got norm {}",
1045            grad_norm
1046        );
1047
1048        // Gradient shape should match input shape
1049        assert_eq!(grad.shape(), &[2, 3]);
1050
1051        // For the correct class, gradient should be negative (softmax - 1 < 0)
1052        // Sample 0, class 0 (target): grad should be (softmax[0,0] - 1) / 2
1053        assert!(
1054            grad_vec[0] < 0.0,
1055            "Gradient for correct class should be negative"
1056        );
1057        // Sample 1, class 1 (target): grad should be (softmax[1,1] - 1) / 2
1058        assert!(
1059            grad_vec[4] < 0.0,
1060            "Gradient for correct class should be negative"
1061        );
1062
1063        // For wrong classes, gradient should be positive (softmax > 0)
1064        assert!(
1065            grad_vec[1] > 0.0,
1066            "Gradient for wrong class should be positive"
1067        );
1068        assert!(
1069            grad_vec[2] > 0.0,
1070            "Gradient for wrong class should be positive"
1071        );
1072    }
1073
1074    #[test]
1075    fn test_cross_entropy_perfect_prediction() {
1076        // When logits strongly favor the correct class, loss should be near zero
1077        let loss_fn = CrossEntropyLoss::new();
1078        let input = Variable::new(
1079            Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).unwrap(),
1080            false,
1081        );
1082        let target = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1083        let loss = loss_fn.compute(&input, &target);
1084        assert!(
1085            loss.data().to_vec()[0] < 0.001,
1086            "Perfect prediction should have near-zero loss"
1087        );
1088    }
1089
1090    #[test]
1091    fn test_cross_entropy_uniform_prediction() {
1092        // When logits are all equal, loss should be ln(num_classes)
1093        let loss_fn = CrossEntropyLoss::new();
1094        let num_classes = 16;
1095        let input = Variable::new(
1096            Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes]).unwrap(),
1097            false,
1098        );
1099        let target = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1100        let loss = loss_fn.compute(&input, &target);
1101        let expected = (num_classes as f32).ln(); // ln(16) ≈ 2.77
1102        let actual = loss.data().to_vec()[0];
1103        assert!(
1104            (actual - expected).abs() < 0.01,
1105            "Uniform logits should give ln(C)={}, got {}",
1106            expected,
1107            actual,
1108        );
1109    }
1110
1111    #[test]
1112    fn test_bce_with_logits_gradient_flow() {
1113        use axonml_autograd::backward;
1114
1115        let input = Variable::new(
1116            Tensor::from_vec(vec![0.5, -0.5, 1.0, -1.0], &[4]).unwrap(),
1117            true,
1118        );
1119        let target = Variable::new(
1120            Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).unwrap(),
1121            false,
1122        );
1123
1124        let loss_fn = BCEWithLogitsLoss::new();
1125        let loss = loss_fn.compute(&input, &target);
1126        assert!(loss.data().to_vec()[0] > 0.0);
1127
1128        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1129        backward(&loss, &ones);
1130
1131        let grad = input.grad().expect("Input should have gradient");
1132        let grad_vec = grad.to_vec();
1133        assert_eq!(grad_vec.len(), 4);
1134
1135        // For target=1, grad = sigmoid(x) - 1 < 0
1136        assert!(grad_vec[0] < 0.0);
1137        // For target=0, grad = sigmoid(x) > 0
1138        assert!(grad_vec[1] > 0.0);
1139    }
1140
1141    #[test]
1142    fn test_smooth_l1_gradient_flow() {
1143        use axonml_autograd::backward;
1144
1145        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 5.0], &[3]).unwrap(), true);
1146        let target = Variable::new(Tensor::from_vec(vec![1.5, 1.5, 1.5], &[3]).unwrap(), false);
1147
1148        let loss_fn = SmoothL1Loss::new();
1149        let loss = loss_fn.compute(&input, &target);
1150        assert!(loss.data().to_vec()[0] > 0.0);
1151
1152        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1153        backward(&loss, &ones);
1154
1155        let grad = input.grad().expect("Input should have gradient");
1156        let grad_vec = grad.to_vec();
1157        assert_eq!(grad_vec.len(), 3);
1158        // Gradients should be non-zero
1159        let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1160        assert!(grad_norm > 1e-10);
1161    }
1162}