use crate::{
DType, RT, Tensor, ZyxError,
kernel::BOp,
shape::{Dim, UAxis, into_axes},
tensor::Axis,
};
use paste::paste;
#[derive(Clone, Copy)]
pub enum ReduceOp {
Sum,
Mean,
Var,
Std,
Max,
Min,
Prod,
}
impl Tensor {
fn inverse(&self) -> Tensor {
let dtype = self.dtype();
if dtype.is_float() {
-self
} else if dtype.is_int() {
self.bitnot()
} else {
!self
}
}
pub(crate) fn reduce_impl<const KEEPDIM: bool>(
&self,
op: ReduceOp,
axes: impl IntoIterator<Item = Axis>,
dtype: Option<DType>,
correction: Dim,
) -> Result<Tensor, ZyxError> {
fn reduce_acc_dtype(dtype: DType) -> DType {
if dtype.is_uint() {
return dtype.least_upper_dtype(DType::U32);
}
if dtype.is_int() || dtype == DType::Bool {
return dtype.least_upper_dtype(DType::I32);
}
dtype.least_upper_dtype(DType::F32)
}
let mut shape = self.shape();
let rank = shape.len();
let x_dtype = self.dtype();
let axes: Vec<_> = axes.into_iter().collect();
let axes_vec: Vec<UAxis> = into_axes(axes.clone(), rank)?;
let mut tensor = match op {
ReduceOp::Sum => {
let x = if let Some(dtype) = dtype {
self.cast(dtype)
} else {
self.cast(reduce_acc_dtype(x_dtype))
};
Tensor { id: RT.lock().reduce(x.id, axes_vec.clone(), BOp::Add) }
}
ReduceOp::Max => {
let x = if let Some(dtype) = dtype {
self.cast(dtype)
} else {
self.cast(reduce_acc_dtype(x_dtype))
};
Tensor { id: RT.lock().reduce(x.id, axes_vec.clone(), BOp::Max) }
}
ReduceOp::Prod => {
let x = if let Some(dtype) = dtype {
self.cast(dtype)
} else {
self.cast(reduce_acc_dtype(x_dtype))
};
Tensor { id: RT.lock().reduce(x.id, axes_vec.clone(), BOp::Mul) }
}
ReduceOp::Min => {
if let Some(dtype) = dtype {
self.inverse().max_dtype(axes, dtype)?.inverse()
} else {
self.inverse().max(axes)?.inverse()
}
}
ReduceOp::Mean => {
let n: i64 = axes_vec.iter().map(|&a| shape[a]).product::<Dim>().try_into().unwrap();
let x = if let Some(dtype) = dtype {
self.sum_dtype(axes, dtype)?
} else {
self.sum(axes)?
};
x / Tensor::from(n).cast(x_dtype)
}
ReduceOp::Var => {
if let Some(dtype) = dtype {
let x = self - self.mean_keepdim_dtype(axes.clone(), dtype)?;
let shape_dims: Vec<u64> = axes_vec.iter().map(|&a| shape[a as usize]).collect();
let d = Axis::try_from(shape_dims.iter().product::<u64>()).unwrap() - Axis::try_from(correction).unwrap();
(x.clone() * x).sum_dtype(axes, dtype)? / Tensor::from(d).cast(x_dtype)
} else {
let x = self - self.mean_keepdim(axes.clone())?;
let shape_dims: Vec<u64> = axes_vec.iter().map(|&a| shape[a as usize]).collect();
let d = Axis::try_from(shape_dims.iter().product::<u64>()).unwrap() - Axis::try_from(correction).unwrap();
(x.clone() * x).sum(axes)? / Tensor::from(d).cast(x_dtype)
}
}
ReduceOp::Std => {
if let Some(dtype) = dtype {
self.var_dtype(axes, dtype)?.sqrt()
} else {
self.var(axes)?.sqrt()
}
}
};
if dtype.is_none() && x_dtype != tensor.dtype() {
tensor = tensor.cast(x_dtype);
}
if KEEPDIM {
for a in axes_vec {
shape[a] = 1;
}
tensor = tensor.reshape(shape)?;
}
Ok(tensor)
}
}
macro_rules! define_reduce_op {
($name:ident, $op_variant:expr) => {
paste! {
impl Tensor {
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction over all elements.\n\n",
"# Examples\n",
"```\n",
"use zyx::Tensor;\n",
"let t = Tensor::from([1.0, 2.0, 3.0]);\n",
"let result = t.", stringify!($name), "_all();\n",
"```\n",
)]
#[must_use]
pub fn [<$name _all>](&self) -> Tensor {
self.reduce_impl::<false>($op_variant, [], None, 1).unwrap()
}
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction over all elements, keeping reduced dimensions.\n\n",
"Reduced axes are retained with length 1.\n\n",
"# Examples\n",
"```\n",
"use zyx::Tensor;\n",
"let t = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);\n",
"let result = t.", stringify!($name), "_all_keepdim();\n",
"```\n",
)]
#[must_use]
pub fn [<$name _all_keepdim>](&self) -> Tensor {
self.reduce_impl::<true>($op_variant, [], None, 1).unwrap()
}
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction along the specified `axes`.\n\n",
"# Arguments\n",
"* `axes` — Iterable of axes to reduce over.\n\n",
"# Examples\n",
"```\n",
"use zyx::Tensor;\n",
"let t = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);\n",
"let result = t.", stringify!($name), "([0]).unwrap();\n",
"```\n",
"\n",
"# Errors\n",
"When axes are out of range\n"
)]
pub fn $name(&self, axes: impl IntoIterator<Item = Axis>) -> Result<Tensor, ZyxError> {
self.reduce_impl::<false>($op_variant, axes, None, 1)
}
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction along the specified `axes`, keeping reduced dimensions.\n\n",
"# Arguments\n",
"* `axes` — Iterable of axes to reduce over.\n",
"* Keeps reduced dimensions with length 1.\n\n",
"# Examples\n",
"```\n",
"use zyx::Tensor;\n",
"let t = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);\n",
"let result = t.", stringify!($name), "_keepdim([1]).unwrap();\n",
"```\n",
"\n",
"# Errors\n",
"When axes are out of range\n"
)]
pub fn [<$name _keepdim>](&self, axes: impl IntoIterator<Item = Axis>) -> Result<Tensor, ZyxError> {
self.reduce_impl::<true>($op_variant, axes, None, 1)
}
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction over all elements and casts the result to `dtype`.\n\n",
"# Arguments\n",
"* `dtype` — Desired output data type.\n\n",
"# Examples\n",
"```\n",
"use zyx::{Tensor, DType};\n",
"let t = Tensor::from([1.0, 2.0, 3.0]);\n",
"let result = t.", stringify!($name), "_all_dtype(DType::F64);\n",
"```\n",
)]
#[must_use]
pub fn [<$name _all_dtype>](&self, dtype: DType) -> Tensor {
self.reduce_impl::<false>($op_variant, [], Some(dtype), 1).unwrap()
}
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction over all elements, keeping reduced dimensions,\n",
"and casts the result to `dtype`.\n\n",
"# Arguments\n",
"* `dtype` — Desired output data type.\n",
)]
#[must_use]
pub fn [<$name _all_keepdim_dtype>](&self, dtype: DType) -> Tensor {
self.reduce_impl::<true>($op_variant, [], Some(dtype), 1).unwrap()
}
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction along specified `axes`, casting the result to `dtype`.\n\n",
"# Arguments\n",
"* `axes` — Iterable of axes to reduce over.\n",
"* `dtype` — Desired output data type.\n",
"\n",
"# Errors\n",
"When axes are out of range\n"
)]
pub fn [<$name _dtype>](
&self,
axes: impl IntoIterator<Item = Axis>,
dtype: DType
) -> Result<Tensor, ZyxError> {
self.reduce_impl::<false>($op_variant, axes, Some(dtype), 1)
}
#[doc = concat!(
"Computes the `", stringify!($name), "` reduction along specified `axes`, keeping reduced dimensions,\n",
"and casts the result to `dtype`.\n\n",
"# Arguments\n",
"* `axes` — Iterable of axes to reduce over.\n",
"* `dtype` — Desired output data type.\n",
"\n",
"# Errors\n",
"When axes are out of range\n"
)]
pub fn [<$name _keepdim_dtype>](
&self,
axes: impl IntoIterator<Item = Axis>,
dtype: DType
) -> Result<Tensor, ZyxError> {
self.reduce_impl::<true>($op_variant, axes, Some(dtype), 1)
}
}
}
};
}
define_reduce_op!(sum, ReduceOp::Sum);
define_reduce_op!(mean, ReduceOp::Mean);
define_reduce_op!(max, ReduceOp::Max);
define_reduce_op!(min, ReduceOp::Min);
define_reduce_op!(prod, ReduceOp::Prod);
macro_rules! define_reduce_op_with_correction {
($name:ident, $op_variant:expr) => {
paste! {
impl Tensor {
#[must_use]
pub fn [<$name _all>](&self) -> Tensor {
self.reduce_impl::<false>($op_variant, [], None, 1).unwrap()
}
#[must_use]
pub fn [<$name _all_keepdim>](&self) -> Tensor {
self.reduce_impl::<true>($op_variant, [], None, 1).unwrap()
}
#[must_use]
pub fn [<$name _all_dtype>](&self, dtype: DType) -> Tensor {
self.reduce_impl::<false>($op_variant, [], Some(dtype), 1).unwrap()
}
pub fn [<$name >](&self, axes: impl IntoIterator<Item = Axis>) -> Result<Tensor, ZyxError> {
self.reduce_impl::<false>($op_variant, axes, None, 1)
}
pub fn [<$name _keepdim>](&self, axes: impl IntoIterator<Item = Axis>) -> Result<Tensor, ZyxError> {
self.reduce_impl::<true>($op_variant, axes, None, 1)
}
pub fn [<$name _dtype>](&self, axes: impl IntoIterator<Item = Axis>, dtype: DType) -> Result<Tensor, ZyxError> {
self.reduce_impl::<false>($op_variant, axes, Some(dtype), 1)
}
pub fn [<$name _correction>](&self, axes: impl IntoIterator<Item = Axis>, correction: Dim) -> Result<Tensor, ZyxError> {
self.reduce_impl::<false>($op_variant, axes, None, correction)
}
pub fn [<$name _all_correction>](&self, correction: Dim) -> Result<Tensor, ZyxError> {
self.reduce_impl::<false>($op_variant, [], None, correction)
}
#[must_use]
pub fn [<$name _keepdim_dtype>](&self, dtype: DType) -> Tensor {
self.reduce_impl::<true>($op_variant, [], Some(dtype), 1).unwrap()
}
#[must_use]
pub fn [<$name _all_keepdim_correction>](&self, correction: Dim) -> Tensor {
self.reduce_impl::<true>($op_variant, [], None, correction).unwrap()
}
#[must_use]
pub fn [<$name _all_dtype_correction>](&self, dtype: DType, correction: Dim) -> Tensor {
self.reduce_impl::<false>($op_variant, [], Some(dtype), correction).unwrap()
}
pub fn [<$name _axes_keepdim_dtype>](&self, axes: impl IntoIterator<Item = Axis>, dtype: DType) -> Result<Tensor, ZyxError> {
self.reduce_impl::<true>($op_variant, axes, Some(dtype), 1)
}
pub fn [<$name _keepdim_correction>](&self, axes: impl IntoIterator<Item = Axis>, correction: Dim) -> Result<Tensor, ZyxError> {
self.reduce_impl::<true>($op_variant, axes, None, correction)
}
pub fn [<$name _dtype_correction>](&self, axes: impl IntoIterator<Item = Axis>, dtype: DType, correction: Dim) -> Result<Tensor, ZyxError> {
self.reduce_impl::<false>($op_variant, axes, Some(dtype), correction)
}
#[must_use]
pub fn [<$name _all_keepdim_dtype_correction>](&self, dtype: DType, correction: Dim) -> Tensor {
self.reduce_impl::<true>($op_variant, [], Some(dtype), correction).unwrap()
}
pub fn [<$name _keepdim_dtype_correction>](&self, axes: impl IntoIterator<Item = Axis>, dtype: DType, correction: Dim) -> Result<Tensor, ZyxError> {
self.reduce_impl::<true>($op_variant, axes, Some(dtype), correction)
}
}
}
};
}
define_reduce_op_with_correction!(var, ReduceOp::Var);
define_reduce_op_with_correction!(std, ReduceOp::Std);