use crate::tensor::{Device, IntTensor};
use crate::{Backend, TensorData, element::ElementConversion};
use alloc::vec::Vec;
use burn_std::Shape;
pub fn argwhere_data<B: Backend>(data: TensorData, device: &Device<B>) -> IntTensor<B> {
let dims = &data.shape;
let ndims = dims.len();
let count_nonzero = data.iter::<bool>().filter(|&v| v).count();
fn unravel_index<B: Backend>(index: usize, shape: &[usize]) -> Vec<B::IntElem> {
shape
.iter()
.rev()
.scan(index, |i, size| {
let dim_idx = *i % size;
*i /= size;
Some((dim_idx as i64).elem())
})
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect()
}
let indices = data
.iter::<bool>()
.enumerate()
.filter_map(|(index, v)| if v { Some(index) } else { None })
.map(|index| unravel_index::<B>(index, dims))
.collect::<Vec<_>>()
.concat();
B::int_from_data(
TensorData::new(indices, Shape::new([count_nonzero, ndims])),
device,
)
}