use std::fmt::Debug;
use rten_base::num::{Identities, IsInt, IsNaN, MinMax};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Storage, Tensor, TensorBase, TensorView};
use crate::buffer_pool::BufferPool;
use crate::operator::OpError;
use crate::ops::{
PadMode, arg_max, div, matmul, mul, norm::NanHandling, pad, reduce_l2, reduce_max, reduce_mean,
reduce_min, reduce_sum, resize_image, softmax, topk,
};
use crate::threading::thread_pool;
#[cfg(feature = "fft")]
use rten_tensor::NdTensor;
#[cfg(feature = "fft")]
use crate::ops::stft;
pub trait Operators {
type Elem;
fn arg_max(&self, axis: isize, keep_dims: bool) -> Result<Tensor<i32>, OpError>
where
Self::Elem: Copy + PartialOrd + IsNaN;
fn div(&self, other: TensorView<Self::Elem>) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy
+ Debug
+ Default
+ std::ops::Mul<Output = Self::Elem>
+ std::ops::Div<Output = Self::Elem>
+ IsInt
+ Identities;
fn mul(&self, other: TensorView<Self::Elem>) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + Debug + Default + std::ops::Mul<Output = Self::Elem>;
fn reduce_max(
&self,
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + PartialOrd + IsNaN + MinMax;
fn reduce_min(
&self,
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + PartialOrd + IsNaN + MinMax;
fn reduce_sum(
&self,
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + Default + std::ops::Add<Self::Elem, Output = Self::Elem>;
fn pad(
&self,
padding: NdTensorView<i32, 1>,
val: Self::Elem,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + Default + PartialEq;
fn topk(
&self,
k: usize,
axis: Option<isize>,
largest: bool,
sorted: bool,
) -> Result<(Tensor<Self::Elem>, Tensor<i32>), OpError>
where
Self::Elem: Copy + Default + PartialOrd + IsNaN;
}
pub trait FloatOperators {
fn matmul(&self, other: TensorView) -> Result<Tensor, OpError>;
fn reduce_l2(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError>;
fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError>;
fn resize_image(&self, size: [usize; 2]) -> Result<Tensor, OpError>;
fn softmax(&self, axis: isize) -> Result<Tensor, OpError>;
#[cfg(feature = "fft")]
fn stft(
&self,
frame_step: i32,
window: Option<NdTensorView<f32, 1>>,
frame_length: Option<i32>,
onesided: bool,
) -> Result<NdTensor<f32, 4>, OpError>;
}
fn run_operator<R: Send, F: Send + FnOnce(&BufferPool) -> R>(op: F) -> R {
let pool = BufferPool::new();
thread_pool().run(|| op(&pool))
}
impl<T: Send, S: Storage<Elem = T> + Sync, L: Layout + Clone + Sync> Operators
for TensorBase<S, L>
{
type Elem = T;
fn arg_max(&self, axis: isize, keep_dims: bool) -> Result<Tensor<i32>, OpError>
where
T: Copy + PartialOrd + IsNaN,
{
run_operator(|pool| arg_max(pool, self.as_dyn(), axis, keep_dims))
}
fn div(&self, other: TensorView<Self::Elem>) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy
+ Debug
+ Default
+ std::ops::Mul<Output = Self::Elem>
+ std::ops::Div<Output = Self::Elem>
+ IsInt
+ Identities,
{
run_operator(|pool| div(pool, self.as_dyn(), other))
}
fn mul(&self, other: TensorView<T>) -> Result<Tensor<T>, OpError>
where
T: Copy + Debug + Default + std::ops::Mul<Output = T>,
{
run_operator(|pool| mul(pool, self.as_dyn(), other))
}
fn reduce_max(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor<T>, OpError>
where
T: Copy + PartialOrd + IsNaN + MinMax,
{
run_operator(|pool| reduce_max(pool, self.as_dyn(), axes, keep_dims))
}
fn reduce_min(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor<T>, OpError>
where
T: Copy + PartialOrd + IsNaN + MinMax,
{
run_operator(|pool| reduce_min(pool, self.as_dyn(), axes, keep_dims))
}
fn reduce_sum(
&self,
axes: Option<&[i32]>,
keep_dims: bool,
) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + Default + std::ops::Add<Self::Elem, Output = Self::Elem>,
{
run_operator(|pool| reduce_sum(pool, self.as_dyn(), axes, keep_dims))
}
fn pad(&self, padding: NdTensorView<i32, 1>, val: T) -> Result<Tensor<Self::Elem>, OpError>
where
Self::Elem: Copy + Default + PartialEq,
{
run_operator(move |pool| pad(pool, self.as_dyn(), &padding, PadMode::Constant, val))
}
fn topk(
&self,
k: usize,
axis: Option<isize>,
largest: bool,
sorted: bool,
) -> Result<(Tensor<Self::Elem>, Tensor<i32>), OpError>
where
T: Copy + Default + PartialOrd + IsNaN,
{
run_operator(|pool| topk(pool, self.as_dyn(), k, axis, largest, sorted))
}
}
impl<S: Storage<Elem = f32> + Sync, L: Layout + Clone + Sync> FloatOperators for TensorBase<S, L> {
fn matmul(&self, other: TensorView) -> Result<Tensor, OpError> {
run_operator(|pool| matmul(pool, self.as_dyn(), other, None))
}
fn reduce_l2(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
run_operator(|pool| reduce_l2(pool, self.as_dyn(), axes, keep_dims))
}
fn reduce_mean(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result<Tensor, OpError> {
run_operator(|pool| reduce_mean(pool, self.as_dyn(), axes, keep_dims))
}
fn resize_image(&self, size: [usize; 2]) -> Result<Tensor, OpError> {
run_operator(|_pool| resize_image(self.as_dyn(), size))
}
fn softmax(&self, axis: isize) -> Result<Tensor, OpError> {
run_operator(|pool| softmax(pool, self.as_dyn(), axis, NanHandling::KeepNans))
}
#[cfg(feature = "fft")]
fn stft(
&self,
frame_step: i32,
window: Option<NdTensorView<f32, 1>>,
frame_length: Option<i32>,
onesided: bool,
) -> Result<NdTensor<f32, 4>, OpError> {
run_operator(|pool| {
stft(
pool,
self.as_dyn(),
frame_step,
window,
frame_length,
onesided,
)
})
}
}