concision_linear/impls/params/
impl_params.rs

1/*
2    Appellation: params <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::params::ParamsBase;
6use concision::prelude::{Predict, PredictError};
7use core::ops::Add;
8use nd::linalg::Dot;
9use nd::*;
10use num::complex::ComplexFloat;
11
12impl<A, K, S, D> ParamsBase<S, D, K>
13where
14    D: RemoveAxis,
15    S: RawData<Elem = A>,
16{
17    pub fn activate<F, X, Y>(&mut self, args: &X, f: F) -> Y
18    where
19        F: for<'a> Fn(&'a Y) -> Y,
20        S: Data<Elem = A>,
21        Self: Predict<X, Output = Y>,
22    {
23        f(&self.predict(args).unwrap())
24    }
25}
26
27impl<A, S, D> Clone for ParamsBase<S, D>
28where
29    A: Clone,
30    D: RemoveAxis,
31    S: RawDataClone<Elem = A>,
32{
33    fn clone(&self) -> Self {
34        Self {
35            weight: self.weight.clone(),
36            bias: self.bias.clone(),
37            _mode: self._mode,
38        }
39    }
40}
41
42impl<A, S, D> Copy for ParamsBase<S, D>
43where
44    A: Copy,
45    D: Copy + RemoveAxis,
46    S: Copy + RawDataClone<Elem = A>,
47    <D as Dimension>::Smaller: Copy,
48{
49}
50
51impl<A, S, D> PartialEq for ParamsBase<S, D>
52where
53    A: PartialEq,
54    D: RemoveAxis,
55    S: Data<Elem = A>,
56{
57    fn eq(&self, other: &Self) -> bool {
58        self.weights() == other.weight && self.bias == other.bias
59    }
60}
61
62impl<A, S, D, K> PartialEq<(ArrayBase<S, D>, Option<ArrayBase<S, D::Smaller>>)>
63    for ParamsBase<S, D, K>
64where
65    A: PartialEq,
66    D: RemoveAxis,
67    S: Data<Elem = A>,
68{
69    fn eq(&self, (weights, bias): &(ArrayBase<S, D>, Option<ArrayBase<S, D::Smaller>>)) -> bool {
70        self.weights() == weights && self.bias.as_ref() == bias.as_ref()
71    }
72}
73
74impl<A, S, D, K> PartialEq<(ArrayBase<S, D>, ArrayBase<S, D::Smaller>)> for ParamsBase<S, D, K>
75where
76    A: PartialEq,
77    D: RemoveAxis,
78    S: Data<Elem = A>,
79{
80    fn eq(&self, (weights, bias): &(ArrayBase<S, D>, ArrayBase<S, D::Smaller>)) -> bool {
81        self.weights() == weights && self.bias.as_ref() == Some(bias)
82    }
83}
84
85impl<A, B, T, S, D, K> Predict<A> for ParamsBase<S, D, K>
86where
87    A: Dot<Array<T, D>, Output = B>,
88    B: for<'a> Add<&'a ArrayBase<S, D::Smaller>, Output = B>,
89    D: RemoveAxis,
90    S: Data<Elem = T>,
91    T: ComplexFloat,
92{
93    type Output = B;
94
95    fn predict(&self, input: &A) -> Result<Self::Output, PredictError> {
96        let wt = self.weights().t().to_owned();
97        let mut res = input.dot(&wt);
98        if let Some(bias) = self.bias.as_ref() {
99            res = res + bias;
100        }
101        Ok(res)
102    }
103}
104
105impl<'a, A, B, T, S, D, K> Predict<A> for &'a ParamsBase<S, D, K>
106where
107    A: Dot<Array<T, D>, Output = B>,
108    B: Add<&'a ArrayBase<S, D::Smaller>, Output = B>,
109    D: RemoveAxis,
110    S: Data<Elem = T>,
111    T: ComplexFloat,
112{
113    type Output = B;
114
115    fn predict(&self, input: &A) -> Result<Self::Output, PredictError> {
116        let wt = self.weights().t().to_owned();
117        let mut res = input.dot(&wt);
118        if let Some(bias) = self.bias.as_ref() {
119            res = res + bias;
120        }
121        Ok(res)
122    }
123}