1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
use crate::{
backend::Backend,
ops::{BoolTensor, IntTensor},
Data, Device, ElementConversion, Shape,
};
use alloc::vec::Vec;
/// Compute the indices of the elements that are non-zero, grouped by element.
///
/// # Arguments
///
/// * `tensor` - The input tensor.
///
/// # Returns
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
pub fn argwhere<B: Backend, const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
// Reading the data to count the number of truth values might cause sync but is required.
// let dims = B::bool_shape(&tensor).dims;
let device = B::bool_device(&tensor);
let data = B::bool_into_data(tensor).read();
argwhere_data::<B, D>(data, &device)
}
/// Compute the indices of the elements that are non-zero, grouped by element.
///
/// # Arguments
///
/// * `tensor` - The input tensor.
///
/// # Returns
///
/// A vector of tensors, one for each dimension of the given tensor, containing the indices of
/// the non-zero elements in that dimension.
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argwhere<B: Backend, const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
// Size of each output tensor is variable (= number of nonzero elements in the tensor).
// Reading the data to count the number of truth values might cause sync but is required.
let device = B::bool_device(&tensor);
let data = B::bool_into_data(tensor).read().await;
argwhere_data::<B, D>(data, &device)
}
fn argwhere_data<B: Backend, const D: usize>(
data: Data<bool, D>,
device: &Device<B>,
) -> IntTensor<B, 2> {
let dims = data.shape.dims;
let count_nonzero = data.value.iter().filter(|&v| *v).count();
/// Converts a flat index into a vector of indices for the specified tensor shape
fn unravel_index<B: Backend, const D: usize>(
index: usize,
shape: &[usize; D],
) -> Vec<B::IntElem> {
shape
.iter()
.rev()
.scan(index, |i, size| {
let dim_idx = *i % size;
*i /= size;
Some((dim_idx as i64).elem())
})
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect()
}
let indices = data
.value
.iter()
.enumerate()
.filter_map(|(index, &v)| if v { Some(index) } else { None })
.map(|index| unravel_index::<B, D>(index, &dims))
.collect::<Vec<_>>()
.concat();
B::int_from_data(Data::new(indices, Shape::new([count_nonzero, D])), device)
}