#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ActivationKind {
ReLU,
LeakyReLU,
Sigmoid,
Tanh,
GELU,
SiLU,
Softplus,
Mish,
}
pub fn normalize_softmax_dim(ndim: usize, dim: isize) -> Option<usize> {
if dim >= 0 {
let d = dim as usize;
if d < ndim { Some(d) } else { None }
} else {
let d = ndim as isize + dim;
if d >= 0 { Some(d as usize) } else { None }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_softmax_dim() {
assert_eq!(normalize_softmax_dim(3, 1), Some(1));
assert_eq!(normalize_softmax_dim(3, -1), Some(2));
assert_eq!(normalize_softmax_dim(3, 3), None);
assert_eq!(normalize_softmax_dim(3, -4), None);
}
}