concision_core/tensor/traits/
ops.rs

1/*
2    appellation: tensor <module>
3    authors: @FL03
4*/
5
6/// apply an affine transformation to a tensor;
7/// affine transformation is defined as `mul * self + add`
8pub trait Affine<X, Y = X> {
9    type Output;
10
11    fn affine(&self, mul: X, add: Y) -> Self::Output;
12}
13/// The [`Inverse`] trait generically establishes an interface for computing the inverse of a
14/// type, regardless of if its a tensor, scalar, or some other compatible type.
15pub trait Inverse {
16    /// the output, or result, of the inverse operation
17    type Output;
18    /// compute the inverse of the current object, producing some [`Output`](Inverse::Output)
19    fn inverse(&self) -> Self::Output;
20}
21/// The [`MatMul`] trait defines an interface for matrix multiplication.
22pub trait MatMul<Rhs = Self> {
23    type Output;
24
25    fn matmul(&self, rhs: &Rhs) -> Self::Output;
26}
27/// The [`MatPow`] trait defines an interface for computing the exponentiation of a matrix.
28pub trait MatPow<Rhs = Self> {
29    type Output;
30    /// raise the tensor to the power of the right-hand side, producing some [`Output`](Matpow::Output)
31    fn matpow(&self, rhs: Rhs) -> Self::Output;
32}
33
34/// The [`Transpose`] trait generically establishes an interface for transposing a type
35pub trait Transpose {
36    /// the output, or result, of the transposition
37    type Output;
38    /// transpose a reference to the current object
39    fn transpose(&self) -> Self::Output;
40}
41
42/*
43 ********* Implementations *********
44*/
45use ndarray::linalg::Dot;
46use ndarray::{Array, Array2, ArrayBase, Data, Dimension, Ix2, LinalgScalar, ScalarOperand, s};
47use num_traits::{Num, NumAssign};
48
49impl<A, D> Affine<A> for Array<A, D>
50where
51    A: LinalgScalar + ScalarOperand,
52    D: Dimension,
53{
54    type Output = Array<A, D>;
55
56    fn affine(&self, mul: A, add: A) -> Self::Output {
57        self * mul + add
58    }
59}
60
61// #[cfg(not(feature = "blas"))]
62impl<T> Inverse for Array<T, Ix2>
63where
64    T: Copy + NumAssign + ScalarOperand,
65{
66    type Output = Option<Self>;
67
68    fn inverse(&self) -> Self::Output {
69        let (rows, cols) = self.dim();
70
71        if !self.is_square() {
72            return None; // Matrix must be square for inversion
73        }
74
75        let identity = Array2::eye(rows);
76
77        // Construct an augmented matrix by concatenating the original matrix with an identity matrix
78        let mut aug = Array2::zeros((rows, 2 * cols));
79        aug.slice_mut(s![.., ..cols]).assign(self);
80        aug.slice_mut(s![.., cols..]).assign(&identity);
81
82        // Perform Gaussian elimination to reduce the left half to the identity matrix
83        for i in 0..rows {
84            let pivot = aug[[i, i]];
85
86            if pivot == T::zero() {
87                return None; // Matrix is singular
88            }
89
90            aug.slice_mut(s![i, ..]).mapv_inplace(|x| x / pivot);
91
92            for j in 0..rows {
93                if i != j {
94                    let am = aug.clone();
95                    let factor = aug[[j, i]];
96                    let rhs = am.slice(s![i, ..]);
97                    aug.slice_mut(s![j, ..])
98                        .zip_mut_with(&rhs, |x, &y| *x -= y * factor);
99                }
100            }
101        }
102
103        // Extract the inverted matrix from the augmented matrix
104        let inverted = aug.slice(s![.., cols..]);
105
106        Some(inverted.to_owned())
107    }
108}
109// #[cfg(feature = "blas")]
110// impl<T> Inverse for Array<T, Ix2>
111// where
112//     T: Copy + NumAssign + ScalarOperand,
113// {
114//     type Output = Option<Self>;
115
116//     fn inverse(&self) -> Self::Output {
117//         use ndarray_linalg::solve::Inverse;
118//         self.inv().ok()
119//     }
120// }
121
122impl<A, S, D, X, Y> MatMul<X> for ArrayBase<S, D>
123where
124    A: ndarray::LinalgScalar,
125    D: Dimension,
126    S: Data<Elem = A>,
127    ArrayBase<S, D>: Dot<X, Output = Y>,
128{
129    type Output = Y;
130
131    fn matmul(&self, rhs: &X) -> Self::Output {
132        <Self as Dot<X>>::dot(self, rhs)
133    }
134}
135
136impl<T> MatMul<Vec<T>> for Vec<T>
137where
138    T: Copy + num::Num,
139{
140    type Output = T;
141
142    fn matmul(&self, rhs: &Vec<T>) -> T {
143        self.iter()
144            .zip(rhs.iter())
145            .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
146    }
147}
148
149impl<T, const N: usize> MatMul<[T; N]> for [T; N]
150where
151    T: Copy + num::Num,
152{
153    type Output = T;
154
155    fn matmul(&self, rhs: &[T; N]) -> T {
156        self.iter()
157            .zip(rhs.iter())
158            .fold(T::zero(), |acc, (&a, &b)| acc + a * b)
159    }
160}
161impl<A, S> MatPow<i32> for ArrayBase<S, ndarray::Ix2>
162where
163    A: Copy + Num + 'static,
164    S: Data<Elem = A>,
165    ArrayBase<S, Ix2>: Clone + Dot<ArrayBase<S, Ix2>, Output = Array<A, Ix2>>,
166{
167    type Output = Array<A, Ix2>;
168
169    fn matpow(&self, rhs: i32) -> Self::Output {
170        if !self.is_square() {
171            panic!("Matrix must be square to be raised to a power");
172        }
173        let mut res = Array::eye(self.shape()[0]);
174        for _ in 0..rhs {
175            res = res.dot(self);
176        }
177        res
178    }
179}
180
181impl<'a, A, S, D> Transpose for &'a ArrayBase<S, D>
182where
183    A: 'a,
184    D: Dimension,
185    S: Data<Elem = A>,
186{
187    type Output = ndarray::ArrayView<'a, A, D>;
188
189    fn transpose(&self) -> Self::Output {
190        self.t()
191    }
192}