1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/*
    Appellation: kinds <mod>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
use crate::ops::kinds::reshape::*;
use crate::shape::{Axis, Shape};
use crate::TensorBase;
use acme::prelude::{BinaryOp, UnaryOp};
use num::Complex;

pub type BoxTensor<T = f64> = Box<TensorBase<T>>;

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum TensorExpr<A, B = A> {
    Binary(BoxTensor<A>, BoxTensor<B>, BinaryOp),
    BinaryScalar(BoxTensor<A>, B, BinaryOp),
    Unary(BoxTensor<A>, UnaryOp),
    Broadcast(BoxTensor<A>, Shape),
    Matmul(BoxTensor<A>, BoxTensor<B>),
    Reshape(BoxTensor<A>, Shape),
    Shape(ReshapeExpr<A>),
    SwapAxes(BoxTensor<A>, Axis, Axis),
    Transpose(BoxTensor<A>),
}

impl<A, B> TensorExpr<A, B> {
    pub fn binary(lhs: TensorBase<A>, rhs: TensorBase<B>, op: BinaryOp) -> Self {
        Self::Binary(Box::new(lhs), Box::new(rhs), op)
    }

    pub fn binary_scalar(lhs: TensorBase<A>, rhs: B, op: BinaryOp) -> Self {
        Self::BinaryScalar(Box::new(lhs), rhs, op)
    }

    pub fn binary_scalar_c(
        lhs: TensorBase<A>,
        rhs: Complex<A>,
        op: BinaryOp,
    ) -> TensorExpr<A, Complex<A>> {
        TensorExpr::BinaryScalar(Box::new(lhs), rhs, op)
    }

    pub fn broadcast(tensor: TensorBase<A>, shape: Shape) -> Self {
        Self::Broadcast(Box::new(tensor), shape)
    }

    pub fn matmul(lhs: TensorBase<A>, rhs: TensorBase<B>) -> Self {
        Self::Matmul(Box::new(lhs), Box::new(rhs))
    }

    pub fn reshape(tensor: TensorBase<A>, shape: Shape) -> Self {
        Self::Reshape(Box::new(tensor), shape)
    }

    pub fn shape(expr: ReshapeExpr<A>) -> Self {
        Self::Shape(expr)
    }

    pub fn swap_axes(tensor: TensorBase<A>, swap: Axis, with: Axis) -> Self {
        Self::SwapAxes(Box::new(tensor), swap, with)
    }

    pub fn transpose(scope: TensorBase<A>) -> Self {
        Self::Transpose(Box::new(scope))
    }

    pub fn unary(tensor: TensorBase<A>, op: UnaryOp) -> Self {
        Self::Unary(Box::new(tensor), op)
    }

    pub fn lhs(self) -> Option<TensorBase<A>> {
        match self {
            Self::Binary(lhs, _, _) => Some(*lhs),
            Self::BinaryScalar(lhs, _, _) => Some(*lhs),
            Self::Unary(lhs, _) => Some(*lhs),
            Self::Broadcast(tensor, _) => Some(*tensor),
            Self::Matmul(lhs, _) => Some(*lhs),
            Self::Transpose(lhs) => Some(*lhs),
            _ => None,
        }
    }

    pub fn rhs(self) -> Option<TensorBase<B>> {
        match self {
            Self::Binary(_, rhs, _) => Some(*rhs),
            Self::BinaryScalar(_, scalar, _) => Some(TensorBase::from_scalar(scalar)),
            Self::Matmul(_, rhs) => Some(*rhs),
            _ => None,
        }
    }
    pub fn view<'a>(&'a self) -> TensorExpr<&'a A, &'a B> {
        match self {
            TensorExpr::Binary(lhs, rhs, op) => TensorExpr::binary(lhs.view(), rhs.view(), *op),
            TensorExpr::BinaryScalar(lhs, rhs, op) => {
                TensorExpr::binary_scalar(lhs.view(), rhs, *op)
            }
            TensorExpr::Unary(tensor, op) => TensorExpr::unary(tensor.view(), *op),
            TensorExpr::Broadcast(tensor, shape) => {
                TensorExpr::broadcast(tensor.view(), shape.clone())
            }
            TensorExpr::Matmul(lhs, rhs) => TensorExpr::matmul(lhs.view(), rhs.view()),
            TensorExpr::Reshape(tensor, shape) => TensorExpr::reshape(tensor.view(), shape.clone()),
            TensorExpr::Transpose(tensor) => TensorExpr::transpose(tensor.view()),
            _ => unimplemented!(),
        }
    }
}