concision_core/ops/
tensor.rs1pub trait Affine<X, Y = X> {
8 type Output;
9
10 fn affine(&self, mul: X, add: Y) -> Self::Output;
11}
12pub trait Inverse {
14 type Output;
15
16 fn inverse(&self) -> Self::Output;
17}
18pub trait Matmul<Rhs = Self> {
20 type Output;
21
22 fn matmul(&self, rhs: &Rhs) -> Self::Output;
23}
24pub trait Matpow<Rhs = Self> {
26 type Output;
27
28 fn pow(&self, rhs: Rhs) -> Self::Output;
29}
30
31pub trait Transpose {
33 type Output;
34
35 fn transpose(&self) -> Self::Output;
36}
37
38use ndarray::linalg::Dot;
42use ndarray::{Array, ArrayBase, Data, Dimension, Ix2, LinalgScalar, ScalarOperand};
43use num_traits::{Num, NumAssign};
44
45impl<A, D> Affine<A> for Array<A, D>
46where
47 A: LinalgScalar + ScalarOperand,
48 D: Dimension,
49{
50 type Output = Array<A, D>;
51
52 fn affine(&self, mul: A, add: A) -> Self::Output {
53 self * mul + add
54 }
55}
56
57impl<T> Inverse for Array<T, Ix2>
59where
60 T: Copy + NumAssign + ScalarOperand,
61{
62 type Output = Option<Self>;
63
64 fn inverse(&self) -> Self::Output {
65 crate::inverse(self)
66 }
67}
68
69impl<A, S, D, X, Y> Matmul<X> for ArrayBase<S, D>
70where
71 A: ndarray::LinalgScalar,
72 D: Dimension,
73 S: Data<Elem = A>,
74 ArrayBase<S, D>: Dot<X, Output = Y>,
75{
76 type Output = Y;
77
78 fn matmul(&self, rhs: &X) -> Self::Output {
79 <Self as Dot<X>>::dot(self, rhs)
80 }
81}
82
83impl<T> Matmul<Vec<T>> for Vec<T>
84where
85 T: Copy + num::Num,
86{
87 type Output = T;
88
89 fn matmul(&self, rhs: &Vec<T>) -> T {
90 self.iter()
91 .zip(rhs.iter())
92 .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
93 }
94}
95
96impl<T, const N: usize> Matmul<[T; N]> for [T; N]
97where
98 T: Copy + num::Num,
99{
100 type Output = T;
101
102 fn matmul(&self, rhs: &[T; N]) -> T {
103 self.iter()
104 .zip(rhs.iter())
105 .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
106 }
107}
108impl<A, S> Matpow<i32> for ArrayBase<S, ndarray::Ix2>
109where
110 A: Copy + Num + 'static,
111 S: Data<Elem = A>,
112 ArrayBase<S, Ix2>: Clone + Dot<ArrayBase<S, Ix2>, Output = Array<A, Ix2>>,
113{
114 type Output = Array<A, Ix2>;
115
116 fn pow(&self, rhs: i32) -> Self::Output {
117 if !self.is_square() {
118 panic!("Matrix must be square to be raised to a power");
119 }
120 let mut res = Array::eye(self.shape()[0]);
121 for _ in 0..rhs {
122 res = res.dot(self);
123 }
124 res
125 }
126}
127
128impl<'a, A, S, D> Transpose for &'a ArrayBase<S, D>
129where
130 A: 'a,
131 D: Dimension,
132 S: Data<Elem = A>,
133{
134 type Output = ndarray::ArrayView<'a, A, D>;
135
136 fn transpose(&self) -> Self::Output {
137 self.t()
138 }
139}