Softmax

Struct Softmax 

Source
pub struct Softmax<'src, 'dst> { /* private fields */ }
Expand description

Computes the softmax function over a slice of floats.

The implementation uses a three-pass approach for numerical stability. See https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html. and https://arxiv.org/abs/2001.04438.

Implementations§

Source§

impl<'src, 'dst> Softmax<'src, 'dst>

Source

pub fn new(input: &'src [f32], output: &'dst mut [MaybeUninit<f32>]) -> Self

Construct a softmax operation which reads input and writes to to output.

Source

pub fn new_mut(input: &'dst mut [f32]) -> Self
where 'dst: 'src,

Construct a softmax operation which updates input in place.

Source

pub fn flush_nans_to_zero(self, flush: bool) -> Self

Replace NaN values in the output with zeros.

This option exists to changing handling of the case where the input values are all negative infinity. In that case the normal output would be NaN.

In the context of attention operations which use negative infinity to represent masked token positions, it is preferable to produce zeros as the output if all input positions are masked. See https://github.com/pytorch/pytorch/issues/41508.

Trait Implementations§

Source§

impl<'dst> SimdOp for Softmax<'_, 'dst>

Source§

type Output = &'dst mut [f32]

The normalized elements.

Source§

fn eval<I: Isa>(self, isa: I) -> Self::Output

Evaluate the operation using the given instruction set.
Source§

fn dispatch(self) -> Self::Output
where Self: Sized,

Dispatch this operation using the preferred ISA for the current platform.

Auto Trait Implementations§

§

impl<'src, 'dst> Freeze for Softmax<'src, 'dst>

§

impl<'src, 'dst> RefUnwindSafe for Softmax<'src, 'dst>

§

impl<'src, 'dst> Send for Softmax<'src, 'dst>

§

impl<'src, 'dst> Sync for Softmax<'src, 'dst>

§

impl<'src, 'dst> Unpin for Softmax<'src, 'dst>

§

impl<'src, 'dst> !UnwindSafe for Softmax<'src, 'dst>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.