1use crate::ops::kinds::reshape::*;
6use crate::shape::{Axis, Shape};
7use crate::tensor::TensorBase;
8use acme::prelude::{BinaryOp, UnaryOp};
9use num::Complex;
10
11pub type BoxTensor<T = f64> = Box<TensorBase<T>>;
12
13#[derive(Clone, Debug, Eq, Hash, PartialEq)]
14#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
15#[non_exhaustive]
16pub enum TensorExpr<A, B = A> {
17 Binary(BoxTensor<A>, BoxTensor<B>, BinaryOp),
18 BinaryScalar(BoxTensor<A>, B, BinaryOp),
19 Unary(BoxTensor<A>, UnaryOp),
20 Matmul(BoxTensor<A>, BoxTensor<B>),
21 Sigmoid(BoxTensor<A>),
22 Shape(ReshapeExpr<A>),
23}
24
25impl<A, B> TensorExpr<A, B> {
26 pub fn binary(lhs: TensorBase<A>, rhs: TensorBase<B>, op: BinaryOp) -> Self {
27 Self::Binary(lhs.boxed(), rhs.boxed(), op)
28 }
29
30 pub fn binary_scalar(lhs: TensorBase<A>, rhs: B, op: BinaryOp) -> Self {
31 Self::BinaryScalar(Box::new(lhs), rhs, op)
32 }
33
34 pub fn binary_scalar_c(
35 lhs: TensorBase<A>,
36 rhs: Complex<A>,
37 op: BinaryOp,
38 ) -> TensorExpr<A, Complex<A>> {
39 TensorExpr::BinaryScalar(Box::new(lhs), rhs, op)
40 }
41
42 pub fn broadcast(tensor: TensorBase<A>, shape: Shape) -> Self {
43 Self::shape(ReshapeExpr::broadcast(tensor, shape))
44 }
45
46 pub fn matmul(lhs: TensorBase<A>, rhs: TensorBase<B>) -> Self {
47 Self::Matmul(Box::new(lhs), Box::new(rhs))
48 }
49
50 pub fn reshape(tensor: TensorBase<A>, shape: Shape) -> Self {
51 Self::shape(ReshapeExpr::reshape(tensor, shape))
52 }
53
54 pub fn shape(expr: ReshapeExpr<A>) -> Self {
55 Self::Shape(expr)
56 }
57
58 pub fn sigmoid(tensor: TensorBase<A>) -> Self {
59 Self::Sigmoid(Box::new(tensor))
60 }
61
62 pub fn swap_axes(tensor: TensorBase<A>, swap: Axis, with: Axis) -> Self {
63 Self::shape(ReshapeExpr::swap_axes(tensor, swap, with))
64 }
65
66 pub fn transpose(tensor: TensorBase<A>) -> Self {
67 Self::Shape(ReshapeExpr::transpose(tensor))
68 }
69
70 pub fn unary(tensor: TensorBase<A>, op: UnaryOp) -> Self {
71 Self::Unary(Box::new(tensor), op)
72 }
73
74 pub fn lhs(self) -> Option<TensorBase<A>> {
75 match self {
76 Self::Binary(lhs, _, _) => Some(*lhs),
77 Self::BinaryScalar(lhs, _, _) => Some(*lhs),
78 Self::Unary(lhs, _) => Some(*lhs),
79 Self::Matmul(lhs, _) => Some(*lhs),
80 _ => None,
81 }
82 }
83
84 pub fn rhs(self) -> Option<TensorBase<B>> {
85 match self {
86 Self::Binary(_, rhs, _) => Some(*rhs),
87 Self::BinaryScalar(_, scalar, _) => Some(TensorBase::from_scalar(scalar)),
88 Self::Matmul(_, rhs) => Some(*rhs),
89 _ => None,
90 }
91 }
92 pub fn view(&self) -> TensorExpr<&A, &B> {
93 match self {
94 TensorExpr::Binary(lhs, rhs, op) => TensorExpr::binary(lhs.view(), rhs.view(), *op),
95 TensorExpr::BinaryScalar(lhs, rhs, op) => {
96 TensorExpr::binary_scalar(lhs.view(), rhs, *op)
97 }
98 TensorExpr::Unary(tensor, op) => TensorExpr::unary(tensor.view(), *op),
99 TensorExpr::Matmul(lhs, rhs) => TensorExpr::matmul(lhs.view(), rhs.view()),
100 TensorExpr::Sigmoid(tensor) => TensorExpr::sigmoid(tensor.view()),
101 TensorExpr::Shape(inner) => TensorExpr::Shape(inner.view()),
102 }
103 }
104 pub fn view_mut(&mut self) -> TensorExpr<&mut A, &mut B> {
105 match self {
106 TensorExpr::Binary(lhs, rhs, op) => {
107 TensorExpr::binary(lhs.view_mut(), rhs.view_mut(), *op)
108 }
109
110 TensorExpr::BinaryScalar(lhs, rhs, op) => {
111 TensorExpr::binary_scalar(lhs.view_mut(), rhs, *op)
112 }
113 TensorExpr::Unary(tensor, op) => TensorExpr::unary(tensor.view_mut(), *op),
114 TensorExpr::Matmul(lhs, rhs) => TensorExpr::matmul(lhs.view_mut(), rhs.view_mut()),
115 TensorExpr::Sigmoid(tensor) => TensorExpr::sigmoid(tensor.view_mut()),
116 TensorExpr::Shape(inner) => TensorExpr::Shape(inner.view_mut()),
117 }
118 }
119}