#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::dtype::DType;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::error::{Error, Result};
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::ops::{RandomOps, SortingOps};
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::runtime::Runtime;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::tensor::Tensor;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub fn randperm_impl<R, C>(client: &C, n: usize) -> Result<Tensor<R>>
where
R: Runtime,
C: RandomOps<R> + SortingOps<R>,
{
if n == 0 {
return Err(Error::InvalidArgument {
arg: "n",
reason: "randperm requires n > 0".to_string(),
});
}
let keys = client.rand(&[n], DType::F32)?;
let perm = client.argsort(&keys, 0, false)?;
Ok(perm)
}