burn_tensor/tensor/api/
argwhere.rs

1use crate::{backend::Backend, ops::IntTensor, Device, ElementConversion, Shape, TensorData};
2use alloc::vec::Vec;
3
4/// Compute the indices of the elements that are non-zero, grouped by element.
5///
6/// # Arguments
7///
8/// * `data` - The input tensor data.
9///
10/// # Returns
11///
12/// A 2D tensor containing the indices of all non-zero elements of the given tensor.
13/// Each row contains the indices of a non-zero element.
14///
15/// # Remarks
16///
17/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
18/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
19/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
20/// or use this function directly.
21pub fn argwhere_data<B: Backend>(data: TensorData, device: &Device<B>) -> IntTensor<B> {
22    let dims = &data.shape;
23    let ndims = dims.len();
24    let count_nonzero = data.iter::<bool>().filter(|&v| v).count();
25
26    /// Converts a flat index into a vector of indices for the specified tensor shape
27    fn unravel_index<B: Backend>(index: usize, shape: &[usize]) -> Vec<B::IntElem> {
28        shape
29            .iter()
30            .rev()
31            .scan(index, |i, size| {
32                let dim_idx = *i % size;
33                *i /= size;
34                Some((dim_idx as i64).elem())
35            })
36            .collect::<Vec<_>>()
37            .into_iter()
38            .rev()
39            .collect()
40    }
41
42    let indices = data
43        .iter::<bool>()
44        .enumerate()
45        .filter_map(|(index, v)| if v { Some(index) } else { None })
46        .map(|index| unravel_index::<B>(index, dims))
47        .collect::<Vec<_>>()
48        .concat();
49
50    B::int_from_data(
51        TensorData::new(indices, Shape::new([count_nonzero, ndims])),
52        device,
53    )
54}