burn_tensor/tensor/linalg/
lu_decomposition.rs

1use crate::{
2    Int, backend::Backend, cast::ToElement, check, check::TensorCheck, linalg::swap_slices, s,
3    tensor::Tensor,
4};
5/// Performs PLU decomposition of a square matrix.
6///
7/// The function decomposes a given square matrix `A` into three matrices: a permutation vector `p`,
8/// a lower triangular matrix `L`, and an upper triangular matrix `U`, such that `PA = LU`.
9/// The permutation vector `p` represents the row swaps made during the decomposition process.
10/// The lower triangular matrix `L` has ones on its diagonal and contains the multipliers used
11/// during the elimination process below the diagonal. The upper triangular matrix `U` contains
12/// the resulting upper triangular form of the matrix after the elimination process.
13///
14/// # Arguments
15/// * `tensor` - A square matrix to decompose, represented as a 2D tensor.
16///
17/// # Returns
18/// A tuple containing:
19/// - A 2D tensor representing the combined `L` and `U` matrices.
20/// - A 1D tensor representing the permutation vector `p`.
21///
22/// # Panics and numerical issues
23/// - The function will panic if the input matrix is singular or near-singular.
24/// - The function will panic if the input matrix is not square.
25/// # Performance note (synchronization / device transfers)
26/// This function may involve multiple synchronizations and device transfers, especially
27/// when determining pivot elements and performing row swaps. This can impact performance,
28pub fn lu_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
29    check!(TensorCheck::is_square::<2>(
30        "lu_decomposition",
31        &tensor.shape()
32    ));
33    let dims = tensor.shape().dims::<2>();
34    let n = dims[0];
35
36    let mut permutations = Tensor::arange(0..n as i64, &tensor.device());
37    let mut tensor = tensor;
38
39    for k in 0..n {
40        // Find the pivot row
41        let p = tensor
42            .clone()
43            .slice(s![k.., k])
44            .abs()
45            .argmax(0)
46            .into_scalar()
47            .to_usize()
48            + k;
49        let max = tensor.clone().slice(s![p, k]).abs();
50
51        // Avoid division by zero
52        let pivot = max.into_scalar();
53        check!(TensorCheck::lu_decomposition_pivot::<B>(pivot));
54
55        if p != k {
56            tensor = swap_slices(tensor, s![k, ..], s![p, ..]);
57            permutations = swap_slices(permutations, s![k], s![p]);
58        }
59
60        // Normalize k-th column under the diagonal
61        if k < n - 1 {
62            let a_kk = tensor.clone().slice(s![k, k]);
63            let column = tensor.clone().slice(s![(k + 1).., k]) / a_kk;
64            tensor = tensor.slice_assign(s![(k + 1).., k], column);
65        }
66
67        // Update the trailing submatrix
68        for i in (k + 1)..n {
69            // a[i, k+1..] -=  a[i, k] * a[k, k+1..]
70            let a_ik = tensor.clone().slice(s![i, k]);
71            let row_k = tensor.clone().slice(s![k, (k + 1)..]);
72            let update = a_ik * row_k;
73            let row_i = tensor.clone().slice(s![i, (k + 1)..]);
74            tensor = tensor.slice_assign(s![i, (k + 1)..], row_i - update);
75        }
76    }
77
78    (tensor, permutations)
79}