use crate::error::TlBackendResult;
use crate::Scirs2Tensor;
use scirs2_core::ndarray::{Array, ArrayD, Axis};
use scirs2_core::random::arrays::OptimizedArrayRandom;
use scirs2_core::random::prelude::*;
#[derive(Debug, Clone, Copy)]
pub struct SteConfig {
pub threshold: f64,
pub clip_gradients: bool,
}
impl Default for SteConfig {
fn default() -> Self {
Self {
threshold: 0.5,
clip_gradients: false,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct GumbelSoftmaxConfig {
pub temperature: f64,
pub hard: bool,
pub seed: Option<u64>,
}
impl Default for GumbelSoftmaxConfig {
fn default() -> Self {
Self {
temperature: 1.0,
hard: false,
seed: None,
}
}
}
#[derive(Debug, Clone)]
pub enum QuantifierMode {
Hard,
Smooth { temperature: f64 },
Probabilistic,
LogSpace(crate::scoring::ScoringConfig),
}
impl Default for QuantifierMode {
fn default() -> Self {
Self::Smooth { temperature: 1.0 }
}
}
pub fn ste_threshold(input: &Scirs2Tensor, config: SteConfig) -> TlBackendResult<Scirs2Tensor> {
let output = input.mapv(|x| if x >= config.threshold { 1.0 } else { 0.0 });
Ok(output)
}
pub fn ste_threshold_backward(
grad_output: &Scirs2Tensor,
_input: &Scirs2Tensor,
config: SteConfig,
) -> TlBackendResult<Scirs2Tensor> {
if config.clip_gradients {
Ok(grad_output.mapv(|g| g.clamp(-1.0, 1.0)))
} else {
Ok(grad_output.clone())
}
}
pub fn gumbel_softmax(
logits: &Scirs2Tensor,
config: GumbelSoftmaxConfig,
) -> TlBackendResult<Scirs2Tensor> {
let gumbel_noise = sample_gumbel(logits.shape(), config.seed)?;
let noisy_logits = logits + &gumbel_noise;
let soft_samples = softmax_temperature(&noisy_logits, config.temperature)?;
if config.hard {
let hard_samples = argmax_to_onehot(&soft_samples)?;
Ok(hard_samples)
} else {
Ok(soft_samples)
}
}
pub fn gumbel_softmax_backward(
grad_output: &Scirs2Tensor,
soft_samples: &Scirs2Tensor,
config: GumbelSoftmaxConfig,
) -> TlBackendResult<Scirs2Tensor> {
let last_axis = soft_samples.ndim() - 1;
let dot_product = (soft_samples * grad_output)
.sum_axis(Axis(last_axis))
.insert_axis(Axis(last_axis));
let grad_logits = soft_samples * &(grad_output - &dot_product);
Ok(grad_logits.mapv(|g| g / config.temperature))
}
pub fn soft_exists(
input: &Scirs2Tensor,
axis: Option<usize>,
mode: QuantifierMode,
) -> TlBackendResult<Scirs2Tensor> {
match mode {
QuantifierMode::Hard => {
if let Some(ax) = axis {
Ok(input.map_axis(Axis(ax), |slice| {
slice.iter().fold(0.0_f64, |a, &b| a.max(b))
}))
} else {
let max_val = input.iter().fold(0.0_f64, |a, &b| a.max(b));
Ok(Array::from_elem(vec![], max_val))
}
}
QuantifierMode::Smooth { temperature } => {
smooth_max(input, axis, temperature)
}
QuantifierMode::Probabilistic => {
probabilistic_exists(input, axis)
}
QuantifierMode::LogSpace(config) => {
crate::scoring::LogSpaceAggregator::new(config)
.log_sum_exp(input, axis)
.map_err(|e| {
crate::error::TlBackendError::NumericalError(crate::error::NumericalError {
kind: crate::error::NumericalErrorKind::Overflow,
location: e.to_string(),
values: None,
})
})
}
}
}
pub fn soft_exists_backward(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
_output: &Scirs2Tensor,
axis: Option<usize>,
mode: QuantifierMode,
) -> TlBackendResult<Scirs2Tensor> {
match mode {
QuantifierMode::Hard => {
argmax_gradient(grad_output, input, axis)
}
QuantifierMode::Smooth { temperature } => {
smooth_max_gradient(grad_output, input, temperature, axis)
}
QuantifierMode::Probabilistic => {
probabilistic_exists_gradient(grad_output, input, axis)
}
QuantifierMode::LogSpace(config) => {
smooth_max_gradient(grad_output, input, config.temperature, axis)
}
}
}
pub fn soft_forall(
input: &Scirs2Tensor,
axis: Option<usize>,
mode: QuantifierMode,
) -> TlBackendResult<Scirs2Tensor> {
match mode {
QuantifierMode::Hard => {
if let Some(ax) = axis {
Ok(input.map_axis(Axis(ax), |slice| {
slice.iter().fold(1.0_f64, |a, &b| a.min(b))
}))
} else {
let min_val = input.iter().fold(1.0_f64, |a, &b| a.min(b));
Ok(Array::from_elem(vec![], min_val))
}
}
QuantifierMode::Smooth { temperature } => {
smooth_min(input, axis, temperature)
}
QuantifierMode::Probabilistic => {
probabilistic_forall(input, axis)
}
QuantifierMode::LogSpace(config) => {
let neg_input = input.mapv(|x| -x / config.temperature);
let lse = crate::scoring::LogSpaceAggregator::new(config.clone())
.log_sum_exp(&neg_input, axis)
.map_err(|e| {
crate::error::TlBackendError::NumericalError(crate::error::NumericalError {
kind: crate::error::NumericalErrorKind::Overflow,
location: e.to_string(),
values: None,
})
})?;
Ok(lse.mapv(|v| -v * config.temperature))
}
}
}
pub fn soft_forall_backward(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
output: &Scirs2Tensor,
axis: Option<usize>,
mode: QuantifierMode,
) -> TlBackendResult<Scirs2Tensor> {
match mode {
QuantifierMode::Hard => {
argmin_gradient(grad_output, input, axis)
}
QuantifierMode::Smooth { temperature } => {
smooth_min_gradient(grad_output, input, temperature, axis)
}
QuantifierMode::Probabilistic => {
probabilistic_forall_gradient(grad_output, input, output, axis)
}
QuantifierMode::LogSpace(config) => {
smooth_min_gradient(grad_output, input, config.temperature, axis)
}
}
}
fn sample_gumbel(shape: &[usize], seed: Option<u64>) -> TlBackendResult<Scirs2Tensor> {
use scirs2_core::ndarray::IxDyn;
let uniform_dist = Uniform::new(1e-10, 1.0 - 1e-10).expect("range 1e-10..=1-1e-10 is valid"); let dyn_shape = IxDyn(shape);
let gumbel = if let Some(s) = seed {
let mut rng = seeded_rng(s);
ArrayD::random_bulk(dyn_shape, uniform_dist, &mut rng)
} else {
let mut rng = thread_rng();
ArrayD::random_bulk(dyn_shape, uniform_dist, &mut rng)
};
let gumbel = gumbel.mapv(|u: f64| -(-u.ln()).ln());
Ok(gumbel)
}
fn softmax_temperature(logits: &Scirs2Tensor, temperature: f64) -> TlBackendResult<Scirs2Tensor> {
let scaled = logits.mapv(|x| x / temperature);
let last_axis = scaled.ndim() - 1;
let max_vals = scaled
.map_axis(Axis(last_axis), |slice| {
slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
})
.insert_axis(Axis(last_axis));
let exp_vals = (&scaled - &max_vals).mapv(|x| x.exp());
let sum_exp = exp_vals
.sum_axis(Axis(last_axis))
.insert_axis(Axis(last_axis));
Ok(&exp_vals / &sum_exp)
}
fn argmax_to_onehot(soft_samples: &Scirs2Tensor) -> TlBackendResult<Scirs2Tensor> {
let last_axis = soft_samples.ndim() - 1;
let mut onehot = ArrayD::zeros(soft_samples.raw_dim());
let n_classes = soft_samples.len_of(Axis(last_axis));
for i in 0..soft_samples.len() / n_classes {
let mut flat_idx = i;
let mut indices = vec![0; soft_samples.ndim()];
for dim in (0..last_axis).rev() {
let size = soft_samples.len_of(Axis(dim));
indices[dim] = flat_idx % size;
flat_idx /= size;
}
let mut max_val = f64::NEG_INFINITY;
let mut max_idx = 0;
for j in 0..n_classes {
indices[last_axis] = j;
let val = soft_samples[&indices[..]];
if val > max_val {
max_val = val;
max_idx = j;
}
}
indices[last_axis] = max_idx;
onehot[&indices[..]] = 1.0;
}
Ok(onehot)
}
fn smooth_max(
input: &Scirs2Tensor,
axis: Option<usize>,
temperature: f64,
) -> TlBackendResult<Scirs2Tensor> {
let scaled = input.mapv(|x| x / temperature);
if let Some(ax) = axis {
let max_vals = scaled.map_axis(Axis(ax), |slice| {
slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
});
let max_vals_broadcast = max_vals.clone().insert_axis(Axis(ax));
let exp_vals = (&scaled - &max_vals_broadcast).mapv(|x| x.exp());
let sum_exp = exp_vals.sum_axis(Axis(ax));
let log_sum_exp = &max_vals + &sum_exp.mapv(|x| x.ln());
Ok(log_sum_exp.mapv(|x| x * temperature))
} else {
let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_vals = scaled.mapv(|x| (x - max_val).exp());
let sum_exp: f64 = exp_vals.iter().sum();
let result = temperature * (max_val + sum_exp.ln());
Ok(Array::from_elem(vec![], result))
}
}
fn smooth_max_gradient(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
temperature: f64,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
let scaled = input.mapv(|x| x / temperature);
if let Some(ax) = axis {
let max_vals = scaled
.map_axis(Axis(ax), |slice| {
slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
})
.insert_axis(Axis(ax));
let exp_vals = (&scaled - &max_vals).mapv(|x| x.exp());
let sum_exp = exp_vals.sum_axis(Axis(ax)).insert_axis(Axis(ax));
let weights = &exp_vals / &sum_exp;
Ok(&weights * grad_output)
} else {
let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_vals = scaled.mapv(|x| (x - max_val).exp());
let sum_exp: f64 = exp_vals.iter().sum();
let weights = exp_vals.mapv(|x| x / sum_exp);
let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
Ok(weights.mapv(|w| w * grad_scalar))
}
}
fn smooth_min(
input: &Scirs2Tensor,
axis: Option<usize>,
temperature: f64,
) -> TlBackendResult<Scirs2Tensor> {
let negated = input.mapv(|x| -x);
let result = smooth_max(&negated, axis, temperature)?;
Ok(result.mapv(|x| -x))
}
fn smooth_min_gradient(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
temperature: f64,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
let negated = input.mapv(|x| -x);
let grad = smooth_max_gradient(grad_output, &negated, temperature, axis)?;
Ok(grad.mapv(|g| -g))
}
fn probabilistic_exists(
input: &Scirs2Tensor,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
let one_minus_input = input.mapv(|x| 1.0 - x);
if let Some(ax) = axis {
let product = one_minus_input.map_axis(Axis(ax), |slice| slice.iter().product::<f64>());
Ok(product.mapv(|p| 1.0 - p))
} else {
let product: f64 = one_minus_input.iter().product();
Ok(Array::from_elem(vec![], 1.0 - product))
}
}
fn probabilistic_exists_gradient(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
let one_minus_input = input.mapv(|x| 1.0 - x);
if let Some(ax) = axis {
let mut grad = ArrayD::zeros(input.raw_dim());
for i in 0..input.len_of(Axis(ax)) {
let mut slice = input.index_axis(Axis(ax), i).to_owned();
let product: f64 = one_minus_input.iter().product();
let elem_val = 1.0 - input.index_axis(Axis(ax), i).iter().next().unwrap_or(&0.0);
let grad_elem = if elem_val.abs() > 1e-10 {
product / elem_val
} else {
0.0
};
slice.fill(grad_elem);
grad.index_axis_mut(Axis(ax), i).assign(&slice);
}
Ok(&grad * grad_output)
} else {
let product: f64 = one_minus_input.iter().product();
let grad = input.mapv(|x| {
let denom = 1.0 - x;
if denom.abs() > 1e-10 {
product / denom
} else {
0.0
}
});
let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
Ok(grad.mapv(|g| g * grad_scalar))
}
}
fn probabilistic_forall(
input: &Scirs2Tensor,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
if let Some(ax) = axis {
Ok(input.map_axis(Axis(ax), |slice| slice.iter().product::<f64>()))
} else {
let product: f64 = input.iter().product();
Ok(Array::from_elem(vec![], product))
}
}
fn probabilistic_forall_gradient(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
output: &Scirs2Tensor,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
if let Some(_ax) = axis {
let grad = output / input;
Ok(&grad * grad_output)
} else {
let output_val = output.iter().next().unwrap_or(&0.0);
let grad = input.mapv(|x| if x.abs() > 1e-10 { output_val / x } else { 0.0 });
let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
Ok(grad.mapv(|g| g * grad_scalar))
}
}
fn argmax_gradient(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
let mut grad = ArrayD::zeros(input.raw_dim());
if let Some(ax) = axis {
for i in 0..input.len_of(Axis(ax)) {
let slice = input.index_axis(Axis(ax), i);
let max_idx = slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
grad.index_axis_mut(Axis(ax), i)[max_idx] = *grad_output
.index_axis(Axis(ax), i)
.iter()
.next()
.unwrap_or(&0.0);
}
} else {
let max_idx = input
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
grad.as_slice_mut().expect("grad has contiguous layout")[max_idx] =
*grad_output.iter().next().unwrap_or(&0.0);
}
Ok(grad)
}
fn argmin_gradient(
grad_output: &Scirs2Tensor,
input: &Scirs2Tensor,
axis: Option<usize>,
) -> TlBackendResult<Scirs2Tensor> {
let mut grad = ArrayD::zeros(input.raw_dim());
if let Some(ax) = axis {
for i in 0..input.len_of(Axis(ax)) {
let slice = input.index_axis(Axis(ax), i);
let min_idx = slice
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
grad.index_axis_mut(Axis(ax), i)[min_idx] = *grad_output
.index_axis(Axis(ax), i)
.iter()
.next()
.unwrap_or(&0.0);
}
} else {
let min_idx = input
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
grad.as_slice_mut().expect("grad has contiguous layout")[min_idx] =
*grad_output.iter().next().unwrap_or(&0.0);
}
Ok(grad)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_ste_threshold_forward() {
let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
let config = SteConfig::default();
let output = ste_threshold(&input, config).expect("unwrap");
let expected = array![[0.0, 1.0], [0.0, 1.0]].into_dyn();
assert_eq!(output, expected);
}
#[test]
fn test_ste_threshold_backward() {
let grad_output = array![[1.0, 2.0], [3.0, 4.0]].into_dyn();
let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
let config = SteConfig::default();
let grad_input = ste_threshold_backward(&grad_output, &input, config).expect("unwrap");
assert_eq!(grad_input, grad_output);
}
#[test]
fn test_ste_gradient_clipping() {
let grad_output = array![[5.0, -3.0], [0.5, -10.0]].into_dyn();
let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
let config = SteConfig {
threshold: 0.5,
clip_gradients: true,
};
let grad_input = ste_threshold_backward(&grad_output, &input, config).expect("unwrap");
let expected = array![[1.0, -1.0], [0.5, -1.0]].into_dyn();
assert_eq!(grad_input, expected);
}
#[test]
fn test_gumbel_softmax_deterministic() {
let logits = array![[1.0, 2.0, 3.0]].into_dyn();
let config = GumbelSoftmaxConfig {
temperature: 1.0,
hard: false,
seed: Some(42),
};
let samples = gumbel_softmax(&logits, config).expect("unwrap");
assert_eq!(samples.shape(), &[1, 3]);
let sum: f64 = samples.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
for &val in samples.iter() {
assert!((0.0..=1.0).contains(&val));
}
}
#[test]
fn test_gumbel_softmax_hard_mode() {
let logits = array![[1.0, 5.0, 2.0]].into_dyn();
let config = GumbelSoftmaxConfig {
temperature: 0.1,
hard: true,
seed: Some(123),
};
let samples = gumbel_softmax(&logits, config).expect("unwrap");
let sum: f64 = samples.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
let max_val = samples.iter().fold(0.0_f64, |a, &b| a.max(b));
assert!(max_val >= 0.9);
}
#[test]
fn test_soft_exists_smooth() {
let input = array![[0.1, 0.3], [0.2, 0.9]].into_dyn();
let mode = QuantifierMode::Smooth { temperature: 1.0 };
let output = soft_exists(&input, Some(1), mode).expect("unwrap");
assert_eq!(output.shape(), &[2]);
assert!(
output[0] >= 0.85 && output[0] <= 0.95,
"output[0] = {} not in [0.85, 0.95]",
output[0]
);
assert!(
output[1] >= 1.25 && output[1] <= 1.35,
"output[1] = {} not in [1.25, 1.35]",
output[1]
);
}
#[test]
fn test_soft_exists_probabilistic() {
let input = array![[0.5, 0.5]].into_dyn();
let mode = QuantifierMode::Probabilistic;
let output = soft_exists(&input, Some(1), mode).expect("unwrap");
assert!((output[0] - 0.75).abs() < 1e-6);
}
#[test]
fn test_soft_forall_probabilistic() {
let input = array![[0.5, 0.5]].into_dyn();
let mode = QuantifierMode::Probabilistic;
let output = soft_forall(&input, Some(1), mode).expect("unwrap");
assert!((output[0] - 0.25).abs() < 1e-6);
}
#[test]
fn test_probabilistic_forall_gradient() {
let input = array![[0.5, 0.8]].into_dyn();
let output = array![0.4].into_dyn(); let grad_output = array![1.0].into_dyn();
let grad_input =
probabilistic_forall_gradient(&grad_output, &input, &output, Some(1)).expect("unwrap");
assert!((grad_input[[0, 0]] - 0.8).abs() < 1e-6);
assert!((grad_input[[0, 1]] - 0.5).abs() < 1e-6);
}
#[test]
fn test_smooth_max_vs_hard_max() {
let input = array![[1.0, 2.0, 3.0]].into_dyn();
let hard = soft_exists(&input, Some(1), QuantifierMode::Hard).expect("unwrap");
assert!((hard[0] - 3.0).abs() < 1e-6);
let smooth = soft_exists(
&input,
Some(1),
QuantifierMode::Smooth { temperature: 0.01 },
)
.expect("unwrap");
assert!((smooth[0] - 3.0).abs() < 0.1); }
#[test]
fn test_gumbel_noise_properties() {
let shape = &[1000];
let noise = sample_gumbel(shape, Some(42)).expect("unwrap");
let mean: f64 = noise.iter().sum::<f64>() / noise.len() as f64;
assert!((mean - 0.5772).abs() < 0.1);
for &val in noise.iter() {
assert!(val.is_finite());
}
}
}