burn_tensor/tensor/api/argwhere.rs
1use crate::{backend::Backend, ops::IntTensor, Device, ElementConversion, Shape, TensorData};
2use alloc::vec::Vec;
3
4/// Compute the indices of the elements that are non-zero, grouped by element.
5///
6/// # Arguments
7///
8/// * `data` - The input tensor data.
9///
10/// # Returns
11///
12/// A 2D tensor containing the indices of all non-zero elements of the given tensor.
13/// Each row contains the indices of a non-zero element.
14///
15/// # Remarks
16///
17/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
18/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
19/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
20/// or use this function directly.
21pub fn argwhere_data<B: Backend>(data: TensorData, device: &Device<B>) -> IntTensor<B> {
22 let dims = &data.shape;
23 let ndims = dims.len();
24 let count_nonzero = data.iter::<bool>().filter(|&v| v).count();
25
26 /// Converts a flat index into a vector of indices for the specified tensor shape
27 fn unravel_index<B: Backend>(index: usize, shape: &[usize]) -> Vec<B::IntElem> {
28 shape
29 .iter()
30 .rev()
31 .scan(index, |i, size| {
32 let dim_idx = *i % size;
33 *i /= size;
34 Some((dim_idx as i64).elem())
35 })
36 .collect::<Vec<_>>()
37 .into_iter()
38 .rev()
39 .collect()
40 }
41
42 let indices = data
43 .iter::<bool>()
44 .enumerate()
45 .filter_map(|(index, v)| if v { Some(index) } else { None })
46 .map(|index| unravel_index::<B>(index, dims))
47 .collect::<Vec<_>>()
48 .concat();
49
50 B::int_from_data(
51 TensorData::new(indices, Shape::new([count_nonzero, ndims])),
52 device,
53 )
54}