burn_tensor/tensor/linalg/
diag.rs1use crate::backend::Backend;
2use crate::check;
3use crate::check::TensorCheck;
4use crate::tensor::{Int, Tensor};
5use crate::{BasicOps, TensorKind};
6
7pub fn diag<B: Backend, const D: usize, const DO: usize, K>(
21 tensor: Tensor<B, D, K>,
22) -> Tensor<B, DO, K>
23where
24 K: TensorKind<B> + BasicOps<B>,
25{
26 check!(TensorCheck::diag::<D, DO>());
27
28 let shape = tensor.shape();
29 let rows = shape[D - 2];
30 let cols = shape[D - 1];
31 let diag_len = rows.min(cols);
32 let device = tensor.device();
33
34 let mut flat_shape = shape.clone();
36 flat_shape[D - 2] = rows * cols;
37 flat_shape[D - 1] = 1;
38 let flat: Tensor<B, D, K> = tensor.reshape(flat_shape);
39
40 let range = Tensor::<B, 1, Int>::arange(0..diag_len as i64, &device);
41 let step_tensor = Tensor::<B, 1, Int>::from_data([cols as i64 + 1], &device);
42 let indices = range * step_tensor;
43 flat.take::<1, D>(D - 2, indices).squeeze_dim(D - 1)
44}