1use crate::ops::{BoxTensor, TensorExpr};
6use crate::shape::{Axis, Shape};
7use crate::tensor::TensorBase;
8use strum::{Display, EnumCount, EnumDiscriminants, EnumIs, EnumIter, EnumString, VariantNames};
9
10#[derive(Clone, Debug, EnumDiscriminants, Eq, Hash, PartialEq)]
11#[repr(C)]
12#[strum_discriminants(derive(
13 Display,
14 EnumCount,
15 EnumIs,
16 EnumIter,
17 EnumString,
18 Hash,
19 Ord,
20 PartialOrd,
21 VariantNames
22))]
23#[cfg_attr(
24 feature = "serde",
25 derive(serde::Deserialize, serde::Serialize),
26 serde(rename_all = "snake_case"),
27 strum(serialize_all = "snake_case"),
28 strum_discriminants(derive(serde::Deserialize, serde::Serialize))
29)]
30#[strum_discriminants(name(ReshapeOp))]
31pub enum ReshapeExpr<T> {
32 Broadcast {
33 recv: BoxTensor<T>,
34 shape: Shape,
35 },
36 Reshape {
37 recv: BoxTensor<T>,
38 shape: Shape,
39 },
40 Swap {
41 recv: BoxTensor<T>,
42 a: usize,
43 b: usize,
44 },
45 SwapAxis {
46 recv: BoxTensor<T>,
47 a: Axis,
48 b: Axis,
49 },
50 Transpose {
51 recv: BoxTensor<T>,
52 },
53}
54
55impl<T> ReshapeExpr<T> {
56 pub fn broadcast(recv: TensorBase<T>, shape: Shape) -> Self {
57 Self::Broadcast {
58 recv: recv.boxed(),
59 shape,
60 }
61 }
62
63 pub fn reshape(recv: TensorBase<T>, shape: Shape) -> Self {
64 Self::Reshape {
65 recv: recv.boxed(),
66 shape,
67 }
68 }
69
70 pub fn swap(recv: TensorBase<T>, a: usize, b: usize) -> Self {
71 Self::Swap {
72 recv: recv.boxed(),
73 a,
74 b,
75 }
76 }
77
78 pub fn swap_axes(recv: TensorBase<T>, a: Axis, b: Axis) -> Self {
79 Self::SwapAxis {
80 recv: recv.boxed(),
81 a,
82 b,
83 }
84 }
85
86 pub fn transpose(recv: TensorBase<T>) -> Self {
87 Self::Transpose { recv: recv.boxed() }
88 }
89
90 pub fn recv(&self) -> &BoxTensor<T> {
91 match self {
92 Self::Broadcast { recv, .. } => recv,
93 Self::Reshape { recv, .. } => recv,
94 Self::Swap { recv, .. } => recv,
95 Self::SwapAxis { recv, .. } => recv,
96 Self::Transpose { recv } => recv,
97 }
98 }
99
100 pub fn recv_mut(&mut self) -> &mut BoxTensor<T> {
101 match self {
102 Self::Broadcast { recv, .. } => recv,
103 Self::Reshape { recv, .. } => recv,
104 Self::Swap { recv, .. } => recv,
105 Self::SwapAxis { recv, .. } => recv,
106 Self::Transpose { recv } => recv,
107 }
108 }
109
110 pub fn view(&self) -> ReshapeExpr<&T> {
111 match self {
112 Self::Broadcast { recv, shape } => ReshapeExpr::broadcast(recv.view(), shape.clone()),
113 Self::Reshape { recv, shape } => ReshapeExpr::reshape(recv.view(), shape.clone()),
114 Self::Swap { recv, a, b } => ReshapeExpr::swap(recv.view(), *a, *b),
115 Self::SwapAxis { recv, a, b } => ReshapeExpr::swap_axes(recv.view(), *a, *b),
116 Self::Transpose { recv } => ReshapeExpr::transpose(recv.view()),
117 }
118 }
119
120 pub fn view_mut(&mut self) -> ReshapeExpr<&mut T> {
121 match self {
122 Self::Broadcast { recv, shape } => {
123 ReshapeExpr::broadcast(recv.view_mut(), shape.clone())
124 }
125 Self::Reshape { recv, shape } => ReshapeExpr::reshape(recv.view_mut(), shape.clone()),
126 Self::Swap { recv, a, b } => ReshapeExpr::swap(recv.view_mut(), *a, *b),
127 Self::SwapAxis { recv, a, b } => ReshapeExpr::swap_axes(recv.view_mut(), *a, *b),
128 Self::Transpose { recv } => ReshapeExpr::transpose(recv.view_mut()),
129 }
130 }
131}
132
133impl<A, B> From<ReshapeExpr<A>> for TensorExpr<A, B> {
134 fn from(expr: ReshapeExpr<A>) -> Self {
135 TensorExpr::Shape(expr)
136 }
137}