use crate::tensor::Tensor;
use crate::{Float, Graph};
use crate::tensor_ops::{binary_ops, math_ops, shape};
#[inline]
#[allow(dead_code)]
pub fn add<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let g = a.as_ref().graph();
Tensor::builder(g)
.setshape(&infer_bin_opshape(g, shape(a), shape(b)))
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(binary_ops::AddOp)
}
#[inline]
#[allow(dead_code)]
pub fn sub<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let g = a.as_ref().graph();
Tensor::builder(g)
.setshape(&infer_bin_opshape(g, shape(a), shape(b)))
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(binary_ops::SubOp)
}
#[inline]
#[allow(dead_code)]
pub fn mul<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let g = a.as_ref().graph();
Tensor::builder(g)
.setshape(&infer_bin_opshape(g, shape(a), shape(b)))
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(binary_ops::MulOp)
}
#[inline]
#[allow(dead_code)]
pub fn div<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let g = a.as_ref().graph();
Tensor::builder(g)
.setshape(&infer_bin_opshape(g, shape(a), shape(b)))
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(binary_ops::DivOp)
}
#[inline]
#[allow(dead_code)]
pub fn sqrt<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Sqrt)
}
#[allow(dead_code)]
pub fn pow<'graph, A, F: Float>(x: A, a: F) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Pow { a })
}
#[allow(dead_code)]
pub fn square<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
mul(x, x)
}
#[allow(dead_code)]
pub fn abs<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Abs)
}
#[allow(dead_code)]
pub fn neg<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::NegOp)
}
#[allow(dead_code)]
pub fn ln<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Ln)
}
#[allow(dead_code)]
pub fn log2<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Log2)
}
#[allow(dead_code)]
pub fn log10<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Log10)
}
#[allow(dead_code)]
pub fn exp<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Exp)
}
#[allow(dead_code)]
pub fn exp2<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Exp2)
}
#[allow(dead_code)]
pub fn exp10<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Exp10)
}
#[allow(dead_code)]
pub fn sin<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Sin)
}
#[allow(dead_code)]
pub fn cos<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Cos)
}
#[allow(dead_code)]
pub fn tan<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Tan)
}
#[allow(dead_code)]
pub fn asin<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Asin)
}
#[allow(dead_code)]
pub fn acos<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Acos)
}
#[allow(dead_code)]
pub fn atan<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Atan)
}
#[allow(dead_code)]
pub fn sinh<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Sinh)
}
#[allow(dead_code)]
pub fn cosh<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Cosh)
}
#[allow(dead_code)]
pub fn tanh<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Tanh)
}
#[allow(dead_code)]
pub fn asinh<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Asinh)
}
#[allow(dead_code)]
pub fn acosh<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Acosh)
}
#[allow(dead_code)]
pub fn atanh<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Atanh)
}
#[allow(dead_code)]
pub fn lgamma_f32<'graph, A>(x: A) -> Tensor<'graph, f32>
where
A: AsRef<Tensor<'graph, f32>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x, false)
.build(math_ops::Lgamma)
}
#[allow(dead_code)]
pub fn lgamma_f64<'graph, A>(x: A) -> Tensor<'graph, f64>
where
A: AsRef<Tensor<'graph, f64>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x, false)
.build(math_ops::Lgamma)
}
#[allow(dead_code)]
pub fn digamma_f32<'graph, A>(x: A) -> Tensor<'graph, f32>
where
A: AsRef<Tensor<'graph, f32>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x, false)
.build(math_ops::Digamma)
}
#[allow(dead_code)]
pub fn digamma_f64<'graph, A>(x: A) -> Tensor<'graph, f64>
where
A: AsRef<Tensor<'graph, f64>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x, false)
.build(math_ops::Digamma)
}
#[allow(dead_code)]
pub fn maximum<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::Maximum)
}
#[allow(dead_code)]
pub fn minimum<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::Minimum)
}
#[allow(dead_code)]
pub fn equal<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::Equal)
}
#[allow(dead_code)]
pub fn not_equal<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::NotEqual)
}
#[allow(dead_code)]
pub fn greater<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::Greater)
}
#[allow(dead_code)]
pub fn greater_equal<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::GreaterEqual)
}
#[allow(dead_code)]
pub fn lesser<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::Lesser)
}
#[allow(dead_code)]
pub fn lesser_equal<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(math_ops::LesserEqual)
}
#[allow(dead_code)]
pub fn floor<'graph, A, F: Float>(a: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.setshape(&shape(a))
.append_input(a.as_ref(), false)
.build(math_ops::Floor)
}
#[allow(dead_code)]
pub fn ceil<'graph, A, F: Float>(a: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
Tensor::builder(g)
.setshape(&shape(a))
.append_input(a.as_ref(), false)
.build(math_ops::Ceil)
}
#[allow(dead_code)]
pub fn inv_sqrt<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let one = crate::tensor_ops::scalar(F::one(), x.as_ref().graph());
div(one, sqrt(x))
}
#[allow(dead_code)]
pub fn clip<'graph, A, F: Float>(x: A, min_value: F, maxvalue: F) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
let min_tensor = crate::tensor_ops::scalar(min_value, g);
let max_tensor = crate::tensor_ops::scalar(maxvalue, g);
let clipped_upper = minimum(x, max_tensor);
maximum(clipped_upper, min_tensor)
}
#[allow(dead_code)]
pub fn inv<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let one = crate::tensor_ops::scalar(F::one(), x.as_ref().graph());
div(one, x)
}
#[allow(dead_code)]
pub fn sign<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(math_ops::Sign)
}
#[inline]
#[allow(dead_code)]
fn infer_bin_opshape<'graph, A, B, F: Float>(
g: &'graph Graph<F>,
shape_a: A,
shape_b: B,
) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
use crate::tensor_ops::array_ops;
Tensor::builder(g)
.append_input(shape_a.as_ref(), false)
.append_input(shape_b.as_ref(), false)
.build(array_ops::InferBinOpShape)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::convert_to_tensor;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_arithmetic_operations() {
crate::run(|g| {
let a = convert_to_tensor(array![1.0_f32, 2.0, 3.0], g);
let b = convert_to_tensor(array![4.0_f32, 5.0, 6.0], g);
let sum_result = add(a, b);
let expected_sum = array![5.0_f32, 7.0, 9.0];
assert_eq!(
sum_result.eval(g).expect("Operation failed"),
expected_sum.into_dyn()
);
let sub_result = sub(a, b);
let expected_sub = array![-3.0_f32, -3.0, -3.0];
assert_eq!(
sub_result.eval(g).expect("Operation failed"),
expected_sub.into_dyn()
);
let mul_result = mul(a, b);
let expected_mul = array![4.0_f32, 10.0, 18.0];
assert_eq!(
mul_result.eval(g).expect("Operation failed"),
expected_mul.into_dyn()
);
let div_result = div(b, a);
let expected_div = array![4.0_f32, 2.5, 2.0];
assert_eq!(
div_result.eval(g).expect("Operation failed"),
expected_div.into_dyn()
);
});
}
#[test]
fn test_mathematical_functions() {
crate::run(|g| {
let x = convert_to_tensor(array![1.0_f32, 4.0, 9.0], g);
let sqrt_result = sqrt(x);
let expected_sqrt = array![1.0_f32, 2.0, 3.0];
let actual_sqrt = sqrt_result.eval(g).expect("Operation failed");
for (actual, expected) in actual_sqrt.iter().zip(expected_sqrt.iter()) {
assert_relative_eq!(actual, expected, epsilon = 1e-6);
}
let square_result = square(x);
let expected_square = array![1.0_f32, 16.0, 81.0];
assert_eq!(
square_result.eval(g).expect("Operation failed"),
expected_square.into_dyn()
);
});
}
#[test]
fn test_trigonometric_functions() {
crate::run(|g| {
let x = convert_to_tensor(
array![0.0_f32, std::f32::consts::PI / 2.0, std::f32::consts::PI],
g,
);
let sin_result = sin(x);
let actual_sin = sin_result.eval(g).expect("Operation failed");
assert_relative_eq!(actual_sin[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(actual_sin[1], 1.0, epsilon = 1e-6);
assert_relative_eq!(actual_sin[2], 0.0, epsilon = 1e-6);
let cos_result = cos(x);
let actual_cos = cos_result.eval(g).expect("Operation failed");
assert_relative_eq!(actual_cos[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(actual_cos[1], 0.0, epsilon = 1e-6);
assert_relative_eq!(actual_cos[2], -1.0, epsilon = 1e-6);
});
}
#[test]
fn test_comparison_operations() {
crate::run(|g| {
let a = convert_to_tensor(array![1.0_f32, 2.0, 3.0], g);
let b = convert_to_tensor(array![3.0_f32, 2.0, 1.0], g);
let equal_result = equal(a, b);
let expected_equal = array![0.0_f32, 1.0, 0.0];
assert_eq!(
equal_result.eval(g).expect("Operation failed"),
expected_equal.into_dyn()
);
let greater_result = greater(a, b);
let expected_greater = array![0.0_f32, 0.0, 1.0];
assert_eq!(
greater_result.eval(g).expect("Operation failed"),
expected_greater.into_dyn()
);
let lesser_result = lesser(a, b);
let expected_lesser = array![1.0_f32, 0.0, 0.0];
assert_eq!(
lesser_result.eval(g).expect("Operation failed"),
expected_lesser.into_dyn()
);
});
}
#[test]
fn test_max_min_operations() {
crate::run(|g| {
let a = convert_to_tensor(array![1.0_f32, 2.0, 3.0], g);
let b = convert_to_tensor(array![3.0_f32, 2.0, 1.0], g);
let max_result = maximum(a, b);
let expected_max = array![3.0_f32, 2.0, 3.0];
assert_eq!(
max_result.eval(g).expect("Operation failed"),
expected_max.into_dyn()
);
let min_result = minimum(a, b);
let expected_min = array![1.0_f32, 2.0, 1.0];
assert_eq!(
min_result.eval(g).expect("Operation failed"),
expected_min.into_dyn()
);
});
}
}