pub struct ParamsBase<S, D = Ix2>{ /* private fields */ }
Expand description
The ParamsBase
struct is a generic container for a set of weights and biases for a
model. The implementation is designed around the ArrayBase
type from the
ndarray
crate, which allows for flexible and efficient storage of multi-dimensional
arrays.
Implementations§
Source§impl<A, S, D> ParamsBase<S, D>
impl<A, S, D> ParamsBase<S, D>
Sourcepub const fn new(
bias: ArrayBase<S, D::Smaller>,
weights: ArrayBase<S, D>,
) -> Self
pub const fn new( bias: ArrayBase<S, D::Smaller>, weights: ArrayBase<S, D>, ) -> Self
create a new instance of the ParamsBase
with the given bias and weights
Sourcepub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
pub fn from_elems<Sh>(shape: Sh, elem: A) -> Self
create a new instance of the [ModelParams
] from the given shape and element;
Sourcepub fn default<Sh>(shape: Sh) -> Self
pub fn default<Sh>(shape: Sh) -> Self
create an instance of the parameters with all values set to the default value
Sourcepub fn zeros<Sh>(shape: Sh) -> Self
pub fn zeros<Sh>(shape: Sh) -> Self
create an instance of the parameters with all values set to zero
Sourcepub const fn bias(&self) -> &ArrayBase<S, D::Smaller>
pub const fn bias(&self) -> &ArrayBase<S, D::Smaller>
returns an immutable reference to the bias
Sourcepub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller>
pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller>
returns a mutable reference to the bias
Sourcepub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D>
pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D>
returns a mutable reference to the weights
Sourcepub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
assign the bias
Sourcepub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
pub fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
assign the weights
Sourcepub fn replace_bias(
&mut self,
bias: ArrayBase<S, D::Smaller>,
) -> ArrayBase<S, D::Smaller>
pub fn replace_bias( &mut self, bias: ArrayBase<S, D::Smaller>, ) -> ArrayBase<S, D::Smaller>
replace the bias and return the previous state; uses replace
Sourcepub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D>
pub fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D>
replace the weights and return the previous state; uses replace
Sourcepub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self
pub fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self
set the weights
Sourcepub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> Result<Z>
pub fn backward<X, Y, Z>(&mut self, input: &X, grad: &Y, lr: A) -> Result<Z>
perform a single backpropagation step
Sourcepub fn iter(&self) -> Iter<'_, A, D> ⓘwhere
D: RemoveAxis,
S: Data,
pub fn iter(&self) -> Iter<'_, A, D> ⓘwhere
D: RemoveAxis,
S: Data,
an iterator of the parameters; the created iterator zips together an axis iterator over the columns of the weights and an iterator over the bias
Sourcepub fn iter_mut(
&mut self,
) -> Zip<AxisIterMut<'_, A, D::Smaller>, IterMut<'_, A, D::Smaller>>where
D: RemoveAxis,
S: DataMut,
pub fn iter_mut(
&mut self,
) -> Zip<AxisIterMut<'_, A, D::Smaller>, IterMut<'_, A, D::Smaller>>where
D: RemoveAxis,
S: DataMut,
a mutable iterator of the parameters
Sourcepub fn iter_bias(&self) -> Iter<'_, A, D::Smaller>where
S: Data,
pub fn iter_bias(&self) -> Iter<'_, A, D::Smaller>where
S: Data,
returns an iterator over the bias
Sourcepub fn iter_bias_mut(&mut self) -> IterMut<'_, A, D::Smaller>where
S: DataMut,
pub fn iter_bias_mut(&mut self) -> IterMut<'_, A, D::Smaller>where
S: DataMut,
returns a mutable iterator over the bias
Sourcepub fn iter_weights(&self) -> Iter<'_, A, D>where
S: Data,
pub fn iter_weights(&self) -> Iter<'_, A, D>where
S: Data,
returns an iterator over the weights
Sourcepub fn iter_weights_mut(&mut self) -> IterMut<'_, A, D>where
S: DataMut,
pub fn iter_weights_mut(&mut self) -> IterMut<'_, A, D>where
S: DataMut,
returns a mutable iterator over the weights; see iter_mut
for more
Sourcepub fn is_empty(&self) -> bool
pub fn is_empty(&self) -> bool
returns true if both the weights and bias are empty; uses is_empty
Sourcepub fn is_weights_empty(&self) -> bool
pub fn is_weights_empty(&self) -> bool
returns true if the weights are empty
Sourcepub fn is_bias_empty(&self) -> bool
pub fn is_bias_empty(&self) -> bool
returns true if the bias is empty
Sourcepub fn count_weight(&self) -> usize
pub fn count_weight(&self) -> usize
the total number of elements within the weight tensor
Sourcepub fn count_bias(&self) -> usize
pub fn count_bias(&self) -> usize
the total number of elements within the bias tensor
Sourcepub fn shape(&self) -> &[usize]
pub fn shape(&self) -> &[usize]
returns the shape of the parameters; uses the shape of the weight tensor
Sourcepub fn shape_bias(&self) -> &[usize]
pub fn shape_bias(&self) -> &[usize]
returns the shape of the bias tensor; the shape should be equivalent to that of the weight tensor minus the “zero-th” axis
Sourcepub fn to_owned(&self) -> ParamsBase<OwnedRepr<A>, D>
pub fn to_owned(&self) -> ParamsBase<OwnedRepr<A>, D>
returns an owned instance of the parameters
Sourcepub fn to_shape<Sh>(
&self,
shape: Sh,
) -> Result<ParamsBase<CowRepr<'_, A>, Sh::Dim>>
pub fn to_shape<Sh>( &self, shape: Sh, ) -> Result<ParamsBase<CowRepr<'_, A>, Sh::Dim>>
change the shape of the parameters; the shape of the bias parameters is determined by removing the “zero-th” axis of the given shape
Source§impl<A, S> ParamsBase<S, Ix1>where
S: RawData<Elem = A>,
impl<A, S> ParamsBase<S, Ix1>where
S: RawData<Elem = A>,
Source§impl<A, S> ParamsBase<S, Ix2>where
S: RawData<Elem = A>,
impl<A, S> ParamsBase<S, Ix2>where
S: RawData<Elem = A>,
Source§impl<A, S, D> ParamsBase<S, D>
impl<A, S, D> ParamsBase<S, D>
pub fn init_rand<G, Dst, Sh>(shape: Sh, distr: G) -> Selfwhere
D: RemoveAxis,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
Dst: Clone + Distribution<A>,
G: Fn(&Sh) -> Dst,
Source§impl<A, S, D> ParamsBase<S, D>
impl<A, S, D> ParamsBase<S, D>
Sourcepub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> Result<Z>where
S: DataMut,
Self: ApplyGradient<Delta, A, Output = Z>,
pub fn apply_gradient<Delta, Z>(&mut self, grad: &Delta, lr: A) -> Result<Z>where
S: DataMut,
Self: ApplyGradient<Delta, A, Output = Z>,
a convenience method used to apply a gradient to the parameters using the given learning rate.
pub fn apply_gradient_with_decay<Grad, Z>(
&mut self,
grad: &Grad,
lr: A,
decay: A,
) -> Result<Z>where
S: DataMut,
Self: ApplyGradient<Grad, A, Output = Z>,
pub fn apply_gradient_with_momentum<Grad, V, Z>(
&mut self,
grad: &Grad,
lr: A,
momentum: A,
velocity: &mut V,
) -> Result<Z>where
S: DataMut,
Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
pub fn apply_gradient_with_decay_and_momentum<Grad, V, Z>(
&mut self,
grad: &Grad,
lr: A,
decay: A,
momentum: A,
velocity: &mut V,
) -> Result<Z>where
S: DataMut,
Self: ApplyGradientExt<Grad, A, Output = Z, Velocity = V>,
Trait Implementations§
Source§impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>where
A: Float + FromPrimitive + ScalarOperand,
S: DataMut<Elem = A>,
T: Data<Elem = A>,
D: Dimension,
impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>where
A: Float + FromPrimitive + ScalarOperand,
S: DataMut<Elem = A>,
T: Data<Elem = A>,
D: Dimension,
type Output = ()
fn apply_gradient( &mut self, grad: &ParamsBase<T, D>, lr: A, ) -> Result<Self::Output>
fn apply_gradient_with_decay( &mut self, grad: &ParamsBase<T, D>, lr: A, decay: A, ) -> Result<Self::Output>
Source§impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>where
A: Float + FromPrimitive + ScalarOperand,
S: DataMut<Elem = A>,
T: Data<Elem = A>,
D: Dimension,
impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>where
A: Float + FromPrimitive + ScalarOperand,
S: DataMut<Elem = A>,
T: Data<Elem = A>,
D: Dimension,
type Velocity = ParamsBase<OwnedRepr<A>, D>
fn apply_gradient_with_momentum( &mut self, grad: &ParamsBase<T, D>, lr: A, momentum: A, velocity: &mut Self::Velocity, ) -> Result<()>
fn apply_gradient_with_decay_and_momentum( &mut self, grad: &ParamsBase<T, D>, lr: A, decay: A, momentum: A, velocity: &mut Self::Velocity, ) -> Result<()>
Source§impl<S, D> Biased<S, D> for ParamsBase<S, D>
impl<S, D> Biased<S, D> for ParamsBase<S, D>
Source§fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller>
fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller>
Source§fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
Source§fn replace_bias(
&mut self,
bias: ArrayBase<S, D::Smaller>,
) -> ArrayBase<S, D::Smaller>
fn replace_bias( &mut self, bias: ArrayBase<S, D::Smaller>, ) -> ArrayBase<S, D::Smaller>
Source§fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self
fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self
Source§impl<A, S, D> Clone for ParamsBase<S, D>
impl<A, S, D> Clone for ParamsBase<S, D>
Source§impl<A, S, D> Debug for ParamsBase<S, D>
impl<A, S, D> Debug for ParamsBase<S, D>
Source§impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D>
type Output = Z
Source§impl<A, S, D> Initialize<S, D> for ParamsBase<S, D>where
D: RemoveAxis,
S: RawData<Elem = A>,
impl<A, S, D> Initialize<S, D> for ParamsBase<S, D>where
D: RemoveAxis,
S: RawData<Elem = A>,
fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
fn bernoulli<Sh>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
Source§fn glorot_normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
fn glorot_normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
Source§fn glorot_uniform<Sh>(shape: Sh) -> UniformResult<Self>where
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float + FromPrimitive + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
fn glorot_uniform<Sh>(shape: Sh) -> UniformResult<Self>where
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float + FromPrimitive + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
Source§fn lecun_normal<Sh>(shape: Sh) -> Selfwhere
StandardNormal: Distribution<S::Elem>,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float,
fn lecun_normal<Sh>(shape: Sh) -> Selfwhere
StandardNormal: Distribution<S::Elem>,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float,
Source§fn normal<Sh>(
shape: Sh,
mean: S::Elem,
std: S::Elem,
) -> Result<Self, NormalError>where
StandardNormal: Distribution<S::Elem>,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float,
fn normal<Sh>(
shape: Sh,
mean: S::Elem,
std: S::Elem,
) -> Result<Self, NormalError>where
StandardNormal: Distribution<S::Elem>,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float,
fn randc<Sh>(shape: Sh, re: S::Elem, im: S::Elem) -> Selfwhere
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
ComplexDistribution<S::Elem, S::Elem>: Distribution<S::Elem>,
Source§fn stdnorm<Sh>(shape: Sh) -> Self
fn stdnorm<Sh>(shape: Sh) -> Self
Source§fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
Source§fn truncnorm<Sh>(
shape: Sh,
mean: S::Elem,
std: S::Elem,
) -> Result<Self, NormalError>where
StandardNormal: Distribution<S::Elem>,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float,
fn truncnorm<Sh>(
shape: Sh,
mean: S::Elem,
std: S::Elem,
) -> Result<Self, NormalError>where
StandardNormal: Distribution<S::Elem>,
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Float,
Source§fn uniform<Sh>(shape: Sh, dk: S::Elem) -> UniformResult<Self>where
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Clone + Neg<Output = S::Elem> + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
fn uniform<Sh>(shape: Sh, dk: S::Elem) -> UniformResult<Self>where
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Clone + Neg<Output = S::Elem> + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
Uniform
distribution with values bounded by +/- dk
Source§fn uniform_from_seed<Sh>(
shape: Sh,
start: S::Elem,
stop: S::Elem,
key: u64,
) -> UniformResult<Self>where
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Clone + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
fn uniform_from_seed<Sh>(
shape: Sh,
start: S::Elem,
stop: S::Elem,
key: u64,
) -> UniformResult<Self>where
S: DataOwned,
Sh: ShapeBuilder<Dim = D>,
S::Elem: Clone + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
Uniform
distribution with values between
the start
and stop
params using some random seed.Source§fn uniform_along<Sh>(shape: Sh, axis: usize) -> UniformResult<Self>where
Sh: ShapeBuilder<Dim = D>,
S: DataOwned,
S::Elem: Float + FromPrimitive + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
fn uniform_along<Sh>(shape: Sh, axis: usize) -> UniformResult<Self>where
Sh: ShapeBuilder<Dim = D>,
S: DataOwned,
S::Elem: Float + FromPrimitive + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
Uniform
distribution with values bounded by the
size of the specified axis.
The values are bounded by +/- dk
where dk = 1 / size(axis)
.Source§fn uniform_between<Sh>(shape: Sh, a: S::Elem, b: S::Elem) -> UniformResult<Self>where
Sh: ShapeBuilder<Dim = D>,
S: DataOwned,
S::Elem: Clone + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
fn uniform_between<Sh>(shape: Sh, a: S::Elem, b: S::Elem) -> UniformResult<Self>where
Sh: ShapeBuilder<Dim = D>,
S: DataOwned,
S::Elem: Clone + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
Uniform
distribution with values between then given
bounds, a
and b
.