concision_core/ops/
tensor.rs

1/*
2    Appellation: ops <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5/// apply an affine transformation to a tensor;
6/// affine transformation is defined as `mul * self + add`
7pub trait Affine<X, Y = X> {
8    type Output;
9
10    fn affine(&self, mul: X, add: Y) -> Self::Output;
11}
12/// this trait enables the inversion of a matrix
13pub trait Inverse {
14    type Output;
15
16    fn inverse(&self) -> Self::Output;
17}
18/// A trait denoting objects capable of matrix multiplication.
19pub trait Matmul<Rhs = Self> {
20    type Output;
21
22    fn matmul(&self, rhs: &Rhs) -> Self::Output;
23}
24/// a trait denoting objects capable of matrix exponentiation
25pub trait Matpow<Rhs = Self> {
26    type Output;
27
28    fn pow(&self, rhs: Rhs) -> Self::Output;
29}
30
31/// the trait denotes the ability to transpose a tensor
32pub trait Transpose {
33    type Output;
34
35    fn transpose(&self) -> Self::Output;
36}
37
38/*
39 ********* Implementations *********
40*/
41use ndarray::linalg::Dot;
42use ndarray::{Array, ArrayBase, Data, Dimension, Ix2, LinalgScalar, ScalarOperand};
43use num_traits::{Num, NumAssign};
44
45impl<A, D> Affine<A> for Array<A, D>
46where
47    A: LinalgScalar + ScalarOperand,
48    D: Dimension,
49{
50    type Output = Array<A, D>;
51
52    fn affine(&self, mul: A, add: A) -> Self::Output {
53        self * mul + add
54    }
55}
56
57// #[cfg(feature = "blas")]
58impl<T> Inverse for Array<T, Ix2>
59where
60    T: Copy + NumAssign + ScalarOperand,
61{
62    type Output = Option<Self>;
63
64    fn inverse(&self) -> Self::Output {
65        crate::inverse(self)
66    }
67}
68
69impl<A, S, D, X, Y> Matmul<X> for ArrayBase<S, D>
70where
71    A: ndarray::LinalgScalar,
72    D: Dimension,
73    S: Data<Elem = A>,
74    ArrayBase<S, D>: Dot<X, Output = Y>,
75{
76    type Output = Y;
77
78    fn matmul(&self, rhs: &X) -> Self::Output {
79        <Self as Dot<X>>::dot(self, rhs)
80    }
81}
82
83impl<T> Matmul<Vec<T>> for Vec<T>
84where
85    T: Copy + num::Num,
86{
87    type Output = T;
88
89    fn matmul(&self, rhs: &Vec<T>) -> T {
90        self.iter()
91            .zip(rhs.iter())
92            .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
93    }
94}
95
96impl<T, const N: usize> Matmul<[T; N]> for [T; N]
97where
98    T: Copy + num::Num,
99{
100    type Output = T;
101
102    fn matmul(&self, rhs: &[T; N]) -> T {
103        self.iter()
104            .zip(rhs.iter())
105            .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
106    }
107}
108impl<A, S> Matpow<i32> for ArrayBase<S, ndarray::Ix2>
109where
110    A: Copy + Num + 'static,
111    S: Data<Elem = A>,
112    ArrayBase<S, Ix2>: Clone + Dot<ArrayBase<S, Ix2>, Output = Array<A, Ix2>>,
113{
114    type Output = Array<A, Ix2>;
115
116    fn pow(&self, rhs: i32) -> Self::Output {
117        if !self.is_square() {
118            panic!("Matrix must be square to be raised to a power");
119        }
120        let mut res = Array::eye(self.shape()[0]);
121        for _ in 0..rhs {
122            res = res.dot(self);
123        }
124        res
125    }
126}
127
128impl<'a, A, S, D> Transpose for &'a ArrayBase<S, D>
129where
130    A: 'a,
131    D: Dimension,
132    S: Data<Elem = A>,
133{
134    type Output = ndarray::ArrayView<'a, A, D>;
135
136    fn transpose(&self) -> Self::Output {
137        self.t()
138    }
139}