use crate::{
Int, backend::Backend, cast::ToElement, check, check::TensorCheck, linalg::swap_slices, s,
tensor::Tensor,
};
pub fn lu_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {
check!(TensorCheck::is_square::<2>(
"lu_decomposition",
&tensor.shape()
));
let dims = tensor.shape().dims::<2>();
let n = dims[0];
let mut permutations = Tensor::arange(0..n as i64, &tensor.device());
let mut tensor = tensor;
for k in 0..n {
let p = tensor
.clone()
.slice(s![k.., k])
.abs()
.argmax(0)
.into_scalar()
.to_usize()
+ k;
let max = tensor.clone().slice(s![p, k]).abs();
let pivot = max.into_scalar();
check!(TensorCheck::lu_decomposition_pivot::<B>(pivot));
if p != k {
tensor = swap_slices(tensor, s![k, ..], s![p, ..]);
permutations = swap_slices(permutations, s![k], s![p]);
}
if k < n - 1 {
let a_kk = tensor.clone().slice(s![k, k]);
let column = tensor.clone().slice(s![(k + 1).., k]) / a_kk;
tensor = tensor.slice_assign(s![(k + 1).., k], column);
}
for i in (k + 1)..n {
let a_ik = tensor.clone().slice(s![i, k]);
let row_k = tensor.clone().slice(s![k, (k + 1)..]);
let update = a_ik * row_k;
let row_i = tensor.clone().slice(s![i, (k + 1)..]);
tensor = tensor.slice_assign(s![i, (k + 1)..], row_i - update);
}
}
(tensor, permutations)
}