use crate::tensor::{AsTensor, Tensor};
use crate::Float;
use crate::tensor_ops::{array_ops, math_ops, reduction_ops, shape};
#[allow(dead_code)]
pub fn reduce_sum<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceSum {
keep_dims,
sparse_axes: false,
})
}
#[allow(dead_code)]
pub fn reduce_mean<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceMean {
keep_dims,
sparse_axes: false,
})
}
#[allow(dead_code)]
pub fn reduce_prod<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceProd {
keep_dims,
sparse_axes: false,
})
}
#[allow(dead_code)]
pub fn reduce_min<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceMin {
keep_dims,
sparse_axes: false,
})
}
#[allow(dead_code)]
pub fn reduce_max<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceMax {
keep_dims,
sparse_axes: false,
})
}
#[allow(dead_code)]
pub fn reduce_variance<'graph, A, AT, F: Float>(
x: A,
axes: &AT,
keep_dims: bool,
) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceVariance {
keep_dims,
sparse_axes: false,
})
}
#[allow(dead_code)]
pub fn reduce_std<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let variance = reduce_variance(x, axes, keep_dims);
crate::tensor_ops::arithmetic::sqrt(variance)
}
#[allow(dead_code)]
pub fn sum_all<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.build(reduction_ops::ReduceSumAll)
}
#[allow(dead_code)]
pub fn mean_all<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.build(reduction_ops::ReduceMeanAll)
}
#[allow(dead_code)]
pub fn argmax<'graph, A, F: Float>(x: A, axis: isize, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.build(reduction_ops::ArgMax {
axis,
keep_dim: keep_dims,
})
}
#[allow(dead_code)]
pub fn argmin<'graph, A, F: Float>(x: A, axis: isize, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.build(reduction_ops::ArgMin {
axis,
keep_dim: keep_dims,
})
}
#[allow(dead_code)]
pub fn reduce_logsumexp<'graph, A, F: Float>(x: A, axis: isize, keep_dim: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
let op = math_ops::LogSumExp {
axis,
keep_dims: keep_dim,
};
Tensor::builder(g).append_input(x.as_ref(), false).build(op)
}
#[allow(dead_code)]
pub fn add_n<'graph, A, F: Float>(xs: &[A]) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let len = xs.len();
assert_ne!(len, 0);
if len == 1 {
*xs[0].as_ref()
} else {
let g = xs[0].as_ref().graph();
let mut b = Tensor::builder(g);
for x in xs {
b = b.append_input(x.as_ref(), false);
}
b.setshape(&shape(xs[0])).build(array_ops::AddN)
}
}
#[allow(dead_code)]
pub fn l1_norm<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let abs_x = crate::tensor_ops::arithmetic::abs(x);
reduce_sum(abs_x, axes, keep_dims)
}
#[allow(dead_code)]
pub fn l2_norm<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let square_x = crate::tensor_ops::arithmetic::square(x);
let sum_square = reduce_sum(square_x, axes, keep_dims);
crate::tensor_ops::arithmetic::sqrt(sum_square)
}
#[allow(dead_code)]
pub fn lp_norm<'graph, A, AT, F: Float>(x: A, p: F, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let abs_x = crate::tensor_ops::arithmetic::abs(x);
let pow_x = crate::tensor_ops::arithmetic::pow(abs_x, p);
let sum_pow = reduce_sum(pow_x, axes, keep_dims);
let one_over_p = F::one() / p;
crate::tensor_ops::arithmetic::pow(sum_pow, one_over_p)
}
#[allow(dead_code)]
pub fn frobenius_norm<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let square_x = crate::tensor_ops::arithmetic::square(x);
let sum_square = sum_all(square_x);
crate::tensor_ops::arithmetic::sqrt(sum_square)
}
#[allow(dead_code)]
pub fn reduce_all<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceAll { keep_dims })
}
#[allow(dead_code)]
pub fn reduce_any<'graph, A, AT, F: Float>(x: A, axes: &AT, keep_dims: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(reduction_ops::ReduceAny { keep_dims })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::convert_to_tensor;
#[allow(unused_imports)]
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_basic_reductions() {
crate::run(|g| {
let x = convert_to_tensor(array![[1.0_f32, 2.0], [3.0, 4.0]], g);
let sum_result = reduce_sum(x, &[0], false);
let expected_sum = array![4.0_f32, 6.0];
assert_eq!(
sum_result.eval(g).expect("Operation failed"),
expected_sum.into_dyn()
);
let mean_result = reduce_mean(x, &[0], false);
let expected_mean = array![2.0_f32, 3.0];
assert_eq!(
mean_result.eval(g).expect("Operation failed"),
expected_mean.into_dyn()
);
let prod_result = reduce_prod(x, &[0], false);
let expected_prod = array![3.0_f32, 8.0];
assert_eq!(
prod_result.eval(g).expect("Operation failed"),
expected_prod.into_dyn()
);
});
}
#[test]
fn test_min_max_reductions() {
crate::run(|g| {
let x = convert_to_tensor(array![[1.0_f32, 4.0], [2.0, 3.0]], g);
let min_result = reduce_min(x, &[0], false);
let expected_min = array![1.0_f32, 3.0];
assert_eq!(
min_result.eval(g).expect("Operation failed"),
expected_min.into_dyn()
);
let max_result = reduce_max(x, &[0], false);
let expected_max = array![2.0_f32, 4.0];
assert_eq!(
max_result.eval(g).expect("Operation failed"),
expected_max.into_dyn()
);
});
}
#[test]
fn test_statistical_reductions() {
crate::run(|g| {
let x = convert_to_tensor(array![[1.0_f32, 2.0], [3.0, 4.0]], g);
let var_result = reduce_variance(x, &[0], false);
let expected_var = array![1.0_f32, 1.0]; assert_eq!(
var_result.eval(g).expect("Operation failed"),
expected_var.into_dyn()
);
let std_result = reduce_std(x, &[0], false);
let expected_std = array![1.0_f32, 1.0];
assert_eq!(
std_result.eval(g).expect("Operation failed"),
expected_std.into_dyn()
);
});
}
#[test]
fn test_global_reductions() {
crate::run(|g| {
let x = convert_to_tensor(array![[1.0_f32, 2.0], [3.0, 4.0]], g);
let sum_all_result = sum_all(x);
assert_eq!(
sum_all_result.eval(g).expect("Operation failed"),
scirs2_core::ndarray::arr0(10.0).into_dyn()
);
let mean_all_result = mean_all(x);
assert_eq!(
mean_all_result.eval(g).expect("Operation failed"),
scirs2_core::ndarray::arr0(2.5).into_dyn()
);
});
}
#[test]
fn test_norm_operations() {
crate::run(|g| {
let x = convert_to_tensor(array![[3.0_f32, 4.0], [0.0, 0.0]], g);
let frob_norm = frobenius_norm(x);
assert_eq!(
frob_norm.eval(g).expect("Operation failed"),
scirs2_core::ndarray::arr0(5.0).into_dyn()
);
let l1_result = l1_norm(x, &[0], false);
let expected_l1 = array![3.0_f32, 4.0];
assert_eq!(
l1_result.eval(g).expect("Operation failed"),
expected_l1.into_dyn()
);
let l2_result = l2_norm(x, &[0], false);
let expected_l2 = array![3.0_f32, 4.0];
assert_eq!(
l2_result.eval(g).expect("Operation failed"),
expected_l2.into_dyn()
);
});
}
#[test]
fn test_add_n() {
crate::run(|g| {
let a = convert_to_tensor(array![1.0_f32, 2.0], g);
let b = convert_to_tensor(array![3.0_f32, 4.0], g);
let c = convert_to_tensor(array![5.0_f32, 6.0], g);
let result = add_n(&[a, b, c]);
let expected = array![9.0_f32, 12.0];
assert_eq!(
result.eval(g).expect("Operation failed"),
expected.into_dyn()
);
});
}
#[test]
fn test_keep_dims() {
crate::run(|g| {
let x = convert_to_tensor(array![[1.0_f32, 2.0], [3.0, 4.0]], g);
let sum_result = reduce_sum(x, &[0], true);
assert_eq!(
sum_result.eval(g).expect("Operation failed").shape(),
&[1, 2]
);
let sum_result = reduce_sum(x, &[0], false);
assert_eq!(sum_result.eval(g).expect("Operation failed").shape(), &[2]);
});
}
}