use crate::error::{NeuralError, Result};
use crate::capsule::network::Capsule;
#[derive(Debug, Clone)]
pub struct MarginLoss {
pub m_plus: f32,
pub m_minus: f32,
pub lambda: f32,
}
impl Default for MarginLoss {
fn default() -> Self {
Self {
m_plus: 0.9,
m_minus: 0.1,
lambda: 0.5,
}
}
}
impl MarginLoss {
pub fn new() -> Self {
Self::default()
}
pub fn with_params(m_plus: f32, m_minus: f32, lambda: f32) -> Result<Self> {
if m_plus <= m_minus {
return Err(NeuralError::InvalidArgument(
"m_plus must be > m_minus".into(),
));
}
if lambda < 0.0 {
return Err(NeuralError::InvalidArgument("lambda must be ≥ 0".into()));
}
if !(0.0..=1.0).contains(&m_plus) || !(0.0..=1.0).contains(&m_minus) {
return Err(NeuralError::InvalidArgument(
"margins must be in [0, 1]".into(),
));
}
Ok(Self {
m_plus,
m_minus,
lambda,
})
}
pub fn per_class_loss(&self, norm: f32, present: bool) -> f32 {
if present {
let diff = (self.m_plus - norm).max(0.0);
diff * diff
} else {
let diff = (norm - self.m_minus).max(0.0);
self.lambda * diff * diff
}
}
pub fn compute(capsules: &[Capsule], labels: &[bool]) -> f32 {
let loss_fn = Self::default();
loss_fn.total_loss(capsules, labels).unwrap_or(0.0)
}
pub fn total_loss(&self, capsules: &[Capsule], labels: &[bool]) -> Result<f32> {
if capsules.len() != labels.len() {
return Err(NeuralError::DimensionMismatch(format!(
"capsules length {} != labels length {}",
capsules.len(),
labels.len()
)));
}
let total = capsules
.iter()
.zip(labels.iter())
.map(|(cap, &present)| self.per_class_loss(cap.activation, present))
.sum();
Ok(total)
}
pub fn per_class_losses(&self, capsules: &[Capsule], labels: &[bool]) -> Result<Vec<f32>> {
if capsules.len() != labels.len() {
return Err(NeuralError::DimensionMismatch(format!(
"capsules length {} != labels length {}",
capsules.len(),
labels.len()
)));
}
Ok(capsules
.iter()
.zip(labels.iter())
.map(|(cap, &present)| self.per_class_loss(cap.activation, present))
.collect())
}
pub fn predict_class(capsules: &[Capsule]) -> Result<usize> {
capsules
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.activation.partial_cmp(&b.activation).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.ok_or_else(|| NeuralError::InvalidArgument("capsules is empty".into()))
}
}
#[derive(Debug, Clone)]
pub struct SpreadLoss {
pub m: f32,
pub m_min: f32,
pub m_max: f32,
}
impl SpreadLoss {
pub fn new() -> Self {
Self {
m: 0.2,
m_min: 0.2,
m_max: 0.9,
}
}
pub fn update_margin(&mut self, epoch: usize, total_epochs: usize) {
let t = epoch as f32 / total_epochs.max(1) as f32;
self.m = self.m_min + (self.m_max - self.m_min) * t;
}
pub fn compute(&self, activations: &[f32], target: usize) -> Result<f32> {
if activations.is_empty() {
return Err(NeuralError::InvalidArgument(
"activations must be non-empty".into(),
));
}
if target >= activations.len() {
return Err(NeuralError::InvalidArgument(format!(
"target {target} out of bounds ({})",
activations.len()
)));
}
let a_t = activations[target];
let loss: f32 = activations
.iter()
.enumerate()
.filter(|&(k, _)| k != target)
.map(|(_, &a_k)| {
let diff = (self.m - (a_t - a_k)).max(0.0);
diff * diff
})
.sum();
Ok(loss)
}
}
impl Default for SpreadLoss {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::capsule::network::Capsule;
fn make_caps(norms: &[f32]) -> Vec<Capsule> {
norms
.iter()
.map(|&n| Capsule {
pose: vec![n, 0.0],
activation: n,
})
.collect()
}
#[test]
fn margin_loss_present_correct() {
let loss = MarginLoss::new();
assert_eq!(loss.per_class_loss(0.95, true), 0.0);
}
#[test]
fn margin_loss_absent_correct() {
let loss = MarginLoss::new();
assert_eq!(loss.per_class_loss(0.05, false), 0.0);
}
#[test]
fn margin_loss_present_wrong() {
let loss = MarginLoss::new();
let l = loss.per_class_loss(0.5, true);
assert!((l - 0.16).abs() < 1e-5, "expected 0.16, got {l}");
}
#[test]
fn margin_loss_total() {
let caps = make_caps(&[0.95, 0.05, 0.05]);
let labels = [true, false, false];
let l = MarginLoss::compute(&caps, &labels);
assert!(l < 0.01, "All correct → near-zero loss, got {l}");
}
#[test]
fn margin_loss_mismatch() {
let loss = MarginLoss::new();
let caps = make_caps(&[0.9, 0.1]);
assert!(loss.total_loss(&caps, &[true]).is_err());
}
#[test]
fn predict_class_returns_highest() {
let caps = make_caps(&[0.2, 0.8, 0.5]);
let pred = MarginLoss::predict_class(&caps).expect("operation should succeed");
assert_eq!(pred, 1);
}
#[test]
fn spread_loss_zero_for_large_margin() {
let loss = SpreadLoss::new();
let acts = vec![0.1, 0.9, 0.1];
let l = loss.compute(&acts, 1).expect("operation should succeed");
assert_eq!(l, 0.0);
}
#[test]
fn spread_loss_nonzero_when_close() {
let loss = SpreadLoss::new();
let acts = vec![0.25, 0.3, 0.25];
let l = loss.compute(&acts, 1).expect("operation should succeed");
assert!(l > 0.0, "Should have positive loss when activations are close");
}
}