#[inline]
pub fn softmax(raw_scores: &[f32]) -> Vec<f32> {
if raw_scores.is_empty() {
return vec![];
}
let max_score = raw_scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = raw_scores.iter().map(|&x| (x - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
exp_scores.iter().map(|&e| e / sum_exp).collect()
}
#[derive(Debug, Clone)]
pub struct MultiClassLogLoss {
pub num_classes: usize,
eps: f32,
}
impl MultiClassLogLoss {
pub fn new(num_classes: usize) -> Self {
Self {
num_classes,
eps: 1e-7,
}
}
#[inline]
pub fn gradient_hessian_for_class(
&self,
target_class: usize,
class_idx: usize,
raw_predictions: &[f32],
) -> (f32, f32) {
let probs = softmax(raw_predictions);
let p = probs[class_idx];
let y = if target_class == class_idx { 1.0 } else { 0.0 };
let gradient = p - y;
let hessian = (p * (1.0 - p)).max(self.eps);
(gradient, hessian)
}
#[inline]
pub fn gradient_hessian_all_classes(
&self,
target_class: usize,
raw_predictions: &[f32],
) -> (Vec<f32>, Vec<f32>) {
let probs = softmax(raw_predictions);
let mut gradients = Vec::with_capacity(self.num_classes);
let mut hessians = Vec::with_capacity(self.num_classes);
for (k, &p) in probs.iter().enumerate() {
let y = if target_class == k { 1.0 } else { 0.0 };
gradients.push(p - y);
hessians.push((p * (1.0 - p)).max(self.eps));
}
(gradients, hessians)
}
pub fn compute_gradients_batch(
&self,
class_idx: usize,
targets: &[f32],
predictions: &[f32],
sample_indices: &[usize],
gradients: &mut [f32],
hessians: &mut [f32],
) {
let num_classes = self.num_classes;
let eps = self.eps;
for &idx in sample_indices {
let target_class = targets[idx] as usize;
let row_start = idx * num_classes;
let row_preds = &predictions[row_start..row_start + num_classes];
let probs = softmax(row_preds);
let p = probs[class_idx];
let y = if target_class == class_idx { 1.0 } else { 0.0 };
gradients[idx] = p - y;
hessians[idx] = (p * (1.0 - p)).max(eps);
}
}
pub fn initial_predictions(&self, targets: &[f32]) -> Vec<f32> {
let mut class_counts = vec![0usize; self.num_classes];
for &t in targets {
let class_idx = t as usize;
if class_idx < self.num_classes {
class_counts[class_idx] += 1;
}
}
let n = targets.len() as f32;
class_counts
.iter()
.map(|&count| {
let p = (count as f32 / n).clamp(self.eps, 1.0 - self.eps);
p.ln()
})
.collect()
}
pub fn num_classes(&self) -> usize {
self.num_classes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_basic() {
let scores = vec![1.0, 2.0, 3.0];
let probs = softmax(&scores);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(probs[2] > probs[1]);
assert!(probs[1] > probs[0]);
}
#[test]
fn test_softmax_numerical_stability() {
let scores = vec![1000.0, 1001.0, 1002.0];
let probs = softmax(&scores);
for p in &probs {
assert!(p.is_finite());
assert!(*p >= 0.0 && *p <= 1.0);
}
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_gradient_for_true_class() {
let loss = MultiClassLogLoss::new(3);
let raw = vec![1.0, 2.0, 0.5];
let (grad, hess) = loss.gradient_hessian_for_class(1, 1, &raw);
let probs = softmax(&raw);
assert!((grad - (probs[1] - 1.0)).abs() < 1e-6);
assert!(hess > 0.0);
}
#[test]
fn test_gradient_for_other_class() {
let loss = MultiClassLogLoss::new(3);
let raw = vec![1.0, 2.0, 0.5];
let (grad, _hess) = loss.gradient_hessian_for_class(1, 0, &raw);
let probs = softmax(&raw);
assert!((grad - probs[0]).abs() < 1e-6);
}
#[test]
fn test_all_classes_gradient() {
let loss = MultiClassLogLoss::new(3);
let raw = vec![1.0, 2.0, 0.5];
let target_class = 1;
let (grads, hess) = loss.gradient_hessian_all_classes(target_class, &raw);
assert_eq!(grads.len(), 3);
assert_eq!(hess.len(), 3);
let grad_sum: f32 = grads.iter().sum();
assert!(grad_sum.abs() < 1e-6);
}
#[test]
fn test_initial_predictions() {
let loss = MultiClassLogLoss::new(3);
let targets: Vec<f32> = vec![
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, ];
let init = loss.initial_predictions(&targets);
assert_eq!(init.len(), 3);
assert!(init[0] > init[1]);
assert!(init[1] > init[2]);
}
}