morok_ir/shape.rs
1//! Shape utilities for UOps with symbolic shape support.
2//!
3//! This module provides shape-related types and functions following Tinygrad's approach:
4//! - Shapes can contain both concrete integers and symbolic UOp expressions
5//! - Shape inference with validation
6//! - Broadcasting utilities (explicit, non-automatic)
7//!
8//! Key differences from Tinygrad:
9//! - Uses Rust's type system (SInt enum vs Python Union)
10//! - Explicit Result types instead of exceptions
11//! - Non-automatic broadcasting (must be explicit)
12
13use std::sync::Arc;
14
15use smallvec::{SmallVec, smallvec};
16use snafu::ensure;
17
18use crate::{ConstValue, Op, Result, SInt, UOp, error::*};
19
20/// Shape type - sequence of symbolic integers.
21///
22/// Uses SmallVec with inline capacity of 4 to avoid heap allocation for
23/// common tensor ranks (1D-4D), which covers 99% of ML workloads.
24///
25/// Can contain mix of concrete and symbolic dimensions.
26pub type Shape = SmallVec<[SInt; 4]>;
27
28// =========================================================================
29// Shape Utilities
30// =========================================================================
31
32/// Check if shape is fully concrete (all dimensions are constants).
33///
34/// # Examples
35///
36/// ```rust
37/// # use morok_ir::{SInt, shape::is_static};
38/// # use smallvec::smallvec;
39/// let shape = smallvec![SInt::from(3), SInt::from(4), SInt::from(5)];
40/// assert!(is_static(&shape));
41/// ```
42pub fn is_static(shape: &Shape) -> bool {
43 shape.iter().all(|dim| dim.is_const())
44}
45
46/// Convert shape to concrete Vec<usize> if fully static, None otherwise.
47///
48/// # Examples
49///
50/// ```rust
51/// # use morok_ir::{SInt, shape::to_static};
52/// # use smallvec::smallvec;
53/// let shape = smallvec![SInt::from(3), SInt::from(4)];
54/// assert_eq!(to_static(&shape), Some(smallvec![3, 4]));
55/// ```
56pub fn to_static(shape: &Shape) -> Option<SmallVec<[usize; 4]>> {
57 is_static(shape).then_some(shape.iter().map(|dim| dim.as_const().unwrap()).collect())
58}
59
60// =========================================================================
61// Shape Validation
62// =========================================================================
63
64/// Validate that a shape specification is valid (all positive, no zeros).
65///
66/// # Errors
67/// Returns error if any dimension is negative or zero.
68///
69/// # Examples
70/// ```rust
71/// # use morok_ir::shape::validate_shape;
72/// let valid = vec![1, 2, 3];
73/// assert!(validate_shape(&valid).is_ok());
74/// let invalid = vec![1, -2, 3];
75/// assert!(validate_shape(&invalid).is_err());
76/// ```
77pub fn validate_shape(shape: &[isize]) -> Result<SmallVec<[usize; 4]>> {
78 ensure!(shape.iter().all(|&s| s >= 0), ReshapeNegativeDimensionSnafu { shape });
79 Ok(shape.iter().map(|&s| s as usize).collect())
80}
81
82/// Check if two shapes are equal.
83///
84/// Uses pointer equality for symbolic dimensions (consistent with hash consing).
85pub fn shapes_equal(lhs: &Shape, rhs: &Shape) -> bool {
86 lhs == rhs
87}
88
89/// Check if all shapes in a slice are equal.
90///
91/// # Examples
92///
93/// ```rust
94/// # use morok_ir::{SInt, shape::all_shapes_equal};
95/// # use smallvec::smallvec;
96/// let shape1 = smallvec![SInt::from(3), SInt::from(4)];
97/// let shape2 = smallvec![SInt::from(3), SInt::from(4)];
98/// let shape3 = smallvec![SInt::from(3), SInt::from(4)];
99/// assert!(all_shapes_equal(&[shape1, shape2, shape3]));
100/// ```
101pub fn all_shapes_equal(shapes: &[Shape]) -> bool {
102 (!shapes.is_empty()) && shapes.iter().all(|s| shapes_equal(s, &shapes[0]))
103}
104
105// =========================================================================
106// Broadcasting Utilities (Explicit, Non-automatic)
107// =========================================================================
108
109/// Align shapes to the left by prepending 1s.
110///
111/// Makes all shapes have the same number of dimensions by adding dimensions
112/// of size 1 on the left.
113///
114/// # Examples
115///
116/// ```rust
117/// # use morok_ir::{SInt, shape::align_shapes_left};
118/// # use smallvec::smallvec;
119/// let shape1 = smallvec![SInt::from(5)];
120/// let shape2 = smallvec![SInt::from(3), SInt::from(5)];
121/// let aligned = align_shapes_left(&[shape1, shape2]);
122/// assert_eq!(aligned.len(), 2);
123/// assert_eq!(aligned[0].len(), 2); // [1, 5]
124/// assert_eq!(aligned[1].len(), 2); // [3, 5]
125/// ```
126pub fn align_shapes_left(shapes: &[Shape]) -> Vec<Shape> {
127 if shapes.is_empty() {
128 return Vec::new();
129 }
130
131 let max_dims = shapes.iter().map(|s| s.len()).max().unwrap();
132
133 shapes
134 .iter()
135 .map(|shape| {
136 let padding = max_dims - shape.len();
137 let mut aligned = SmallVec::with_capacity(max_dims);
138 aligned.extend(std::iter::repeat_n(SInt::from(1), padding));
139 aligned.extend(shape.iter().cloned());
140 aligned
141 })
142 .collect()
143}
144
145/// Check if two shapes can be broadcast together (NumPy-style broadcasting).
146///
147/// Two shapes are broadcastable if:
148/// - They have the same number of dimensions
149/// - For each dimension, either the dimensions match or one of them is 1
150///
151/// # Examples
152///
153/// ```rust
154/// # use morok_ir::{SInt, shape::can_broadcast};
155/// # use smallvec::smallvec;
156/// let shape1 = smallvec![SInt::from(1), SInt::from(5)];
157/// let shape2 = smallvec![SInt::from(3), SInt::from(5)];
158/// assert!(can_broadcast(&shape1, &shape2));
159///
160/// let shape3 = smallvec![SInt::from(3), SInt::from(4)];
161/// assert!(!can_broadcast(&shape1, &shape3));
162/// ```
163pub fn can_broadcast(lhs: &Shape, rhs: &Shape) -> bool {
164 if lhs.len() != rhs.len() {
165 return false;
166 }
167
168 lhs.iter().zip(rhs.iter()).all(|(l, r)| {
169 // If both are concrete, check broadcast rule
170 if let (Some(lv), Some(rv)) = (l.as_const(), r.as_const()) {
171 lv == rv || lv == 1 || rv == 1
172 } else if l == r {
173 // Same symbolic expression
174 true
175 } else {
176 // Different symbolic expressions - conservatively assume compatible
177 // (runtime check would be needed)
178 true
179 }
180 })
181}
182
183/// Compute the broadcast result shape for two shapes.
184///
185/// Returns the shape that results from broadcasting the two input shapes.
186/// Both shapes must be broadcastable (checked with can_broadcast).
187///
188/// # Errors
189/// Returns error if shapes are not broadcastable.
190///
191/// # Examples
192///
193/// ```rust
194/// # use morok_ir::{SInt, shape::broadcast_shape};
195/// # use smallvec::smallvec;
196/// let shape1 = smallvec![SInt::from(1), SInt::from(5)];
197/// let shape2 = smallvec![SInt::from(3), SInt::from(5)];
198/// let result = broadcast_shape(&shape1, &shape2).unwrap();
199/// assert_eq!(result[0].as_const(), Some(3));
200/// assert_eq!(result[1].as_const(), Some(5));
201/// ```
202pub fn broadcast_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape> {
203 use crate::error::BroadcastShapeMismatchSnafu;
204 use snafu::ensure;
205
206 ensure!(lhs.len() == rhs.len(), BroadcastShapeMismatchSnafu { lhs: lhs.clone(), rhs: rhs.clone() });
207
208 let mut result = SmallVec::with_capacity(lhs.len());
209
210 for (l, r) in lhs.iter().zip(rhs.iter()) {
211 if l == r {
212 // Same dimension (concrete value or symbolic expression)
213 result.push(l.clone());
214 } else if let (Some(lv), Some(rv)) = (l.as_const(), r.as_const()) {
215 // Both concrete - apply broadcast rule
216 if lv == 1 {
217 result.push(r.clone());
218 } else if rv == 1 || lv == rv {
219 result.push(l.clone());
220 } else {
221 return BroadcastShapeMismatchSnafu { lhs: lhs.clone(), rhs: rhs.clone() }.fail();
222 }
223 } else {
224 // At least one is symbolic - use max (conservatively)
225 result.push(crate::sint_max(&[l.clone(), r.clone()]));
226 }
227 }
228
229 Ok(result)
230}
231
232/// Compute broadcast result for multiple shapes.
233///
234/// # Errors
235/// Returns error if any pair of shapes is not broadcastable.
236pub fn broadcast_shapes(shapes: &[Shape]) -> Result<Shape> {
237 if shapes.is_empty() {
238 return Ok(SmallVec::new());
239 }
240
241 // Align all shapes to same number of dimensions
242 let aligned = align_shapes_left(shapes);
243
244 // Successively broadcast pairs
245 let mut result = aligned[0].clone();
246 for shape in &aligned[1..] {
247 result = broadcast_shape(&result, shape)?;
248 }
249
250 Ok(result)
251}
252
253/// Convert shape to Vec<usize>, ensuring all dimensions are concrete.
254///
255/// This is a helper function to reduce boilerplate when converting shapes
256/// for operations that require concrete (non-symbolic) dimensions.
257///
258/// # Errors
259///
260/// Returns error if any dimension contains a symbolic (non-const) value.
261pub fn to_vec_usize(shape: &Shape) -> Result<Vec<usize>> {
262 shape
263 .iter()
264 .map(|dim| {
265 dim.as_const().ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "shape conversion".to_string() })
266 })
267 .collect()
268}
269
270/// Convert shape to Vec<isize>, ensuring all dimensions are concrete.
271///
272/// # Errors
273///
274/// Returns error if any dimension contains a symbolic (non-const) value.
275pub fn to_vec_isize(shape: &Shape) -> Result<Vec<isize>> {
276 shape
277 .iter()
278 .map(|dim| {
279 dim.as_const()
280 .map(|v| v as isize)
281 .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "shape conversion".to_string() })
282 })
283 .collect()
284}
285
286// =========================================================================
287// Movement Op Argument Extraction (marg equivalent)
288// =========================================================================
289
290/// Extract shape dimensions from a VECTORIZE or CONST UOp.
291///
292/// Following Tinygrad's `marg` pattern, this extracts concrete or symbolic
293/// dimensions from the UOp used to store shape information.
294///
295/// Returns None if the UOp is not in the expected format.
296fn extract_shape_from_uop(shape_uop: &Arc<UOp>) -> Option<Shape> {
297 match shape_uop.op() {
298 // VECTORIZE with Index-typed elements
299 Op::Vectorize { elements } => Some(elements.into_iter().cloned().map(SInt::from).collect()),
300
301 // Single CONST value (for 1D shapes)
302 Op::Const(const_hash) => match const_hash.0 {
303 ConstValue::Int(v) if v >= 0 => Some(smallvec![SInt::from(v as usize)]),
304 ConstValue::UInt(v) => Some(smallvec![SInt::from(v as usize)]),
305 _ => None,
306 },
307
308 // VConst for multiple concrete dimensions
309 Op::VConst { values } => {
310 let mut dims = SmallVec::with_capacity(values.len());
311 for val in values {
312 match val {
313 ConstValue::Int(v) if *v >= 0 => dims.push(SInt::from(*v as usize)),
314 ConstValue::UInt(v) => dims.push(SInt::from(*v as usize)),
315 _ => return None,
316 }
317 }
318 Some(dims)
319 }
320
321 _ => None,
322 }
323}
324
325/// Extract padding/shrink ranges from UOps.
326///
327/// Returns pairs of (begin, end) for each dimension.
328fn extract_ranges_from_uops(begins_uop: &Arc<UOp>, ends_uop: &Arc<UOp>) -> Option<Vec<(SInt, SInt)>> {
329 let begins = extract_shape_from_uop(begins_uop)?;
330 let ends = extract_shape_from_uop(ends_uop)?;
331
332 if begins.len() != ends.len() {
333 return None;
334 }
335
336 Some(begins.into_iter().zip(ends).collect())
337}
338
339/// Convert a Shape to a VECTORIZE UOp for use in movement operations.
340///
341/// This creates a UOp that encodes the shape dimensions, suitable for
342/// passing to Reshape, Expand, etc.
343///
344/// # Examples
345///
346/// ```rust
347/// # use morok_ir::{SInt, shape::shape_to_uop};
348/// # use morok_dtype::DType;
349/// # use smallvec::smallvec;
350/// let shape = smallvec![SInt::from(3), SInt::from(4), SInt::from(5)];
351/// let shape_uop = shape_to_uop(&shape);
352/// assert_eq!(shape_uop.dtype(), DType::Index.vec(3));
353///
354/// // Scalar (empty shape) is supported
355/// let scalar_shape: smallvec::SmallVec<[SInt; 4]> = smallvec![];
356/// let scalar_uop = shape_to_uop(&scalar_shape);
357/// // VConst with empty values represents scalar
358/// ```
359pub fn shape_to_uop(shape: &Shape) -> Arc<UOp> {
360 use morok_dtype::DType;
361 use smallvec::SmallVec;
362
363 // Empty shape = scalar: use VConst with empty values
364 // extract_shape_from_uop will decode this back to empty Shape
365 if shape.is_empty() {
366 return UOp::vconst(vec![], DType::Index);
367 }
368
369 let elements: SmallVec<[Arc<UOp>; 4]> = shape.iter().map(|dim| dim.to_uop(DType::Index)).collect();
370 UOp::vectorize(elements)
371}
372
373/// Convert a vector of (begin, end) ranges to two UOps for Pad/Shrink operations.
374///
375/// Returns (begins_uop, ends_uop) as VECTORIZE UOps.
376///
377/// # Panics
378/// Panics if `ranges` is empty; handle scalars at the callsite.
379pub fn ranges_to_uops(ranges: &[(SInt, SInt)]) -> (Arc<UOp>, Arc<UOp>) {
380 use morok_dtype::DType;
381 use smallvec::SmallVec;
382
383 assert!(!ranges.is_empty(), "ranges_to_uops does not support empty ranges (scalars); handle at callsite");
384
385 let begins: SmallVec<[Arc<UOp>; 4]> = ranges.iter().map(|(begin, _)| begin.to_uop(DType::Index)).collect();
386 let ends: SmallVec<[Arc<UOp>; 4]> = ranges.iter().map(|(_, end)| end.to_uop(DType::Index)).collect();
387
388 (UOp::vectorize(begins), UOp::vectorize(ends))
389}
390
391// =========================================================================
392// Shape Inference (Tinygrad-style)
393// =========================================================================
394
395/// Infer shape from a UOp's operation.
396///
397/// This is the core shape inference function, following Tinygrad's approach.
398/// Returns None for operations without a well-defined shape (control flow, etc.).
399///
400/// # Shape Inference Rules
401///
402/// - **Nullary ops** (Const, VConst): Return concrete shape
403/// - **Unary ops**: Preserve input shape
404/// - **Binary ops**: Validate inputs match, return common shape
405/// - **Ternary ops**: Return shape of value branches
406/// - **Movement ops**: Compute shape from operation arguments
407/// - **Reduce ops**: Compute reduced shape
408/// - **Late/control flow ops**: Return None
409pub fn infer_shape_from_op(uop: &UOp) -> crate::Result<Option<Shape>> {
410 Ok(match uop.op() {
411 // =====================================================================
412 // Nullary operations
413 // =====================================================================
414 Op::Const(_) => Some(SmallVec::new()), // Scalar has empty shape
415
416 Op::VConst { .. } => None,
417
418 Op::Unique(_) | Op::Device(_) | Op::Noop | Op::Invalid => None,
419
420 // DefineLocal: shape from PtrDType.size
421 Op::DefineLocal(_id) => {
422 use morok_dtype::DType;
423 match uop.dtype() {
424 DType::Ptr { size: Some(s), .. } => Some(smallvec![SInt::from(s)]),
425 DType::Ptr { size: None, .. } => {
426 let neg_one = UOp::index_const(-1);
427 Some(smallvec![SInt::from(neg_one)])
428 }
429 dtype => {
430 return crate::error::BufferDefRequiresPtrDTypeSnafu { op: "DefineLocal", dtype: dtype.clone() }
431 .fail();
432 }
433 }
434 }
435
436 // =====================================================================
437 // Unary operations - preserve shape
438 // =====================================================================
439 Op::Unary(_, input) => input.shape()?.cloned(),
440
441 // =====================================================================
442 // Binary operations - validate shapes match
443 // =====================================================================
444 Op::Binary(op, lhs, rhs) => {
445 match (lhs.shape()?, rhs.shape()?) {
446 (Some(lhs_shape), Some(rhs_shape)) if !shapes_equal(lhs_shape, rhs_shape) => {
447 // Both have shapes but they differ - ERROR
448 return BinaryShapeMismatchSnafu {
449 op: *op,
450 lhs: Box::new(lhs_shape.clone()),
451 rhs: Box::new(rhs_shape.clone()),
452 }
453 .fail();
454 }
455 (Some(s), _) | (_, Some(s)) => Some(s.clone()),
456 (None, None) => None, // Both shapeless - valid (RANGE + RANGE)
457 }
458 }
459
460 // =====================================================================
461 // Ternary operations
462 // =====================================================================
463 Op::Ternary(_, _condition, true_val, false_val) => {
464 // Result has shape of value branches - they must match
465 let true_shape = true_val.shape()?;
466 let false_shape = false_val.shape()?;
467
468 match (true_shape, false_shape) {
469 (Some(ts), Some(fs)) if !shapes_equal(ts, fs) => {
470 return crate::error::TernaryBranchShapeMismatchSnafu {
471 true_branch: Box::new(ts.clone()),
472 false_branch: Box::new(fs.clone()),
473 }
474 .fail();
475 }
476 (Some(s), _) | (_, Some(s)) => Some(s.clone()),
477 (None, None) => None,
478 }
479 }
480
481 // =====================================================================
482 // Type operations
483 // =====================================================================
484 Op::Cast { src, .. } => src.shape()?.cloned(),
485 // BitCast: byte-reinterpretation. Same itemsize → same shape.
486 // Different itemsize → adjust last dimension (Tinygrad tensor.py:3549-3568).
487 // BitCast: byte-reinterpretation (Tinygrad ops.py:240-245).
488 // Same itemsize → same shape. Different itemsize → adjust last dimension.
489 Op::BitCast { src, dtype } => {
490 let src_shape = src.shape()?;
491 match src_shape {
492 Some(shape) if !shape.is_empty() => {
493 let src_bytes = src.dtype().bytes();
494 let dst_bytes = dtype.bytes();
495 if src_bytes == dst_bytes {
496 Some(shape.clone())
497 } else {
498 // Adjust last dimension: (last * src_bytes) / dst_bytes
499 let mut new_shape = shape.clone();
500 let last = new_shape.last().unwrap().clone();
501 let new_last = (last * SInt::Const(src_bytes)) / SInt::Const(dst_bytes);
502 *new_shape.last_mut().unwrap() = new_last;
503 Some(new_shape)
504 }
505 }
506 other => other.cloned(),
507 }
508 }
509
510 // =====================================================================
511 // Vector operations (kernel-level, no tensor shape)
512 // =====================================================================
513 Op::Vectorize { .. } => None,
514
515 Op::Gep { .. } => Some(SmallVec::new()), // Extract element from vector -> scalar
516
517 // =====================================================================
518 // Movement operations
519 // =====================================================================
520 Op::Reshape { new_shape, .. } => {
521 // Extract shape from VECTORIZE/CONST UOp
522 extract_shape_from_uop(new_shape)
523 }
524
525 Op::Permute { axes, src } => {
526 let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
527 // Reorder dimensions according to permutation
528 Some(axes.iter().map(|&i| src_shape[i].clone()).collect())
529 }
530
531 Op::Expand { new_shape, .. } => {
532 // Extract shape from VECTORIZE/CONST UOp
533 extract_shape_from_uop(new_shape)
534 }
535
536 Op::Pad { src, begin_pads, end_pads } => {
537 let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
538 let ranges = extract_ranges_from_uops(begin_pads, end_pads).ok_or_else(|| crate::Error::VoidTypeInOp)?;
539
540 if src_shape.len() != ranges.len() {
541 return Ok(None);
542 }
543
544 // New shape = src_shape + begin_pads + end_pads for each dimension
545 Some(
546 src_shape
547 .iter()
548 .zip(ranges.iter())
549 .map(|(dim, (begin, end))| Ok(dim + begin + end))
550 .collect::<crate::Result<Shape>>()?,
551 )
552 }
553
554 Op::Shrink { src, begins, ends } => {
555 let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
556 let ranges = extract_ranges_from_uops(begins, ends).ok_or_else(|| crate::Error::VoidTypeInOp)?;
557
558 if src_shape.len() != ranges.len() {
559 return Ok(None);
560 }
561
562 // New shape = end - begin for each dimension
563 Some(
564 ranges
565 .iter()
566 .zip(src_shape.iter())
567 .map(|((begin, end), dim)| {
568 // Identity range (0, dim_size) → preserve dim (supports symbolic batch)
569 if begin.as_const() == Some(0) && end == dim {
570 return Ok(dim.clone());
571 }
572 // end - begin (works for both concrete and symbolic)
573 Ok(end - begin)
574 })
575 .collect::<crate::Result<Shape>>()?,
576 )
577 }
578
579 Op::Flip { src, .. } => {
580 // Flip preserves shape
581 src.shape()?.cloned()
582 }
583
584 Op::Multi { src, .. } => {
585 // Multi scales the specified axis by device count
586 // TODO: Need device count from somewhere - for now preserve shape
587 // Tinygrad: tuple(s*len(self.device) if a == self.axis else s for a,s in enumerate(ps))
588 src.shape()?.cloned()
589 }
590
591 // =====================================================================
592 // Reduction operations
593 // =====================================================================
594 Op::ReduceAxis { axes, src, .. } => {
595 let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
596 // Set reduced axes to 1 (don't remove them - matches Tinygrad)
597 Some(
598 src_shape
599 .iter()
600 .enumerate()
601 .map(|(i, dim)| if axes.contains(&i) { SInt::from(1) } else { dim.clone() })
602 .collect(),
603 )
604 }
605
606 Op::Reduce { .. } => {
607 // Reduce with ranges - context dependent
608 None
609 }
610
611 Op::AllReduce { src, .. } => {
612 // AllReduce preserves shape
613 src.shape()?.cloned()
614 }
615
616 // =====================================================================
617 // Buffer and memory operations - shape depends on buffer
618 // =====================================================================
619 // Buffer operations have shape (size,)
620 Op::Buffer { size, .. } => Some(smallvec![SInt::from(*size)]),
621 Op::Param { size, .. } => Some(smallvec![SInt::from(*size)]),
622 Op::BufferView { size, .. } => Some(smallvec![SInt::from(*size)]),
623
624 // Passthrough operations
625 Op::Copy { src, .. } => src.shape()?.cloned(),
626 Op::MStack { buffers } => match buffers.first() {
627 Some(b) => b.shape()?.cloned(),
628 None => None,
629 },
630
631 // BUFFERIZE shape is derived from ranges (like Tinygrad)
632 // Shape = [end_0, end_1, ...] where end_i is the size of each range
633 Op::Bufferize { ranges, .. } => {
634 let mut dims: Shape = SmallVec::new();
635 for range in ranges.iter() {
636 match range.op() {
637 // Range: shape dim = end (the upper bound)
638 Op::Range { end, .. } => {
639 // Try to get constant value from end
640 if let Op::Const(val) = end.op() {
641 match val.0 {
642 ConstValue::Int(v) if v >= 0 => {
643 dims.push(SInt::Const(v as usize));
644 continue;
645 }
646 ConstValue::UInt(v) => {
647 dims.push(SInt::Const(v as usize));
648 continue;
649 }
650 _ => {}
651 }
652 }
653 // Fall back to symbolic
654 dims.push(SInt::Symbolic(end.clone()));
655 }
656 // CONST range (already dead axis) has size from vmax+1
657 Op::Const(val) => {
658 match val.0 {
659 ConstValue::Int(v) if v >= 0 => {
660 dims.push(SInt::Const((v + 1) as usize)); // vmax+1 for shape
661 }
662 ConstValue::UInt(v) => {
663 dims.push(SInt::Const((v + 1) as usize)); // vmax+1 for shape
664 }
665 _ => return Ok(None), // Can't determine shape
666 }
667 }
668 // Other range types: use symbolic
669 _ => {
670 dims.push(SInt::Symbolic(range.clone()));
671 }
672 }
673 }
674 Some(dims)
675 }
676
677 // These have no shape
678 Op::Index { .. } | Op::Load { .. } | Op::Store { .. } => None,
679
680 // =====================================================================
681 // Control flow - no static shape
682 // =====================================================================
683 Op::If { .. } | Op::EndIf { .. } | Op::Range { .. } | Op::Barrier { .. } => None,
684
685 // End passes through the computation shape
686 Op::End { computation, .. } => computation.shape()?.cloned(),
687
688 // =====================================================================
689 // Special operations
690 // =====================================================================
691 // MSelect passes through buffer shape
692 Op::MSelect { buffer, .. } => buffer.shape()?.cloned(),
693
694 Op::Special { .. } => None,
695
696 Op::DefineVar { .. } => Some(SmallVec::new()), // Variable is scalar
697
698 Op::Bind { value, .. } => value.shape()?.cloned(),
699
700 Op::DefineReg { size, .. } => Some(smallvec![SInt::from(*size)]),
701
702 // =====================================================================
703 // Advanced operations
704 // =====================================================================
705 Op::Wmma { .. } | Op::Contract { .. } | Op::Unroll { .. } => {
706 // These require more complex shape computation
707 None
708 }
709
710 Op::Kernel { .. } => None,
711
712 Op::Assign { target, .. } => target.shape()?.cloned(),
713
714 Op::Detach { src } | Op::Contiguous { src, .. } | Op::ContiguousBackward { src } | Op::Precast { src } => {
715 src.shape()?.cloned()
716 }
717
718 Op::After { passthrough, .. } => passthrough.shape()?.cloned(),
719
720 Op::Custom { .. } | Op::CustomI { .. } => None,
721
722 // Graph organization operations have no shape
723 Op::Sink { .. } => None,
724 Op::Group { sources } => match sources.first() {
725 Some(src) => src.shape()?.cloned(),
726 None => None,
727 },
728
729 // PointerIndex is a scalar index operation (no shape)
730 Op::PointerIndex { .. } => Some(smallvec![]),
731
732 // Cat and PtrCat are kernel-level vector ops (no tensor shape)
733 Op::Cat { .. } | Op::PtrCat { .. } => None,
734 })
735}