morok_ir/uop/constructors/
shape.rs1use std::sync::Arc;
7
8use crate::Result;
9use crate::op::Op;
10use crate::uop::UOp;
11
12#[allow(dead_code)]
14impl UOp {
15 pub(crate) fn reshape(src: Arc<Self>, new_shape: Arc<Self>) -> Arc<Self> {
20 let dtype = src.dtype();
21 Self::new(Op::Reshape { src, new_shape }, dtype)
22 }
23
24 pub(crate) fn permute(src: Arc<Self>, axes: Vec<usize>) -> Arc<Self> {
28 let dtype = src.dtype();
29 Self::new(Op::Permute { src, axes }, dtype)
30 }
31
32 pub(crate) fn expand(src: Arc<Self>, new_shape: Arc<Self>) -> Arc<Self> {
37 let dtype = src.dtype();
38 Self::new(Op::Expand { src, new_shape }, dtype)
39 }
40
41 pub(crate) fn pad(src: Arc<Self>, begin_pads: Arc<Self>, end_pads: Arc<Self>) -> Arc<Self> {
46 let dtype = src.dtype();
47 Self::new(Op::Pad { src, begin_pads, end_pads }, dtype)
48 }
49
50 pub(crate) fn shrink(src: Arc<Self>, begins: Arc<Self>, ends: Arc<Self>) -> Arc<Self> {
55 let dtype = src.dtype();
56 Self::new(Op::Shrink { src, begins, ends }, dtype)
57 }
58
59 pub(crate) fn flip(src: Arc<Self>, axes: Vec<bool>) -> Arc<Self> {
63 let dtype = src.dtype();
64 Self::new(Op::Flip { src, axes }, dtype)
65 }
66}
67
68impl UOp {
70 pub fn try_reshape(self: &Arc<Self>, new_shape: &crate::shape::Shape) -> Result<Arc<Self>> {
76 use crate::error::ReshapeSizeMismatchSnafu;
77 use crate::shape::shape_to_uop;
78 use snafu::ensure;
79
80 if let Some(src_shape) = self.shape()? {
82 if src_shape.as_slice() == new_shape.as_slice()
88 && !matches!(self.op(), crate::Op::Buffer { .. } | crate::Op::Param { .. } | crate::Op::Const(_))
89 {
90 return Ok(self.clone());
91 }
92
93 let src_product = crate::sint_prod(src_shape);
94 let dst_product = crate::sint_prod(new_shape);
95
96 if let (Some(src_prod), Some(dst_prod)) = (src_product.as_const(), dst_product.as_const()) {
98 ensure!(src_prod == dst_prod, ReshapeSizeMismatchSnafu { input_size: src_prod, output_size: dst_prod });
99 }
100 }
102
103 let shape_uop = shape_to_uop(new_shape);
104 let dtype = self.dtype();
105 Ok(Self::new(Op::Reshape { src: self.clone(), new_shape: shape_uop }, dtype))
106 }
107
108 pub fn try_expand(self: &Arc<Self>, new_shape: &crate::shape::Shape) -> Result<Arc<Self>> {
114 use crate::error::ExpandDimensionMismatchSnafu;
115 use crate::error::ExpandInvalidDimensionSnafu;
116 use crate::shape::shape_to_uop;
117 use snafu::ensure;
118
119 if let Some(src_shape) = self.shape()? {
120 ensure!(
122 src_shape.len() == new_shape.len(),
123 ExpandDimensionMismatchSnafu { input_dims: src_shape.len(), output_dims: new_shape.len() }
124 );
125
126 for (dim_idx, (src_dim, new_dim)) in src_shape.iter().zip(new_shape.iter()).enumerate() {
128 if let (Some(s), Some(ns)) = (src_dim.as_const(), new_dim.as_const()) {
130 ensure!(s == ns || s == 1, ExpandInvalidDimensionSnafu { dim: dim_idx, input: s, output: ns });
131 }
132 }
134
135 if src_shape.as_slice() == new_shape.as_slice() {
137 return Ok(self.clone());
138 }
139 }
140
141 let shape_uop = shape_to_uop(new_shape);
142 let dtype = self.dtype();
143 Ok(Self::new(Op::Expand { src: self.clone(), new_shape: shape_uop }, dtype))
144 }
145
146 pub fn try_permute(self: &Arc<Self>, axes: Vec<usize>) -> Result<Arc<Self>> {
151 if let Some(src_shape) = self.shape()? {
153 Self::validate_permutation(&axes, src_shape.len())?;
154
155 if axes.iter().enumerate().all(|(i, &a)| a == i) {
157 return Ok(self.clone());
158 }
159 }
160
161 let dtype = self.dtype();
162 Ok(Self::new(Op::Permute { src: self.clone(), axes }, dtype))
163 }
164
165 pub fn try_pad(self: &Arc<Self>, padding: &[(crate::SInt, crate::SInt)]) -> Result<Arc<Self>> {
171 use crate::error::{PadDimensionMismatchSnafu, SymbolicPaddingUnsupportedSnafu};
172 use crate::shape::ranges_to_uops;
173 use snafu::ensure;
174
175 if padding.is_empty() {
177 return Ok(self.clone());
178 }
179
180 for (begin, end) in padding {
182 ensure!(begin.is_const(), SymbolicPaddingUnsupportedSnafu);
183 ensure!(end.is_const(), SymbolicPaddingUnsupportedSnafu);
184 }
185
186 if padding.iter().all(|(b, e)| b.as_const() == Some(0) && e.as_const() == Some(0)) {
188 return Ok(self.clone());
189 }
190
191 if let Some(src_shape) = self.shape()? {
192 ensure!(
194 padding.len() == src_shape.len(),
195 PadDimensionMismatchSnafu { padding_dims: padding.len(), shape_dims: src_shape.len() }
196 );
197 }
198
199 let (begin_pads, end_pads) = ranges_to_uops(padding);
200 let dtype = self.dtype();
201 Ok(Self::new(Op::Pad { src: self.clone(), begin_pads, end_pads }, dtype))
202 }
203
204 pub fn try_shrink(self: &Arc<Self>, ranges: &[(crate::SInt, crate::SInt)]) -> Result<Arc<Self>> {
211 use crate::error::ShrinkBoundsViolationSnafu;
212 use crate::shape::ranges_to_uops;
213 use snafu::ensure;
214
215 if ranges.is_empty() {
217 return Ok(self.clone());
218 }
219
220 if let Some(src_shape) = self.shape()? {
223 for (dim_idx, ((begin, end), dim_size)) in ranges.iter().zip(src_shape.iter()).enumerate() {
224 if let (Some(b), Some(e), Some(s)) = (begin.as_const(), end.as_const(), dim_size.as_const()) {
225 ensure!(
226 b <= e && e <= s,
227 ShrinkBoundsViolationSnafu { dim: dim_idx, begin: b, end: e, shape_size: s }
228 );
229 }
230 }
231
232 if ranges.iter().zip(src_shape.iter()).all(|((b, e), d)| b.as_const() == Some(0) && *e == *d) {
234 return Ok(self.clone());
235 }
236 }
237
238 let (begins, ends) = ranges_to_uops(ranges);
239 let dtype = self.dtype();
240 let result = Self::new(Op::Shrink { src: self.clone(), begins, ends }, dtype);
241 if result.shape().ok().flatten() == self.shape().ok().flatten() {
243 return Ok(self.clone());
244 }
245 Ok(result)
246 }
247
248 pub fn try_flip(self: &Arc<Self>, axes: Vec<bool>) -> Result<Arc<Self>> {
253 if !axes.iter().any(|&a| a) {
255 return Ok(self.clone());
256 }
257
258 if let Some(src_shape) = self.shape()? {
259 Self::validate_flip_axes(&axes, src_shape.len())?;
260 }
261
262 let dtype = self.dtype();
263 Ok(Self::new(Op::Flip { src: self.clone(), axes }, dtype))
264 }
265
266 pub fn multi(src: Arc<Self>, axis: usize) -> Arc<Self> {
271 let dtype = src.dtype();
272 Self::new(Op::Multi { src, axis }, dtype)
273 }
274}