pub struct Linear<const I: usize, const O: usize> {
pub weight: Tensor2D<O, I, NoTape>,
pub bias: Tensor1D<O, NoTape>,
}
Expand description
A linear transformation of the form x * transpose(W) + b
, where W
is a matrix, x
is a vector or matrix,
and b
is a vector. If x
is a matrix this does matrix multiplication.
Implements:
- Module for vectors like Tensor1D
- Module for matrices like Tensor2D<B, I>, where
B
is batch size - ResetParams to set weight & bias to uniform random numbers from a distribution based on
I
. - CanUpdateWithGradients
- SaveToNpz
- LoadFromNpz
Generics:
I
The input size of vectors & matrices.O
The output size of vectors & matrices.
Example usage:
Linear<5, 2>
can act on vectors with 5 elements, and results in vectors with 2 elements.
let model: Linear<5, 2> = Default::default();
assert_eq!(model.weight.data(), &[[0.0; 5]; 2]);
assert_eq!(model.bias.data(), &[0.0; 2]);
let x: Tensor1D<5> = Default::default();
let y: Tensor1D<2> = model.forward(x);
assert_eq!(y.data(), &[0.0; 2]);
Fields
weight: Tensor2D<O, I, NoTape>
Transposed weight matrix, shape (O, I)
bias: Tensor1D<O, NoTape>
Bias vector, shape (O, )
Trait Implementations
sourceimpl<const I: usize, const O: usize> CanUpdateWithGradients for Linear<I, O>
impl<const I: usize, const O: usize> CanUpdateWithGradients for Linear<I, O>
fn update<G: GradientProvider>(&mut self, grads: &mut G)
sourceimpl<const I: usize, const O: usize> LoadFromNpz for Linear<I, O>
impl<const I: usize, const O: usize> LoadFromNpz for Linear<I, O>
sourceimpl<const B: usize, const I: usize, const O: usize, H: Tape> Module<Tensor2D<B, I, H>> for Linear<I, O>
impl<const B: usize, const I: usize, const O: usize, H: Tape> Module<Tensor2D<B, I, H>> for Linear<I, O>
sourceimpl<const I: usize, const O: usize> ResetParams for Linear<I, O>
impl<const I: usize, const O: usize> ResetParams for Linear<I, O>
sourcefn reset_params<R: Rng>(&mut self, rng: &mut R)
fn reset_params<R: Rng>(&mut self, rng: &mut R)
Initializes self.weight
and self.bias
from a Uniform distribution
between [-1 / sqrt(I), 1 / sqrt(I)].
This uses Randomize::randomize() to set the values of the tensor.
Auto Trait Implementations
impl<const I: usize, const O: usize> RefUnwindSafe for Linear<I, O>
impl<const I: usize, const O: usize> Send for Linear<I, O>
impl<const I: usize, const O: usize> Sync for Linear<I, O>
impl<const I: usize, const O: usize> Unpin for Linear<I, O>
impl<const I: usize, const O: usize> UnwindSafe for Linear<I, O>
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more