burn_tensor/tensor/linalg/
det.rs1use crate::backend::Backend;
2use crate::check::TensorCheck;
3use crate::{Tensor, check, linalg};
4use burn_std::{DType, FloatDType};
5#[allow(unused_imports)]
6use num_traits::float::Float;
7
8pub fn det<B: Backend, const D: usize, const D1: usize, const D2: usize>(
79 mut tensor: Tensor<B, D>,
80) -> Tensor<B, D2> {
81 let dims = tensor.dims();
83 let original_dtype = tensor.dtype();
84 check!(TensorCheck::det::<D, D1, D2>(dims, original_dtype));
85
86 let needs_upcast = original_dtype == DType::F16 || original_dtype == DType::BF16;
88 let working_float_dtype: FloatDType;
89 if needs_upcast {
90 working_float_dtype = FloatDType::F32;
91 tensor = tensor.cast(working_float_dtype);
92 } else {
93 working_float_dtype = original_dtype.into()
94 };
95
96 let rank = D as isize;
98 if dims[D - 1] == 1 {
99 let det_tensor = tensor.squeeze_dims::<D2>(&[rank - 2, rank - 1]);
100 if needs_upcast {
101 return det_tensor.cast(original_dtype);
102 }
103 return det_tensor;
104 } else if dims[D - 1] == 2 {
105 let a = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 0);
106 let b = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 1);
107 let c = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 0);
108 let d = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 1);
109 let det_tensor = (a * d - b * c).squeeze_dims::<D2>(&[rank - 2, rank - 1]);
110 if needs_upcast {
111 return det_tensor.cast(original_dtype);
112 }
113 return det_tensor;
114 } else if dims[D - 1] == 3 {
115 let a = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 0);
116 let b = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 1);
117 let c = tensor.clone().slice_dim(D - 2, 0).slice_dim(D - 1, 2);
118 let d = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 0);
119 let e = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 1);
120 let f = tensor.clone().slice_dim(D - 2, 1).slice_dim(D - 1, 2);
121 let g = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 0);
122 let h = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 1);
123 let i = tensor.clone().slice_dim(D - 2, 2).slice_dim(D - 1, 2);
124 let det_tensor = (a * (e.clone() * i.clone() - f.clone() * h.clone())
125 - b * (d.clone() * i - f * g.clone())
126 + c * (d * h - e * g))
127 .squeeze_dims::<D2>(&[rank - 2, rank - 1]);
128 if needs_upcast {
129 return det_tensor.cast(original_dtype);
130 }
131 return det_tensor;
132 }
133
134 let (lu, pivots) = linalg::compute_lu_decomposition::<B, D, D1>(tensor.clone());
138
139 let squeezed_pivots = pivots.squeeze_dim::<D1>(D - 1);
141 let n_pivots = squeezed_pivots.dims()[D1 - 1] as i64;
142 let range_1d: Tensor<B, 1> =
143 Tensor::arange(0..n_pivots, &tensor.device()).cast(working_float_dtype);
144 let mut reshape_dims = [1; D1];
145 reshape_dims[D1 - 1] = n_pivots;
146 let range = range_1d.reshape(reshape_dims);
147 let expand_dims: [usize; D1] = squeezed_pivots.dims();
148 let batched_range_tensor = range.expand(expand_dims);
149 let n_row_swaps = squeezed_pivots
150 .not_equal(batched_range_tensor)
151 .int()
152 .sum_dim(D1 - 1);
153 let odd_mask = n_row_swaps.clone().remainder_scalar(2).equal_elem(1);
154 let p_det = n_row_swaps
155 .cast(working_float_dtype)
156 .ones_like()
157 .mask_fill(odd_mask, -1.0)
158 .squeeze_dim(D1 - 1);
159
160 let u_diag = linalg::diag::<B, D, D1, _>(lu);
162 let mut u_det = u_diag.clone().prod_dim(D1 - 1).squeeze_dim(D1 - 1);
163 let eps = tensor
164 .dtype()
165 .finfo()
166 .expect("The input tensor to linalg::det should have float dtype.")
167 .epsilon;
168 let n = dims[D - 1]; let threshold = u_diag.clone().abs().max_dim(D1 - 1) * (n as f64).sqrt() * eps;
170 let near_zero = u_diag.abs().lower_equal(threshold);
171 let singular_mask = near_zero.any_dim(D1 - 1).squeeze_dim::<D2>(D1 - 1);
172 u_det = u_det.mask_fill(singular_mask, 0.0);
173
174 let final_det = p_det * u_det;
175
176 if needs_upcast {
178 final_det.cast(original_dtype)
179 } else {
180 final_det
181 }
182}