use super::core::{get_binary_op_registry, BinaryOpAnalytics};
use super::implementation::binary_op;
use super::operations::{AddOp, DivOp, MaxOp, MinOp, MulOp, PowOp, SubOp};
use crate::{Result, Tensor};
use scirs2_core::numeric::Zero;
use std::ops::{Add as StdAdd, Div as StdDiv, Mul as StdMul, Sub as StdSub};
#[inline]
pub fn add<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ StdAdd<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
binary_op(a, b, AddOp)
}
#[inline]
pub fn sub<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ StdSub<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
binary_op(a, b, SubOp)
}
#[inline]
pub fn mul<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ StdMul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
binary_op(a, b, MulOp)
}
#[inline]
pub fn div<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ StdDiv<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
binary_op(a, b, DivOp)
}
#[inline]
pub fn pow<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ scirs2_core::num_traits::Float
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
binary_op(a, b, PowOp)
}
#[inline]
pub fn min<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ PartialOrd
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
binary_op(a, b, MinOp)
}
#[inline]
pub fn max<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ PartialOrd
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
binary_op(a, b, MaxOp)
}
pub fn scalar_add<T>(tensor: &Tensor<T>, scalar: T) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ StdAdd<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let scalar_tensor = Tensor::from_scalar(scalar);
add(tensor, &scalar_tensor)
}
pub fn clamp<T>(tensor: &Tensor<T>, min_val: T, max_val: T) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ PartialOrd
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
match &tensor.storage {
crate::tensor::TensorStorage::Cpu(arr) => {
let result = arr.mapv(|v| {
if v < min_val {
min_val
} else if v > max_val {
max_val
} else {
v
}
});
Ok(Tensor::from_array(result))
}
#[cfg(feature = "gpu")]
crate::tensor::TensorStorage::Gpu(_gpu_buffer) => {
let cpu_tensor = tensor.to_cpu()?;
clamp(&cpu_tensor, min_val, max_val)
}
}
}
pub fn ultra_add<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ Send
+ Sync
+ StdAdd<Output = T>
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let registry = get_binary_op_registry();
let start = std::time::Instant::now();
let result = binary_op(a, b, AddOp);
let duration = start.elapsed();
registry.record_operation("ultra_add", a.shape().size(), duration.as_nanos() as u64);
result
}
pub fn ultra_mul<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ Send
+ Sync
+ StdMul<Output = T>
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let registry = get_binary_op_registry();
let start = std::time::Instant::now();
let result = binary_op(a, b, MulOp);
let duration = start.elapsed();
registry.record_operation("ultra_mul", a.shape().size(), duration.as_nanos() as u64);
result
}
pub fn ultra_sub<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ Send
+ Sync
+ StdSub<Output = T>
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let registry = get_binary_op_registry();
let start = std::time::Instant::now();
let result = binary_op(a, b, SubOp);
let duration = start.elapsed();
registry.record_operation("ultra_sub", a.shape().size(), duration.as_nanos() as u64);
result
}
pub fn ultra_div<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ Send
+ Sync
+ StdDiv<Output = T>
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let registry = get_binary_op_registry();
let start = std::time::Instant::now();
let result = binary_op(a, b, DivOp);
let duration = start.elapsed();
registry.record_operation("ultra_div", a.shape().size(), duration.as_nanos() as u64);
result
}
pub fn get_binary_op_performance_report() -> BinaryOpAnalytics {
get_binary_op_registry().get_analytics()
}
pub fn reset_binary_op_counters() {
}