concision_core/ops/
tensor.rs1pub trait Affine<X, Y = X> {
9 type Output;
10
11 fn affine(&self, mul: X, add: Y) -> Self::Output;
12}
13pub trait Inverse {
16 type Output;
18 fn inverse(&self) -> Self::Output;
20}
21pub trait MatMul<Rhs = Self> {
23 type Output;
24
25 fn matmul(&self, rhs: &Rhs) -> Self::Output;
26}
27pub trait MatPow<Rhs = Self> {
29 type Output;
30 fn matpow(&self, rhs: Rhs) -> Self::Output;
32}
33
34pub trait Transpose {
36 type Output;
38 fn transpose(&self) -> Self::Output;
40}
41
42use ndarray::linalg::Dot;
46use ndarray::{Array, ArrayBase, Data, Dimension, Ix2, LinalgScalar, ScalarOperand};
47use num_traits::{Num, NumAssign};
48
49impl<A, D> Affine<A> for Array<A, D>
50where
51 A: LinalgScalar + ScalarOperand,
52 D: Dimension,
53{
54 type Output = Array<A, D>;
55
56 fn affine(&self, mul: A, add: A) -> Self::Output {
57 self * mul + add
58 }
59}
60
61impl<T> Inverse for Array<T, Ix2>
63where
64 T: Copy + NumAssign + ScalarOperand,
65{
66 type Output = Option<Self>;
67
68 fn inverse(&self) -> Self::Output {
69 crate::inverse(self)
70 }
71}
72
73impl<A, S, D, X, Y> MatMul<X> for ArrayBase<S, D>
74where
75 A: ndarray::LinalgScalar,
76 D: Dimension,
77 S: Data<Elem = A>,
78 ArrayBase<S, D>: Dot<X, Output = Y>,
79{
80 type Output = Y;
81
82 fn matmul(&self, rhs: &X) -> Self::Output {
83 <Self as Dot<X>>::dot(self, rhs)
84 }
85}
86
87impl<T> MatMul<Vec<T>> for Vec<T>
88where
89 T: Copy + num::Num,
90{
91 type Output = T;
92
93 fn matmul(&self, rhs: &Vec<T>) -> T {
94 self.iter()
95 .zip(rhs.iter())
96 .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
97 }
98}
99
100impl<T, const N: usize> MatMul<[T; N]> for [T; N]
101where
102 T: Copy + num::Num,
103{
104 type Output = T;
105
106 fn matmul(&self, rhs: &[T; N]) -> T {
107 self.iter()
108 .zip(rhs.iter())
109 .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
110 }
111}
112impl<A, S> MatPow<i32> for ArrayBase<S, ndarray::Ix2>
113where
114 A: Copy + Num + 'static,
115 S: Data<Elem = A>,
116 ArrayBase<S, Ix2>: Clone + Dot<ArrayBase<S, Ix2>, Output = Array<A, Ix2>>,
117{
118 type Output = Array<A, Ix2>;
119
120 fn matpow(&self, rhs: i32) -> Self::Output {
121 if !self.is_square() {
122 panic!("Matrix must be square to be raised to a power");
123 }
124 let mut res = Array::eye(self.shape()[0]);
125 for _ in 0..rhs {
126 res = res.dot(self);
127 }
128 res
129 }
130}
131
132impl<'a, A, S, D> Transpose for &'a ArrayBase<S, D>
133where
134 A: 'a,
135 D: Dimension,
136 S: Data<Elem = A>,
137{
138 type Output = ndarray::ArrayView<'a, A, D>;
139
140 fn transpose(&self) -> Self::Output {
141 self.t()
142 }
143}