use crate::device::Device;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::int_tensor::{IntElement, IntTensor};
use crate::shape::normalize_axis;
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
fn factor(shape: &[usize], dim: usize) -> (usize, usize, usize) {
let outer: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let inner: usize = shape[dim + 1..].iter().product();
(outer, dim_size, inner)
}
fn shape_without(shape: &[usize], dim: usize) -> Vec<usize> {
let mut s = shape.to_vec();
s.remove(dim);
s
}
fn arg_reduce_ref<V: Copy>(
data: &[V],
outer: usize,
dim_size: usize,
inner: usize,
better: impl Fn(V, V) -> bool,
) -> Vec<i64> {
let mut out = vec![0i64; outer * inner];
for o in 0..outer {
for k in 0..inner {
let base = o * dim_size * inner + k;
let mut best_j = 0usize;
let mut best = data[base];
for j in 1..dim_size {
let v = data[base + j * inner];
if better(v, best) {
best = v;
best_j = j;
}
}
out[o * inner + k] = best_j as i64;
}
}
out
}
fn tensor_arg<T: Float>(
input: &Tensor<T>,
dim: Option<isize>,
is_max: bool,
) -> FerrotorchResult<IntTensor<i64>> {
let op = if is_max { "argmax" } else { "argmin" };
if input.numel() == 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("{op}: cannot reduce an empty tensor"),
});
}
let input = input.contiguous()?;
let (outer, dim_size, inner, out_shape) = match dim {
None => (1usize, input.numel(), 1usize, Vec::new()),
Some(d) => {
let d = normalize_axis(d, input.ndim())?;
let (o, ds, inn) = factor(input.shape(), d);
(o, ds, inn, shape_without(input.shape(), d))
}
};
if input.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = input.gpu_handle()?;
let out_h = if is_max {
backend.argmax(h, outer, dim_size, inner)?
} else {
backend.argmin(h, outer, dim_size, inner)?
};
Ok(IntTensor::from_gpu_handle(out_h, out_shape))
} else {
let data = input.data_vec()?;
let out = if is_max {
arg_reduce_ref(&data, outer, dim_size, inner, |c, b| c > b)
} else {
arg_reduce_ref(&data, outer, dim_size, inner, |c, b| c < b)
};
IntTensor::<i64>::from_vec(out, out_shape)
}
}
fn inttensor_arg<I: IntElement>(
input: &IntTensor<I>,
dim: Option<isize>,
is_max: bool,
) -> FerrotorchResult<IntTensor<i64>> {
let op = if is_max { "argmax" } else { "argmin" };
if input.numel() == 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("{op}: cannot reduce an empty tensor"),
});
}
let (outer, dim_size, inner, out_shape) = match dim {
None => (1usize, input.numel(), 1usize, Vec::new()),
Some(d) => {
let d = normalize_axis(d, input.ndim())?;
let (o, ds, inn) = factor(input.shape(), d);
(o, ds, inn, shape_without(input.shape(), d))
}
};
if input.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = input.gpu_handle()?;
let out_h = if is_max {
backend.argmax(h, outer, dim_size, inner)?
} else {
backend.argmin(h, outer, dim_size, inner)?
};
Ok(IntTensor::from_gpu_handle(out_h, out_shape))
} else {
let data: Vec<i64> = input.data()?.iter().map(|v| v.to_i64()).collect();
let out = if is_max {
arg_reduce_ref(&data, outer, dim_size, inner, |c, b| c > b)
} else {
arg_reduce_ref(&data, outer, dim_size, inner, |c, b| c < b)
};
IntTensor::<i64>::from_vec(out, out_shape)
}
}
fn index_select_ref<V: Copy>(
data: &[V],
indices: &[i64],
outer: usize,
in_dim: usize,
inner: usize,
zero: V,
) -> Vec<V> {
let out_dim = indices.len();
let mut out = vec![zero; outer * out_dim * inner];
for o in 0..outer {
for (i, &sel) in indices.iter().enumerate() {
let sel = sel as usize;
for k in 0..inner {
let src = o * in_dim * inner + sel * inner + k;
out[(o * out_dim + i) * inner + k] = data[src];
}
}
}
out
}
fn gather_ref<V: Copy>(
data: &[V],
indices: &[i64],
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
zero: V,
) -> Vec<V> {
let mut out = vec![zero; outer * out_dim * inner];
for o in 0..outer {
for i in 0..out_dim {
for k in 0..inner {
let t = (o * out_dim + i) * inner + k;
let sel = indices[t] as usize;
let src = o * in_dim * inner + sel * inner + k;
out[t] = data[src];
}
}
}
out
}
fn index_as_i64<I: IntElement>(index: &IntTensor<I>) -> FerrotorchResult<Vec<i64>> {
Ok(index.data()?.iter().map(|v| v.to_i64()).collect())
}
impl<T: Float> Tensor<T> {
pub fn argmax(&self, dim: Option<isize>) -> FerrotorchResult<IntTensor<i64>> {
tensor_arg(self, dim, true)
}
pub fn argmin(&self, dim: Option<isize>) -> FerrotorchResult<IntTensor<i64>> {
tensor_arg(self, dim, false)
}
pub fn index_select<I: IntElement>(
&self,
dim: isize,
indices: &IntTensor<I>,
) -> FerrotorchResult<Tensor<T>> {
if indices.ndim() > 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index_select: indices must be 1-D, got shape {:?}",
indices.shape()
),
});
}
let input = self.contiguous()?;
let d = normalize_axis(dim, input.ndim())?;
let (outer, in_dim, inner) = factor(input.shape(), d);
let out_dim = indices.numel();
let mut out_shape = input.shape().to_vec();
out_shape[d] = out_dim;
if input.is_cuda() {
check_same_device(input.device(), indices.device(), "index_select")?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend.index_select_intidx(
input.gpu_handle()?,
indices.gpu_handle()?,
outer,
in_dim,
out_dim,
inner,
)?;
Tensor::from_storage(TensorStorage::gpu(h), out_shape, false)
} else {
let data = input.data_vec()?;
let idx = index_as_i64(&indices.to(Device::Cpu)?)?;
let out = index_select_ref(
&data,
&idx,
outer,
in_dim,
inner,
<T as num_traits::Zero>::zero(),
);
Tensor::from_storage(TensorStorage::cpu(out), out_shape, false)
}
}
pub fn gather<I: IntElement>(
&self,
dim: isize,
index: &IntTensor<I>,
) -> FerrotorchResult<Tensor<T>> {
let input = self.contiguous()?;
let d = normalize_axis(dim, input.ndim())?;
gather_check_shapes(input.shape(), index.shape(), d, "gather")?;
let (outer, in_dim, inner) = factor(input.shape(), d);
let out_dim = index.shape()[d];
let out_shape = index.shape().to_vec();
if input.is_cuda() {
check_same_device(input.device(), index.device(), "gather")?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend.gather_intidx(
input.gpu_handle()?,
index.gpu_handle()?,
outer,
in_dim,
out_dim,
inner,
)?;
Tensor::from_storage(TensorStorage::gpu(h), out_shape, false)
} else {
let data = input.data_vec()?;
let idx = index_as_i64(&index.to(Device::Cpu)?)?;
let out = gather_ref(
&data,
&idx,
outer,
in_dim,
out_dim,
inner,
<T as num_traits::Zero>::zero(),
);
Tensor::from_storage(TensorStorage::cpu(out), out_shape, false)
}
}
pub fn to_int<I: IntElement>(&self) -> FerrotorchResult<IntTensor<I>> {
let input = self.contiguous()?;
let shape = input.shape().to_vec();
if input.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend.cast_f_to_i(input.gpu_handle()?, I::dtype())?;
Ok(IntTensor::from_gpu_handle(h, shape))
} else {
let data = input.data_vec()?;
let mut out: Vec<I> = Vec::with_capacity(data.len());
for &v in &data {
let truncated = num_traits::Float::trunc(v);
let as_i64 = float_to_i64_trunc(truncated);
out.push(
I::try_from_i64(as_i64).ok_or(FerrotorchError::InvalidArgument {
message: format!("to_int: value out of range for {}", I::dtype_name()),
})?,
);
}
IntTensor::<I>::from_vec(out, shape)
}
}
}
fn float_to_i64_trunc<T: Float>(v: T) -> i64 {
let f: f64 = num_traits::ToPrimitive::to_f64(&v).unwrap_or(0.0);
f as i64
}
impl<I: IntElement> IntTensor<I> {
pub fn argmax(&self, dim: Option<isize>) -> FerrotorchResult<IntTensor<i64>> {
inttensor_arg(self, dim, true)
}
pub fn argmin(&self, dim: Option<isize>) -> FerrotorchResult<IntTensor<i64>> {
inttensor_arg(self, dim, false)
}
pub fn index_select<J: IntElement>(
&self,
dim: isize,
indices: &IntTensor<J>,
) -> FerrotorchResult<IntTensor<I>> {
if indices.ndim() > 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index_select: indices must be 1-D, got shape {:?}",
indices.shape()
),
});
}
let d = normalize_axis(dim, self.ndim())?;
let (outer, in_dim, inner) = factor(self.shape(), d);
let out_dim = indices.numel();
let mut out_shape = self.shape().to_vec();
out_shape[d] = out_dim;
if self.is_cuda() {
check_same_device(self.device(), indices.device(), "index_select")?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend.index_select_intidx(
self.gpu_handle()?,
indices.gpu_handle()?,
outer,
in_dim,
out_dim,
inner,
)?;
Ok(IntTensor::from_gpu_handle(h, out_shape))
} else {
let data = self.data()?;
let idx = index_as_i64(&indices.to(Device::Cpu)?)?;
let zero = I::try_from_i64(0).expect("0 is in range for i32/i64");
let out = index_select_ref(data, &idx, outer, in_dim, inner, zero);
IntTensor::<I>::from_vec(out, out_shape)
}
}
pub fn gather<J: IntElement>(
&self,
dim: isize,
index: &IntTensor<J>,
) -> FerrotorchResult<IntTensor<I>> {
let d = normalize_axis(dim, self.ndim())?;
gather_check_shapes(self.shape(), index.shape(), d, "gather")?;
let (outer, in_dim, inner) = factor(self.shape(), d);
let out_dim = index.shape()[d];
let out_shape = index.shape().to_vec();
if self.is_cuda() {
check_same_device(self.device(), index.device(), "gather")?;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend.gather_intidx(
self.gpu_handle()?,
index.gpu_handle()?,
outer,
in_dim,
out_dim,
inner,
)?;
Ok(IntTensor::from_gpu_handle(h, out_shape))
} else {
let data = self.data()?;
let idx = index_as_i64(&index.to(Device::Cpu)?)?;
let zero = I::try_from_i64(0).expect("0 is in range for i32/i64");
let out = gather_ref(data, &idx, outer, in_dim, out_dim, inner, zero);
IntTensor::<I>::from_vec(out, out_shape)
}
}
pub fn to_float<T: Float>(&self) -> FerrotorchResult<Tensor<T>> {
let shape = self.shape().to_vec();
if self.is_cuda() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let h = backend.cast_i_to_f(self.gpu_handle()?, T::dtype())?;
Tensor::from_storage(TensorStorage::gpu(h), shape, false)
} else {
let data = self.data()?;
let mut out: Vec<T> = Vec::with_capacity(data.len());
for &v in data {
out.push(num_traits::NumCast::from(v.to_i64()).ok_or(
FerrotorchError::InvalidArgument {
message: "to_float: integer not representable in target float".into(),
},
)?);
}
Tensor::from_storage(TensorStorage::cpu(out), shape, false)
}
}
pub(crate) fn cast_gpu<J: IntElement>(&self) -> Option<FerrotorchResult<IntTensor<J>>> {
if !self.is_cuda() {
return None;
}
let shape = self.shape().to_vec();
let backend = match crate::gpu_dispatch::gpu_backend() {
Some(b) => b,
None => return Some(Err(FerrotorchError::DeviceUnavailable)),
};
let h = match self.gpu_handle() {
Ok(h) => h,
Err(e) => return Some(Err(e)),
};
Some(
backend
.cast_i_to_i(h, J::dtype())
.map(|out_h| IntTensor::from_gpu_handle(out_h, shape)),
)
}
}
fn check_same_device(a: Device, b: Device, op: &str) -> FerrotorchResult<()> {
if a != b {
return Err(FerrotorchError::DeviceMismatch {
expected: a,
got: b,
});
}
let _ = op;
Ok(())
}
fn gather_check_shapes(
input_shape: &[usize],
index_shape: &[usize],
dim: usize,
op: &str,
) -> FerrotorchResult<()> {
if index_shape.len() != input_shape.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"{op}: index ndim {} != input ndim {}",
index_shape.len(),
input_shape.len()
),
});
}
for (ax, (&isz, &xsz)) in index_shape.iter().zip(input_shape.iter()).enumerate() {
if ax != dim && isz > xsz {
return Err(FerrotorchError::ShapeMismatch {
message: format!("{op}: index dim {ax} size {isz} exceeds input size {xsz}"),
});
}
}
Ok(())
}