use crate::Tensor;
use ndarray::Array1;
use super::LossFn;
pub struct CausalLMLoss {
vocab_size: usize,
}
impl CausalLMLoss {
pub fn new(vocab_size: usize) -> Self {
Self { vocab_size }
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exp_vals.iter().sum();
exp_vals.iter().map(|&x| x / sum).collect()
}
}
impl LossFn for CausalLMLoss {
fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
let seq_len = targets.len();
let vocab_size = self.vocab_size;
assert_eq!(
predictions.len(),
seq_len * vocab_size,
"Predictions must be seq_len * vocab_size"
);
let pred_data = predictions.data();
let target_data = targets.data();
let mut total_loss = 0.0;
let mut grads = vec![0.0; predictions.len()];
for pos in 0..seq_len {
let start = pos * vocab_size;
let end = start + vocab_size;
let logits =
&pred_data.as_slice().expect("prediction data must be contiguous")[start..end];
let probs = Self::softmax(logits);
let target_idx = target_data[pos] as usize;
if target_idx < vocab_size {
let prob = probs[target_idx].max(1e-10);
total_loss -= prob.ln();
for (i, &p) in probs.iter().enumerate() {
grads[start + i] = if i == target_idx { p - 1.0 } else { p };
}
}
}
let avg_loss = total_loss / seq_len as f32;
let mut loss = Tensor::from_vec(vec![avg_loss], true);
let scale = 1.0 / seq_len as f32;
for g in &mut grads {
*g *= scale;
}
use crate::autograd::BackwardOp;
use std::rc::Rc;
struct CausalLMBackward {
pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
pred_backward_op: Option<Rc<dyn BackwardOp>>,
grad: Array1<f32>,
}
impl BackwardOp for CausalLMBackward {
fn backward(&self) {
let mut pred_grad = self.pred_grad_cell.borrow_mut();
if let Some(existing) = pred_grad.as_mut() {
*existing = &*existing + &self.grad;
} else {
*pred_grad = Some(self.grad.clone());
}
drop(pred_grad);
if let Some(ref op) = self.pred_backward_op {
op.backward();
}
}
}
if predictions.requires_grad() {
loss.set_backward_op(Rc::new(CausalLMBackward {
pred_grad_cell: predictions.grad_cell(),
pred_backward_op: predictions.backward_op(),
grad: Array1::from(grads),
}));
}
loss
}
fn name(&self) -> &'static str {
"CausalLM"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_causal_lm_loss_basic() {
let loss_fn = CausalLMLoss::new(10); let logits = Tensor::from_vec(vec![0.1; 30], true);
let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0] > 0.0);
assert!(loss.data()[0].is_finite());
}
#[test]
fn test_causal_lm_loss_perfect_prediction() {
let loss_fn = CausalLMLoss::new(3); let logits = Tensor::from_vec(
vec![
10.0, 0.0, 0.0, 0.0, 10.0, 0.0, ],
true,
);
let targets = Tensor::from_vec(vec![0.0, 1.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0] < 0.1);
}
#[test]
fn test_causal_lm_loss_gradient() {
let loss_fn = CausalLMLoss::new(4); let logits = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
let targets = Tensor::from_vec(vec![2.0], false);
let loss = loss_fn.forward(&logits, &targets);
if let Some(backward_op) = loss.backward_op() {
backward_op.backward();
}
let grad = logits.grad().expect("gradient should be available");
for g in &grad {
assert!(g.is_finite());
}
assert!(grad[2] < 0.0);
}
#[test]
fn test_causal_lm_loss_name() {
let loss_fn = CausalLMLoss::new(10);
assert_eq!(loss_fn.name(), "CausalLM");
}
#[test]
fn test_causal_lm_loss_longer_sequence() {
let loss_fn = CausalLMLoss::new(100); let seq_len = 10;
let logits = Tensor::from_vec(vec![0.1; seq_len * 100], true);
let targets: Vec<f32> = (0..seq_len).map(|i| (i % 100) as f32).collect();
let targets = Tensor::from_vec(targets, false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0] > 0.0);
assert!(loss.data()[0].is_finite());
}
#[test]
#[should_panic(expected = "seq_len * vocab_size")]
fn test_causal_lm_loss_mismatched_sizes() {
let loss_fn = CausalLMLoss::new(10);
let logits = Tensor::from_vec(vec![0.1; 20], true); let targets = Tensor::from_vec(vec![0.0, 1.0, 2.0], false); loss_fn.forward(&logits, &targets);
}
#[test]
fn test_causal_lm_loss_no_grad() {
let loss_fn = CausalLMLoss::new(5);
let logits = Tensor::from_vec(vec![0.1; 10], false); let targets = Tensor::from_vec(vec![0.0, 1.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0] > 0.0);
}
}