concision_params/traits/
wnb.rs

1/*
2    Appellation: wnb <module>
3    Created At: 2025.11.28:21:21:42
4    Contrib: @FL03
5*/
6use ndarray::iter as nditer;
7use ndarray::{ArrayBase, Data, DataMut, Dimension, RawData};
8
9/// A trait denoting an implementor with weights and associated methods
10pub trait Weighted<S, D, A = <S as RawData>::Elem>: Sized
11where
12    D: Dimension,
13    S: RawData<Elem = A>,
14{
15    type Tensor<_S, _D, _A>
16    where
17        _D: Dimension,
18        _S: RawData<Elem = _A>;
19    /// returns the weights of the model
20    fn weights(&self) -> &Self::Tensor<S, D, A>;
21    /// returns a mutable reference to the weights of the model
22    fn weights_mut(&mut self) -> &mut Self::Tensor<S, D, A>;
23    /// replaces the current weights with the given weights
24    fn replace_weights(&mut self, weights: Self::Tensor<S, D, A>) -> Self::Tensor<S, D, A> {
25        core::mem::replace(self.weights_mut(), weights)
26    }
27    /// sets the weights of the model
28    fn set_weights(&mut self, weights: Self::Tensor<S, D, A>) -> &mut Self {
29        *self.weights_mut() = weights;
30        self
31    }
32}
33
34pub trait Biased<S, D, A = <S as RawData>::Elem>: Weighted<S, D, A>
35where
36    D: Dimension,
37    S: RawData<Elem = A>,
38{
39    /// returns the bias of the model
40    fn bias(&self) -> &ArrayBase<S, D::Smaller, A>;
41    /// returns a mutable reference to the bias of the model
42    fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A>;
43    /// assigns the given bias to the current bias
44    fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller, A>) -> &mut Self
45    where
46        S: DataMut,
47        S::Elem: Clone,
48    {
49        self.bias_mut().assign(bias);
50        self
51    }
52    /// replaces the current bias with the given bias
53    fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller, A>) -> ArrayBase<S, D::Smaller, A> {
54        core::mem::replace(self.bias_mut(), bias)
55    }
56    /// sets the bias of the model
57    fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller, A>) -> &mut Self {
58        *self.bias_mut() = bias;
59        self
60    }
61    /// returns an iterator over the bias
62    fn iter_bias<'a>(&'a self) -> nditer::Iter<'a, S::Elem, D::Smaller>
63    where
64        S: Data + 'a,
65        D: 'a,
66    {
67        self.bias().iter()
68    }
69    /// returns a mutable iterator over the bias
70    fn iter_bias_mut<'a>(&'a mut self) -> nditer::IterMut<'a, S::Elem, D::Smaller>
71    where
72        S: DataMut + 'a,
73        D: 'a,
74    {
75        self.bias_mut().iter_mut()
76    }
77}