Skip to main content

morok_ir/uop/constructors/
shape.rs

1//! Shape manipulation: reshape, permute, expand, pad, shrink, flip.
2//!
3//! These operations manipulate tensor shapes and layouts without changing
4//! the underlying data (except for padding which may add values).
5
6use std::sync::Arc;
7
8use crate::Result;
9use crate::op::Op;
10use crate::uop::UOp;
11
12// Low-level constructors (pub(crate) - not yet used but will be needed for optimization passes)
13#[allow(dead_code)]
14impl UOp {
15    /// Reshape tensor to new shape (low-level, UOp-based constructor).
16    ///
17    /// Takes a UOp for the shape parameter (used internally by compiler passes).
18    /// For the public API with validation, use `try_reshape`.
19    pub(crate) fn reshape(src: Arc<Self>, new_shape: Arc<Self>) -> Arc<Self> {
20        let dtype = src.dtype();
21        Self::new(Op::Reshape { src, new_shape }, dtype)
22    }
23
24    /// Permute dimensions (low-level, UOp-based constructor).
25    ///
26    /// For the public API with validation, use `try_permute`.
27    pub(crate) fn permute(src: Arc<Self>, axes: Vec<usize>) -> Arc<Self> {
28        let dtype = src.dtype();
29        Self::new(Op::Permute { src, axes }, dtype)
30    }
31
32    /// Expand (broadcast) dimensions (low-level, UOp-based constructor).
33    ///
34    /// Takes a UOp for the shape parameter (used internally by compiler passes).
35    /// For the public API with validation, use `try_expand`.
36    pub(crate) fn expand(src: Arc<Self>, new_shape: Arc<Self>) -> Arc<Self> {
37        let dtype = src.dtype();
38        Self::new(Op::Expand { src, new_shape }, dtype)
39    }
40
41    /// Pad tensor (low-level, UOp-based constructor).
42    ///
43    /// Takes UOps for padding parameters (used internally by compiler passes).
44    /// For the public API with validation, use `try_pad`.
45    pub(crate) fn pad(src: Arc<Self>, begin_pads: Arc<Self>, end_pads: Arc<Self>) -> Arc<Self> {
46        let dtype = src.dtype();
47        Self::new(Op::Pad { src, begin_pads, end_pads }, dtype)
48    }
49
50    /// Shrink (slice) tensor (low-level, UOp-based constructor).
51    ///
52    /// Takes UOps for range parameters (used internally by compiler passes).
53    /// For the public API with validation, use `try_shrink`.
54    pub(crate) fn shrink(src: Arc<Self>, begins: Arc<Self>, ends: Arc<Self>) -> Arc<Self> {
55        let dtype = src.dtype();
56        Self::new(Op::Shrink { src, begins, ends }, dtype)
57    }
58
59    /// Flip (reverse) axes (low-level, UOp-based constructor).
60    ///
61    /// For the public API with validation, use `try_flip`.
62    pub(crate) fn flip(src: Arc<Self>, axes: Vec<bool>) -> Arc<Self> {
63        let dtype = src.dtype();
64        Self::new(Op::Flip { src, axes }, dtype)
65    }
66}
67
68// Primary Movement Operation Constructors (with validation)
69impl UOp {
70    /// Reshape with strict validation (fail-fast).
71    ///
72    /// Validates:
73    /// - No negative dimensions in new_shape
74    /// - Product of input shape == product of output shape
75    pub fn try_reshape(self: &Arc<Self>, new_shape: &crate::shape::Shape) -> Result<Arc<Self>> {
76        use crate::error::ReshapeSizeMismatchSnafu;
77        use crate::shape::shape_to_uop;
78        use snafu::ensure;
79
80        // Validate product equality if source shape is known
81        if let Some(src_shape) = self.shape()? {
82            // Identity reshape: skip if shapes already match.
83            // Exclude BUFFER and CONST: the rangeify pipeline requires RESHAPE(BUFFER) to
84            // generate INDEX operations for ASSIGN targets. Bare BUFFER can't be indexed.
85            // Tinygrad avoids this because their BUFFER carries shape natively and their
86            // rangeify handles bare BUFFER → INDEX directly.
87            if src_shape.as_slice() == new_shape.as_slice()
88                && !matches!(self.op(), crate::Op::Buffer { .. } | crate::Op::Param { .. } | crate::Op::Const(_))
89            {
90                return Ok(self.clone());
91            }
92
93            let src_product = crate::sint_prod(src_shape);
94            let dst_product = crate::sint_prod(new_shape);
95
96            // If both are concrete, validate equality
97            if let (Some(src_prod), Some(dst_prod)) = (src_product.as_const(), dst_product.as_const()) {
98                ensure!(src_prod == dst_prod, ReshapeSizeMismatchSnafu { input_size: src_prod, output_size: dst_prod });
99            }
100            // Symbolic products can't be validated at compile time
101        }
102
103        let shape_uop = shape_to_uop(new_shape);
104        let dtype = self.dtype();
105        Ok(Self::new(Op::Reshape { src: self.clone(), new_shape: shape_uop }, dtype))
106    }
107
108    /// Expand (broadcast) with strict validation.
109    ///
110    /// Validates:
111    /// - Number of dimensions matches
112    /// - Each dimension either matches or src dimension is 1
113    pub fn try_expand(self: &Arc<Self>, new_shape: &crate::shape::Shape) -> Result<Arc<Self>> {
114        use crate::error::ExpandDimensionMismatchSnafu;
115        use crate::error::ExpandInvalidDimensionSnafu;
116        use crate::shape::shape_to_uop;
117        use snafu::ensure;
118
119        if let Some(src_shape) = self.shape()? {
120            // Check dimension count
121            ensure!(
122                src_shape.len() == new_shape.len(),
123                ExpandDimensionMismatchSnafu { input_dims: src_shape.len(), output_dims: new_shape.len() }
124            );
125
126            // Check each dimension can be expanded
127            for (dim_idx, (src_dim, new_dim)) in src_shape.iter().zip(new_shape.iter()).enumerate() {
128                // If both are concrete, validate expand rule
129                if let (Some(s), Some(ns)) = (src_dim.as_const(), new_dim.as_const()) {
130                    ensure!(s == ns || s == 1, ExpandInvalidDimensionSnafu { dim: dim_idx, input: s, output: ns });
131                }
132                // Symbolic dimensions assumed compatible
133            }
134
135            // Identity expand: skip if shapes already match
136            if src_shape.as_slice() == new_shape.as_slice() {
137                return Ok(self.clone());
138            }
139        }
140
141        let shape_uop = shape_to_uop(new_shape);
142        let dtype = self.dtype();
143        Ok(Self::new(Op::Expand { src: self.clone(), new_shape: shape_uop }, dtype))
144    }
145
146    /// Permute with strict validation.
147    ///
148    /// Validates:
149    /// - Permutation is valid (contains each index 0..n exactly once)
150    pub fn try_permute(self: &Arc<Self>, axes: Vec<usize>) -> Result<Arc<Self>> {
151        // Validate permutation if source shape is known
152        if let Some(src_shape) = self.shape()? {
153            Self::validate_permutation(&axes, src_shape.len())?;
154
155            // Identity permute: skip if axes is [0, 1, 2, ..., n-1]
156            if axes.iter().enumerate().all(|(i, &a)| a == i) {
157                return Ok(self.clone());
158            }
159        }
160
161        let dtype = self.dtype();
162        Ok(Self::new(Op::Permute { src: self.clone(), axes }, dtype))
163    }
164
165    /// Pad with strict validation.
166    ///
167    /// Validates:
168    /// - Padding values are concrete (not symbolic)
169    /// - Number of padding pairs matches dimensions
170    pub fn try_pad(self: &Arc<Self>, padding: &[(crate::SInt, crate::SInt)]) -> Result<Arc<Self>> {
171        use crate::error::{PadDimensionMismatchSnafu, SymbolicPaddingUnsupportedSnafu};
172        use crate::shape::ranges_to_uops;
173        use snafu::ensure;
174
175        // Empty padding (scalar) → identity
176        if padding.is_empty() {
177            return Ok(self.clone());
178        }
179
180        // Check for symbolic padding values
181        for (begin, end) in padding {
182            ensure!(begin.is_const(), SymbolicPaddingUnsupportedSnafu);
183            ensure!(end.is_const(), SymbolicPaddingUnsupportedSnafu);
184        }
185
186        // All-zero padding: no-op
187        if padding.iter().all(|(b, e)| b.as_const() == Some(0) && e.as_const() == Some(0)) {
188            return Ok(self.clone());
189        }
190
191        if let Some(src_shape) = self.shape()? {
192            // Check dimension count
193            ensure!(
194                padding.len() == src_shape.len(),
195                PadDimensionMismatchSnafu { padding_dims: padding.len(), shape_dims: src_shape.len() }
196            );
197        }
198
199        let (begin_pads, end_pads) = ranges_to_uops(padding);
200        let dtype = self.dtype();
201        Ok(Self::new(Op::Pad { src: self.clone(), begin_pads, end_pads }, dtype))
202    }
203
204    /// Shrink (slice) with strict validation.
205    ///
206    /// Validates:
207    /// - Range values are concrete (not symbolic)
208    /// - begin <= end for each dimension
209    /// - 0 <= begin, end <= dimension_size
210    pub fn try_shrink(self: &Arc<Self>, ranges: &[(crate::SInt, crate::SInt)]) -> Result<Arc<Self>> {
211        use crate::error::ShrinkBoundsViolationSnafu;
212        use crate::shape::ranges_to_uops;
213        use snafu::ensure;
214
215        // Empty ranges (scalar) → identity
216        if ranges.is_empty() {
217            return Ok(self.clone());
218        }
219
220        // Symbolic shrink ranges are allowed — the rangeify pipeline handles
221        // symbolic range ends. Only concrete ranges are bounds-checked.
222        if let Some(src_shape) = self.shape()? {
223            for (dim_idx, ((begin, end), dim_size)) in ranges.iter().zip(src_shape.iter()).enumerate() {
224                if let (Some(b), Some(e), Some(s)) = (begin.as_const(), end.as_const(), dim_size.as_const()) {
225                    ensure!(
226                        b <= e && e <= s,
227                        ShrinkBoundsViolationSnafu { dim: dim_idx, begin: b, end: e, shape_size: s }
228                    );
229                }
230            }
231
232            // Identity shrink: skip if all ranges span the full dimension
233            if ranges.iter().zip(src_shape.iter()).all(|((b, e), d)| b.as_const() == Some(0) && *e == *d) {
234                return Ok(self.clone());
235            }
236        }
237
238        let (begins, ends) = ranges_to_uops(ranges);
239        let dtype = self.dtype();
240        let result = Self::new(Op::Shrink { src: self.clone(), begins, ends }, dtype);
241        // Tinygrad (movement.py:128): return self if ret.shape == self.shape else ret
242        if result.shape().ok().flatten() == self.shape().ok().flatten() {
243            return Ok(self.clone());
244        }
245        Ok(result)
246    }
247
248    /// Flip with strict validation.
249    ///
250    /// Validates:
251    /// - Flip specification length matches shape dimensions
252    pub fn try_flip(self: &Arc<Self>, axes: Vec<bool>) -> Result<Arc<Self>> {
253        // All-false flip: no-op
254        if !axes.iter().any(|&a| a) {
255            return Ok(self.clone());
256        }
257
258        if let Some(src_shape) = self.shape()? {
259            Self::validate_flip_axes(&axes, src_shape.len())?;
260        }
261
262        let dtype = self.dtype();
263        Ok(Self::new(Op::Flip { src: self.clone(), axes }, dtype))
264    }
265
266    /// Split tensor across multiple devices along specified axis.
267    ///
268    /// Creates a multi-device tensor where each device holds a shard.
269    /// Use with MSTACK/MSELECT for distributed tensor operations.
270    pub fn multi(src: Arc<Self>, axis: usize) -> Arc<Self> {
271        let dtype = src.dtype();
272        Self::new(Op::Multi { src, axis }, dtype)
273    }
274}