concision_linear/model/
layer.rs

1/*
2    Appellation: model <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::{Config, Layout};
6use crate::{Biased, LinearParams, ParamMode, ParamsBase, Unbiased};
7use concision::prelude::{Predict, Result};
8use nd::prelude::*;
9use nd::{DataOwned, OwnedRepr, RawData, RemoveAxis};
10
11/// An implementation of a linear model.
12///
13/// In an effort to streamline the api, the [Linear] model relies upon a [ParamMode] type ([Biased] or [Unbiased](crate::params::mode::Unbiased))
14/// which enables the model to automatically determine whether or not to include a bias term. Doing so allows the model to inherit several methods
15/// familar to the underlying [ndarray](https://docs.rs/ndarray) crate.
16pub struct Linear<A = f64, K = Biased, D = Ix2, S = OwnedRepr<A>>
17where
18    D: Dimension,
19    S: RawData<Elem = A>,
20{
21    pub(crate) config: Config<K, D>,
22    pub(crate) params: ParamsBase<S, D, K>,
23}
24
25impl<A, K> Linear<A, K, Ix2, OwnedRepr<A>>
26where
27    K: ParamMode,
28{
29    pub fn std(inputs: usize, outputs: usize) -> Self
30    where
31        A: Default,
32    {
33        let config = Config::<K, Ix2>::new().with_shape((inputs, outputs));
34        let params = ParamsBase::new(config.features());
35        Linear { config, params }
36    }
37}
38
39impl<A, S, D, K> Linear<A, K, D, S>
40where
41    D: RemoveAxis,
42    K: ParamMode,
43    S: RawData<Elem = A>,
44{
45    mbuilder!(new where A: Default, S: DataOwned);
46    mbuilder!(ones where A: Clone + num::One, S: DataOwned);
47    mbuilder!(zeros where A: Clone + num::Zero, S: DataOwned);
48
49    pub fn from_config(config: Config<K, D>) -> Self
50    where
51        A: Clone + Default,
52        K: ParamMode,
53        S: DataOwned,
54    {
55        let params = ParamsBase::new(config.dim());
56        Self { config, params }
57    }
58
59    pub fn from_layout(layout: Layout<D>) -> Self
60    where
61        A: Clone + Default,
62        K: ParamMode,
63        S: DataOwned,
64    {
65        let config = Config::<K, D>::new().with_layout(layout);
66        let params = ParamsBase::new(config.dim());
67        Self { config, params }
68    }
69
70    pub fn from_params(params: ParamsBase<S, D, K>) -> Self {
71        let config = Config::<K, D>::new().with_shape(params.raw_dim());
72        Self { config, params }
73    }
74
75    /// Applies an activcation function onto the prediction of the model.
76    pub fn activate<X, Y, F>(&self, args: &X, func: F) -> Result<Y>
77    where
78        F: Fn(Y) -> Y,
79        Self: Predict<X, Output = Y>,
80    {
81        Ok(func(self.predict(args)?))
82    }
83
84    pub const fn config(&self) -> &Config<K, D> {
85        &self.config
86    }
87
88    pub fn weights(&self) -> &ArrayBase<S, D> {
89        self.params.weights()
90    }
91
92    pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
93        self.params.weights_mut()
94    }
95
96    pub const fn params(&self) -> &ParamsBase<S, D, K> {
97        &self.params
98    }
99
100    pub fn params_mut(&mut self) -> &mut ParamsBase<S, D, K> {
101        &mut self.params
102    }
103
104    pub fn into_biased(self) -> Linear<A, Biased, D, S>
105    where
106        A: Default,
107        K: 'static,
108        S: DataOwned,
109    {
110        Linear {
111            config: self.config.into_biased(),
112            params: self.params.into_biased(),
113        }
114    }
115
116    pub fn into_unbiased(self) -> Linear<A, Unbiased, D, S>
117    where
118        A: Default,
119        K: 'static,
120        S: DataOwned,
121    {
122        Linear {
123            config: self.config.into_unbiased(),
124            params: self.params.into_unbiased(),
125        }
126    }
127
128    pub fn is_biased(&self) -> bool
129    where
130        K: 'static,
131    {
132        self.config().is_biased()
133    }
134
135    pub fn with_params<E>(self, params: LinearParams<A, K, E>) -> Linear<A, K, E>
136    where
137        E: RemoveAxis,
138    {
139        let config = self.config.into_dimensionality(params.raw_dim()).unwrap();
140        Linear { config, params }
141    }
142
143    pub fn with_name(self, name: impl ToString) -> Self {
144        Self {
145            config: self.config.with_name(name),
146            ..self
147        }
148    }
149
150    concision::dimensional!(params());
151}
152
153impl<A, S, D> Linear<A, Biased, D, S>
154where
155    D: RemoveAxis,
156    S: RawData<Elem = A>,
157{
158    pub fn biased<Sh>(shape: Sh) -> Self
159    where
160        A: Default,
161        S: DataOwned,
162        Sh: ShapeBuilder<Dim = D>,
163    {
164        let config = Config::<Biased, D>::new().with_shape(shape);
165        let params = ParamsBase::biased(config.dim());
166        Linear { config, params }
167    }
168
169    pub fn bias(&self) -> &ArrayBase<S, D::Smaller> {
170        self.params().bias()
171    }
172
173    pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
174        self.params_mut().bias_mut()
175    }
176}
177
178impl<A, S, D> Linear<A, Unbiased, D, S>
179where
180    D: RemoveAxis,
181    S: RawData<Elem = A>,
182{
183    pub fn unbiased<Sh>(shape: Sh) -> Self
184    where
185        A: Default,
186        S: DataOwned,
187        Sh: ShapeBuilder<Dim = D>,
188    {
189        let config = Config::<Unbiased, D>::new().with_shape(shape);
190        let params = ParamsBase::unbiased(config.dim());
191        Linear { config, params }
192    }
193}