concision_params/traits/
raw_params.rs

1/*
2    Appellation: tensor <module>
3    Created At: 2025.12.08:16:03:55
4    Contrib: @FL03
5*/
6
7/// The [`RawParams`] trait is used to denote objects capable of being used as a paramater
8/// within a neural network or machine learning context. More over, it provides us with an
9/// ability to associate some generic element type with the parameter and thus allows us to
10/// consider so-called _parameter spaces_. If we allow a parameter space to simply be a
11/// collection of points then we can refine the definition downstream to consider specific
12/// interpolations, distributions, or manifolds. In other words, we are trying to construct
13/// a tangible configuration space for our models so that we can reason about optimization
14/// and training in a more formal manner.
15///
16/// **Note**: This trait is sealed and cannot be implemented outside of this crate.
17pub trait RawParams {
18    type Elem: ?Sized;
19
20    private! {}
21}
22
23/// The [`ScalarParams`] is a marker trait automatically implemented for
24pub trait ScalarParams: RawParams<Elem = Self> + Sized {
25    private!();
26}
27
28pub trait TensorParams: RawParams {
29    /// returns the number of dimensions of the parameter
30    fn rank(&self) -> usize;
31    /// returns the size of the parameter
32    fn size(&self) -> usize;
33}
34
35pub trait ExactDimParams: TensorParams {
36    type Shape: ?Sized;
37    /// returns a reference to the shape of the parameter
38    fn shape(&self) -> &Self::Shape;
39}
40
41/*
42 ************* Implementations *************
43*/
44use crate::ParamsBase;
45use ndarray::{ArrayBase, Dimension, RawData};
46
47impl<A, T> RawParams for &T
48where
49    T: RawParams<Elem = A>,
50{
51    type Elem = A;
52
53    seal! {}
54}
55
56impl<A, T> RawParams for &mut T
57where
58    T: RawParams<Elem = A>,
59{
60    type Elem = A;
61
62    seal! {}
63}
64
65impl<T> ScalarParams for T
66where
67    T: RawParams<Elem = T>,
68{
69    seal! {}
70}
71
72macro_rules! impl_scalar_param {
73    ($($T:ty),* $(,)?) => {
74        $(impl_scalar_param!(@impl $T);)*
75    };
76    (@impl $T:ty) => {
77        impl RawParams for $T {
78            type Elem = $T;
79
80            seal! {}
81        }
82
83        impl TensorParams for $T {
84            fn rank(&self) -> usize {
85                0
86            }
87
88            fn size(&self) -> usize {
89                1
90            }
91        }
92
93        impl ExactDimParams for $T {
94            type Shape = [usize; 0];
95
96            fn shape(&self) -> &Self::Shape {
97                &[]
98            }
99        }
100    };
101}
102
103impl_scalar_param! {
104    u8, u16, u32, u64, u128, usize,
105    i8, i16, i32, i64, i128, isize,
106    f32, f64,
107    bool, char, str
108}
109
110#[cfg(feature = "alloc")]
111impl RawParams for alloc::string::String {
112    type Elem = alloc::string::String;
113
114    seal! {}
115}
116
117impl<S, D, A> RawParams for ArrayBase<S, D, A>
118where
119    D: Dimension,
120    S: RawData<Elem = A>,
121{
122    type Elem = A;
123
124    seal! {}
125}
126
127impl<S, D, A> TensorParams for ArrayBase<S, D, A>
128where
129    D: Dimension,
130    S: RawData<Elem = A>,
131{
132    fn rank(&self) -> usize {
133        self.ndim()
134    }
135
136    fn size(&self) -> usize {
137        self.len()
138    }
139}
140
141impl<S, D, A> ExactDimParams for ArrayBase<S, D, A>
142where
143    D: Dimension,
144    S: RawData<Elem = A>,
145{
146    type Shape = [usize];
147
148    fn shape(&self) -> &[usize] {
149        self.shape()
150    }
151}
152
153impl<S, D, A> RawParams for ParamsBase<S, D, A>
154where
155    D: Dimension,
156    S: RawData<Elem = A>,
157{
158    type Elem = A;
159
160    seal! {}
161}
162
163impl<S, D, A> TensorParams for ParamsBase<S, D, A>
164where
165    D: Dimension,
166    S: RawData<Elem = A>,
167{
168    fn rank(&self) -> usize {
169        self.weights().ndim()
170    }
171
172    fn size(&self) -> usize {
173        self.weights().len()
174    }
175}
176
177impl<S, D, A> ExactDimParams for ParamsBase<S, D, A>
178where
179    D: Dimension,
180    S: RawData<Elem = A>,
181{
182    type Shape = [usize];
183
184    fn shape(&self) -> &[usize] {
185        self.weights().shape()
186    }
187}
188
189impl<T> RawParams for [T] {
190    type Elem = T;
191
192    seal! {}
193}
194
195impl<T> RawParams for &[T] {
196    type Elem = T;
197
198    seal! {}
199}
200
201impl<T> RawParams for &mut [T] {
202    type Elem = T;
203
204    seal! {}
205}
206
207impl<const N: usize, T> RawParams for [T; N] {
208    type Elem = T;
209
210    seal! {}
211}
212
213impl<const N: usize, T> TensorParams for [T; N] {
214    fn rank(&self) -> usize {
215        1
216    }
217
218    fn size(&self) -> usize {
219        N
220    }
221}
222
223impl<const N: usize, T> ExactDimParams for [T; N] {
224    type Shape = [usize; 1];
225
226    fn shape(&self) -> &Self::Shape {
227        &[N]
228    }
229}
230
231#[cfg(feature = "alloc")]
232mod impl_alloc {
233    use super::*;
234    use alloc::vec::Vec;
235
236    impl<T> RawParams for Vec<T>
237    where
238        T: RawParams,
239    {
240        type Elem = T::Elem;
241
242        seal! {}
243    }
244}