candle-nn 0.10.2

Minimalist ML framework.
Documentation
use candle::{Result, Tensor};

/// Sample according to the Gumbel-Softmax distribution.
pub fn gumbel_softmax<D: candle::shape::Dim>(
    logits: &Tensor,
    temperature: f64,
    dim: D,
) -> Result<Tensor> {
    if temperature <= 0.0 {
        logits.argmax(dim)
    } else {
        // Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
        let logits = logits.to_dtype(candle::DType::F32)?;
        let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
        if temperature == 1.0 {
            let sampled = (logits - minus_g)?.argmax(dim)?;
            Ok(sampled)
        } else {
            let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
            Ok(sampled)
        }
    }
}