Skip to main content

entrenar/train/loss/
causal_lm.rs

1//! Causal Language Modeling Loss
2
3use crate::Tensor;
4use ndarray::Array1;
5
6use super::LossFn;
7
8/// Causal Language Modeling Loss
9///
10/// Computes cross-entropy loss for next-token prediction tasks.
11/// This is the standard loss function for autoregressive language models.
12///
13/// The loss is computed as:
14/// L = -sum(log(softmax(logits)[target_token])) / num_tokens
15///
16/// # Example
17///
18/// ```
19/// use entrenar::train::{CausalLMLoss, LossFn};
20/// use entrenar::Tensor;
21///
22/// let loss_fn = CausalLMLoss::new(10); // vocab_size = 10
23/// let logits = Tensor::from_vec(vec![0.1; 3 * 10], true); // seq_len=3, vocab=10
24/// let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], false); // target token IDs
25///
26/// let loss = loss_fn.forward(&logits, &targets);
27/// assert!(loss.data()[0] > 0.0);
28/// ```
29pub struct CausalLMLoss {
30    /// Vocabulary size
31    vocab_size: usize,
32}
33
34impl CausalLMLoss {
35    /// Create new causal LM loss with given vocabulary size
36    pub fn new(vocab_size: usize) -> Self {
37        Self { vocab_size }
38    }
39
40    /// Compute softmax for a single position
41    fn softmax(logits: &[f32]) -> Vec<f32> {
42        let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
43        let exp_vals: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
44        let sum: f32 = exp_vals.iter().sum();
45        exp_vals.iter().map(|&x| x / sum).collect()
46    }
47}
48
49impl LossFn for CausalLMLoss {
50    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
51        let seq_len = targets.len();
52        let vocab_size = self.vocab_size;
53
54        assert_eq!(
55            predictions.len(),
56            seq_len * vocab_size,
57            "Predictions must be seq_len * vocab_size"
58        );
59
60        let pred_data = predictions.data();
61        let target_data = targets.data();
62
63        // Compute cross-entropy loss for each position
64        let mut total_loss = 0.0;
65        let mut grads = vec![0.0; predictions.len()];
66
67        for pos in 0..seq_len {
68            let start = pos * vocab_size;
69            let end = start + vocab_size;
70            let logits =
71                &pred_data.as_slice().expect("prediction data must be contiguous")[start..end];
72
73            // Softmax
74            let probs = Self::softmax(logits);
75
76            // Get target token ID
77            let target_idx = target_data[pos] as usize;
78            if target_idx < vocab_size {
79                // Cross-entropy: -log(prob of correct token)
80                let prob = probs[target_idx].max(1e-10);
81                total_loss -= prob.ln();
82
83                // Gradient: probs - one_hot(target)
84                for (i, &p) in probs.iter().enumerate() {
85                    grads[start + i] = if i == target_idx { p - 1.0 } else { p };
86                }
87            }
88        }
89
90        // Average loss over sequence
91        let avg_loss = total_loss / seq_len as f32;
92        let mut loss = Tensor::from_vec(vec![avg_loss], true);
93
94        // Scale gradients by 1/seq_len
95        let scale = 1.0 / seq_len as f32;
96        for g in &mut grads {
97            *g *= scale;
98        }
99
100        // Setup backward
101        use crate::autograd::BackwardOp;
102        use std::rc::Rc;
103
104        struct CausalLMBackward {
105            pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
106            pred_backward_op: Option<Rc<dyn BackwardOp>>,
107            grad: Array1<f32>,
108        }
109
110        impl BackwardOp for CausalLMBackward {
111            fn backward(&self) {
112                // Set gradient on predictions
113                let mut pred_grad = self.pred_grad_cell.borrow_mut();
114                if let Some(existing) = pred_grad.as_mut() {
115                    *existing = &*existing + &self.grad;
116                } else {
117                    *pred_grad = Some(self.grad.clone());
118                }
119                drop(pred_grad); // Release borrow before recursive call
120
121                // Continue backward propagation through the computational graph
122                if let Some(ref op) = self.pred_backward_op {
123                    op.backward();
124                }
125            }
126        }
127
128        if predictions.requires_grad() {
129            loss.set_backward_op(Rc::new(CausalLMBackward {
130                pred_grad_cell: predictions.grad_cell(),
131                pred_backward_op: predictions.backward_op(),
132                grad: Array1::from(grads),
133            }));
134        }
135
136        loss
137    }
138
139    fn name(&self) -> &'static str {
140        "CausalLM"
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_causal_lm_loss_basic() {
150        let loss_fn = CausalLMLoss::new(10); // vocab_size = 10
151                                             // 3 positions, each with 10 logits
152        let logits = Tensor::from_vec(vec![0.1; 30], true);
153        // Targets: token 0, 1, 2
154        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], false);
155
156        let loss = loss_fn.forward(&logits, &targets);
157
158        // Loss should be positive and finite
159        assert!(loss.data()[0] > 0.0);
160        assert!(loss.data()[0].is_finite());
161    }
162
163    #[test]
164    fn test_causal_lm_loss_perfect_prediction() {
165        let loss_fn = CausalLMLoss::new(3); // vocab_size = 3
166                                            // Perfect logits: high value at correct index
167        let logits = Tensor::from_vec(
168            vec![
169                10.0, 0.0, 0.0, // position 0: target 0
170                0.0, 10.0, 0.0, // position 1: target 1
171            ],
172            true,
173        );
174        let targets = Tensor::from_vec(vec![0.0, 1.0], false);
175
176        let loss = loss_fn.forward(&logits, &targets);
177
178        // Loss should be very small (near zero) for correct predictions
179        assert!(loss.data()[0] < 0.1);
180    }
181
182    #[test]
183    fn test_causal_lm_loss_gradient() {
184        let loss_fn = CausalLMLoss::new(4); // vocab_size = 4
185        let logits = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
186        let targets = Tensor::from_vec(vec![2.0], false); // target = token 2
187
188        let loss = loss_fn.forward(&logits, &targets);
189
190        if let Some(backward_op) = loss.backward_op() {
191            backward_op.backward();
192        }
193
194        let grad = logits.grad().expect("gradient should be available");
195        // Gradient should be finite
196        for g in &grad {
197            assert!(g.is_finite());
198        }
199        // Gradient at correct position should be negative (prob - 1)
200        // since target is index 2
201        assert!(grad[2] < 0.0);
202    }
203
204    #[test]
205    fn test_causal_lm_loss_name() {
206        let loss_fn = CausalLMLoss::new(10);
207        assert_eq!(loss_fn.name(), "CausalLM");
208    }
209
210    #[test]
211    fn test_causal_lm_loss_longer_sequence() {
212        let loss_fn = CausalLMLoss::new(100); // vocab_size = 100
213        let seq_len = 10;
214        let logits = Tensor::from_vec(vec![0.1; seq_len * 100], true);
215        let targets: Vec<f32> = (0..seq_len).map(|i| (i % 100) as f32).collect();
216        let targets = Tensor::from_vec(targets, false);
217
218        let loss = loss_fn.forward(&logits, &targets);
219        assert!(loss.data()[0] > 0.0);
220        assert!(loss.data()[0].is_finite());
221    }
222
223    #[test]
224    #[should_panic(expected = "seq_len * vocab_size")]
225    fn test_causal_lm_loss_mismatched_sizes() {
226        let loss_fn = CausalLMLoss::new(10);
227        let logits = Tensor::from_vec(vec![0.1; 20], true); // Only 2 positions
228        let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], false); // 3 targets
229        loss_fn.forward(&logits, &targets);
230    }
231
232    #[test]
233    fn test_causal_lm_loss_no_grad() {
234        let loss_fn = CausalLMLoss::new(5);
235        let logits = Tensor::from_vec(vec![0.1; 10], false); // no grad
236        let targets = Tensor::from_vec(vec![0.0, 1.0], false);
237        let loss = loss_fn.forward(&logits, &targets);
238        assert!(loss.data()[0] > 0.0);
239    }
240}