burn_tensor/tensor/api/argwhere.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
use crate::{backend::Backend, ops::IntTensor, Device, ElementConversion, Shape, TensorData};
use alloc::vec::Vec;
/// Compute the indices of the elements that are non-zero, grouped by element.
///
/// # Arguments
///
/// * `data` - The input tensor data.
///
/// # Returns
///
/// A 2D tensor containing the indices of all non-zero elements of the given tensor.
/// Each row contains the indices of a non-zero element.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn argwhere_data<B: Backend>(data: TensorData, device: &Device<B>) -> IntTensor<B> {
let dims = &data.shape;
let ndims = dims.len();
let count_nonzero = data.iter::<bool>().filter(|&v| v).count();
/// Converts a flat index into a vector of indices for the specified tensor shape
fn unravel_index<B: Backend>(index: usize, shape: &[usize]) -> Vec<B::IntElem> {
shape
.iter()
.rev()
.scan(index, |i, size| {
let dim_idx = *i % size;
*i /= size;
Some((dim_idx as i64).elem())
})
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect()
}
let indices = data
.iter::<bool>()
.enumerate()
.filter_map(|(index, v)| if v { Some(index) } else { None })
.map(|index| unravel_index::<B>(index, dims))
.collect::<Vec<_>>()
.concat();
B::int_from_data(
TensorData::new(indices, Shape::new([count_nonzero, ndims])),
device,
)
}