Skip to main content

burn_tensor/tensor/linalg/
det.rs

1use crate::backend::Backend;
2use crate::check::TensorCheck;
3use crate::{Tensor, check, linalg};
4use burn_std::{DType, FloatDType};
5#[allow(unused_imports)]
6use num_traits::float::Float;
7
8/// Computes the determinant on the last two dimensions of the input tensor.
9///
10/// # Arguments
11/// - `tensor` - The input tensor of shape `[..., N, N]`.
12///
13/// # Returns
14/// - The determinant tensor of shape `[...]` where its rank is less than the
15///   input tensor's rank by two.
16///
17/// # Generic Parameters
18/// - `D`: The rank of the input tensor.
19/// - `D1`: Must be set to `D - 1`.
20/// - `D2`: Must be set to `D - 2`.
21///
22/// # Panics
23/// This function will panic if:
24/// - The generic parameters do not satisfy `D - 1 == D1`.
25/// - The generic parameters do not satisfy `D - 2 == D2`.
26/// - The input tensor rank `D` is less than 3.
27/// - The last two dimensions of the input tensor are not equal.
28/// - The input is a quantized tensor with dtype `DType::QFloat`.
29///
30/// # Performance Note
31/// The determinant for 1 by 1, 2 by 2, and 3 by 3 matrices are computed using closed-form
32/// expressions. For larger matrices (4 by 4 or larger), the determinant function relies on
33/// the LU decomposition function under the hood,which is not fully optimized. It will not be
34/// as fast as highly tuned specialized libraries, especially for very large matrices or large
35/// batch sizes.
36///
37/// # Numerical Behavior
38/// - If the input tensors have types F16 or BF16, then they are internally upcast to
39///   F32 to perform the computations and cast back to the original data type (F16 or BF16)
40///   right before the function returns.
41/// - In this case, if the determinant values fall outside of the original data type's
42///   range, then the cast-back will underflow to zero.
43///
44/// # Example
45/// ```rust,ignore
46/// use burn::tensor::Tensor;
47/// use burn::tensor::linalg;
48///
49/// fn example<B: Backend>() {
50///     let device = Default::default();
51///     let tensor = Tensor::<B, 3>::from_data([[[4.0, 3.0], [6.0, 3.0]]], &device);
52///
53///     // Compute determinant
54///     let result = linalg::det::<B, 3, 2, 1>(tensor);
55///
56///     // Expected Output:
57///     // result: [-6.0]
58/// }
59///
60/// fn example2<B: Backend>() {
61///     let device = Default::default();
62///     let tensor = Tensor::<B, 3>::from_data(
63///         [
64///             [[1.0, 2.0], [3.0, 4.0]],   // det = -2
65///             [[2.0, 0.0], [0.0, 3.0]],   // det = 6
66///             [[5.0, 6.0], [7.0, 8.0]],   // det = -2
67///         ],
68///         &device,
69///     );
70///
71///     // Compute determinant
72///     let result = linalg::det::<B, 3, 2, 1>(tensor);
73///
74///     // Expected Output:
75///     // result: [-2.0, 6.0, -2.0]
76/// }
77/// ```
78pub fn det<B: Backend, const D: usize, const D1: usize, const D2: usize>(
79    mut tensor: Tensor<B, D>,
80) -> Tensor<B, D2> {
81    // Check whether input tensor has valid shape to compute determinant
82    let dims = tensor.dims();
83    let original_dtype = tensor.dtype();
84    check!(TensorCheck::det::<D, D1, D2>(dims, original_dtype));
85
86    // Upcast f16 and bf16 to f32
87    let needs_upcast = original_dtype == DType::F16 || original_dtype == DType::BF16;
88    let working_float_dtype: FloatDType;
89    if needs_upcast {
90        working_float_dtype = FloatDType::F32;
91        tensor = tensor.cast(working_float_dtype);
92    } else {
93        working_float_dtype = original_dtype.into()
94    };
95
96    // Compute determinant for base cases (1x1, 2x2, and 3x3 matrices)
97    let rank = D as isize;
98    if dims[D - 1] == 1 {
99        let det_tensor = tensor.squeeze_dims::<D2>(&[rank - 2, rank - 1]);
100        if needs_upcast {
101            return det_tensor.cast(original_dtype);
102        }
103        return det_tensor;
104    } else if dims[D - 1] == 2 {
105        let a = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 0);
106        let b = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 1);
107        let c = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 0);
108        let d = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 1);
109        let det_tensor = (a * d - b * c).squeeze_dims::<D2>(&[rank - 2, rank - 1]);
110        if needs_upcast {
111            return det_tensor.cast(original_dtype);
112        }
113        return det_tensor;
114    } else if dims[D - 1] == 3 {
115        let a = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 0);
116        let b = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 1);
117        let c = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 2);
118        let d = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 0);
119        let e = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 1);
120        let f = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 2);
121        let g = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 0);
122        let h = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 1);
123        let i = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 2);
124        let det_tensor = (a * (e.clone() * i.clone() - f.clone() * h.clone())
125            - b * (d.clone() * i - f * g.clone())
126            + c * (d * h - e * g))
127            .squeeze_dims::<D2>(&[rank - 2, rank - 1]);
128        if needs_upcast {
129            return det_tensor.cast(original_dtype);
130        }
131        return det_tensor;
132    }
133
134    // Compute determinant for general case
135    // det(A) = det(P) * det(L) * det(U)
136    // det(A) = det(P) * 1 * det(U)
137    let (lu, pivots) = linalg::compute_lu_decomposition::<B, D, D1>(tensor.clone());
138
139    // Compute the determinant of P
140    let squeezed_pivots = pivots.squeeze_dim::<D1>(D - 1);
141    let n_pivots = squeezed_pivots.dims()[D1 - 1] as i64;
142    let range_1d: Tensor<B, 1> =
143        Tensor::arange(0..n_pivots, &tensor.device()).cast(working_float_dtype);
144    let mut reshape_dims = [1; D1];
145    reshape_dims[D1 - 1] = n_pivots;
146    let range = range_1d.reshape(reshape_dims);
147    let expand_dims: [usize; D1] = squeezed_pivots.dims();
148    let batched_range_tensor = range.expand(expand_dims);
149    let n_row_swaps = squeezed_pivots
150        .not_equal(batched_range_tensor)
151        .int()
152        .sum_dim(D1 - 1);
153    let odd_mask = n_row_swaps.clone().remainder_scalar(2).equal_elem(1);
154    let p_det = n_row_swaps
155        .cast(working_float_dtype)
156        .ones_like()
157        .mask_fill(odd_mask, -1.0)
158        .squeeze_dim(D1 - 1);
159
160    // Compute the determinant of U
161    let u_diag = linalg::diag::<B, D, D1, _>(lu);
162    let mut u_det = u_diag.clone().prod_dim(D1 - 1).squeeze_dim(D1 - 1);
163    let eps = tensor
164        .dtype()
165        .finfo()
166        .expect("The input tensor to linalg::det should have float dtype.")
167        .epsilon;
168    let n = dims[D - 1]; // The input tensor contains n by n matrices
169    let threshold = u_diag.clone().abs().max_dim(D1 - 1) * (n as f64).sqrt() * eps;
170    let near_zero = u_diag.abs().lower_equal(threshold);
171    let singular_mask = near_zero.any_dim(D1 - 1).squeeze_dim::<D2>(D1 - 1);
172    u_det = u_det.mask_fill(singular_mask, 0.0);
173
174    let final_det = p_det * u_det;
175
176    // Cast back to original dtypes
177    if needs_upcast {
178        final_det.cast(original_dtype)
179    } else {
180        final_det
181    }
182}