use crate::{ElementConversion, Tensor, check, check::TensorCheck};
use alloc::vec;
use alloc::vec::Vec;
use burn_backend::{
Backend, Slice,
tensor::{Bool, IndexingUpdateOp, Int},
};
use burn_std::FloatDType;
pub fn lu<B: Backend, const D: usize, const D1: usize>(
tensor: Tensor<B, D>,
) -> (Tensor<B, D>, Tensor<B, D>, Tensor<B, D>) {
let dims = tensor.dims();
check!(TensorCheck::lu_generic_param::<D, D1>("linalg::lu"));
check!(TensorCheck::lu_input_tensor::<D>("linalg::lu", &dims));
let device = tensor.device();
let n_rows = dims[D - 2];
let n_cols = dims[D - 1];
let (lu_tensor, p_compact) = compute_lu_decomposition::<B, D, D1>(tensor);
let u;
let temp_l;
if n_rows < n_cols {
temp_l = lu_tensor.clone().slice_dim(D - 1, 0..n_rows).tril(0);
u = lu_tensor.triu(0);
} else {
temp_l = lu_tensor.clone().tril(0);
u = lu_tensor.slice_dim(D - 2, 0..n_cols).triu(0);
}
let mask = Tensor::<B, D, Bool>::diag_mask(temp_l.shape(), 0, &device).bool_not();
let l = temp_l.mask_fill(mask, 1.0);
let p = construct_full_permutation_tensor(p_compact, n_rows, &device).transpose();
(p, l, u)
}
pub(super) fn compute_lu_decomposition<B: Backend, const D: usize, const D1: usize>(
tensor: Tensor<B, D>,
) -> (Tensor<B, D>, Tensor<B, D>) {
let device = tensor.device();
let dims = tensor.dims();
let n_rows = dims[D - 2];
let n_cols = dims[D - 1];
let size = n_rows.min(n_cols);
if size < 256 {
return standard_lu_with_partial_piv::<B, D, D1>(tensor, &device);
}
block_lu_with_partial_piv::<B, D, D1>(tensor)
}
fn block_lu_with_partial_piv<B: Backend, const D: usize, const D1: usize>(
mut tensor: Tensor<B, D>,
) -> (Tensor<B, D>, Tensor<B, D>) {
let device = tensor.device();
let dims = tensor.dims();
let n_rows = dims[D - 2];
let n_cols = dims[D - 1];
let piv_nums = n_rows.min(n_cols);
let dtype = tensor.dtype().into();
let mut global_piv = create_permutation_tensor::<B, D>(piv_nums, dims, dtype, &device);
let block_size = 128;
let n_blocks = piv_nums.div_ceil(block_size);
let mut slices = vec![Slice::full(); D];
for block_k in 0..n_blocks {
let k_start = block_k * block_size;
let k_end = (k_start + block_size).min(piv_nums);
let current_block_size = k_end - k_start;
let sub_tensor = tensor
.clone()
.slice_dim(D - 2, k_start..)
.slice_dim(D - 1, k_start..k_end);
let (block_column, local_piv) =
standard_lu_with_partial_piv::<B, D, D1>(sub_tensor, &device);
slices[D - 2] = Slice::from(k_start..);
slices[D - 1] = Slice::from(k_start..k_end);
tensor = tensor.slice_assign(&slices, block_column);
global_piv =
update_permutations_to_global_idx(global_piv.clone(), local_piv.clone(), k_start);
if block_k != 0 {
let left_sub_tensor = tensor
.clone()
.slice_dim(D - 2, k_start..)
.slice_dim(D - 1, ..k_start);
let permutated_left_sub_tensor =
apply_permutations_to_tensor(left_sub_tensor, local_piv.clone(), &device);
slices[D - 2] = Slice::from(k_start..);
slices[D - 1] = Slice::from(..k_start);
tensor = tensor.slice_assign(&slices, permutated_left_sub_tensor);
}
if k_end < n_cols {
let right_sub_tensor = tensor
.clone()
.slice_dim(D - 2, k_start..)
.slice_dim(D - 1, k_end..);
let permutated_right_sub_tensor =
apply_permutations_to_tensor(right_sub_tensor, local_piv, &device);
slices[D - 2] = Slice::from(k_start..);
slices[D - 1] = Slice::from(k_end..);
tensor = tensor.slice_assign(&slices, permutated_right_sub_tensor);
let diagonal_l_block = tensor
.clone()
.slice_dim(D - 2, k_start..k_end)
.slice_dim(D - 1, k_start..k_end);
let row_blocks = tensor
.clone()
.slice_dim(D - 2, k_start..k_end)
.slice_dim(D - 1, k_end..);
let updated_row_blocks =
solve_for_u_blocks(diagonal_l_block, row_blocks, current_block_size);
slices[D - 2] = Slice::from(k_start..k_end);
slices[D - 1] = Slice::from(k_end..);
tensor = tensor.slice_assign(&slices, updated_row_blocks.clone());
if k_end < n_rows {
let trailing_a_blocks = tensor
.clone()
.slice_dim(D - 2, k_end..)
.slice_dim(D - 1, k_end..);
let l_col_blocks = tensor
.clone()
.slice_dim(D - 2, k_end..)
.slice_dim(D - 1, k_start..k_end);
let outer_prod = l_col_blocks.matmul(updated_row_blocks);
let new_trailing_a_blocks = trailing_a_blocks - outer_prod;
slices[D - 2] = Slice::from(k_end..);
slices[D - 1] = Slice::from(k_end..);
tensor = tensor.slice_assign(&slices, new_trailing_a_blocks);
}
}
}
(tensor, global_piv)
}
fn standard_lu_with_partial_piv<B: Backend, const D: usize, const D1: usize>(
mut tensor: Tensor<B, D>,
device: &B::Device,
) -> (Tensor<B, D>, Tensor<B, D>) {
let dims = tensor.dims();
let n_rows = dims[D - 2];
let n_cols = dims[D - 1];
let piv_nums = n_rows.min(n_cols);
let dtype = tensor.dtype().into();
let mut permutations = create_permutation_tensor::<B, D>(piv_nums, dims, dtype, device);
for k in 0..piv_nums {
let max_row_indices = tensor
.clone()
.slice_dim(D - 2, k..)
.slice_dim(D - 1, k)
.abs()
.argmax(D - 2)
+ (k as i64);
tensor = swap_tensor_rows(tensor, max_row_indices.clone(), k, device);
permutations = update_permutations(permutations, max_row_indices, k, dtype);
if k < n_rows - 1 {
tensor = update_kth_column(tensor, k);
if k < piv_nums - 1 {
tensor = update_trailing_submatrix::<B, D, D1>(tensor, k);
}
}
}
(tensor, permutations)
}
fn construct_full_permutation_tensor<B: Backend, const D: usize>(
piv: Tensor<B, D>,
n_rows: usize,
device: &B::Device,
) -> Tensor<B, D> {
let dims = piv.dims();
let identity_2d = Tensor::eye(n_rows, device);
let mut reshape_dims = [1; D];
reshape_dims[D - 2] = n_rows;
reshape_dims[D - 1] = n_rows;
let reshaped_identity = identity_2d.reshape(reshape_dims);
let mut expand_dims = [n_rows; D];
expand_dims[..(D - 2)].copy_from_slice(&dims[..(D - 2)]);
let identity = reshaped_identity.expand(expand_dims);
apply_permutations_to_tensor(identity, piv, device)
}
fn create_permutation_tensor<B: Backend, const D: usize>(
piv_nums: usize,
dims: [usize; D],
dtype: FloatDType,
device: &B::Device,
) -> Tensor<B, D> {
let piv = Tensor::arange(0..piv_nums as i64, device).cast(dtype);
let mut reshape_dims = [1; D];
reshape_dims[D - 2] = piv_nums;
let reshaped = piv.reshape(reshape_dims);
let mut expand_dims = [piv_nums; D];
expand_dims[..(D - 2)].copy_from_slice(&dims[..(D - 2)]);
expand_dims[D - 1] = 1;
reshaped.expand(expand_dims)
}
fn swap_tensor_rows<B: Backend, const D: usize>(
tensor: Tensor<B, D>,
mut swap_target_row_tensor: Tensor<B, D, Int>,
k: usize,
device: &B::Device,
) -> Tensor<B, D> {
let mut expand_dims = tensor.dims();
expand_dims[D - 2] = 1;
swap_target_row_tensor = swap_target_row_tensor.expand(expand_dims);
let k_index_tensor =
Tensor::<B, D, Int>::full(swap_target_row_tensor.shape(), k as i32, device);
let val_k = tensor.clone().gather(D - 2, k_index_tensor.clone());
let val_r = tensor.clone().gather(D - 2, swap_target_row_tensor.clone());
let val_k_minus_r = val_k.clone() - val_r.clone();
let val_r_minus_k = val_r - val_k;
let tensor = tensor.scatter(D - 2, k_index_tensor, val_r_minus_k, IndexingUpdateOp::Add);
tensor.scatter(
D - 2,
swap_target_row_tensor,
val_k_minus_r,
IndexingUpdateOp::Add,
)
}
fn update_permutations<B: Backend, const D: usize>(
mut permutations: Tensor<B, D>,
max_row_index_tensor: Tensor<B, D, Int>,
k: usize,
dtype: FloatDType,
) -> Tensor<B, D> {
let mut slices = vec![Slice::full(); D];
slices[D - 2] = Slice::from(k);
let float_max_row_indices = max_row_index_tensor.cast(dtype);
permutations = permutations.slice_assign(&slices, float_max_row_indices);
permutations
}
fn update_kth_column<B: Backend, const D: usize>(tensor: Tensor<B, D>, k: usize) -> Tensor<B, D> {
let a_kk = tensor.clone().slice_dim(D - 2, k).slice_dim(D - 1, k);
let a_rho_k = tensor.clone().slice_dim(D - 2, k + 1..).slice_dim(D - 1, k);
let is_zero_mask = a_kk.clone().equal_elem(0.0);
let safe_a_kk = a_kk.mask_fill(is_zero_mask, 1.0);
let updated_column = a_rho_k / safe_a_kk;
let mut slices = vec![Slice::full(); D];
slices[D - 2] = Slice::from((k + 1)..); slices[D - 1] = Slice::from(k..(k + 1));
tensor.slice_assign(&slices, updated_column)
}
fn update_trailing_submatrix<B: Backend, const D: usize, const D1: usize>(
tensor: Tensor<B, D>,
k: usize,
) -> Tensor<B, D> {
let a_rho_k = tensor.clone().slice_dim(D - 2, k + 1..).slice_dim(D - 1, k);
let a_k_rho = tensor.clone().slice_dim(D - 2, k).slice_dim(D - 1, k + 1..);
let outer_product = a_rho_k.matmul(a_k_rho);
let a_rho_rho = tensor
.clone()
.slice_dim(D - 2, k + 1..)
.slice_dim(D - 1, k + 1..);
let updated_a_rho_rho = a_rho_rho - outer_product;
let mut slices = vec![Slice::full(); D];
slices[D - 2] = Slice::from((k + 1)..); slices[D - 1] = Slice::from((k + 1)..); tensor.slice_assign(&slices, updated_a_rho_rho)
}
fn update_permutations_to_global_idx<B: Backend, const D: usize>(
global_piv: Tensor<B, D>,
local_piv: Tensor<B, D>,
k_start: usize,
) -> Tensor<B, D> {
let n = local_piv.dims()[D - 2];
let mut slices = vec![Slice::full(); D];
slices[D - 2] = Slice::from(k_start..(n + k_start));
let global_val = local_piv.add_scalar(k_start as f32);
global_piv.slice_assign(&slices, global_val)
}
fn apply_permutations_to_tensor<B: Backend, const D: usize>(
tensor: Tensor<B, D>,
piv: Tensor<B, D>,
device: &B::Device,
) -> Tensor<B, D> {
let tensor_dims = tensor.dims();
let n_rows = tensor_dims[D - 2];
let n_pivots = piv.dims()[D - 2];
let piv_data: Vec<f32> = piv.into_data().convert::<f32>().into_vec::<f32>().unwrap();
let batch_size: usize = tensor_dims[..D - 2].iter().product();
if batch_size <= 1 {
let mut perm: Vec<i64> = (0..n_rows as i64).collect();
for (i, piv_val) in piv_data.iter().enumerate().take(n_pivots) {
let j = piv_val.elem::<u32>() as usize;
perm.swap(i, j);
}
let perm_tensor = Tensor::<B, 1, Int>::from_data(&perm[..], device);
return tensor.select(D - 2, perm_tensor);
}
let n_cols = tensor_dims[D - 1];
let flat_tensor: Tensor<B, 3> = tensor.reshape([batch_size, n_rows, n_cols]);
let mut results: Vec<Tensor<B, 3>> = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let mut perm: Vec<i64> = (0..n_rows as i64).collect();
let offset = b * n_pivots;
for i in 0..n_pivots {
let j = piv_data[offset + i].elem::<u32>() as usize;
perm.swap(i, j);
}
let perm_tensor = Tensor::<B, 1, Int>::from_data(&perm[..], device);
let batch_elem = flat_tensor.clone().slice_dim(0, b); let permuted = batch_elem.select(1, perm_tensor); results.push(permuted);
}
let concatenated: Tensor<B, 3> = Tensor::cat(results, 0); concatenated.reshape(tensor_dims)
}
fn solve_for_u_blocks<B: Backend, const D: usize>(
diagonal_l_block: Tensor<B, D>,
mut a_row_blocks: Tensor<B, D>,
block_size: usize,
) -> Tensor<B, D> {
let mut slices = vec![Slice::full(); D];
for i in 1..block_size {
let l_multipliers = diagonal_l_block
.clone()
.slice_dim(D - 2, i)
.slice_dim(D - 1, 0..i);
let u_computed = a_row_blocks.clone().slice_dim(D - 2, 0..i);
let prod = l_multipliers.matmul(u_computed);
let current_rows = a_row_blocks.clone().slice_dim(D - 2, i);
let updated_rows = current_rows - prod;
slices[D - 2] = Slice::from(i);
a_row_blocks = a_row_blocks.slice_assign(&slices, updated_rows);
}
a_row_blocks.clone()
}