#![allow(dead_code)]
use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Sub};
use torsh_core::{dtype::TensorElement, error::Result};
use crate::Tensor;
pub trait Expression<T: TensorElement> {
fn eval_at(&self, index: usize) -> T;
fn size(&self) -> usize;
fn eval_vec(&self) -> Vec<T> {
(0..self.size()).map(|i| self.eval_at(i)).collect()
}
fn eval_tensor(
&self,
shape: Vec<usize>,
device: torsh_core::device::DeviceType,
) -> Result<Tensor<T>>
where
T: Copy,
{
let data = self.eval_vec();
Tensor::from_data(data, shape, device)
}
}
pub struct TensorExpr<'a, T: TensorElement> {
data: Vec<T>,
size: usize,
_phantom: PhantomData<&'a T>,
}
impl<'a, T: TensorElement + Copy> TensorExpr<'a, T> {
pub fn new(tensor: &'a Tensor<T>) -> Result<Self> {
let data = tensor.to_vec()?;
let size = data.len();
Ok(Self {
data,
size,
_phantom: PhantomData,
})
}
}
impl<'a, T: TensorElement> Expression<T> for TensorExpr<'a, T> {
fn eval_at(&self, index: usize) -> T {
self.data[index]
}
fn size(&self) -> usize {
self.size
}
fn eval_vec(&self) -> Vec<T> {
self.data.clone()
}
}
pub struct AddScalarExpr<T: TensorElement, E: Expression<T>> {
expr: E,
scalar: T,
}
impl<T: TensorElement + Add<Output = T>, E: Expression<T>> Expression<T> for AddScalarExpr<T, E> {
fn eval_at(&self, index: usize) -> T {
self.expr.eval_at(index) + self.scalar
}
fn size(&self) -> usize {
self.expr.size()
}
}
pub struct MulScalarExpr<T: TensorElement, E: Expression<T>> {
expr: E,
scalar: T,
}
impl<T: TensorElement + Mul<Output = T>, E: Expression<T>> Expression<T> for MulScalarExpr<T, E> {
fn eval_at(&self, index: usize) -> T {
self.expr.eval_at(index) * self.scalar
}
fn size(&self) -> usize {
self.expr.size()
}
}
pub struct SubScalarExpr<T: TensorElement, E: Expression<T>> {
expr: E,
scalar: T,
}
impl<T: TensorElement + Sub<Output = T>, E: Expression<T>> Expression<T> for SubScalarExpr<T, E> {
fn eval_at(&self, index: usize) -> T {
self.expr.eval_at(index) - self.scalar
}
fn size(&self) -> usize {
self.expr.size()
}
}
pub struct DivScalarExpr<T: TensorElement, E: Expression<T>> {
expr: E,
scalar: T,
}
impl<T: TensorElement + Div<Output = T>, E: Expression<T>> Expression<T> for DivScalarExpr<T, E> {
fn eval_at(&self, index: usize) -> T {
self.expr.eval_at(index) / self.scalar
}
fn size(&self) -> usize {
self.expr.size()
}
}
pub struct AddExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
left: E1,
right: E2,
_phantom: PhantomData<T>,
}
impl<T: TensorElement + Add<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
for AddExpr<T, E1, E2>
{
fn eval_at(&self, index: usize) -> T {
self.left.eval_at(index) + self.right.eval_at(index)
}
fn size(&self) -> usize {
self.left.size().min(self.right.size())
}
}
pub struct MulExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
left: E1,
right: E2,
_phantom: PhantomData<T>,
}
impl<T: TensorElement + Mul<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
for MulExpr<T, E1, E2>
{
fn eval_at(&self, index: usize) -> T {
self.left.eval_at(index) * self.right.eval_at(index)
}
fn size(&self) -> usize {
self.left.size().min(self.right.size())
}
}
pub struct SubExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
left: E1,
right: E2,
_phantom: PhantomData<T>,
}
impl<T: TensorElement + Sub<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
for SubExpr<T, E1, E2>
{
fn eval_at(&self, index: usize) -> T {
self.left.eval_at(index) - self.right.eval_at(index)
}
fn size(&self) -> usize {
self.left.size().min(self.right.size())
}
}
pub struct DivExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
left: E1,
right: E2,
_phantom: PhantomData<T>,
}
impl<T: TensorElement + Div<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
for DivExpr<T, E1, E2>
{
fn eval_at(&self, index: usize) -> T {
self.left.eval_at(index) / self.right.eval_at(index)
}
fn size(&self) -> usize {
self.left.size().min(self.right.size())
}
}
pub struct NegExpr<T: TensorElement, E: Expression<T>> {
expr: E,
_phantom: PhantomData<T>,
}
impl<T: TensorElement + std::ops::Neg<Output = T>, E: Expression<T>> Expression<T>
for NegExpr<T, E>
{
fn eval_at(&self, index: usize) -> T {
-self.expr.eval_at(index)
}
fn size(&self) -> usize {
self.expr.size()
}
}
pub struct ExprBuilder<T: TensorElement, E: Expression<T>> {
expr: E,
_phantom: PhantomData<T>,
}
impl<T: TensorElement, E: Expression<T>> ExprBuilder<T, E> {
pub fn new(expr: E) -> Self {
Self {
expr,
_phantom: PhantomData,
}
}
pub fn add_scalar(self, scalar: T) -> ExprBuilder<T, AddScalarExpr<T, E>>
where
T: Add<Output = T>,
{
ExprBuilder::new(AddScalarExpr {
expr: self.expr,
scalar,
})
}
pub fn mul_scalar(self, scalar: T) -> ExprBuilder<T, MulScalarExpr<T, E>>
where
T: Mul<Output = T>,
{
ExprBuilder::new(MulScalarExpr {
expr: self.expr,
scalar,
})
}
pub fn sub_scalar(self, scalar: T) -> ExprBuilder<T, SubScalarExpr<T, E>>
where
T: Sub<Output = T>,
{
ExprBuilder::new(SubScalarExpr {
expr: self.expr,
scalar,
})
}
pub fn div_scalar(self, scalar: T) -> ExprBuilder<T, DivScalarExpr<T, E>>
where
T: Div<Output = T>,
{
ExprBuilder::new(DivScalarExpr {
expr: self.expr,
scalar,
})
}
pub fn add<E2: Expression<T>>(
self,
other: ExprBuilder<T, E2>,
) -> ExprBuilder<T, AddExpr<T, E, E2>>
where
T: Add<Output = T>,
{
ExprBuilder::new(AddExpr {
left: self.expr,
right: other.expr,
_phantom: PhantomData,
})
}
pub fn mul<E2: Expression<T>>(
self,
other: ExprBuilder<T, E2>,
) -> ExprBuilder<T, MulExpr<T, E, E2>>
where
T: Mul<Output = T>,
{
ExprBuilder::new(MulExpr {
left: self.expr,
right: other.expr,
_phantom: PhantomData,
})
}
pub fn sub<E2: Expression<T>>(
self,
other: ExprBuilder<T, E2>,
) -> ExprBuilder<T, SubExpr<T, E, E2>>
where
T: Sub<Output = T>,
{
ExprBuilder::new(SubExpr {
left: self.expr,
right: other.expr,
_phantom: PhantomData,
})
}
pub fn div<E2: Expression<T>>(
self,
other: ExprBuilder<T, E2>,
) -> ExprBuilder<T, DivExpr<T, E, E2>>
where
T: Div<Output = T>,
{
ExprBuilder::new(DivExpr {
left: self.expr,
right: other.expr,
_phantom: PhantomData,
})
}
pub fn neg(self) -> ExprBuilder<T, NegExpr<T, E>>
where
T: std::ops::Neg<Output = T>,
{
ExprBuilder::new(NegExpr {
expr: self.expr,
_phantom: PhantomData,
})
}
pub fn eval_vec(&self) -> Vec<T> {
self.expr.eval_vec()
}
pub fn eval_tensor(
&self,
shape: Vec<usize>,
device: torsh_core::device::DeviceType,
) -> Result<Tensor<T>>
where
T: Copy,
{
self.expr.eval_tensor(shape, device)
}
}
pub fn expr<'a, T: TensorElement + Copy>(
tensor: &'a Tensor<T>,
) -> Result<ExprBuilder<T, TensorExpr<'a, T>>> {
let tensor_expr = TensorExpr::new(tensor)?;
Ok(ExprBuilder::new(tensor_expr))
}
pub trait TensorExprExt<T: TensorElement> {
fn expr(&self) -> Result<ExprBuilder<T, TensorExpr<'_, T>>>
where
T: Copy;
}
impl<T: TensorElement + Copy> TensorExprExt<T> for Tensor<T> {
fn expr(&self) -> Result<ExprBuilder<T, TensorExpr<'_, T>>> {
expr(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation::*;
use torsh_core::device::DeviceType;
#[test]
fn test_scalar_operations() {
let tensor =
tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
let result = tensor
.expr()
.expect("tensor_1d creation should succeed")
.add_scalar(1.0)
.mul_scalar(2.0)
.eval_vec();
assert_eq!(result, vec![4.0, 6.0, 8.0, 10.0]);
}
#[test]
fn test_element_wise_operations() {
let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
let result = a
.expr()
.expect("expression should exist")
.add(b.expr().expect("expression should exist"))
.eval_vec();
assert_eq!(result, vec![3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_complex_expression() {
let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
let result = a
.expr()
.expect("tensor_1d creation should succeed")
.add(b.expr().expect("expression should exist"))
.mul_scalar(2.0)
.add_scalar(1.0)
.eval_vec();
assert_eq!(result, vec![7.0, 9.0, 11.0, 13.0]);
}
#[test]
fn test_negation() {
let tensor =
tensor_1d(&[1.0f32, 2.0, -3.0, 4.0]).expect("tensor_1d creation should succeed");
let result = tensor
.expr()
.expect("expression should exist")
.neg()
.eval_vec();
assert_eq!(result, vec![-1.0, -2.0, 3.0, -4.0]);
}
#[test]
fn test_eval_tensor() {
let tensor =
tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
let result = tensor
.expr()
.expect("tensor_1d creation should succeed")
.mul_scalar(2.0)
.eval_tensor(vec![4], DeviceType::Cpu)
.expect("tensor_1d creation should succeed");
let data = result.to_vec().expect("to_vec conversion should succeed");
assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_division() {
let a = tensor_1d(&[10.0f32, 20.0, 30.0, 40.0]).expect("tensor_1d creation should succeed");
let b = tensor_1d(&[2.0f32, 4.0, 5.0, 8.0]).expect("tensor_1d creation should succeed");
let result = a
.expr()
.expect("expression should exist")
.div(b.expr().expect("expression should exist"))
.eval_vec();
assert_eq!(result, vec![5.0, 5.0, 6.0, 5.0]);
}
#[test]
fn test_subtraction() {
let a = tensor_1d(&[10.0f32, 20.0, 30.0, 40.0]).expect("tensor_1d creation should succeed");
let b = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
let result = a
.expr()
.expect("expression should exist")
.sub(b.expr().expect("expression should exist"))
.eval_vec();
assert_eq!(result, vec![9.0, 18.0, 27.0, 36.0]);
}
#[test]
fn test_multiple_operations_chain() {
let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
let c = tensor_1d(&[3.0f32, 3.0, 3.0, 3.0]).expect("tensor_1d creation should succeed");
let result = a
.expr()
.expect("tensor_1d creation should succeed")
.add(b.expr().expect("expression should exist"))
.mul(c.expr().expect("expression should exist"))
.div_scalar(2.0)
.add_scalar(1.0)
.eval_vec();
assert_eq!(result, vec![5.5, 7.0, 8.5, 10.0]);
}
}