use torsh_core::{dtype::FloatElement, Result as TorshResult};
use torsh_tensor::Tensor;
pub fn softmax<T: FloatElement>(
input: &Tensor<T>,
dim: i64,
dtype: Option<torsh_core::DType>,
) -> TorshResult<Tensor<T>> {
let input = if let Some(_dtype) = dtype {
input.clone()
} else {
input.clone()
};
input.softmax(dim as i32)
}
pub fn log_softmax<T: FloatElement>(
input: &Tensor<T>,
dim: i64,
dtype: Option<torsh_core::DType>,
) -> TorshResult<Tensor<T>> {
let input = if let Some(_dtype) = dtype {
input.clone()
} else {
input.clone()
};
input.log_softmax(dim as i32)
}
pub fn softmin<T: FloatElement>(
input: &Tensor<T>,
dim: i64,
dtype: Option<torsh_core::DType>,
) -> TorshResult<Tensor<T>>
where
T: Default,
{
let neg_input = input.neg()?;
softmax(&neg_input, dim, dtype)
}
pub fn gumbel_softmax<T: FloatElement>(
logits: &Tensor<T>,
tau: f64,
hard: bool,
eps: f64,
dim: i64,
) -> TorshResult<Tensor<T>>
where
T: Copy + PartialOrd + From<f32> + Default,
{
let eps_val = <T as From<f32>>::from(eps as f32);
let _tau_val = <T as From<f32>>::from(tau as f32);
let shape = logits.shape().dims().to_vec();
let noise_data: Vec<T> = (0..logits.shape().numel())
.map(|i| {
let u = <T as From<f32>>::from(0.5 + 0.1 * ((i as f32 * 0.123) % 1.0));
let gumbel_noise = -(-(u + eps_val).ln()).ln();
gumbel_noise
})
.collect();
let gumbel_noise = Tensor::from_data(noise_data, shape, logits.device())?;
let noisy_logits = logits.add(&gumbel_noise)?;
let scaled_logits = noisy_logits.div_scalar(<T as From<f32>>::from(tau as f32))?;
let y_soft = scaled_logits.softmax(dim as i32)?;
if hard {
Ok(y_soft)
} else {
Ok(y_soft)
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_softmax_properties() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let output = softmax(&input, 0, None)?;
let output_data = output.data()?;
let sum: f32 = output_data.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"Softmax sum should be 1, got {}",
sum
);
for &val in output_data.iter() {
assert!(
val > 0.0 && val < 1.0,
"Softmax output {} not in (0,1)",
val
);
}
assert!(output_data[0] < output_data[1]);
assert!(output_data[1] < output_data[2]);
Ok(())
}
#[test]
fn test_log_softmax_properties() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let output = log_softmax(&input, 0, None)?;
let output_data = output.data()?;
for &val in output_data.iter() {
assert!(val <= 0.0, "Log softmax output {} should be ≤ 0", val);
}
let exp_output = output.exp()?;
let softmax_output = softmax(&input, 0, None)?;
let exp_data = exp_output.data()?;
let softmax_data = softmax_output.data()?;
for (i, (&exp_val, &soft_val)) in exp_data.iter().zip(softmax_data.iter()).enumerate() {
assert!(
((exp_val - soft_val) as f32).abs() < 1e-5,
"exp(log_softmax) != softmax at index {}: {} vs {}",
i,
exp_val,
soft_val
);
}
Ok(())
}
#[test]
fn test_softmin_properties() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let output = softmin(&input, 0, None)?;
let output_data = output.data()?;
let sum: f32 = output_data.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"Softmin sum should be 1, got {}",
sum
);
for &val in output_data.iter() {
assert!(
val > 0.0 && val < 1.0,
"Softmin output {} not in (0,1)",
val
);
}
assert!(output_data[0] > output_data[1]);
assert!(output_data[1] > output_data[2]);
let neg_input = input.neg()?;
let softmax_neg = softmax(&neg_input, 0, None)?;
let softmax_neg_data = softmax_neg.data()?;
for (i, (&softmin_val, &softmax_neg_val)) in
output_data.iter().zip(softmax_neg_data.iter()).enumerate()
{
assert!(
(softmin_val - softmax_neg_val).abs() < 1e-5,
"softmin(x) != softmax(-x) at index {}: {} vs {}",
i,
softmin_val,
softmax_neg_val
);
}
Ok(())
}
#[test]
fn test_gumbel_softmax_properties() -> TorshResult<()> {
let logits = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let result = gumbel_softmax(&logits, 1.0, false, 1e-10, 0)?;
let result_data = result.data()?;
let sum: f32 = result_data.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"Gumbel softmax sum should be 1, got {}",
sum
);
for &val in result_data.iter() {
assert!(
val > 0.0 && val < 1.0,
"Gumbel softmax output {} not in (0,1)",
val
);
}
Ok(())
}
#[test]
fn test_gumbel_softmax_temperature() -> TorshResult<()> {
let logits = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let result_high_temp = gumbel_softmax(&logits, 2.0, false, 1e-10, 0)?;
let result_low_temp = gumbel_softmax(&logits, 0.5, false, 1e-10, 0)?;
let high_temp_data = result_high_temp.data()?;
let low_temp_data = result_low_temp.data()?;
let sum_high: f32 = high_temp_data.iter().sum();
let sum_low: f32 = low_temp_data.iter().sum();
assert!((sum_high - 1.0).abs() < 1e-5);
assert!((sum_low - 1.0).abs() < 1e-5);
for &val in low_temp_data.iter() {
assert!(val >= 0.0 && val <= 1.0);
}
for &val in high_temp_data.iter() {
assert!(val >= 0.0 && val <= 1.0);
}
Ok(())
}
#[test]
fn test_softmax_numerical_stability() -> TorshResult<()> {
let input = from_vec(vec![100.0, 101.0, 102.0], &[3], DeviceType::Cpu)?;
let output = softmax(&input, 0, None)?;
let output_data = output.data()?;
let sum: f32 = output_data.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
for &val in output_data.iter() {
assert!(
val.is_finite(),
"Softmax produced non-finite value: {}",
val
);
assert!(val >= 0.0 && val <= 1.0);
}
Ok(())
}
#[test]
fn test_log_softmax_numerical_stability() -> TorshResult<()> {
let input = from_vec(vec![100.0, 101.0, 102.0], &[3], DeviceType::Cpu)?;
let output = log_softmax(&input, 0, None)?;
let output_data = output.data()?;
for &val in output_data.iter() {
assert!(
val.is_finite(),
"Log softmax produced non-finite value: {}",
val
);
assert!(val <= 0.0, "Log softmax should be ≤ 0, got {}", val);
}
let max_idx = output_data
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("numeric comparison should succeed"))
.map(|(i, _)| i)
.expect("operation should succeed");
assert_eq!(
max_idx, 2,
"Largest input should have largest log probability"
);
Ok(())
}
}