use crate::prelude_dev::*;
pub fn into_moveaxis_f<IS, ID, S, D>(tensor: TensorBase<S, D>, source: IS, destination: ID) -> Result<TensorBase<S, D>>
where
D: DimAPI,
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
{
let source = source.try_into().map_err(Into::into)?;
let destination = destination.try_into().map_err(Into::into)?;
let ndim = tensor.ndim();
let source = normalize_axes_index(source, tensor.ndim(), false, false)?;
let destination = normalize_axes_index(destination, tensor.ndim(), false, false)?;
rstsr_assert_eq!(
source.len(),
destination.len(),
InvalidValue,
"`source` and `destination` arguments must have the same number of elements"
)?;
let mut order: Vec<isize> = (0..ndim as isize).filter(|&i| !source.contains(&i)).collect();
let mut pairs: Vec<(isize, isize)> = destination.iter().zip(source.iter()).map(|(&d, &s)| (d, s)).collect();
pairs.sort_by_key(|&(d, _)| d);
for (dest, src) in pairs {
order.insert(dest as usize, src);
}
let (storage, layout) = tensor.into_raw_parts();
let layout = layout.transpose(&order)?;
unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
}
pub fn moveaxis<IS, ID, R, T, B, D>(
tensor: &TensorAny<R, T, B, D>,
source: IS,
destination: ID,
) -> TensorView<'_, T, B, D>
where
D: DimAPI,
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
R: DataAPI<Data = B::Raw>,
B: DeviceAPI<T>,
{
into_moveaxis_f(tensor.view(), source, destination).rstsr_unwrap()
}
pub fn moveaxis_f<IS, ID, R, T, B, D>(
tensor: &TensorAny<R, T, B, D>,
source: IS,
destination: ID,
) -> Result<TensorView<'_, T, B, D>>
where
D: DimAPI,
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
R: DataAPI<Data = B::Raw>,
B: DeviceAPI<T>,
{
into_moveaxis_f(tensor.view(), source, destination)
}
pub fn into_moveaxis<IS, ID, S, D>(tensor: TensorBase<S, D>, source: IS, destination: ID) -> TensorBase<S, D>
where
D: DimAPI,
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
{
into_moveaxis_f(tensor, source, destination).rstsr_unwrap()
}
impl<R, T, B, D> TensorAny<R, T, B, D>
where
R: DataAPI<Data = B::Raw>,
B: DeviceAPI<T>,
D: DimAPI,
{
pub fn moveaxis<IS, ID>(&self, source: IS, destination: ID) -> TensorView<'_, T, B, D>
where
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
{
moveaxis(self, source, destination)
}
pub fn moveaxis_f<IS, ID>(&self, source: IS, destination: ID) -> Result<TensorView<'_, T, B, D>>
where
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
{
moveaxis_f(self, source, destination)
}
pub fn into_moveaxis<IS, ID>(self, source: IS, destination: ID) -> TensorAny<R, T, B, D>
where
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
{
into_moveaxis(self, source, destination)
}
pub fn into_moveaxis_f<IS, ID>(self, source: IS, destination: ID) -> Result<TensorAny<R, T, B, D>>
where
IS: TryInto<AxesIndex<isize>, Error: Into<Error>>,
ID: TryInto<AxesIndex<isize>, Error: Into<Error>>,
{
into_moveaxis_f(self, source, destination)
}
}