rust_lstm/
loss.rs

1use ndarray::{Array1, Array2};
2
3/// Loss function trait for training neural networks
4pub trait LossFunction {
5    /// Compute the loss between predictions and targets
6    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64;
7    
8    /// Compute the gradient of the loss with respect to predictions
9    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64>;
10}
11
12/// Mean Squared Error loss function
13pub struct MSELoss;
14
15impl LossFunction for MSELoss {
16    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
17        let diff = predictions - targets;
18        let squared_diff = &diff * &diff;
19        squared_diff.sum() / (predictions.len() as f64)
20    }
21    
22    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
23        let diff = predictions - targets;
24        2.0 * diff / (predictions.len() as f64)
25    }
26}
27
28/// Mean Absolute Error loss function
29pub struct MAELoss;
30
31impl LossFunction for MAELoss {
32    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
33        let diff = predictions - targets;
34        diff.map(|x| x.abs()).sum() / (predictions.len() as f64)
35    }
36    
37    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
38        let diff = predictions - targets;
39        diff.map(|x| if *x > 0.0 { 1.0 } else if *x < 0.0 { -1.0 } else { 0.0 }) / (predictions.len() as f64)
40    }
41}
42
43/// Cross-Entropy Loss with softmax
44pub struct CrossEntropyLoss;
45
46impl LossFunction for CrossEntropyLoss {
47    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
48        let softmax_preds = softmax(predictions);
49        let epsilon = 1e-15;
50        let log_preds = softmax_preds.map(|x| (x + epsilon).ln());
51        -(targets * log_preds).sum() / (predictions.shape()[1] as f64)
52    }
53    
54    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
55        let softmax_preds = softmax(predictions);
56        (softmax_preds - targets) / (predictions.shape()[1] as f64)
57    }
58}
59
60/// Numerically stable softmax function
61pub fn softmax(x: &Array2<f64>) -> Array2<f64> {
62    let mut result = Array2::zeros(x.raw_dim());
63    
64    for (i, col) in x.axis_iter(ndarray::Axis(1)).enumerate() {
65        let max_val = col.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
66        let exp_vals: Array1<f64> = col.map(|&val| (val - max_val).exp());
67        let sum_exp = exp_vals.sum();
68        
69        for (j, &exp_val) in exp_vals.iter().enumerate() {
70            result[[j, i]] = exp_val / sum_exp;
71        }
72    }
73    
74    result
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use ndarray::arr2;
81
82    #[test]
83    fn test_mse_loss() {
84        let loss_fn = MSELoss;
85        let predictions = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
86        let targets = arr2(&[[1.5, 2.5], [2.5, 3.5]]);
87        
88        let loss = loss_fn.compute_loss(&predictions, &targets);
89        assert!((loss - 0.25).abs() < 1e-6);
90        
91        let gradient = loss_fn.compute_gradient(&predictions, &targets);
92        assert_eq!(gradient.shape(), predictions.shape());
93    }
94
95    #[test]
96    fn test_mae_loss() {
97        let loss_fn = MAELoss;
98        let predictions = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
99        let targets = arr2(&[[1.5, 2.5], [2.5, 3.5]]);
100        
101        let loss = loss_fn.compute_loss(&predictions, &targets);
102        assert!((loss - 0.5).abs() < 1e-6);
103        
104        let gradient = loss_fn.compute_gradient(&predictions, &targets);
105        assert_eq!(gradient.shape(), predictions.shape());
106    }
107
108    #[test]
109    fn test_softmax() {
110        let input = arr2(&[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
111        let output = softmax(&input);
112        
113        // Each column should sum to 1
114        for col in output.axis_iter(ndarray::Axis(1)) {
115            let sum: f64 = col.sum();
116            assert!((sum - 1.0).abs() < 1e-6);
117        }
118    }
119}