acme_tensor/ops/kinds/
reshape.rs

1/*
2    Appellation: reshape <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::ops::{BoxTensor, TensorExpr};
6use crate::shape::{Axis, Shape};
7use crate::tensor::TensorBase;
8use strum::{Display, EnumCount, EnumDiscriminants, EnumIs, EnumIter, EnumString, VariantNames};
9
10#[derive(Clone, Debug, EnumDiscriminants, Eq, Hash, PartialEq)]
11#[repr(C)]
12#[strum_discriminants(derive(
13    Display,
14    EnumCount,
15    EnumIs,
16    EnumIter,
17    EnumString,
18    Hash,
19    Ord,
20    PartialOrd,
21    VariantNames
22))]
23#[cfg_attr(
24    feature = "serde",
25    derive(serde::Deserialize, serde::Serialize),
26    serde(rename_all = "snake_case"),
27    strum(serialize_all = "snake_case"),
28    strum_discriminants(derive(serde::Deserialize, serde::Serialize))
29)]
30#[strum_discriminants(name(ReshapeOp))]
31pub enum ReshapeExpr<T> {
32    Broadcast {
33        recv: BoxTensor<T>,
34        shape: Shape,
35    },
36    Reshape {
37        recv: BoxTensor<T>,
38        shape: Shape,
39    },
40    Swap {
41        recv: BoxTensor<T>,
42        a: usize,
43        b: usize,
44    },
45    SwapAxis {
46        recv: BoxTensor<T>,
47        a: Axis,
48        b: Axis,
49    },
50    Transpose {
51        recv: BoxTensor<T>,
52    },
53}
54
55impl<T> ReshapeExpr<T> {
56    pub fn broadcast(recv: TensorBase<T>, shape: Shape) -> Self {
57        Self::Broadcast {
58            recv: recv.boxed(),
59            shape,
60        }
61    }
62
63    pub fn reshape(recv: TensorBase<T>, shape: Shape) -> Self {
64        Self::Reshape {
65            recv: recv.boxed(),
66            shape,
67        }
68    }
69
70    pub fn swap(recv: TensorBase<T>, a: usize, b: usize) -> Self {
71        Self::Swap {
72            recv: recv.boxed(),
73            a,
74            b,
75        }
76    }
77
78    pub fn swap_axes(recv: TensorBase<T>, a: Axis, b: Axis) -> Self {
79        Self::SwapAxis {
80            recv: recv.boxed(),
81            a,
82            b,
83        }
84    }
85
86    pub fn transpose(recv: TensorBase<T>) -> Self {
87        Self::Transpose { recv: recv.boxed() }
88    }
89
90    pub fn recv(&self) -> &BoxTensor<T> {
91        match self {
92            Self::Broadcast { recv, .. } => recv,
93            Self::Reshape { recv, .. } => recv,
94            Self::Swap { recv, .. } => recv,
95            Self::SwapAxis { recv, .. } => recv,
96            Self::Transpose { recv } => recv,
97        }
98    }
99
100    pub fn recv_mut(&mut self) -> &mut BoxTensor<T> {
101        match self {
102            Self::Broadcast { recv, .. } => recv,
103            Self::Reshape { recv, .. } => recv,
104            Self::Swap { recv, .. } => recv,
105            Self::SwapAxis { recv, .. } => recv,
106            Self::Transpose { recv } => recv,
107        }
108    }
109
110    pub fn view(&self) -> ReshapeExpr<&T> {
111        match self {
112            Self::Broadcast { recv, shape } => ReshapeExpr::broadcast(recv.view(), shape.clone()),
113            Self::Reshape { recv, shape } => ReshapeExpr::reshape(recv.view(), shape.clone()),
114            Self::Swap { recv, a, b } => ReshapeExpr::swap(recv.view(), *a, *b),
115            Self::SwapAxis { recv, a, b } => ReshapeExpr::swap_axes(recv.view(), *a, *b),
116            Self::Transpose { recv } => ReshapeExpr::transpose(recv.view()),
117        }
118    }
119
120    pub fn view_mut(&mut self) -> ReshapeExpr<&mut T> {
121        match self {
122            Self::Broadcast { recv, shape } => {
123                ReshapeExpr::broadcast(recv.view_mut(), shape.clone())
124            }
125            Self::Reshape { recv, shape } => ReshapeExpr::reshape(recv.view_mut(), shape.clone()),
126            Self::Swap { recv, a, b } => ReshapeExpr::swap(recv.view_mut(), *a, *b),
127            Self::SwapAxis { recv, a, b } => ReshapeExpr::swap_axes(recv.view_mut(), *a, *b),
128            Self::Transpose { recv } => ReshapeExpr::transpose(recv.view_mut()),
129        }
130    }
131}
132
133impl<A, B> From<ReshapeExpr<A>> for TensorExpr<A, B> {
134    fn from(expr: ReshapeExpr<A>) -> Self {
135        TensorExpr::Shape(expr)
136    }
137}