Skip to main content

morok_ir/
shape.rs

1//! Shape utilities for UOps with symbolic shape support.
2//!
3//! This module provides shape-related types and functions following Tinygrad's approach:
4//! - Shapes can contain both concrete integers and symbolic UOp expressions
5//! - Shape inference with validation
6//! - Broadcasting utilities (explicit, non-automatic)
7//!
8//! Key differences from Tinygrad:
9//! - Uses Rust's type system (SInt enum vs Python Union)
10//! - Explicit Result types instead of exceptions
11//! - Non-automatic broadcasting (must be explicit)
12
13use std::sync::Arc;
14
15use smallvec::{SmallVec, smallvec};
16use snafu::ensure;
17
18use crate::{ConstValue, Op, Result, SInt, UOp, error::*};
19
20/// Shape type - sequence of symbolic integers.
21///
22/// Uses SmallVec with inline capacity of 4 to avoid heap allocation for
23/// common tensor ranks (1D-4D), which covers 99% of ML workloads.
24///
25/// Can contain mix of concrete and symbolic dimensions.
26pub type Shape = SmallVec<[SInt; 4]>;
27
28// =========================================================================
29// Shape Utilities
30// =========================================================================
31
32/// Check if shape is fully concrete (all dimensions are constants).
33///
34/// # Examples
35///
36/// ```rust
37/// # use morok_ir::{SInt, shape::is_static};
38/// # use smallvec::smallvec;
39/// let shape = smallvec![SInt::from(3), SInt::from(4), SInt::from(5)];
40/// assert!(is_static(&shape));
41/// ```
42pub fn is_static(shape: &Shape) -> bool {
43    shape.iter().all(|dim| dim.is_const())
44}
45
46/// Convert shape to concrete Vec<usize> if fully static, None otherwise.
47///
48/// # Examples
49///
50/// ```rust
51/// # use morok_ir::{SInt, shape::to_static};
52/// # use smallvec::smallvec;
53/// let shape = smallvec![SInt::from(3), SInt::from(4)];
54/// assert_eq!(to_static(&shape), Some(smallvec![3, 4]));
55/// ```
56pub fn to_static(shape: &Shape) -> Option<SmallVec<[usize; 4]>> {
57    is_static(shape).then_some(shape.iter().map(|dim| dim.as_const().unwrap()).collect())
58}
59
60// =========================================================================
61// Shape Validation
62// =========================================================================
63
64/// Validate that a shape specification is valid (all positive, no zeros).
65///
66/// # Errors
67/// Returns error if any dimension is negative or zero.
68///
69/// # Examples
70/// ```rust
71/// # use morok_ir::shape::validate_shape;
72/// let valid = vec![1, 2, 3];
73/// assert!(validate_shape(&valid).is_ok());
74/// let invalid = vec![1, -2, 3];
75/// assert!(validate_shape(&invalid).is_err());
76/// ```
77pub fn validate_shape(shape: &[isize]) -> Result<SmallVec<[usize; 4]>> {
78    ensure!(shape.iter().all(|&s| s >= 0), ReshapeNegativeDimensionSnafu { shape });
79    Ok(shape.iter().map(|&s| s as usize).collect())
80}
81
82/// Check if two shapes are equal.
83///
84/// Uses pointer equality for symbolic dimensions (consistent with hash consing).
85pub fn shapes_equal(lhs: &Shape, rhs: &Shape) -> bool {
86    lhs == rhs
87}
88
89/// Check if all shapes in a slice are equal.
90///
91/// # Examples
92///
93/// ```rust
94/// # use morok_ir::{SInt, shape::all_shapes_equal};
95/// # use smallvec::smallvec;
96/// let shape1 = smallvec![SInt::from(3), SInt::from(4)];
97/// let shape2 = smallvec![SInt::from(3), SInt::from(4)];
98/// let shape3 = smallvec![SInt::from(3), SInt::from(4)];
99/// assert!(all_shapes_equal(&[shape1, shape2, shape3]));
100/// ```
101pub fn all_shapes_equal(shapes: &[Shape]) -> bool {
102    (!shapes.is_empty()) && shapes.iter().all(|s| shapes_equal(s, &shapes[0]))
103}
104
105// =========================================================================
106// Broadcasting Utilities (Explicit, Non-automatic)
107// =========================================================================
108
109/// Align shapes to the left by prepending 1s.
110///
111/// Makes all shapes have the same number of dimensions by adding dimensions
112/// of size 1 on the left.
113///
114/// # Examples
115///
116/// ```rust
117/// # use morok_ir::{SInt, shape::align_shapes_left};
118/// # use smallvec::smallvec;
119/// let shape1 = smallvec![SInt::from(5)];
120/// let shape2 = smallvec![SInt::from(3), SInt::from(5)];
121/// let aligned = align_shapes_left(&[shape1, shape2]);
122/// assert_eq!(aligned.len(), 2);
123/// assert_eq!(aligned[0].len(), 2); // [1, 5]
124/// assert_eq!(aligned[1].len(), 2); // [3, 5]
125/// ```
126pub fn align_shapes_left(shapes: &[Shape]) -> Vec<Shape> {
127    if shapes.is_empty() {
128        return Vec::new();
129    }
130
131    let max_dims = shapes.iter().map(|s| s.len()).max().unwrap();
132
133    shapes
134        .iter()
135        .map(|shape| {
136            let padding = max_dims - shape.len();
137            let mut aligned = SmallVec::with_capacity(max_dims);
138            aligned.extend(std::iter::repeat_n(SInt::from(1), padding));
139            aligned.extend(shape.iter().cloned());
140            aligned
141        })
142        .collect()
143}
144
145/// Check if two shapes can be broadcast together (NumPy-style broadcasting).
146///
147/// Two shapes are broadcastable if:
148/// - They have the same number of dimensions
149/// - For each dimension, either the dimensions match or one of them is 1
150///
151/// # Examples
152///
153/// ```rust
154/// # use morok_ir::{SInt, shape::can_broadcast};
155/// # use smallvec::smallvec;
156/// let shape1 = smallvec![SInt::from(1), SInt::from(5)];
157/// let shape2 = smallvec![SInt::from(3), SInt::from(5)];
158/// assert!(can_broadcast(&shape1, &shape2));
159///
160/// let shape3 = smallvec![SInt::from(3), SInt::from(4)];
161/// assert!(!can_broadcast(&shape1, &shape3));
162/// ```
163pub fn can_broadcast(lhs: &Shape, rhs: &Shape) -> bool {
164    if lhs.len() != rhs.len() {
165        return false;
166    }
167
168    lhs.iter().zip(rhs.iter()).all(|(l, r)| {
169        // If both are concrete, check broadcast rule
170        if let (Some(lv), Some(rv)) = (l.as_const(), r.as_const()) {
171            lv == rv || lv == 1 || rv == 1
172        } else if l == r {
173            // Same symbolic expression
174            true
175        } else {
176            // Different symbolic expressions - conservatively assume compatible
177            // (runtime check would be needed)
178            true
179        }
180    })
181}
182
183/// Compute the broadcast result shape for two shapes.
184///
185/// Returns the shape that results from broadcasting the two input shapes.
186/// Both shapes must be broadcastable (checked with can_broadcast).
187///
188/// # Errors
189/// Returns error if shapes are not broadcastable.
190///
191/// # Examples
192///
193/// ```rust
194/// # use morok_ir::{SInt, shape::broadcast_shape};
195/// # use smallvec::smallvec;
196/// let shape1 = smallvec![SInt::from(1), SInt::from(5)];
197/// let shape2 = smallvec![SInt::from(3), SInt::from(5)];
198/// let result = broadcast_shape(&shape1, &shape2).unwrap();
199/// assert_eq!(result[0].as_const(), Some(3));
200/// assert_eq!(result[1].as_const(), Some(5));
201/// ```
202pub fn broadcast_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape> {
203    use crate::error::BroadcastShapeMismatchSnafu;
204    use snafu::ensure;
205
206    ensure!(lhs.len() == rhs.len(), BroadcastShapeMismatchSnafu { lhs: lhs.clone(), rhs: rhs.clone() });
207
208    let mut result = SmallVec::with_capacity(lhs.len());
209
210    for (l, r) in lhs.iter().zip(rhs.iter()) {
211        if l == r {
212            // Same dimension (concrete value or symbolic expression)
213            result.push(l.clone());
214        } else if let (Some(lv), Some(rv)) = (l.as_const(), r.as_const()) {
215            // Both concrete - apply broadcast rule
216            if lv == 1 {
217                result.push(r.clone());
218            } else if rv == 1 || lv == rv {
219                result.push(l.clone());
220            } else {
221                return BroadcastShapeMismatchSnafu { lhs: lhs.clone(), rhs: rhs.clone() }.fail();
222            }
223        } else {
224            // At least one is symbolic - use max (conservatively)
225            result.push(crate::sint_max(&[l.clone(), r.clone()]));
226        }
227    }
228
229    Ok(result)
230}
231
232/// Compute broadcast result for multiple shapes.
233///
234/// # Errors
235/// Returns error if any pair of shapes is not broadcastable.
236pub fn broadcast_shapes(shapes: &[Shape]) -> Result<Shape> {
237    if shapes.is_empty() {
238        return Ok(SmallVec::new());
239    }
240
241    // Align all shapes to same number of dimensions
242    let aligned = align_shapes_left(shapes);
243
244    // Successively broadcast pairs
245    let mut result = aligned[0].clone();
246    for shape in &aligned[1..] {
247        result = broadcast_shape(&result, shape)?;
248    }
249
250    Ok(result)
251}
252
253/// Convert shape to Vec<usize>, ensuring all dimensions are concrete.
254///
255/// This is a helper function to reduce boilerplate when converting shapes
256/// for operations that require concrete (non-symbolic) dimensions.
257///
258/// # Errors
259///
260/// Returns error if any dimension contains a symbolic (non-const) value.
261pub fn to_vec_usize(shape: &Shape) -> Result<Vec<usize>> {
262    shape
263        .iter()
264        .map(|dim| {
265            dim.as_const().ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "shape conversion".to_string() })
266        })
267        .collect()
268}
269
270/// Convert shape to Vec<isize>, ensuring all dimensions are concrete.
271///
272/// # Errors
273///
274/// Returns error if any dimension contains a symbolic (non-const) value.
275pub fn to_vec_isize(shape: &Shape) -> Result<Vec<isize>> {
276    shape
277        .iter()
278        .map(|dim| {
279            dim.as_const()
280                .map(|v| v as isize)
281                .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "shape conversion".to_string() })
282        })
283        .collect()
284}
285
286// =========================================================================
287// Movement Op Argument Extraction (marg equivalent)
288// =========================================================================
289
290/// Extract shape dimensions from a VECTORIZE or CONST UOp.
291///
292/// Following Tinygrad's `marg` pattern, this extracts concrete or symbolic
293/// dimensions from the UOp used to store shape information.
294///
295/// Returns None if the UOp is not in the expected format.
296fn extract_shape_from_uop(shape_uop: &Arc<UOp>) -> Option<Shape> {
297    match shape_uop.op() {
298        // VECTORIZE with Index-typed elements
299        Op::Vectorize { elements } => Some(elements.into_iter().cloned().map(SInt::from).collect()),
300
301        // Single CONST value (for 1D shapes)
302        Op::Const(const_hash) => match const_hash.0 {
303            ConstValue::Int(v) if v >= 0 => Some(smallvec![SInt::from(v as usize)]),
304            ConstValue::UInt(v) => Some(smallvec![SInt::from(v as usize)]),
305            _ => None,
306        },
307
308        // VConst for multiple concrete dimensions
309        Op::VConst { values } => {
310            let mut dims = SmallVec::with_capacity(values.len());
311            for val in values {
312                match val {
313                    ConstValue::Int(v) if *v >= 0 => dims.push(SInt::from(*v as usize)),
314                    ConstValue::UInt(v) => dims.push(SInt::from(*v as usize)),
315                    _ => return None,
316                }
317            }
318            Some(dims)
319        }
320
321        _ => None,
322    }
323}
324
325/// Extract padding/shrink ranges from UOps.
326///
327/// Returns pairs of (begin, end) for each dimension.
328fn extract_ranges_from_uops(begins_uop: &Arc<UOp>, ends_uop: &Arc<UOp>) -> Option<Vec<(SInt, SInt)>> {
329    let begins = extract_shape_from_uop(begins_uop)?;
330    let ends = extract_shape_from_uop(ends_uop)?;
331
332    if begins.len() != ends.len() {
333        return None;
334    }
335
336    Some(begins.into_iter().zip(ends).collect())
337}
338
339/// Convert a Shape to a VECTORIZE UOp for use in movement operations.
340///
341/// This creates a UOp that encodes the shape dimensions, suitable for
342/// passing to Reshape, Expand, etc.
343///
344/// # Examples
345///
346/// ```rust
347/// # use morok_ir::{SInt, shape::shape_to_uop};
348/// # use morok_dtype::DType;
349/// # use smallvec::smallvec;
350/// let shape = smallvec![SInt::from(3), SInt::from(4), SInt::from(5)];
351/// let shape_uop = shape_to_uop(&shape);
352/// assert_eq!(shape_uop.dtype(), DType::Index.vec(3));
353///
354/// // Scalar (empty shape) is supported
355/// let scalar_shape: smallvec::SmallVec<[SInt; 4]> = smallvec![];
356/// let scalar_uop = shape_to_uop(&scalar_shape);
357/// // VConst with empty values represents scalar
358/// ```
359pub fn shape_to_uop(shape: &Shape) -> Arc<UOp> {
360    use morok_dtype::DType;
361    use smallvec::SmallVec;
362
363    // Empty shape = scalar: use VConst with empty values
364    // extract_shape_from_uop will decode this back to empty Shape
365    if shape.is_empty() {
366        return UOp::vconst(vec![], DType::Index);
367    }
368
369    let elements: SmallVec<[Arc<UOp>; 4]> = shape.iter().map(|dim| dim.to_uop(DType::Index)).collect();
370    UOp::vectorize(elements)
371}
372
373/// Convert a vector of (begin, end) ranges to two UOps for Pad/Shrink operations.
374///
375/// Returns (begins_uop, ends_uop) as VECTORIZE UOps.
376///
377/// # Panics
378/// Panics if `ranges` is empty; handle scalars at the callsite.
379pub fn ranges_to_uops(ranges: &[(SInt, SInt)]) -> (Arc<UOp>, Arc<UOp>) {
380    use morok_dtype::DType;
381    use smallvec::SmallVec;
382
383    assert!(!ranges.is_empty(), "ranges_to_uops does not support empty ranges (scalars); handle at callsite");
384
385    let begins: SmallVec<[Arc<UOp>; 4]> = ranges.iter().map(|(begin, _)| begin.to_uop(DType::Index)).collect();
386    let ends: SmallVec<[Arc<UOp>; 4]> = ranges.iter().map(|(_, end)| end.to_uop(DType::Index)).collect();
387
388    (UOp::vectorize(begins), UOp::vectorize(ends))
389}
390
391// =========================================================================
392// Shape Inference (Tinygrad-style)
393// =========================================================================
394
395/// Infer shape from a UOp's operation.
396///
397/// This is the core shape inference function, following Tinygrad's approach.
398/// Returns None for operations without a well-defined shape (control flow, etc.).
399///
400/// # Shape Inference Rules
401///
402/// - **Nullary ops** (Const, VConst): Return concrete shape
403/// - **Unary ops**: Preserve input shape
404/// - **Binary ops**: Validate inputs match, return common shape
405/// - **Ternary ops**: Return shape of value branches
406/// - **Movement ops**: Compute shape from operation arguments
407/// - **Reduce ops**: Compute reduced shape
408/// - **Late/control flow ops**: Return None
409pub fn infer_shape_from_op(uop: &UOp) -> crate::Result<Option<Shape>> {
410    Ok(match uop.op() {
411        // =====================================================================
412        // Nullary operations
413        // =====================================================================
414        Op::Const(_) => Some(SmallVec::new()), // Scalar has empty shape
415
416        Op::VConst { .. } => None,
417
418        Op::Unique(_) | Op::Device(_) | Op::Noop | Op::Invalid => None,
419
420        // DefineLocal: shape from PtrDType.size
421        Op::DefineLocal(_id) => {
422            use morok_dtype::DType;
423            match uop.dtype() {
424                DType::Ptr { size: Some(s), .. } => Some(smallvec![SInt::from(s)]),
425                DType::Ptr { size: None, .. } => {
426                    let neg_one = UOp::index_const(-1);
427                    Some(smallvec![SInt::from(neg_one)])
428                }
429                dtype => {
430                    return crate::error::BufferDefRequiresPtrDTypeSnafu { op: "DefineLocal", dtype: dtype.clone() }
431                        .fail();
432                }
433            }
434        }
435
436        // =====================================================================
437        // Unary operations - preserve shape
438        // =====================================================================
439        Op::Unary(_, input) => input.shape()?.cloned(),
440
441        // =====================================================================
442        // Binary operations - validate shapes match
443        // =====================================================================
444        Op::Binary(op, lhs, rhs) => {
445            match (lhs.shape()?, rhs.shape()?) {
446                (Some(lhs_shape), Some(rhs_shape)) if !shapes_equal(lhs_shape, rhs_shape) => {
447                    // Both have shapes but they differ - ERROR
448                    return BinaryShapeMismatchSnafu {
449                        op: *op,
450                        lhs: Box::new(lhs_shape.clone()),
451                        rhs: Box::new(rhs_shape.clone()),
452                    }
453                    .fail();
454                }
455                (Some(s), _) | (_, Some(s)) => Some(s.clone()),
456                (None, None) => None, // Both shapeless - valid (RANGE + RANGE)
457            }
458        }
459
460        // =====================================================================
461        // Ternary operations
462        // =====================================================================
463        Op::Ternary(_, _condition, true_val, false_val) => {
464            // Result has shape of value branches - they must match
465            let true_shape = true_val.shape()?;
466            let false_shape = false_val.shape()?;
467
468            match (true_shape, false_shape) {
469                (Some(ts), Some(fs)) if !shapes_equal(ts, fs) => {
470                    return crate::error::TernaryBranchShapeMismatchSnafu {
471                        true_branch: Box::new(ts.clone()),
472                        false_branch: Box::new(fs.clone()),
473                    }
474                    .fail();
475                }
476                (Some(s), _) | (_, Some(s)) => Some(s.clone()),
477                (None, None) => None,
478            }
479        }
480
481        // =====================================================================
482        // Type operations
483        // =====================================================================
484        Op::Cast { src, .. } => src.shape()?.cloned(),
485        // BitCast: byte-reinterpretation. Same itemsize → same shape.
486        // Different itemsize → adjust last dimension (Tinygrad tensor.py:3549-3568).
487        // BitCast: byte-reinterpretation (Tinygrad ops.py:240-245).
488        // Same itemsize → same shape. Different itemsize → adjust last dimension.
489        Op::BitCast { src, dtype } => {
490            let src_shape = src.shape()?;
491            match src_shape {
492                Some(shape) if !shape.is_empty() => {
493                    let src_bytes = src.dtype().bytes();
494                    let dst_bytes = dtype.bytes();
495                    if src_bytes == dst_bytes {
496                        Some(shape.clone())
497                    } else {
498                        // Adjust last dimension: (last * src_bytes) / dst_bytes
499                        let mut new_shape = shape.clone();
500                        let last = new_shape.last().unwrap().clone();
501                        let new_last = (last * SInt::Const(src_bytes)) / SInt::Const(dst_bytes);
502                        *new_shape.last_mut().unwrap() = new_last;
503                        Some(new_shape)
504                    }
505                }
506                other => other.cloned(),
507            }
508        }
509
510        // =====================================================================
511        // Vector operations (kernel-level, no tensor shape)
512        // =====================================================================
513        Op::Vectorize { .. } => None,
514
515        Op::Gep { .. } => Some(SmallVec::new()), // Extract element from vector -> scalar
516
517        // =====================================================================
518        // Movement operations
519        // =====================================================================
520        Op::Reshape { new_shape, .. } => {
521            // Extract shape from VECTORIZE/CONST UOp
522            extract_shape_from_uop(new_shape)
523        }
524
525        Op::Permute { axes, src } => {
526            let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
527            // Reorder dimensions according to permutation
528            Some(axes.iter().map(|&i| src_shape[i].clone()).collect())
529        }
530
531        Op::Expand { new_shape, .. } => {
532            // Extract shape from VECTORIZE/CONST UOp
533            extract_shape_from_uop(new_shape)
534        }
535
536        Op::Pad { src, begin_pads, end_pads } => {
537            let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
538            let ranges = extract_ranges_from_uops(begin_pads, end_pads).ok_or_else(|| crate::Error::VoidTypeInOp)?;
539
540            if src_shape.len() != ranges.len() {
541                return Ok(None);
542            }
543
544            // New shape = src_shape + begin_pads + end_pads for each dimension
545            Some(
546                src_shape
547                    .iter()
548                    .zip(ranges.iter())
549                    .map(|(dim, (begin, end))| Ok(dim + begin + end))
550                    .collect::<crate::Result<Shape>>()?,
551            )
552        }
553
554        Op::Shrink { src, begins, ends } => {
555            let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
556            let ranges = extract_ranges_from_uops(begins, ends).ok_or_else(|| crate::Error::VoidTypeInOp)?;
557
558            if src_shape.len() != ranges.len() {
559                return Ok(None);
560            }
561
562            // New shape = end - begin for each dimension
563            Some(
564                ranges
565                    .iter()
566                    .zip(src_shape.iter())
567                    .map(|((begin, end), dim)| {
568                        // Identity range (0, dim_size) → preserve dim (supports symbolic batch)
569                        if begin.as_const() == Some(0) && end == dim {
570                            return Ok(dim.clone());
571                        }
572                        // end - begin (works for both concrete and symbolic)
573                        Ok(end - begin)
574                    })
575                    .collect::<crate::Result<Shape>>()?,
576            )
577        }
578
579        Op::Flip { src, .. } => {
580            // Flip preserves shape
581            src.shape()?.cloned()
582        }
583
584        Op::Multi { src, .. } => {
585            // Multi scales the specified axis by device count
586            // TODO: Need device count from somewhere - for now preserve shape
587            // Tinygrad: tuple(s*len(self.device) if a == self.axis else s for a,s in enumerate(ps))
588            src.shape()?.cloned()
589        }
590
591        // =====================================================================
592        // Reduction operations
593        // =====================================================================
594        Op::ReduceAxis { axes, src, .. } => {
595            let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
596            // Set reduced axes to 1 (don't remove them - matches Tinygrad)
597            Some(
598                src_shape
599                    .iter()
600                    .enumerate()
601                    .map(|(i, dim)| if axes.contains(&i) { SInt::from(1) } else { dim.clone() })
602                    .collect(),
603            )
604        }
605
606        Op::Reduce { .. } => {
607            // Reduce with ranges - context dependent
608            None
609        }
610
611        Op::AllReduce { src, .. } => {
612            // AllReduce preserves shape
613            src.shape()?.cloned()
614        }
615
616        // =====================================================================
617        // Buffer and memory operations - shape depends on buffer
618        // =====================================================================
619        // Buffer operations have shape (size,)
620        Op::Buffer { size, .. } => Some(smallvec![SInt::from(*size)]),
621        Op::Param { size, .. } => Some(smallvec![SInt::from(*size)]),
622        Op::BufferView { size, .. } => Some(smallvec![SInt::from(*size)]),
623
624        // Passthrough operations
625        Op::Copy { src, .. } => src.shape()?.cloned(),
626        Op::MStack { buffers } => match buffers.first() {
627            Some(b) => b.shape()?.cloned(),
628            None => None,
629        },
630
631        // BUFFERIZE shape is derived from ranges (like Tinygrad)
632        // Shape = [end_0, end_1, ...] where end_i is the size of each range
633        Op::Bufferize { ranges, .. } => {
634            let mut dims: Shape = SmallVec::new();
635            for range in ranges.iter() {
636                match range.op() {
637                    // Range: shape dim = end (the upper bound)
638                    Op::Range { end, .. } => {
639                        // Try to get constant value from end
640                        if let Op::Const(val) = end.op() {
641                            match val.0 {
642                                ConstValue::Int(v) if v >= 0 => {
643                                    dims.push(SInt::Const(v as usize));
644                                    continue;
645                                }
646                                ConstValue::UInt(v) => {
647                                    dims.push(SInt::Const(v as usize));
648                                    continue;
649                                }
650                                _ => {}
651                            }
652                        }
653                        // Fall back to symbolic
654                        dims.push(SInt::Symbolic(end.clone()));
655                    }
656                    // CONST range (already dead axis) has size from vmax+1
657                    Op::Const(val) => {
658                        match val.0 {
659                            ConstValue::Int(v) if v >= 0 => {
660                                dims.push(SInt::Const((v + 1) as usize)); // vmax+1 for shape
661                            }
662                            ConstValue::UInt(v) => {
663                                dims.push(SInt::Const((v + 1) as usize)); // vmax+1 for shape
664                            }
665                            _ => return Ok(None), // Can't determine shape
666                        }
667                    }
668                    // Other range types: use symbolic
669                    _ => {
670                        dims.push(SInt::Symbolic(range.clone()));
671                    }
672                }
673            }
674            Some(dims)
675        }
676
677        // These have no shape
678        Op::Index { .. } | Op::Load { .. } | Op::Store { .. } => None,
679
680        // =====================================================================
681        // Control flow - no static shape
682        // =====================================================================
683        Op::If { .. } | Op::EndIf { .. } | Op::Range { .. } | Op::Barrier { .. } => None,
684
685        // End passes through the computation shape
686        Op::End { computation, .. } => computation.shape()?.cloned(),
687
688        // =====================================================================
689        // Special operations
690        // =====================================================================
691        // MSelect passes through buffer shape
692        Op::MSelect { buffer, .. } => buffer.shape()?.cloned(),
693
694        Op::Special { .. } => None,
695
696        Op::DefineVar { .. } => Some(SmallVec::new()), // Variable is scalar
697
698        Op::Bind { value, .. } => value.shape()?.cloned(),
699
700        Op::DefineReg { size, .. } => Some(smallvec![SInt::from(*size)]),
701
702        // =====================================================================
703        // Advanced operations
704        // =====================================================================
705        Op::Wmma { .. } | Op::Contract { .. } | Op::Unroll { .. } => {
706            // These require more complex shape computation
707            None
708        }
709
710        Op::Kernel { .. } => None,
711
712        Op::Assign { target, .. } => target.shape()?.cloned(),
713
714        Op::Detach { src } | Op::Contiguous { src, .. } | Op::ContiguousBackward { src } | Op::Precast { src } => {
715            src.shape()?.cloned()
716        }
717
718        Op::After { passthrough, .. } => passthrough.shape()?.cloned(),
719
720        Op::Custom { .. } | Op::CustomI { .. } => None,
721
722        // Graph organization operations have no shape
723        Op::Sink { .. } => None,
724        Op::Group { sources } => match sources.first() {
725            Some(src) => src.shape()?.cloned(),
726            None => None,
727        },
728
729        // PointerIndex is a scalar index operation (no shape)
730        Op::PointerIndex { .. } => Some(smallvec![]),
731
732        // Cat and PtrCat are kernel-level vector ops (no tensor shape)
733        Op::Cat { .. } | Op::PtrCat { .. } => None,
734    })
735}