concision_core/tensor/traits/
ops.rs1pub trait Affine<X, Y = X> {
9 type Output;
10
11 fn affine(&self, mul: X, add: Y) -> Self::Output;
12}
13pub trait Inverse {
16 type Output;
18 fn inverse(&self) -> Self::Output;
20}
21pub trait MatMul<Rhs = Self> {
23 type Output;
24
25 fn matmul(&self, rhs: &Rhs) -> Self::Output;
26}
27pub trait MatPow<Rhs = Self> {
29 type Output;
30 fn matpow(&self, rhs: Rhs) -> Self::Output;
32}
33
34pub trait Transpose {
36 type Output;
38 fn transpose(&self) -> Self::Output;
40}
41
42use ndarray::linalg::Dot;
46use ndarray::{Array, Array2, ArrayBase, Data, Dimension, Ix2, LinalgScalar, ScalarOperand, s};
47use num_traits::{Num, NumAssign};
48
49impl<A, D> Affine<A> for Array<A, D>
50where
51 A: LinalgScalar + ScalarOperand,
52 D: Dimension,
53{
54 type Output = Array<A, D>;
55
56 fn affine(&self, mul: A, add: A) -> Self::Output {
57 self * mul + add
58 }
59}
60
61impl<T> Inverse for Array<T, Ix2>
63where
64 T: Copy + NumAssign + ScalarOperand,
65{
66 type Output = Option<Self>;
67
68 fn inverse(&self) -> Self::Output {
69 let (rows, cols) = self.dim();
70
71 if !self.is_square() {
72 return None; }
74
75 let identity = Array2::eye(rows);
76
77 let mut aug = Array2::zeros((rows, 2 * cols));
79 aug.slice_mut(s![.., ..cols]).assign(self);
80 aug.slice_mut(s![.., cols..]).assign(&identity);
81
82 for i in 0..rows {
84 let pivot = aug[[i, i]];
85
86 if pivot == T::zero() {
87 return None; }
89
90 aug.slice_mut(s![i, ..]).mapv_inplace(|x| x / pivot);
91
92 for j in 0..rows {
93 if i != j {
94 let am = aug.clone();
95 let factor = aug[[j, i]];
96 let rhs = am.slice(s![i, ..]);
97 aug.slice_mut(s![j, ..])
98 .zip_mut_with(&rhs, |x, &y| *x -= y * factor);
99 }
100 }
101 }
102
103 let inverted = aug.slice(s![.., cols..]);
105
106 Some(inverted.to_owned())
107 }
108}
109impl<A, S, D, X, Y> MatMul<X> for ArrayBase<S, D>
123where
124 A: ndarray::LinalgScalar,
125 D: Dimension,
126 S: Data<Elem = A>,
127 ArrayBase<S, D>: Dot<X, Output = Y>,
128{
129 type Output = Y;
130
131 fn matmul(&self, rhs: &X) -> Self::Output {
132 <Self as Dot<X>>::dot(self, rhs)
133 }
134}
135
136impl<T> MatMul<Vec<T>> for Vec<T>
137where
138 T: Copy + num::Num,
139{
140 type Output = T;
141
142 fn matmul(&self, rhs: &Vec<T>) -> T {
143 self.iter()
144 .zip(rhs.iter())
145 .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
146 }
147}
148
149impl<T, const N: usize> MatMul<[T; N]> for [T; N]
150where
151 T: Copy + num::Num,
152{
153 type Output = T;
154
155 fn matmul(&self, rhs: &[T; N]) -> T {
156 self.iter()
157 .zip(rhs.iter())
158 .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
159 }
160}
161impl<A, S> MatPow<i32> for ArrayBase<S, ndarray::Ix2>
162where
163 A: Copy + Num + 'static,
164 S: Data<Elem = A>,
165 ArrayBase<S, Ix2>: Clone + Dot<ArrayBase<S, Ix2>, Output = Array<A, Ix2>>,
166{
167 type Output = Array<A, Ix2>;
168
169 fn matpow(&self, rhs: i32) -> Self::Output {
170 if !self.is_square() {
171 panic!("Matrix must be square to be raised to a power");
172 }
173 let mut res = Array::eye(self.shape()[0]);
174 for _ in 0..rhs {
175 res = res.dot(self);
176 }
177 res
178 }
179}
180
181impl<'a, A, S, D> Transpose for &'a ArrayBase<S, D>
182where
183 A: 'a,
184 D: Dimension,
185 S: Data<Elem = A>,
186{
187 type Output = ndarray::ArrayView<'a, A, D>;
188
189 fn transpose(&self) -> Self::Output {
190 self.t()
191 }
192}