concision_linear/impls/params/
impl_params.rs1use crate::params::ParamsBase;
6use concision::prelude::{Predict, PredictError};
7use core::ops::Add;
8use nd::linalg::Dot;
9use nd::*;
10use num::complex::ComplexFloat;
11
12impl<A, K, S, D> ParamsBase<S, D, K>
13where
14 D: RemoveAxis,
15 S: RawData<Elem = A>,
16{
17 pub fn activate<F, X, Y>(&mut self, args: &X, f: F) -> Y
18 where
19 F: for<'a> Fn(&'a Y) -> Y,
20 S: Data<Elem = A>,
21 Self: Predict<X, Output = Y>,
22 {
23 f(&self.predict(args).unwrap())
24 }
25}
26
27impl<A, S, D> Clone for ParamsBase<S, D>
28where
29 A: Clone,
30 D: RemoveAxis,
31 S: RawDataClone<Elem = A>,
32{
33 fn clone(&self) -> Self {
34 Self {
35 weight: self.weight.clone(),
36 bias: self.bias.clone(),
37 _mode: self._mode,
38 }
39 }
40}
41
42impl<A, S, D> Copy for ParamsBase<S, D>
43where
44 A: Copy,
45 D: Copy + RemoveAxis,
46 S: Copy + RawDataClone<Elem = A>,
47 <D as Dimension>::Smaller: Copy,
48{
49}
50
51impl<A, S, D> PartialEq for ParamsBase<S, D>
52where
53 A: PartialEq,
54 D: RemoveAxis,
55 S: Data<Elem = A>,
56{
57 fn eq(&self, other: &Self) -> bool {
58 self.weights() == other.weight && self.bias == other.bias
59 }
60}
61
62impl<A, S, D, K> PartialEq<(ArrayBase<S, D>, Option<ArrayBase<S, D::Smaller>>)>
63 for ParamsBase<S, D, K>
64where
65 A: PartialEq,
66 D: RemoveAxis,
67 S: Data<Elem = A>,
68{
69 fn eq(&self, (weights, bias): &(ArrayBase<S, D>, Option<ArrayBase<S, D::Smaller>>)) -> bool {
70 self.weights() == weights && self.bias.as_ref() == bias.as_ref()
71 }
72}
73
74impl<A, S, D, K> PartialEq<(ArrayBase<S, D>, ArrayBase<S, D::Smaller>)> for ParamsBase<S, D, K>
75where
76 A: PartialEq,
77 D: RemoveAxis,
78 S: Data<Elem = A>,
79{
80 fn eq(&self, (weights, bias): &(ArrayBase<S, D>, ArrayBase<S, D::Smaller>)) -> bool {
81 self.weights() == weights && self.bias.as_ref() == Some(bias)
82 }
83}
84
85impl<A, B, T, S, D, K> Predict<A> for ParamsBase<S, D, K>
86where
87 A: Dot<Array<T, D>, Output = B>,
88 B: for<'a> Add<&'a ArrayBase<S, D::Smaller>, Output = B>,
89 D: RemoveAxis,
90 S: Data<Elem = T>,
91 T: ComplexFloat,
92{
93 type Output = B;
94
95 fn predict(&self, input: &A) -> Result<Self::Output, PredictError> {
96 let wt = self.weights().t().to_owned();
97 let mut res = input.dot(&wt);
98 if let Some(bias) = self.bias.as_ref() {
99 res = res + bias;
100 }
101 Ok(res)
102 }
103}
104
105impl<'a, A, B, T, S, D, K> Predict<A> for &'a ParamsBase<S, D, K>
106where
107 A: Dot<Array<T, D>, Output = B>,
108 B: Add<&'a ArrayBase<S, D::Smaller>, Output = B>,
109 D: RemoveAxis,
110 S: Data<Elem = T>,
111 T: ComplexFloat,
112{
113 type Output = B;
114
115 fn predict(&self, input: &A) -> Result<Self::Output, PredictError> {
116 let wt = self.weights().t().to_owned();
117 let mut res = input.dot(&wt);
118 if let Some(bias) = self.bias.as_ref() {
119 res = res + bias;
120 }
121 Ok(res)
122 }
123}