pub struct Linear<const I: usize, const O: usize> {
pub weight: Tensor2D<I, O, NoTape>,
pub bias: Tensor1D<O, NoTape>,
}
Expand description
A linear transformation of the form xW + b
, where W
is a matrix, x
is a vector or matrix,
and b
is a vector. If x
is a matrix this does matrix multiplication.
Implements Module for both vectors of size [I]
and batches of vectors of size [B, I]
.
Implements Randomize to set weight & bias to random numbers drawn from a distribution.
Example usage:
Linear<5, 2>
can act on vectors with 5 elements, and results in vectors with 2 elements.
let model: Linear<5, 2> = Default::default();
assert_eq!(model.weight.data(), &[[0.0; 2]; 5]);
assert_eq!(model.bias.data(), &[0.0; 2]);
let x: Tensor1D<5> = Default::default();
let y: Tensor1D<2> = model.forward(x);
Fields
weight: Tensor2D<I, O, NoTape>
Weight matrix, shape (I, O)
bias: Tensor1D<O, NoTape>
Bias vector, shape (O, )
Trait Implementations
sourceimpl<const I: usize, const O: usize> CanUpdateWithGradients for Linear<I, O>
impl<const I: usize, const O: usize> CanUpdateWithGradients for Linear<I, O>
fn update<G: GradientProvider>(&mut self, grads: &mut G)
sourceimpl<const I: usize, const O: usize> LoadFromNpz for Linear<I, O>
impl<const I: usize, const O: usize> LoadFromNpz for Linear<I, O>
sourcefn read<R>(
&mut self,
filename_prefix: &String,
r: &mut ZipArchive<R>
) -> Result<(), NpzError> where
R: Read + Seek,
fn read<R>(
&mut self,
filename_prefix: &String,
r: &mut ZipArchive<R>
) -> Result<(), NpzError> where
R: Read + Seek,
Reads self.weight
from {filename_prefix}weight.npy
and self.bias
from {filename_prefix}bias.npy
using numpy::read().
sourceimpl<const B: usize, const I: usize, const O: usize, H: Tape> Module<Tensor2D<B, I, H>> for Linear<I, O>
impl<const B: usize, const I: usize, const O: usize, H: Tape> Module<Tensor2D<B, I, H>> for Linear<I, O>
Auto Trait Implementations
impl<const I: usize, const O: usize> RefUnwindSafe for Linear<I, O>
impl<const I: usize, const O: usize> Send for Linear<I, O>
impl<const I: usize, const O: usize> Sync for Linear<I, O>
impl<const I: usize, const O: usize> Unpin for Linear<I, O>
impl<const I: usize, const O: usize> UnwindSafe for Linear<I, O>
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more