rten_vecmath/
lib.rs

1//! SIMD-vectorized implementations of operations used in neural networks.
2//!
3//! These implementations are used as kernels for operations in the
4//! [rten](https://crates.io/crates/rten) crate.
5//!
6//! ## Constructing and dispatching operations
7//!
8//! The operations are implemented by structs which implement the SIMD operation
9//! traits from [rten-simd](rten_simd). To apply an operation to data, first
10//! construct the operation using the struct from this crate, then use a
11//! dispatch method from the [`SimdOp`](rten_simd::SimdOp) or
12//! [`SimdUnaryOp`](rten_simd::SimdUnaryOp) traits to execute
13//! the operation.
14//!
15//! ## In-place and non in-place operations
16//!
17//! Some operations support both updating data in place or reading input from
18//! one slice and writing to another. For unary operations this is controlled by
19//! dispatching with either [`map`](rten_simd::SimdUnaryOp::map) or
20//! [`map_mut`](rten_simd::SimdUnaryOp::map_mut). For other operations
21//! this is handled by exposing different constructors for the in-place and
22//! mutating cases, such as [`Softmax::new`] and [`Softmax::new_mut`].
23//!
24//! For operations which use a separate source and destination, the destination
25//! is expected to be an uninitialized slice (`[MaybeUninit<T>]`). This allows
26//! the caller to control allocation of the buffer and avoid the overhead of
27//! initializing elements which the operation will overwrite. The [`ExtendInit`]
28//! trait provides a safe API for the common task of filling a new `Vec` with
29//! the result of the operation.
30//!
31//! ## Examples
32//!
33//! ### Applying a vectorized unary function
34//!
35//! ```
36//! use std::mem::MaybeUninit;
37//!
38//! use rten_simd::SimdUnaryOp;
39//! use rten_vecmath::Erf;
40//!
41//! // Apply the error function to each element of `data`.
42//! let mut data = [1., 0.5, 2.0];
43//! let erf_op = Erf {};
44//! erf_op.map_mut(&mut data);
45//!
46//! // Apply the error function to each element of `src`, writing to `dest`.
47//! let src = [1., 0.5, 2.0];
48//! let mut dest = [MaybeUninit::uninit(); 3];
49//! erf_op.map(&src, &mut dest);
50//! ```
51//!
52//! ### Applying softmax in place
53//!
54//! This example applies the softmax function in-place to a mutable slice.
55//!
56//! ```
57//! use rten_simd::SimdOp;
58//! use rten_vecmath::Softmax;
59//!
60//! let mut data = [1., 0.5, 2.0];
61//! Softmax::new_mut(&mut data).dispatch();
62//! ```
63//!
64//! ### Applying softmax with separate input and output buffers
65//!
66//! This example reads data from an input and writes to an uninitialized output
67//! buffer (`&mut [MaybeUninit<f32>]`), obtained from the uninitialized portion
68//! of a `Vec<f32>`. To update the length of the `Vec<f32>` after it is
69//! initialized, the helper `ExtendInit` trait is used.
70//!
71//! ```
72//! use rten_simd::SimdOp;
73//! use rten_vecmath::{Softmax, ExtendInit};
74//!
75//! let data = [1., 0.5, 2.0];
76//! let mut output = Vec::with_capacity(data.len());
77//! output.extend_init(|output_uninit| {
78//!     // `output_uninit` is the uninitialized part of `output`, as returned by
79//!     // `output.spare_capacity_mut()`.
80//!     //
81//!     // The `dispatch` call initializes it and returns the initialized slice.
82//!     Softmax::new(&data, output_uninit).dispatch()
83//! });
84//! assert_eq!(output.len(), 3);
85//! ```
86//!
87//! ### Computing the sum of a list of floats
88//!
89//! ```
90//! use rten_simd::SimdOp;
91//! use rten_vecmath::Sum;
92//!
93//! let data = [1., 0.5, 2.0];
94//! let sum = Sum::new(&data).dispatch();
95//! ```
96
97mod erf;
98mod exp;
99mod min_max;
100mod normalize;
101mod quantize;
102mod relu;
103mod sin_cos;
104mod softmax;
105mod sum;
106mod tanh;
107
108#[cfg(test)]
109mod ulp;
110
111#[cfg(test)]
112mod testing;
113
114mod extend_init;
115
116// Unary functions.
117pub use erf::{ApproxGelu, Erf, Gelu};
118pub use exp::{Elu, Exp, Sigmoid, Silu, Swish};
119pub use quantize::Quantize;
120pub use relu::LeakyRelu;
121pub use sin_cos::{Cos, Sin};
122pub use tanh::Tanh;
123
124// Normalization and reduction functions.
125pub use min_max::{MaxNum, MinMax, MinNum};
126pub use normalize::{Normalize, NormalizeOptions};
127pub use softmax::Softmax;
128pub use sum::{Sum, SumSquare, SumSquareSub};
129
130// Utilities
131pub use extend_init::ExtendInit;