acme_tensor/impls/
linalg.rs

1/*
2    Appellation: linalg <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5//! Implementations for linear algebra operations.
6//!
7//!
8use crate::linalg::{Inverse, Matmul};
9use crate::prelude::{Scalar, ShapeError, TensorError, TensorExpr, TensorResult};
10use crate::tensor::{self, TensorBase};
11use acme::prelude::{nested, UnaryOp};
12use num::traits::{Num, NumAssign};
13
14fn inverse_impl<T>(tensor: &TensorBase<T>) -> TensorResult<TensorBase<T>>
15where
16    T: Copy + NumAssign + PartialOrd,
17{
18    let op = TensorExpr::unary(tensor.clone(), UnaryOp::Inv);
19    let n = tensor.nrows();
20
21    if !tensor.is_square() {
22        return Err(ShapeError::NotSquare.into()); // Matrix must be square for inversion
23    }
24
25    let eye = TensorBase::eye(n);
26
27    // Construct an augmented matrix by concatenating the original matrix with an identity matrix
28    let mut aug = TensorBase::zeros((n, 2 * n));
29    // aug.slice_mut(s![.., ..cols]).assign(matrix);
30    for i in 0..n {
31        for j in 0..n {
32            aug[[i, j]] = tensor[[i, j]];
33        }
34        for j in n..(2 * n) {
35            aug[[i, j]] = eye[[i, j - n]];
36        }
37    }
38
39    // Perform Gaussian elimination to reduce the left half to the identity matrix
40    for i in 0..n {
41        let pivot = aug[[i, i]];
42
43        if pivot == T::zero() {
44            return Err(TensorError::Singular); // Matrix is singular
45        }
46
47        for j in 0..(2 * n) {
48            aug[[i, j]] /= pivot;
49        }
50
51        for j in 0..n {
52            if i != j {
53                let am = aug.clone();
54                let factor = aug[[j, i]];
55                for k in 0..(2 * n) {
56                    aug[[j, k]] -= factor * am[[i, k]];
57                }
58            }
59        }
60    }
61
62    // Extract the inverted matrix from the augmented matrix
63    let mut inv = tensor.zeros_like().with_op(op.into());
64    for i in 0..n {
65        for j in 0..n {
66            inv[[i, j]] = aug[[i, j + n]];
67        }
68    }
69
70    Ok(inv.to_owned())
71}
72
73impl<T> TensorBase<T>
74where
75    T: Copy,
76{
77    /// Creates a new tensor containing only the diagonal elements of the original tensor.
78    pub fn diag(&self) -> Self {
79        let n = self.nrows();
80        Self::from_shape_iter(self.shape().diag(), (0..n).map(|i| self[vec![i; n]]))
81    }
82    /// Find the inverse of the tensor
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the matrix is not square or if the matrix is singular.
87    ///
88    pub fn inv(&self) -> TensorResult<Self>
89    where
90        T: NumAssign + PartialOrd,
91    {
92        inverse_impl(self)
93    }
94    /// Compute the trace of the matrix.
95    /// The trace of a matrix is the sum of the diagonal elements.
96    pub fn trace(&self) -> TensorResult<T>
97    where
98        T: Num,
99    {
100        if !self.is_square() {
101            return Err(ShapeError::NotSquare.into());
102        }
103        let n = self.nrows();
104        let trace = (0..n).fold(T::zero(), |acc, i| acc + self[[i, i]]);
105        Ok(trace)
106    }
107}
108
109impl<T> Inverse for TensorBase<T>
110where
111    T: Copy + Num + NumAssign + PartialOrd,
112{
113    type Output = TensorResult<Self>;
114
115    fn inv(&self) -> Self::Output {
116        inverse_impl(self)
117    }
118}
119
120impl<T> Matmul<TensorBase<T>> for TensorBase<T>
121where
122    T: Scalar,
123{
124    type Output = Self;
125
126    fn matmul(&self, other: &Self) -> Self {
127        let sc = |m: usize, n: usize| m * self.ncols() + n;
128        let oc = |m: usize, n: usize| m * other.ncols() + n;
129
130        let shape = self.shape().matmul_shape(other.shape()).unwrap();
131        let mut result = vec![T::zero(); shape.size()];
132
133        nested!(i in 0..self.nrows() => j in 0..other.ncols() => k in 0..self.ncols() => {
134            result[oc(i, j)] += self.data[sc(i, k)] * other.data[oc(k, j)]
135        });
136        let op = TensorExpr::matmul(self.clone(), other.clone());
137        tensor::from_vec_with_op(false, op, shape, result)
138    }
139}