concision_params/params_base.rs
1/*
2 Appellation: params <module>
3 Contrib: @FL03
4*/
5#[cfg(feature = "alloc")]
6use alloc::boxed::Box;
7use ndarray::{
8 ArrayBase, CowRepr, Dimension, Ix2, OwnedArcRepr, OwnedRepr, RawData, RawRef, RawViewRepr,
9 ViewRepr,
10};
11
12/// A type alias for a [`ParamsBase`] with an owned internal layout
13pub type Params<A = f32, D = Ix2> = ParamsBase<OwnedRepr<A>, D, A>;
14/// A type alias for shared parameters
15pub type ArcParams<A = f32, D = Ix2> = ParamsBase<OwnedArcRepr<A>, D, A>;
16/// A type alias for an immutable view of the parameters
17pub type ParamsView<'a, A = f32, D = Ix2> = ParamsBase<ViewRepr<&'a A>, D, A>;
18/// A type alias for a mutable view of the parameters
19pub type ParamsViewMut<'a, A = f32, D = Ix2> = ParamsBase<ViewRepr<&'a mut A>, D, A>;
20/// A type alias for a [`ParamsBase`] with a _borrowed_ internal layout
21pub type CowParams<'a, A = f32, D = Ix2> = ParamsBase<CowRepr<'a, A>, D, A>;
22/// A type alias for the [`ParamsBase`] whose elements are of type `*const A` using a
23/// [`RawViewRepr`] layout
24pub type RawViewParams<A = f32, D = Ix2> = ParamsBase<RawViewRepr<*const A>, D, A>;
25/// A type alias for the [`ParamsBase`] whose elements are of type `*mut A` using a
26/// [`RawViewRepr`] layout
27pub type RawMutParams<A = f32, D = Ix2> = ParamsBase<RawViewRepr<*mut A>, D, A>;
28
29#[cfg(feature = "alloc")]
30pub struct ParamsRef<A, D: Dimension> {
31 pub bias: Box<RawRef<A, D::Smaller>>,
32 pub weights: RawRef<A, D>,
33}
34
35/// The [`ParamsBase`] implementation aims to provide a generic, n-dimensional weight and bias
36/// pair for a model (or layer). The object requires the bias tensor to be a single dimension
37/// smaller than the weights tensor.
38///
39/// Therefore, we allow the weight tensor to be the _shape_ of the parameters, using the shape
40/// as the basis for the bias tensor by removing the first axis.
41/// Consequently, this constrains the [`ParamsBase`] implementation to only support dimensions
42/// that can be reduced by one axis, typically the "zero-th" axis: $\text{rank}(D)$.
43pub struct ParamsBase<S, D = ndarray::Ix2, A = <S as RawData>::Elem>
44where
45 D: Dimension,
46 S: RawData<Elem = A>,
47{
48 pub bias: ArrayBase<S, D::Smaller, A>,
49 pub weights: ArrayBase<S, D, A>,
50}