use crate::dtype::DType;
use crate::error::Result;
use crate::ops::traits::{
ActivationOps, BinaryOps, CompareOps, ConvOps, CumulativeOps, IndexingOps, MatmulOps,
NormalizationOps, PaddingMode, ReduceOps, ScalarOps, ShapeOps, TypeConversionOps, UnaryOps,
UtilityOps,
};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
impl<R: Runtime> Tensor<R>
where
R::Client: BinaryOps<R>,
{
pub fn add(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.add(self, other)
}
pub fn sub(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.sub(self, other)
}
pub fn mul(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.mul(self, other)
}
pub fn div(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.div(self, other)
}
pub fn pow(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.pow(self, other)
}
pub fn maximum(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.maximum(self, other)
}
pub fn minimum(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.minimum(self, other)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: UnaryOps<R>,
{
pub fn neg(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.neg(self)
}
pub fn abs(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.abs(self)
}
pub fn sqrt(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.sqrt(self)
}
pub fn exp(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.exp(self)
}
pub fn log(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.log(self)
}
pub fn sin(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.sin(self)
}
pub fn cos(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.cos(self)
}
pub fn tan(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.tan(self)
}
pub fn tanh(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.tanh(self)
}
pub fn recip(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.recip(self)
}
pub fn floor(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.floor(self)
}
pub fn ceil(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.ceil(self)
}
pub fn round(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.round(self)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: ScalarOps<R>,
{
pub fn add_scalar(&self, scalar: f64) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.add_scalar(self, scalar)
}
pub fn mul_scalar(&self, scalar: f64) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.mul_scalar(self, scalar)
}
pub fn scale(&self, scalar: f64) -> Result<Tensor<R>> {
self.mul_scalar(scalar)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: ActivationOps<R>,
{
pub fn relu(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.relu(self)
}
pub fn sigmoid(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.sigmoid(self)
}
pub fn gelu(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.gelu(self)
}
pub fn silu(&self) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.silu(self)
}
pub fn softmax(&self, dim: isize) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.softmax(self, dim)
}
pub fn log_softmax(&self, dim: isize) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.log_softmax(self, dim)
}
pub fn dropout(&self, p: f64, training: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.dropout(self, p, training)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: ReduceOps<R>,
{
pub fn sum(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.sum(self, dims, keepdim)
}
pub fn mean(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.mean(self, dims, keepdim)
}
pub fn max(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.max(self, dims, keepdim)
}
pub fn min(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.min(self, dims, keepdim)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: MatmulOps<R>,
{
pub fn matmul(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.matmul(self, other)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: NormalizationOps<R>,
{
pub fn rms_norm(&self, weight: &Tensor<R>, eps: f32) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.rms_norm(self, weight, eps)
}
pub fn layer_norm(&self, weight: &Tensor<R>, bias: &Tensor<R>, eps: f32) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.layer_norm(self, weight, bias, eps)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: CompareOps<R>,
{
pub fn eq(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.eq(self, other)
}
pub fn gt(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.gt(self, other)
}
pub fn lt(&self, other: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.lt(self, other)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: IndexingOps<R>,
{
pub fn index_select(&self, dim: usize, indices: &Tensor<R>) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.index_select(self, dim, indices)
}
pub fn argmax(&self, dim: usize, keepdim: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.argmax(self, dim, keepdim)
}
pub fn argmin(&self, dim: usize, keepdim: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.argmin(self, dim, keepdim)
}
pub fn masked_fill(&self, mask: &Tensor<R>, value: f64) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.masked_fill(self, mask, value)
}
pub fn slice_assign(&self, src: &Tensor<R>, dim: usize, start: usize) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.slice_assign(self, src, dim, start)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: ShapeOps<R>,
{
pub fn cat(tensors: &[&Tensor<R>], dim: isize) -> Result<Tensor<R>> {
if tensors.is_empty() {
return Err(crate::error::Error::InvalidArgument {
arg: "tensors",
reason: "cannot concatenate empty list".into(),
});
}
let client = R::default_client(tensors[0].device());
client.cat(tensors, dim)
}
pub fn stack(tensors: &[&Tensor<R>], dim: isize) -> Result<Tensor<R>> {
if tensors.is_empty() {
return Err(crate::error::Error::InvalidArgument {
arg: "tensors",
reason: "cannot stack empty list".into(),
});
}
let client = R::default_client(tensors[0].device());
client.stack(tensors, dim)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: CumulativeOps<R>,
{
pub fn cumsum(&self, dim: isize) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.cumsum(self, dim)
}
pub fn cumprod(&self, dim: isize) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.cumprod(self, dim)
}
pub fn logsumexp(&self, dims: &[usize], keepdim: bool) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.logsumexp(self, dims, keepdim)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: TypeConversionOps<R>,
{
pub fn to_dtype(&self, dtype: DType) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.cast(self, dtype)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: UtilityOps<R>,
{
pub fn clamp(&self, min: f64, max: f64) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.clamp(self, min, max)
}
pub fn one_hot(&self, num_classes: usize) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.one_hot(self, num_classes)
}
}
impl<R: Runtime> Tensor<R>
where
R::Client: ConvOps<R>,
{
pub fn conv1d(
&self,
weight: &Tensor<R>,
bias: Option<&Tensor<R>>,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
) -> Result<Tensor<R>> {
let client = R::default_client(self.device());
client.conv1d(self, weight, bias, stride, padding, dilation, groups)
}
}