use burn::tensor::{backend::Backend, Tensor};
use crate::types::{Energy, Representation};
pub trait EnergyFn<B: Backend> {
fn compute(&self, predicted: &Representation<B>, actual: &Representation<B>) -> Energy<B>;
}
#[derive(Debug, Clone, Copy)]
pub struct L2Energy;
impl<B: Backend> EnergyFn<B> for L2Energy {
fn compute(&self, predicted: &Representation<B>, actual: &Representation<B>) -> Energy<B> {
let diff = predicted.embeddings.clone() - actual.embeddings.clone();
let squared = diff.clone() * diff;
let mean = squared.mean();
Energy {
value: mean.unsqueeze(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CosineEnergy;
impl<B: Backend> EnergyFn<B> for CosineEnergy {
fn compute(&self, predicted: &Representation<B>, actual: &Representation<B>) -> Energy<B> {
let [batch, seq_len, embed_dim] = predicted.embeddings.dims();
let p = predicted
.embeddings
.clone()
.reshape([batch * seq_len, embed_dim]);
let a = actual
.embeddings
.clone()
.reshape([batch * seq_len, embed_dim]);
let dot = (p.clone() * a.clone()).sum_dim(1);
let norm_p = (p.clone() * p).sum_dim(1).sqrt();
let norm_a = (a.clone() * a).sum_dim(1).sqrt();
let eps: f64 = 1e-8;
let cos_sim = dot / (norm_p * norm_a + eps);
let one: Tensor<B, 1> = Tensor::ones([1], &cos_sim.device());
let energy_value = one - cos_sim.mean().unsqueeze();
Energy {
value: energy_value,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SmoothL1Energy {
pub beta: f64,
}
impl SmoothL1Energy {
pub fn new(beta: f64) -> Self {
Self { beta }
}
}
impl<B: Backend> EnergyFn<B> for SmoothL1Energy {
fn compute(&self, predicted: &Representation<B>, actual: &Representation<B>) -> Energy<B> {
let diff = predicted.embeddings.clone() - actual.embeddings.clone();
let abs_diff = diff.abs();
let beta_tensor: Tensor<B, 3> =
Tensor::full(abs_diff.dims(), self.beta, &abs_diff.device());
let quadratic = abs_diff.clone() * abs_diff.clone() * 0.5 / self.beta;
let linear = abs_diff.clone() - 0.5 * self.beta;
let mask = abs_diff.lower(beta_tensor).float();
let one_minus_mask = mask.clone().neg() + 1.0;
let loss = quadratic * mask + linear * one_minus_mask;
let mean = loss.mean();
Energy {
value: mean.unsqueeze(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::ElementConversion;
use burn_ndarray::NdArray;
use proptest::prelude::*;
use rand::RngExt as _;
use rand::SeedableRng;
type TestBackend = NdArray<f32>;
fn device() -> burn_ndarray::NdArrayDevice {
burn_ndarray::NdArrayDevice::Cpu
}
fn make_repr(data: &[f32], shape: [usize; 3]) -> Representation<TestBackend> {
Representation::new(Tensor::from_floats(
burn::tensor::TensorData::new(data.to_vec(), shape),
&device(),
))
}
#[test]
fn test_l2_energy_identical_representations_is_zero() {
let data: Vec<f32> = (0..24).map(|i| i as f32 * 0.1).collect();
let repr = make_repr(&data, [2, 3, 4]);
let energy = L2Energy.compute(&repr, &repr);
let val: f32 = energy.value.into_scalar().elem();
assert!(val.abs() < 1e-6, "expected ~0, got {val}");
}
#[test]
fn test_l2_energy_different_representations_is_positive() {
let a_data: Vec<f32> = (0..24).map(|i| i as f32 * 0.1).collect();
let b_data: Vec<f32> = (0..24).map(|i| (i as f32 + 1.0) * 0.1).collect();
let a = make_repr(&a_data, [2, 3, 4]);
let b = make_repr(&b_data, [2, 3, 4]);
let energy = L2Energy.compute(&a, &b);
let val: f32 = energy.value.into_scalar().elem();
assert!(val > 0.0, "expected positive, got {val}");
}
#[test]
fn test_l2_energy_is_symmetric() {
let a_data: Vec<f32> = (0..24).map(|i| i as f32 * 0.1).collect();
let b_data: Vec<f32> = (0..24).map(|i| (i as f32 + 5.0) * 0.3).collect();
let a = make_repr(&a_data, [2, 3, 4]);
let b = make_repr(&b_data, [2, 3, 4]);
let e_ab: f32 = L2Energy.compute(&a, &b).value.into_scalar().elem();
let e_ba: f32 = L2Energy.compute(&b, &a).value.into_scalar().elem();
assert!(
(e_ab - e_ba).abs() < 1e-6,
"L2 energy not symmetric: {e_ab} vs {e_ba}"
);
}
#[test]
fn test_cosine_energy_identical_is_near_zero() {
let data: Vec<f32> = (1..25).map(|i| i as f32).collect();
let repr = make_repr(&data, [2, 3, 4]);
let energy = CosineEnergy.compute(&repr, &repr);
let val: f32 = energy.value.into_scalar().elem();
assert!(val.abs() < 1e-5, "expected ~0, got {val}");
}
#[test]
fn test_cosine_energy_orthogonal_is_near_one() {
let a = make_repr(&[1.0, 0.0], [1, 1, 2]);
let b = make_repr(&[0.0, 1.0], [1, 1, 2]);
let energy = CosineEnergy.compute(&a, &b);
let val: f32 = energy.value.into_scalar().elem();
assert!(
(val - 1.0).abs() < 1e-5,
"expected ~1.0 for orthogonal, got {val}"
);
}
#[test]
fn test_smooth_l1_identical_is_zero() {
let data: Vec<f32> = (0..12).map(|i| i as f32 * 0.1).collect();
let repr = make_repr(&data, [1, 3, 4]);
let energy = SmoothL1Energy::new(1.0).compute(&repr, &repr);
let val: f32 = energy.value.into_scalar().elem();
assert!(val.abs() < 1e-6, "expected ~0, got {val}");
}
#[test]
fn test_smooth_l1_is_non_negative() {
let a_data: Vec<f32> = (0..12).map(|i| i as f32 * 0.5).collect();
let b_data: Vec<f32> = (0..12).map(|i| (i as f32 + 2.0) * 0.3).collect();
let a = make_repr(&a_data, [1, 3, 4]);
let b = make_repr(&b_data, [1, 3, 4]);
let energy = SmoothL1Energy::new(1.0).compute(&a, &b);
let val: f32 = energy.value.into_scalar().elem();
assert!(val >= 0.0, "expected non-negative, got {val}");
}
#[test]
fn test_cosine_energy_with_near_zero_vectors() {
let a = make_repr(&[1e-7, 1e-7, 1e-7, 1e-7], [1, 1, 4]);
let b = make_repr(&[1e-7, 1e-7, 1e-7, 1e-7], [1, 1, 4]);
let energy = CosineEnergy.compute(&a, &b);
let val: f32 = energy.value.into_scalar().elem();
assert!(
val.is_finite(),
"cosine energy should be finite for near-zero vectors, got {val}"
);
}
#[test]
fn test_cosine_energy_antiparallel_is_near_two() {
let a = make_repr(&[1.0, 0.0], [1, 1, 2]);
let b = make_repr(&[-1.0, 0.0], [1, 1, 2]);
let energy = CosineEnergy.compute(&a, &b);
let val: f32 = energy.value.into_scalar().elem();
assert!(
(val - 2.0).abs() < 1e-4,
"expected ~2.0 for antiparallel, got {val}"
);
}
#[test]
fn test_smooth_l1_small_differences_are_quadratic() {
let beta = 2.0;
let a = make_repr(&[0.0; 4], [1, 1, 4]);
let b = make_repr(&[0.5; 4], [1, 1, 4]); let energy = SmoothL1Energy::new(beta).compute(&a, &b);
let val: f32 = energy.value.into_scalar().elem();
assert!((val - 0.0625).abs() < 1e-4, "expected ~0.0625, got {val}");
}
#[test]
fn test_smooth_l1_large_differences_are_linear() {
let beta = 0.1;
let a = make_repr(&[0.0; 4], [1, 1, 4]);
let b = make_repr(&[5.0; 4], [1, 1, 4]); let energy = SmoothL1Energy::new(beta).compute(&a, &b);
let val: f32 = energy.value.into_scalar().elem();
assert!((val - 4.95).abs() < 1e-3, "expected ~4.95, got {val}");
}
#[test]
fn test_l2_energy_large_values_stays_finite() {
let data: Vec<f32> = (0..24).map(|i| i as f32 * 1000.0).collect();
let a = make_repr(&data, [2, 3, 4]);
let zeros = make_repr(&[0.0; 24], [2, 3, 4]);
let val: f32 = L2Energy.compute(&a, &zeros).value.into_scalar().elem();
assert!(
val.is_finite(),
"L2 energy should stay finite for large values, got {val}"
);
assert!(val > 0.0);
}
proptest! {
#[test]
fn prop_l2_energy_never_negative(seed in 0u64..10000) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let a_data: Vec<f32> = (0..24).map(|_| (rng.random::<f32>() - 0.5) * 10.0).collect();
let b_data: Vec<f32> = (0..24).map(|_| (rng.random::<f32>() - 0.5) * 10.0).collect();
let a = make_repr(&a_data, [2, 3, 4]);
let b = make_repr(&b_data, [2, 3, 4]);
let val: f32 = L2Energy.compute(&a, &b).value.into_scalar().elem();
prop_assert!(val >= 0.0, "L2 energy was negative: {val}");
prop_assert!(val.is_finite(), "L2 energy was not finite: {val}");
}
#[test]
fn prop_l2_energy_is_symmetric(seed in 0u64..10000) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let a_data: Vec<f32> = (0..24).map(|_| rng.random::<f32>() * 5.0).collect();
let b_data: Vec<f32> = (0..24).map(|_| rng.random::<f32>() * 5.0).collect();
let a = make_repr(&a_data, [2, 3, 4]);
let b = make_repr(&b_data, [2, 3, 4]);
let e_ab: f32 = L2Energy.compute(&a, &b).value.into_scalar().elem();
let e_ba: f32 = L2Energy.compute(&b, &a).value.into_scalar().elem();
prop_assert!((e_ab - e_ba).abs() < 1e-5, "not symmetric: {e_ab} vs {e_ba}");
}
#[test]
fn prop_l2_energy_zero_for_identical(seed in 0u64..10000) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let data: Vec<f32> = (0..24).map(|_| rng.random::<f32>() * 10.0).collect();
let repr = make_repr(&data, [2, 3, 4]);
let val: f32 = L2Energy.compute(&repr, &repr).value.into_scalar().elem();
prop_assert!(val.abs() < 1e-6, "expected ~0 for identical, got {val}");
}
#[test]
fn prop_smooth_l1_never_negative(seed in 0u64..10000) {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let a_data: Vec<f32> = (0..12).map(|_| rng.random::<f32>() * 10.0).collect();
let b_data: Vec<f32> = (0..12).map(|_| rng.random::<f32>() * 10.0).collect();
let a = make_repr(&a_data, [1, 3, 4]);
let b = make_repr(&b_data, [1, 3, 4]);
let val: f32 = SmoothL1Energy::new(1.0).compute(&a, &b).value.into_scalar().elem();
prop_assert!(val >= 0.0, "SmoothL1 was negative: {val}");
}
}
}