burn_tensor/tensor/linalg/lu_decomposition.rs
1use crate::{
2 Int, backend::Backend, cast::ToElement, check, check::TensorCheck, linalg::swap_slices, s,
3 tensor::Tensor,
4};
5/// Performs PLU decomposition of a square matrix.
6///
7/// The function decomposes a given square matrix `A` into three matrices: a permutation vector `p`,
8/// a lower triangular matrix `L`, and an upper triangular matrix `U`, such that `PA = LU`.
9/// The permutation vector `p` represents the row swaps made during the decomposition process.
10/// The lower triangular matrix `L` has ones on its diagonal and contains the multipliers used
11/// during the elimination process below the diagonal. The upper triangular matrix `U` contains
12/// the resulting upper triangular form of the matrix after the elimination process.
13///
14/// # Arguments
15/// * `tensor` - A square matrix to decompose, represented as a 2D tensor.
16///
17/// # Returns
18/// A tuple containing:
19/// - A 2D tensor representing the combined `L` and `U` matrices.
20/// - A 1D tensor representing the permutation vector `p`.
21///
22/// # Panics and numerical issues
23/// - The function will panic if the input matrix is singular or near-singular.
24/// - The function will panic if the input matrix is not square.
25/// # Performance note (synchronization / device transfers)
26/// This function may involve multiple synchronizations and device transfers, especially
27/// when determining pivot elements and performing row swaps. This can impact performance,
28pub fn lu_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
29 check!(TensorCheck::is_square::<2>(
30 "lu_decomposition",
31 &tensor.shape()
32 ));
33 let dims = tensor.shape().dims::<2>();
34 let n = dims[0];
35
36 let mut permutations = Tensor::arange(0..n as i64, &tensor.device());
37 let mut tensor = tensor;
38
39 for k in 0..n {
40 // Find the pivot row
41 let p = tensor
42 .clone()
43 .slice(s![k.., k])
44 .abs()
45 .argmax(0)
46 .into_scalar()
47 .to_usize()
48 + k;
49 let max = tensor.clone().slice(s![p, k]).abs();
50
51 // Avoid division by zero
52 let pivot = max.into_scalar();
53 check!(TensorCheck::lu_decomposition_pivot::<B>(pivot));
54
55 if p != k {
56 tensor = swap_slices(tensor, s![k, ..], s![p, ..]);
57 permutations = swap_slices(permutations, s![k], s![p]);
58 }
59
60 // Normalize k-th column under the diagonal
61 if k < n - 1 {
62 let a_kk = tensor.clone().slice(s![k, k]);
63 let column = tensor.clone().slice(s![(k + 1).., k]) / a_kk;
64 tensor = tensor.slice_assign(s![(k + 1).., k], column);
65 }
66
67 // Update the trailing submatrix
68 for i in (k + 1)..n {
69 // a[i, k+1..] -= a[i, k] * a[k, k+1..]
70 let a_ik = tensor.clone().slice(s![i, k]);
71 let row_k = tensor.clone().slice(s![k, (k + 1)..]);
72 let update = a_ik * row_k;
73 let row_i = tensor.clone().slice(s![i, (k + 1)..]);
74 tensor = tensor.slice_assign(s![i, (k + 1)..], row_i - update);
75 }
76 }
77
78 (tensor, permutations)
79}