use crate::backend::Backend;
use crate::check::TensorCheck;
use crate::{Tensor, check, linalg};
use burn_std::{DType, FloatDType};
#[allow(unused_imports)]
use num_traits::float::Float;
pub fn det<B: Backend, const D: usize, const D1: usize, const D2: usize>(
mut tensor: Tensor<B, D>,
) -> Tensor<B, D2> {
let dims = tensor.dims();
let original_dtype = tensor.dtype();
check!(TensorCheck::det::<D, D1, D2>(dims, original_dtype));
let needs_upcast = original_dtype == DType::F16 || original_dtype == DType::BF16;
let working_float_dtype: FloatDType;
if needs_upcast {
working_float_dtype = FloatDType::F32;
tensor = tensor.cast(working_float_dtype);
} else {
working_float_dtype = original_dtype.into()
};
let rank = D as isize;
if dims[D - 1] == 1 {
let det_tensor = tensor.squeeze_dims::<D2>(&[rank - 2, rank - 1]);
if needs_upcast {
return det_tensor.cast(original_dtype);
}
return det_tensor;
} else if dims[D - 1] == 2 {
let a = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 0);
let b = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 1);
let c = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 0);
let d = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 1);
let det_tensor = (a * d - b * c).squeeze_dims::<D2>(&[rank - 2, rank - 1]);
if needs_upcast {
return det_tensor.cast(original_dtype);
}
return det_tensor;
} else if dims[D - 1] == 3 {
let a = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 0);
let b = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 1);
let c = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 2);
let d = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 0);
let e = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 1);
let f = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 2);
let g = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 0);
let h = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 1);
let i = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 2);
let det_tensor = (a * (e.clone() * i.clone() - f.clone() * h.clone())
- b * (d.clone() * i - f * g.clone())
+ c * (d * h - e * g))
.squeeze_dims::<D2>(&[rank - 2, rank - 1]);
if needs_upcast {
return det_tensor.cast(original_dtype);
}
return det_tensor;
}
let (lu, pivots) = linalg::compute_lu_decomposition::<B, D, D1>(tensor.clone());
let squeezed_pivots = pivots.squeeze_dim::<D1>(D - 1);
let n_pivots = squeezed_pivots.dims()[D1 - 1] as i64;
let range_1d: Tensor<B, 1> =
Tensor::arange(0..n_pivots, &tensor.device()).cast(working_float_dtype);
let mut reshape_dims = [1; D1];
reshape_dims[D1 - 1] = n_pivots;
let range = range_1d.reshape(reshape_dims);
let expand_dims: [usize; D1] = squeezed_pivots.dims();
let batched_range_tensor = range.expand(expand_dims);
let n_row_swaps = squeezed_pivots
.not_equal(batched_range_tensor)
.int()
.sum_dim(D1 - 1);
let odd_mask = n_row_swaps.clone().remainder_scalar(2).equal_elem(1);
let p_det = n_row_swaps
.cast(working_float_dtype)
.ones_like()
.mask_fill(odd_mask, -1.0)
.squeeze_dim(D1 - 1);
let u_diag = linalg::diag::<B, D, D1, _>(lu);
let mut u_det = u_diag.clone().prod_dim(D1 - 1).squeeze_dim(D1 - 1);
let eps = tensor
.dtype()
.finfo()
.expect("The input tensor to linalg::det should have float dtype.")
.epsilon;
let n = dims[D - 1]; let threshold = u_diag.clone().abs().max_dim(D1 - 1) * (n as f64).sqrt() * eps;
let near_zero = u_diag.abs().lower_equal(threshold);
let singular_mask = near_zero.any_dim(D1 - 1).squeeze_dim::<D2>(D1 - 1);
u_det = u_det.mask_fill(singular_mask, 0.0);
let final_det = p_det * u_det;
if needs_upcast {
final_det.cast(original_dtype)
} else {
final_det
}
}