Skip to main content

mlx_ops/
shape_inference.rs

1//! Shape inference for graph ops.
2//!
3//! Given an `OpKind` and input shapes, computes the output shape. This is used
4//! by the graph builder to set `TensorMeta` on newly created nodes.
5
6use mlx_core::Shape;
7use mlx_core::graph::OpKind;
8
9/// Error returned when shapes are incompatible for an op.
10#[derive(Debug, thiserror::Error)]
11pub enum ShapeError {
12    #[error("shape mismatch: {0}")]
13    Mismatch(String),
14
15    #[error("invalid axis {axis} for ndim {ndim}")]
16    InvalidAxis { axis: i32, ndim: usize },
17
18    #[error("matmul inner dimensions mismatch: {k1} vs {k2}")]
19    MatmulMismatch { k1: i64, k2: i64 },
20}
21
22/// Infer the output shape for a given op and input shapes.
23pub fn infer_shape(op: &OpKind, inputs: &[&Shape]) -> Result<Shape, ShapeError> {
24    match op {
25        // Binary elementwise ops: shapes must match (or be broadcastable).
26        OpKind::Add | OpKind::Sub | OpKind::Mul | OpKind::Div => {
27            let a = inputs
28                .first()
29                .ok_or(ShapeError::Mismatch("missing input 0".into()))?;
30            let b = inputs
31                .get(1)
32                .ok_or(ShapeError::Mismatch("missing input 1".into()))?;
33            crate::broadcast_shapes(a, b)
34                .ok_or_else(|| ShapeError::Mismatch(format!("cannot broadcast {a} with {b}")))
35        }
36
37        // Unary ops preserve shape.
38        OpKind::Neg
39        | OpKind::Exp
40        | OpKind::Log
41        | OpKind::Silu
42        | OpKind::Gelu
43        | OpKind::Sqrt
44        | OpKind::Constant
45        | OpKind::Parameter
46        | OpKind::Rope { .. }
47        | OpKind::RoPE { .. } => {
48            let a = inputs
49                .first()
50                .ok_or(ShapeError::Mismatch("missing input".into()))?;
51            Ok((*a).clone())
52        }
53
54        // Normalization preserves shape.
55        OpKind::LayerNorm { .. } | OpKind::RmsNorm { .. } => {
56            let a = inputs
57                .first()
58                .ok_or(ShapeError::Mismatch("missing input".into()))?;
59            Ok((*a).clone())
60        }
61
62        // Softmax preserves shape.
63        OpKind::Softmax { axis } => {
64            let a = inputs
65                .first()
66                .ok_or(ShapeError::Mismatch("missing input".into()))?;
67            validate_axis(*axis, a.ndim())?;
68            Ok((*a).clone())
69        }
70
71        // Reductions remove the specified axis.
72        OpKind::Sum { axis } | OpKind::Mean { axis } | OpKind::Max { axis } => {
73            let a = inputs
74                .first()
75                .ok_or(ShapeError::Mismatch("missing input".into()))?;
76            match axis {
77                None => Ok(Shape::scalar()),
78                Some(ax) => {
79                    let resolved = resolve_axis(*ax, a.ndim())?;
80                    let mut dims = a.0.clone();
81                    dims.remove(resolved);
82                    Ok(Shape::new(dims))
83                }
84            }
85        }
86
87        // MatMul: [M, K] @ [K, N] → [M, N]
88        OpKind::MatMul => {
89            let a = inputs
90                .first()
91                .ok_or(ShapeError::Mismatch("missing input 0".into()))?;
92            let b = inputs
93                .get(1)
94                .ok_or(ShapeError::Mismatch("missing input 1".into()))?;
95            if a.ndim() != 2 || b.ndim() != 2 {
96                return Err(ShapeError::Mismatch("matmul requires 2D tensors".into()));
97            }
98            let k1 = a.0[1];
99            let k2 = b.0[0];
100            if k1 != k2 {
101                return Err(ShapeError::MatmulMismatch { k1, k2 });
102            }
103            Ok(Shape::new(vec![a.0[0], b.0[1]]))
104        }
105
106        // Reshape: output shape is specified in the op.
107        OpKind::Reshape { new_shape } => {
108            let a = inputs
109                .first()
110                .ok_or(ShapeError::Mismatch("missing input".into()))?;
111            if a.numel() != new_shape.numel() {
112                return Err(ShapeError::Mismatch(format!(
113                    "reshape cannot change numel from {} to {}",
114                    a.numel(),
115                    new_shape.numel()
116                )));
117            }
118            Ok(new_shape.clone())
119        }
120
121        // Broadcast: output shape is the target shape.
122        OpKind::Broadcast { target_shape } => {
123            let a = inputs
124                .first()
125                .ok_or(ShapeError::Mismatch("missing input".into()))?;
126            // Validate broadcast compatibility.
127            let result = crate::broadcast_shapes(a, target_shape).ok_or_else(|| {
128                ShapeError::Mismatch(format!("cannot broadcast {a} to {target_shape}"))
129            })?;
130            if &result != target_shape {
131                return Err(ShapeError::Mismatch(format!(
132                    "broadcast result {result} does not match target {target_shape}"
133                )));
134            }
135            Ok(target_shape.clone())
136        }
137
138        // Backward ops: output shape = input shape (must match grad_output shape).
139        OpKind::LayerNormVjp { .. }
140        | OpKind::RmsNormVjp { .. }
141        | OpKind::SoftmaxVjp { .. }
142        | OpKind::SiluVjp
143        | OpKind::GeluVjp => {
144            let grad_output = inputs
145                .first()
146                .ok_or(ShapeError::Mismatch("missing grad_output (input 0)".into()))?;
147            let original_input = inputs.get(1).ok_or(ShapeError::Mismatch(
148                "missing original input (input 1)".into(),
149            ))?;
150            if grad_output.0 != original_input.0 {
151                return Err(ShapeError::Mismatch(
152                    "VJP grad_output and input shapes must match".into(),
153                ));
154            }
155            Ok((*original_input).clone())
156        }
157
158        // ScaledMaskedSoftmax preserves shape (must be 2D).
159        OpKind::ScaledMaskedSoftmax { .. } => {
160            let a = inputs
161                .first()
162                .ok_or(ShapeError::Mismatch("missing input".into()))?;
163            if a.ndim() != 2 {
164                return Err(ShapeError::Mismatch(
165                    "ScaledMaskedSoftmax requires 2D input [Tq, Tk]".into(),
166                ));
167            }
168            Ok((*a).clone())
169        }
170
171        // Attention: [Q, K, V] -> [Tq, Dh]
172        OpKind::Attention { .. } => {
173            let q = inputs
174                .first()
175                .ok_or(ShapeError::Mismatch("missing Q (input 0)".into()))?;
176            let k = inputs
177                .get(1)
178                .ok_or(ShapeError::Mismatch("missing K (input 1)".into()))?;
179            let v = inputs
180                .get(2)
181                .ok_or(ShapeError::Mismatch("missing V (input 2)".into()))?;
182            if q.ndim() != 2 || k.ndim() != 2 || v.ndim() != 2 {
183                return Err(ShapeError::Mismatch("Attention inputs must be 2D".into()));
184            }
185            let tq = q.0[0];
186            let dh = q.0[1];
187            let tk = k.0[0];
188            let dh_k = k.0[1];
189            let tk_v = v.0[0];
190            let dh_v = v.0[1];
191            if dh != dh_k {
192                return Err(ShapeError::Mismatch(format!(
193                    "Q head_dim {} != K head_dim {}",
194                    dh, dh_k
195                )));
196            }
197            if tk != tk_v {
198                return Err(ShapeError::Mismatch(format!(
199                    "K seq_len {} != V seq_len {}",
200                    tk, tk_v
201                )));
202            }
203            if dh != dh_v {
204                return Err(ShapeError::Mismatch(format!(
205                    "Q head_dim {} != V head_dim {}",
206                    dh, dh_v
207                )));
208            }
209            Ok(Shape::new(vec![tq, dh]))
210        }
211
212        // Embedding: [vocab, dim] + [seq_len] -> [seq_len, dim]
213        OpKind::Embedding => {
214            let weight = inputs
215                .first()
216                .ok_or(ShapeError::Mismatch("missing weight (input 0)".into()))?;
217            let indices = inputs
218                .get(1)
219                .ok_or(ShapeError::Mismatch("missing indices (input 1)".into()))?;
220            if weight.ndim() != 2 {
221                return Err(ShapeError::Mismatch(
222                    "Embedding weight must be 2D [vocab, dim]".into(),
223                ));
224            }
225            if indices.ndim() != 1 {
226                return Err(ShapeError::Mismatch(
227                    "Embedding indices must be 1D [seq_len]".into(),
228                ));
229            }
230            let seq_len = indices.0[0];
231            let dim = weight.0[1];
232            Ok(Shape::new(vec![seq_len, dim]))
233        }
234
235        // Narrow: slice along axis
236        OpKind::Narrow {
237            axis,
238            start,
239            length,
240        } => {
241            let a = inputs
242                .first()
243                .ok_or(ShapeError::Mismatch("missing input".into()))?;
244            let resolved = resolve_axis(*axis, a.ndim())?;
245            let dim_size = a.0[resolved];
246            if *start < 0 || start + length > dim_size {
247                return Err(ShapeError::Mismatch(format!(
248                    "Narrow: start {} + length {} exceeds dim size {}",
249                    start, length, dim_size
250                )));
251            }
252            let mut dims = a.0.clone();
253            dims[resolved] = *length;
254            Ok(Shape::new(dims))
255        }
256
257        // Concatenate: join along axis
258        OpKind::Concatenate { axis } => {
259            let first = inputs
260                .first()
261                .ok_or(ShapeError::Mismatch("missing input".into()))?;
262            let resolved = resolve_axis(*axis, first.ndim())?;
263            let mut total_dim: i64 = 0;
264            for inp in inputs {
265                if inp.ndim() != first.ndim() {
266                    return Err(ShapeError::Mismatch(
267                        "Concatenate: all inputs must have same ndim".into(),
268                    ));
269                }
270                for (d, (&a, &b)) in first.0.iter().zip(inp.0.iter()).enumerate() {
271                    if d != resolved && a != b {
272                        return Err(ShapeError::Mismatch(format!(
273                            "Concatenate: mismatch at dim {d}: {a} vs {b}"
274                        )));
275                    }
276                }
277                total_dim += inp.0[resolved];
278            }
279            let mut dims = first.0.clone();
280            dims[resolved] = total_dim;
281            Ok(Shape::new(dims))
282        }
283
284        // Transpose: permute dimensions.
285        OpKind::Transpose { axes } => {
286            let a = inputs
287                .first()
288                .ok_or(ShapeError::Mismatch("missing input".into()))?;
289            let ndim = a.ndim();
290            let perm: Vec<usize> = match axes {
291                Some(ax) => {
292                    if ax.len() != ndim {
293                        return Err(ShapeError::Mismatch(format!(
294                            "transpose axes length {} does not match ndim {}",
295                            ax.len(),
296                            ndim
297                        )));
298                    }
299                    ax.clone()
300                }
301                None => (0..ndim).rev().collect(),
302            };
303
304            // Validate permutation: bounds + uniqueness in a single pass.
305            let mut seen = vec![false; ndim];
306            for &ax in &perm {
307                if ax >= ndim {
308                    return Err(ShapeError::InvalidAxis {
309                        axis: ax as i32,
310                        ndim,
311                    });
312                }
313                if seen[ax] {
314                    return Err(ShapeError::Mismatch(format!(
315                        "duplicate axis {} in transpose",
316                        ax
317                    )));
318                }
319                seen[ax] = true;
320            }
321
322            let new_dims: Vec<i64> = perm.iter().map(|&ax| a.0[ax]).collect();
323            Ok(Shape::new(new_dims))
324        }
325    }
326}
327
328fn validate_axis(axis: i32, ndim: usize) -> Result<usize, ShapeError> {
329    resolve_axis(axis, ndim)
330}
331
332fn resolve_axis(axis: i32, ndim: usize) -> Result<usize, ShapeError> {
333    let ndim_i = ndim as i32;
334    let resolved = if axis < 0 { ndim_i + axis } else { axis };
335    if resolved < 0 || resolved >= ndim_i {
336        return Err(ShapeError::InvalidAxis { axis, ndim });
337    }
338    Ok(resolved as usize)
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    fn s(dims: &[i64]) -> Shape {
346        Shape::new(dims.to_vec())
347    }
348
349    #[test]
350    fn test_binary_same_shape() {
351        let a = s(&[2, 3]);
352        let result = infer_shape(&OpKind::Add, &[&a, &a]).unwrap();
353        assert_eq!(result, s(&[2, 3]));
354    }
355
356    #[test]
357    fn test_binary_broadcast() {
358        let a = s(&[2, 1]);
359        let b = s(&[1, 3]);
360        let result = infer_shape(&OpKind::Mul, &[&a, &b]).unwrap();
361        assert_eq!(result, s(&[2, 3]));
362    }
363
364    #[test]
365    fn test_binary_incompatible() {
366        let a = s(&[2, 3]);
367        let b = s(&[2, 4]);
368        assert!(infer_shape(&OpKind::Add, &[&a, &b]).is_err());
369    }
370
371    #[test]
372    fn test_unary_preserves_shape() {
373        let a = s(&[3, 4]);
374        assert_eq!(infer_shape(&OpKind::Neg, &[&a]).unwrap(), s(&[3, 4]));
375        assert_eq!(infer_shape(&OpKind::Silu, &[&a]).unwrap(), s(&[3, 4]));
376    }
377
378    #[test]
379    fn test_sum_axis() {
380        let a = s(&[2, 3, 4]);
381        let result = infer_shape(&OpKind::Sum { axis: Some(1) }, &[&a]).unwrap();
382        assert_eq!(result, s(&[2, 4]));
383    }
384
385    #[test]
386    fn test_sum_all() {
387        let a = s(&[2, 3]);
388        let result = infer_shape(&OpKind::Sum { axis: None }, &[&a]).unwrap();
389        assert_eq!(result, Shape::scalar());
390    }
391
392    #[test]
393    fn test_sum_negative_axis() {
394        let a = s(&[2, 3, 4]);
395        let result = infer_shape(&OpKind::Sum { axis: Some(-1) }, &[&a]).unwrap();
396        assert_eq!(result, s(&[2, 3]));
397    }
398
399    #[test]
400    fn test_matmul() {
401        let a = s(&[2, 3]);
402        let b = s(&[3, 4]);
403        let result = infer_shape(&OpKind::MatMul, &[&a, &b]).unwrap();
404        assert_eq!(result, s(&[2, 4]));
405    }
406
407    #[test]
408    fn test_matmul_mismatch() {
409        let a = s(&[2, 3]);
410        let b = s(&[4, 5]);
411        assert!(infer_shape(&OpKind::MatMul, &[&a, &b]).is_err());
412    }
413
414    #[test]
415    fn test_transpose_default() {
416        let a = s(&[2, 3]);
417        let result = infer_shape(&OpKind::Transpose { axes: None }, &[&a]).unwrap();
418        assert_eq!(result, s(&[3, 2]));
419    }
420
421    #[test]
422    fn test_transpose_custom() {
423        let a = s(&[2, 3, 4]);
424        let result = infer_shape(
425            &OpKind::Transpose {
426                axes: Some(vec![2, 0, 1]),
427            },
428            &[&a],
429        )
430        .unwrap();
431        assert_eq!(result, s(&[4, 2, 3]));
432    }
433
434    #[test]
435    fn test_reshape() {
436        let a = s(&[2, 3]);
437        let result = infer_shape(
438            &OpKind::Reshape {
439                new_shape: s(&[3, 2]),
440            },
441            &[&a],
442        )
443        .unwrap();
444        assert_eq!(result, s(&[3, 2]));
445    }
446
447    #[test]
448    fn test_softmax_preserves_shape() {
449        let a = s(&[2, 3]);
450        let result = infer_shape(&OpKind::Softmax { axis: 1 }, &[&a]).unwrap();
451        assert_eq!(result, s(&[2, 3]));
452    }
453
454    #[test]
455    fn test_layer_norm_preserves_shape() {
456        let a = s(&[4, 8]);
457        let result = infer_shape(&OpKind::LayerNorm { eps: 1e-5 }, &[&a]).unwrap();
458        assert_eq!(result, s(&[4, 8]));
459    }
460
461    #[test]
462    fn test_transpose_validation() {
463        let a = s(&[2, 3]);
464
465        // Missing axis
466        let res = infer_shape(
467            &OpKind::Transpose {
468                axes: Some(vec![0]),
469            },
470            &[&a],
471        );
472        assert!(res.is_err());
473
474        // Duplicate axis
475        let res = infer_shape(
476            &OpKind::Transpose {
477                axes: Some(vec![0, 0]),
478            },
479            &[&a],
480        );
481        assert!(res.is_err());
482
483        // Out of bounds axis
484        let res = infer_shape(
485            &OpKind::Transpose {
486                axes: Some(vec![0, 5]),
487            },
488            &[&a],
489        );
490        assert!(res.is_err());
491    }
492
493    #[test]
494    fn test_reshape_validation() {
495        let a = s(&[2, 3]); // numel = 6
496        let new_shape = s(&[2, 4]); // numel = 8
497
498        let res = infer_shape(&OpKind::Reshape { new_shape }, &[&a]);
499        assert!(res.is_err());
500    }
501}