1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
//! Shape manipulation: reshape, permute, expand, pad, shrink, flip.
//!
//! These operations manipulate tensor shapes and layouts without changing
//! the underlying data (except for padding which may add values).
use std::sync::Arc;
use crate::Result;
use crate::op::Op;
use crate::uop::UOp;
// Low-level constructors (pub(crate) - not yet used but will be needed for optimization passes)
#[allow(dead_code)]
impl UOp {
/// Reshape tensor to new shape (low-level, UOp-based constructor).
///
/// Takes a UOp for the shape parameter (used internally by compiler passes).
/// For the public API with validation, use `try_reshape`.
pub(crate) fn reshape(src: Arc<Self>, new_shape: Arc<Self>) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::Reshape { src, new_shape }, dtype)
}
/// Permute dimensions (low-level, UOp-based constructor).
///
/// For the public API with validation, use `try_permute`.
pub(crate) fn permute(src: Arc<Self>, axes: Vec<usize>) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::Permute { src, axes }, dtype)
}
/// Expand (broadcast) dimensions (low-level, UOp-based constructor).
///
/// Takes a UOp for the shape parameter (used internally by compiler passes).
/// For the public API with validation, use `try_expand`.
pub(crate) fn expand(src: Arc<Self>, new_shape: Arc<Self>) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::Expand { src, new_shape }, dtype)
}
/// Pad tensor (low-level, UOp-based constructor).
///
/// Takes UOps for padding parameters (used internally by compiler passes).
/// For the public API with validation, use `try_pad`.
pub(crate) fn pad(src: Arc<Self>, begin_pads: Arc<Self>, end_pads: Arc<Self>) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::Pad { src, begin_pads, end_pads }, dtype)
}
/// Shrink (slice) tensor (low-level, UOp-based constructor).
///
/// Takes UOps for range parameters (used internally by compiler passes).
/// For the public API with validation, use `try_shrink`.
pub(crate) fn shrink(src: Arc<Self>, begins: Arc<Self>, ends: Arc<Self>) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::Shrink { src, begins, ends }, dtype)
}
/// Flip (reverse) axes (low-level, UOp-based constructor).
///
/// For the public API with validation, use `try_flip`.
pub(crate) fn flip(src: Arc<Self>, axes: Vec<bool>) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::Flip { src, axes }, dtype)
}
}
// Primary Movement Operation Constructors (with validation)
impl UOp {
/// Reshape with strict validation (fail-fast).
///
/// Validates:
/// - No negative dimensions in new_shape
/// - Product of input shape == product of output shape
pub fn try_reshape(self: &Arc<Self>, new_shape: &crate::shape::Shape) -> Result<Arc<Self>> {
use crate::error::ReshapeSizeMismatchSnafu;
use crate::shape::shape_to_uop;
use snafu::ensure;
// Validate product equality if source shape is known
if let Some(src_shape) = self.shape()? {
// Identity reshape: skip if shapes already match.
// Exclude BUFFER and CONST: the rangeify pipeline requires RESHAPE(BUFFER) to
// generate INDEX operations for ASSIGN targets. Bare BUFFER can't be indexed.
// Tinygrad avoids this because their BUFFER carries shape natively and their
// rangeify handles bare BUFFER → INDEX directly.
if src_shape.as_slice() == new_shape.as_slice()
&& !matches!(self.op(), crate::Op::Buffer { .. } | crate::Op::Param { .. } | crate::Op::Const(_))
{
return Ok(self.clone());
}
let src_product = crate::sint_prod(src_shape);
let dst_product = crate::sint_prod(new_shape);
// If both are concrete, validate equality
if let (Some(src_prod), Some(dst_prod)) = (src_product.as_const(), dst_product.as_const()) {
ensure!(src_prod == dst_prod, ReshapeSizeMismatchSnafu { input_size: src_prod, output_size: dst_prod });
}
// Symbolic products can't be validated at compile time
}
let shape_uop = shape_to_uop(new_shape);
let dtype = self.dtype();
Ok(Self::new(Op::Reshape { src: self.clone(), new_shape: shape_uop }, dtype))
}
/// Expand (broadcast) with strict validation.
///
/// Validates:
/// - Number of dimensions matches
/// - Each dimension either matches or src dimension is 1
pub fn try_expand(self: &Arc<Self>, new_shape: &crate::shape::Shape) -> Result<Arc<Self>> {
use crate::error::ExpandDimensionMismatchSnafu;
use crate::error::ExpandInvalidDimensionSnafu;
use crate::shape::shape_to_uop;
use snafu::ensure;
if let Some(src_shape) = self.shape()? {
// Check dimension count
ensure!(
src_shape.len() == new_shape.len(),
ExpandDimensionMismatchSnafu { input_dims: src_shape.len(), output_dims: new_shape.len() }
);
// Check each dimension can be expanded
for (dim_idx, (src_dim, new_dim)) in src_shape.iter().zip(new_shape.iter()).enumerate() {
// If both are concrete, validate expand rule
if let (Some(s), Some(ns)) = (src_dim.as_const(), new_dim.as_const()) {
ensure!(s == ns || s == 1, ExpandInvalidDimensionSnafu { dim: dim_idx, input: s, output: ns });
}
// Symbolic dimensions assumed compatible
}
// Identity expand: skip if shapes already match
if src_shape.as_slice() == new_shape.as_slice() {
return Ok(self.clone());
}
}
let shape_uop = shape_to_uop(new_shape);
let dtype = self.dtype();
Ok(Self::new(Op::Expand { src: self.clone(), new_shape: shape_uop }, dtype))
}
/// Permute with strict validation.
///
/// Validates:
/// - Permutation is valid (contains each index 0..n exactly once)
pub fn try_permute(self: &Arc<Self>, axes: Vec<usize>) -> Result<Arc<Self>> {
// Validate permutation if source shape is known
if let Some(src_shape) = self.shape()? {
Self::validate_permutation(&axes, src_shape.len())?;
// Identity permute: skip if axes is [0, 1, 2, ..., n-1]
if axes.iter().enumerate().all(|(i, &a)| a == i) {
return Ok(self.clone());
}
}
let dtype = self.dtype();
Ok(Self::new(Op::Permute { src: self.clone(), axes }, dtype))
}
/// Pad with strict validation.
///
/// Validates:
/// - Padding values are concrete (not symbolic)
/// - Number of padding pairs matches dimensions
pub fn try_pad(self: &Arc<Self>, padding: &[(crate::SInt, crate::SInt)]) -> Result<Arc<Self>> {
use crate::error::{PadDimensionMismatchSnafu, SymbolicPaddingUnsupportedSnafu};
use crate::shape::ranges_to_uops;
use snafu::ensure;
// Empty padding (scalar) → identity
if padding.is_empty() {
return Ok(self.clone());
}
// Check for symbolic padding values
for (begin, end) in padding {
ensure!(begin.is_const(), SymbolicPaddingUnsupportedSnafu);
ensure!(end.is_const(), SymbolicPaddingUnsupportedSnafu);
}
// All-zero padding: no-op
if padding.iter().all(|(b, e)| b.as_const() == Some(0) && e.as_const() == Some(0)) {
return Ok(self.clone());
}
if let Some(src_shape) = self.shape()? {
// Check dimension count
ensure!(
padding.len() == src_shape.len(),
PadDimensionMismatchSnafu { padding_dims: padding.len(), shape_dims: src_shape.len() }
);
}
let (begin_pads, end_pads) = ranges_to_uops(padding);
let dtype = self.dtype();
Ok(Self::new(Op::Pad { src: self.clone(), begin_pads, end_pads }, dtype))
}
/// Shrink (slice) with strict validation.
///
/// Validates:
/// - Range values are concrete (not symbolic)
/// - begin <= end for each dimension
/// - 0 <= begin, end <= dimension_size
pub fn try_shrink(self: &Arc<Self>, ranges: &[(crate::SInt, crate::SInt)]) -> Result<Arc<Self>> {
use crate::error::ShrinkBoundsViolationSnafu;
use crate::shape::ranges_to_uops;
use snafu::ensure;
// Empty ranges (scalar) → identity
if ranges.is_empty() {
return Ok(self.clone());
}
// Symbolic shrink ranges are allowed — the rangeify pipeline handles
// symbolic range ends. Only concrete ranges are bounds-checked.
if let Some(src_shape) = self.shape()? {
for (dim_idx, ((begin, end), dim_size)) in ranges.iter().zip(src_shape.iter()).enumerate() {
if let (Some(b), Some(e), Some(s)) = (begin.as_const(), end.as_const(), dim_size.as_const()) {
ensure!(
b <= e && e <= s,
ShrinkBoundsViolationSnafu { dim: dim_idx, begin: b, end: e, shape_size: s }
);
}
}
// Identity shrink: skip if all ranges span the full dimension
if ranges.iter().zip(src_shape.iter()).all(|((b, e), d)| b.as_const() == Some(0) && *e == *d) {
return Ok(self.clone());
}
}
let (begins, ends) = ranges_to_uops(ranges);
let dtype = self.dtype();
let result = Self::new(Op::Shrink { src: self.clone(), begins, ends }, dtype);
// Tinygrad (movement.py:128): return self if ret.shape == self.shape else ret
if result.shape().ok().flatten() == self.shape().ok().flatten() {
return Ok(self.clone());
}
Ok(result)
}
/// Flip with strict validation.
///
/// Validates:
/// - Flip specification length matches shape dimensions
pub fn try_flip(self: &Arc<Self>, axes: Vec<bool>) -> Result<Arc<Self>> {
// All-false flip: no-op
if !axes.iter().any(|&a| a) {
return Ok(self.clone());
}
if let Some(src_shape) = self.shape()? {
Self::validate_flip_axes(&axes, src_shape.len())?;
}
let dtype = self.dtype();
Ok(Self::new(Op::Flip { src: self.clone(), axes }, dtype))
}
/// Split tensor across multiple devices along specified axis.
///
/// Creates a multi-device tensor where each device holds a shard.
/// Use with MSTACK/MSELECT for distributed tensor operations.
pub fn multi(src: Arc<Self>, axis: usize) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::Multi { src, axis }, dtype)
}
}