entrenar/train/loss/
causal_lm.rs1use crate::Tensor;
4use ndarray::Array1;
5
6use super::LossFn;
7
8pub struct CausalLMLoss {
30 vocab_size: usize,
32}
33
34impl CausalLMLoss {
35 pub fn new(vocab_size: usize) -> Self {
37 Self { vocab_size }
38 }
39
40 fn softmax(logits: &[f32]) -> Vec<f32> {
42 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
43 let exp_vals: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
44 let sum: f32 = exp_vals.iter().sum();
45 exp_vals.iter().map(|&x| x / sum).collect()
46 }
47}
48
49impl LossFn for CausalLMLoss {
50 fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
51 let seq_len = targets.len();
52 let vocab_size = self.vocab_size;
53
54 assert_eq!(
55 predictions.len(),
56 seq_len * vocab_size,
57 "Predictions must be seq_len * vocab_size"
58 );
59
60 let pred_data = predictions.data();
61 let target_data = targets.data();
62
63 let mut total_loss = 0.0;
65 let mut grads = vec![0.0; predictions.len()];
66
67 for pos in 0..seq_len {
68 let start = pos * vocab_size;
69 let end = start + vocab_size;
70 let logits =
71 &pred_data.as_slice().expect("prediction data must be contiguous")[start..end];
72
73 let probs = Self::softmax(logits);
75
76 let target_idx = target_data[pos] as usize;
78 if target_idx < vocab_size {
79 let prob = probs[target_idx].max(1e-10);
81 total_loss -= prob.ln();
82
83 for (i, &p) in probs.iter().enumerate() {
85 grads[start + i] = if i == target_idx { p - 1.0 } else { p };
86 }
87 }
88 }
89
90 let avg_loss = total_loss / seq_len as f32;
92 let mut loss = Tensor::from_vec(vec![avg_loss], true);
93
94 let scale = 1.0 / seq_len as f32;
96 for g in &mut grads {
97 *g *= scale;
98 }
99
100 use crate::autograd::BackwardOp;
102 use std::rc::Rc;
103
104 struct CausalLMBackward {
105 pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
106 pred_backward_op: Option<Rc<dyn BackwardOp>>,
107 grad: Array1<f32>,
108 }
109
110 impl BackwardOp for CausalLMBackward {
111 fn backward(&self) {
112 let mut pred_grad = self.pred_grad_cell.borrow_mut();
114 if let Some(existing) = pred_grad.as_mut() {
115 *existing = &*existing + &self.grad;
116 } else {
117 *pred_grad = Some(self.grad.clone());
118 }
119 drop(pred_grad); if let Some(ref op) = self.pred_backward_op {
123 op.backward();
124 }
125 }
126 }
127
128 if predictions.requires_grad() {
129 loss.set_backward_op(Rc::new(CausalLMBackward {
130 pred_grad_cell: predictions.grad_cell(),
131 pred_backward_op: predictions.backward_op(),
132 grad: Array1::from(grads),
133 }));
134 }
135
136 loss
137 }
138
139 fn name(&self) -> &'static str {
140 "CausalLM"
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_causal_lm_loss_basic() {
150 let loss_fn = CausalLMLoss::new(10); let logits = Tensor::from_vec(vec![0.1; 30], true);
153 let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], false);
155
156 let loss = loss_fn.forward(&logits, &targets);
157
158 assert!(loss.data()[0] > 0.0);
160 assert!(loss.data()[0].is_finite());
161 }
162
163 #[test]
164 fn test_causal_lm_loss_perfect_prediction() {
165 let loss_fn = CausalLMLoss::new(3); let logits = Tensor::from_vec(
168 vec![
169 10.0, 0.0, 0.0, 0.0, 10.0, 0.0, ],
172 true,
173 );
174 let targets = Tensor::from_vec(vec![0.0, 1.0], false);
175
176 let loss = loss_fn.forward(&logits, &targets);
177
178 assert!(loss.data()[0] < 0.1);
180 }
181
182 #[test]
183 fn test_causal_lm_loss_gradient() {
184 let loss_fn = CausalLMLoss::new(4); let logits = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
186 let targets = Tensor::from_vec(vec![2.0], false); let loss = loss_fn.forward(&logits, &targets);
189
190 if let Some(backward_op) = loss.backward_op() {
191 backward_op.backward();
192 }
193
194 let grad = logits.grad().expect("gradient should be available");
195 for g in &grad {
197 assert!(g.is_finite());
198 }
199 assert!(grad[2] < 0.0);
202 }
203
204 #[test]
205 fn test_causal_lm_loss_name() {
206 let loss_fn = CausalLMLoss::new(10);
207 assert_eq!(loss_fn.name(), "CausalLM");
208 }
209
210 #[test]
211 fn test_causal_lm_loss_longer_sequence() {
212 let loss_fn = CausalLMLoss::new(100); let seq_len = 10;
214 let logits = Tensor::from_vec(vec![0.1; seq_len * 100], true);
215 let targets: Vec<f32> = (0..seq_len).map(|i| (i % 100) as f32).collect();
216 let targets = Tensor::from_vec(targets, false);
217
218 let loss = loss_fn.forward(&logits, &targets);
219 assert!(loss.data()[0] > 0.0);
220 assert!(loss.data()[0].is_finite());
221 }
222
223 #[test]
224 #[should_panic(expected = "seq_len * vocab_size")]
225 fn test_causal_lm_loss_mismatched_sizes() {
226 let loss_fn = CausalLMLoss::new(10);
227 let logits = Tensor::from_vec(vec![0.1; 20], true); let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], false); loss_fn.forward(&logits, &targets);
230 }
231
232 #[test]
233 fn test_causal_lm_loss_no_grad() {
234 let loss_fn = CausalLMLoss::new(5);
235 let logits = Tensor::from_vec(vec![0.1; 10], false); let targets = Tensor::from_vec(vec![0.0, 1.0], false);
237 let loss = loss_fn.forward(&logits, &targets);
238 assert!(loss.data()[0] > 0.0);
239 }
240}