use crate::backend::Backend;
use crate::check;
use crate::check::TensorCheck;
use crate::tensor::{Int, Tensor};
use crate::{BasicOps, TensorKind};
pub fn diag<B: Backend, const D: usize, const DO: usize, K>(
tensor: Tensor<B, D, K>,
) -> Tensor<B, DO, K>
where
K: TensorKind<B> + BasicOps<B>,
{
check!(TensorCheck::diag::<D, DO>());
let shape = tensor.shape();
let rows = shape[D - 2];
let cols = shape[D - 1];
let diag_len = rows.min(cols);
let device = tensor.device();
let mut flat_shape = shape.clone();
flat_shape[D - 2] = rows * cols;
flat_shape[D - 1] = 1;
let flat: Tensor<B, D, K> = tensor.reshape(flat_shape);
let range = Tensor::<B, 1, Int>::arange(0..diag_len as i64, &device);
let step_tensor = Tensor::<B, 1, Int>::from_data([cols as i64 + 1], &device);
let indices = range * step_tensor;
flat.take::<1, D>(D - 2, indices).squeeze_dim(D - 1)
}