use crate::ops::kinds::reshape::*;
use crate::shape::{Axis, Shape};
use crate::TensorBase;
use acme::prelude::{BinaryOp, UnaryOp};
use num::Complex;
pub type BoxTensor<T = f64> = Box<TensorBase<T>>;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum TensorExpr<A, B = A> {
Binary(BoxTensor<A>, BoxTensor<B>, BinaryOp),
BinaryScalar(BoxTensor<A>, B, BinaryOp),
Unary(BoxTensor<A>, UnaryOp),
Broadcast(BoxTensor<A>, Shape),
Matmul(BoxTensor<A>, BoxTensor<B>),
Reshape(BoxTensor<A>, Shape),
Shape(ReshapeExpr<A>),
SwapAxes(BoxTensor<A>, Axis, Axis),
Transpose(BoxTensor<A>),
}
impl<A, B> TensorExpr<A, B> {
pub fn binary(lhs: TensorBase<A>, rhs: TensorBase<B>, op: BinaryOp) -> Self {
Self::Binary(Box::new(lhs), Box::new(rhs), op)
}
pub fn binary_scalar(lhs: TensorBase<A>, rhs: B, op: BinaryOp) -> Self {
Self::BinaryScalar(Box::new(lhs), rhs, op)
}
pub fn binary_scalar_c(
lhs: TensorBase<A>,
rhs: Complex<A>,
op: BinaryOp,
) -> TensorExpr<A, Complex<A>> {
TensorExpr::BinaryScalar(Box::new(lhs), rhs, op)
}
pub fn broadcast(tensor: TensorBase<A>, shape: Shape) -> Self {
Self::Broadcast(Box::new(tensor), shape)
}
pub fn matmul(lhs: TensorBase<A>, rhs: TensorBase<B>) -> Self {
Self::Matmul(Box::new(lhs), Box::new(rhs))
}
pub fn reshape(tensor: TensorBase<A>, shape: Shape) -> Self {
Self::Reshape(Box::new(tensor), shape)
}
pub fn shape(expr: ReshapeExpr<A>) -> Self {
Self::Shape(expr)
}
pub fn swap_axes(tensor: TensorBase<A>, swap: Axis, with: Axis) -> Self {
Self::SwapAxes(Box::new(tensor), swap, with)
}
pub fn transpose(scope: TensorBase<A>) -> Self {
Self::Transpose(Box::new(scope))
}
pub fn unary(tensor: TensorBase<A>, op: UnaryOp) -> Self {
Self::Unary(Box::new(tensor), op)
}
pub fn lhs(self) -> Option<TensorBase<A>> {
match self {
Self::Binary(lhs, _, _) => Some(*lhs),
Self::BinaryScalar(lhs, _, _) => Some(*lhs),
Self::Unary(lhs, _) => Some(*lhs),
Self::Broadcast(tensor, _) => Some(*tensor),
Self::Matmul(lhs, _) => Some(*lhs),
Self::Transpose(lhs) => Some(*lhs),
_ => None,
}
}
pub fn rhs(self) -> Option<TensorBase<B>> {
match self {
Self::Binary(_, rhs, _) => Some(*rhs),
Self::BinaryScalar(_, scalar, _) => Some(TensorBase::from_scalar(scalar)),
Self::Matmul(_, rhs) => Some(*rhs),
_ => None,
}
}
pub fn view<'a>(&'a self) -> TensorExpr<&'a A, &'a B> {
match self {
TensorExpr::Binary(lhs, rhs, op) => TensorExpr::binary(lhs.view(), rhs.view(), *op),
TensorExpr::BinaryScalar(lhs, rhs, op) => {
TensorExpr::binary_scalar(lhs.view(), rhs, *op)
}
TensorExpr::Unary(tensor, op) => TensorExpr::unary(tensor.view(), *op),
TensorExpr::Broadcast(tensor, shape) => {
TensorExpr::broadcast(tensor.view(), shape.clone())
}
TensorExpr::Matmul(lhs, rhs) => TensorExpr::matmul(lhs.view(), rhs.view()),
TensorExpr::Reshape(tensor, shape) => TensorExpr::reshape(tensor.view(), shape.clone()),
TensorExpr::Transpose(tensor) => TensorExpr::transpose(tensor.view()),
_ => unimplemented!(),
}
}
}