Skip to main content

burn_tensor/tensor/linalg/
diag.rs

1use crate::backend::Backend;
2use crate::check;
3use crate::check::TensorCheck;
4use crate::tensor::{Int, Tensor};
5use crate::{BasicOps, TensorKind};
6
7/// Returns the diag of a matrix.
8///
9/// For batched inputs, returns of each matrix in the batch independently.
10///
11/// The diag operation extracts the diagonal elements of the last two dimensions,
12/// treating them as the matrix dimensions, while preserving all leading batch dimensions.
13///
14/// # Arguments
15///
16/// * `tensor` - The input tensor with at least 2 dimensions.
17///
18/// # Returns
19/// A tensor of rank `D - 1`, where the last dimension contains the diagonal elements of the input.
20pub fn diag<B: Backend, const D: usize, const DO: usize, K>(
21    tensor: Tensor<B, D, K>,
22) -> Tensor<B, DO, K>
23where
24    K: TensorKind<B> + BasicOps<B>,
25{
26    check!(TensorCheck::diag::<D, DO>());
27
28    let shape = tensor.shape();
29    let rows = shape[D - 2];
30    let cols = shape[D - 1];
31    let diag_len = rows.min(cols);
32    let device = tensor.device();
33
34    // create the indices for the diag
35    let mut flat_shape = shape.clone();
36    flat_shape[D - 2] = rows * cols;
37    flat_shape[D - 1] = 1;
38    let flat: Tensor<B, D, K> = tensor.reshape(flat_shape);
39
40    let range = Tensor::<B, 1, Int>::arange(0..diag_len as i64, &device);
41    let step_tensor = Tensor::<B, 1, Int>::from_data([cols as i64 + 1], &device);
42    let indices = range * step_tensor;
43    flat.take::<1, D>(D - 2, indices).squeeze_dim(D - 1)
44}