rust_lstm/
loss.rs

1use ndarray::{Array1, Array2};
2
3/// Loss function trait for training neural networks
4pub trait LossFunction {
5    /// Compute the loss between predictions and targets
6    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64;
7    
8    /// Compute the gradient of the loss with respect to predictions
9    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64>;
10
11    /// Compute batch loss for multiple predictions and targets
12    /// Default implementation computes average loss across batch
13    fn compute_batch_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
14        let batch_size = predictions.ncols();
15        let mut total_loss = 0.0;
16
17        for i in 0..batch_size {
18            let pred_col = predictions.column(i).to_owned().insert_axis(ndarray::Axis(1));
19            let target_col = targets.column(i).to_owned().insert_axis(ndarray::Axis(1));
20            total_loss += self.compute_loss(&pred_col, &target_col);
21        }
22
23        total_loss / batch_size as f64
24    }
25
26    /// Compute batch gradients for multiple predictions and targets
27    /// Default implementation computes gradients for each sample and concatenates
28    fn compute_batch_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
29        let batch_size = predictions.ncols();
30        let mut batch_gradients = Array2::zeros(predictions.raw_dim());
31
32        for i in 0..batch_size {
33            let pred_col = predictions.column(i).to_owned().insert_axis(ndarray::Axis(1));
34            let target_col = targets.column(i).to_owned().insert_axis(ndarray::Axis(1));
35            let grad = self.compute_gradient(&pred_col, &target_col);
36            batch_gradients.column_mut(i).assign(&grad.column(0));
37        }
38
39        batch_gradients
40    }
41}
42
43/// Mean Squared Error loss function
44pub struct MSELoss;
45
46impl LossFunction for MSELoss {
47    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
48        let diff = predictions - targets;
49        let squared_diff = &diff * &diff;
50        squared_diff.sum() / (predictions.len() as f64)
51    }
52    
53    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
54        let diff = predictions - targets;
55        2.0 * diff / (predictions.len() as f64)
56    }
57
58    fn compute_batch_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
59        let diff = predictions - targets;
60        let squared_diff = &diff * &diff;
61        squared_diff.sum() / (predictions.len() as f64)
62    }
63
64    fn compute_batch_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
65        let diff = predictions - targets;
66        2.0 * diff / (predictions.len() as f64)
67    }
68}
69
70/// Mean Absolute Error loss function
71pub struct MAELoss;
72
73impl LossFunction for MAELoss {
74    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
75        let diff = predictions - targets;
76        diff.map(|x| x.abs()).sum() / (predictions.len() as f64)
77    }
78    
79    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
80        let diff = predictions - targets;
81        diff.map(|x| if *x > 0.0 { 1.0 } else if *x < 0.0 { -1.0 } else { 0.0 }) / (predictions.len() as f64)
82    }
83
84    fn compute_batch_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
85        let diff = predictions - targets;
86        diff.map(|x| x.abs()).sum() / (predictions.len() as f64)
87    }
88
89    fn compute_batch_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
90        let diff = predictions - targets;
91        diff.map(|x| if *x > 0.0 { 1.0 } else if *x < 0.0 { -1.0 } else { 0.0 }) / (predictions.len() as f64)
92    }
93}
94
95/// Cross-Entropy Loss with softmax
96pub struct CrossEntropyLoss;
97
98impl LossFunction for CrossEntropyLoss {
99    fn compute_loss(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> f64 {
100        let softmax_preds = softmax(predictions);
101        let epsilon = 1e-15;
102        let log_preds = softmax_preds.map(|x| (x + epsilon).ln());
103        -(targets * log_preds).sum() / (predictions.shape()[1] as f64)
104    }
105    
106    fn compute_gradient(&self, predictions: &Array2<f64>, targets: &Array2<f64>) -> Array2<f64> {
107        let softmax_preds = softmax(predictions);
108        (softmax_preds - targets) / (predictions.shape()[1] as f64)
109    }
110}
111
112/// Numerically stable softmax function
113pub fn softmax(x: &Array2<f64>) -> Array2<f64> {
114    let mut result = Array2::zeros(x.raw_dim());
115    
116    for (i, col) in x.axis_iter(ndarray::Axis(1)).enumerate() {
117        let max_val = col.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
118        let exp_vals: Array1<f64> = col.map(|&val| (val - max_val).exp());
119        let sum_exp = exp_vals.sum();
120        
121        for (j, &exp_val) in exp_vals.iter().enumerate() {
122            result[[j, i]] = exp_val / sum_exp;
123        }
124    }
125    
126    result
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use ndarray::arr2;
133
134    #[test]
135    fn test_mse_loss() {
136        let loss_fn = MSELoss;
137        let predictions = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
138        let targets = arr2(&[[1.5, 2.5], [2.5, 3.5]]);
139        
140        let loss = loss_fn.compute_loss(&predictions, &targets);
141        assert!((loss - 0.25).abs() < 1e-6);
142        
143        let gradient = loss_fn.compute_gradient(&predictions, &targets);
144        assert_eq!(gradient.shape(), predictions.shape());
145    }
146
147    #[test]
148    fn test_mae_loss() {
149        let loss_fn = MAELoss;
150        let predictions = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
151        let targets = arr2(&[[1.5, 2.5], [2.5, 3.5]]);
152        
153        let loss = loss_fn.compute_loss(&predictions, &targets);
154        assert!((loss - 0.5).abs() < 1e-6);
155        
156        let gradient = loss_fn.compute_gradient(&predictions, &targets);
157        assert_eq!(gradient.shape(), predictions.shape());
158    }
159
160    #[test]
161    fn test_softmax() {
162        let input = arr2(&[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);
163        let output = softmax(&input);
164        
165        // Each column should sum to 1
166        for col in output.axis_iter(ndarray::Axis(1)) {
167            let sum: f64 = col.sum();
168            assert!((sum - 1.0).abs() < 1e-6);
169        }
170    }
171}