1use mlx_core::Shape;
7use mlx_core::graph::OpKind;
8
9#[derive(Debug, thiserror::Error)]
11pub enum ShapeError {
12 #[error("shape mismatch: {0}")]
13 Mismatch(String),
14
15 #[error("invalid axis {axis} for ndim {ndim}")]
16 InvalidAxis { axis: i32, ndim: usize },
17
18 #[error("matmul inner dimensions mismatch: {k1} vs {k2}")]
19 MatmulMismatch { k1: i64, k2: i64 },
20}
21
22pub fn infer_shape(op: &OpKind, inputs: &[&Shape]) -> Result<Shape, ShapeError> {
24 match op {
25 OpKind::Add | OpKind::Sub | OpKind::Mul | OpKind::Div => {
27 let a = inputs
28 .first()
29 .ok_or(ShapeError::Mismatch("missing input 0".into()))?;
30 let b = inputs
31 .get(1)
32 .ok_or(ShapeError::Mismatch("missing input 1".into()))?;
33 crate::broadcast_shapes(a, b)
34 .ok_or_else(|| ShapeError::Mismatch(format!("cannot broadcast {a} with {b}")))
35 }
36
37 OpKind::Neg
39 | OpKind::Exp
40 | OpKind::Log
41 | OpKind::Silu
42 | OpKind::Gelu
43 | OpKind::Sqrt
44 | OpKind::Constant
45 | OpKind::Parameter
46 | OpKind::Rope { .. }
47 | OpKind::RoPE { .. } => {
48 let a = inputs
49 .first()
50 .ok_or(ShapeError::Mismatch("missing input".into()))?;
51 Ok((*a).clone())
52 }
53
54 OpKind::LayerNorm { .. } | OpKind::RmsNorm { .. } => {
56 let a = inputs
57 .first()
58 .ok_or(ShapeError::Mismatch("missing input".into()))?;
59 Ok((*a).clone())
60 }
61
62 OpKind::Softmax { axis } => {
64 let a = inputs
65 .first()
66 .ok_or(ShapeError::Mismatch("missing input".into()))?;
67 validate_axis(*axis, a.ndim())?;
68 Ok((*a).clone())
69 }
70
71 OpKind::Sum { axis } | OpKind::Mean { axis } | OpKind::Max { axis } => {
73 let a = inputs
74 .first()
75 .ok_or(ShapeError::Mismatch("missing input".into()))?;
76 match axis {
77 None => Ok(Shape::scalar()),
78 Some(ax) => {
79 let resolved = resolve_axis(*ax, a.ndim())?;
80 let mut dims = a.0.clone();
81 dims.remove(resolved);
82 Ok(Shape::new(dims))
83 }
84 }
85 }
86
87 OpKind::MatMul => {
89 let a = inputs
90 .first()
91 .ok_or(ShapeError::Mismatch("missing input 0".into()))?;
92 let b = inputs
93 .get(1)
94 .ok_or(ShapeError::Mismatch("missing input 1".into()))?;
95 if a.ndim() != 2 || b.ndim() != 2 {
96 return Err(ShapeError::Mismatch("matmul requires 2D tensors".into()));
97 }
98 let k1 = a.0[1];
99 let k2 = b.0[0];
100 if k1 != k2 {
101 return Err(ShapeError::MatmulMismatch { k1, k2 });
102 }
103 Ok(Shape::new(vec![a.0[0], b.0[1]]))
104 }
105
106 OpKind::Reshape { new_shape } => {
108 let a = inputs
109 .first()
110 .ok_or(ShapeError::Mismatch("missing input".into()))?;
111 if a.numel() != new_shape.numel() {
112 return Err(ShapeError::Mismatch(format!(
113 "reshape cannot change numel from {} to {}",
114 a.numel(),
115 new_shape.numel()
116 )));
117 }
118 Ok(new_shape.clone())
119 }
120
121 OpKind::Broadcast { target_shape } => {
123 let a = inputs
124 .first()
125 .ok_or(ShapeError::Mismatch("missing input".into()))?;
126 let result = crate::broadcast_shapes(a, target_shape).ok_or_else(|| {
128 ShapeError::Mismatch(format!("cannot broadcast {a} to {target_shape}"))
129 })?;
130 if &result != target_shape {
131 return Err(ShapeError::Mismatch(format!(
132 "broadcast result {result} does not match target {target_shape}"
133 )));
134 }
135 Ok(target_shape.clone())
136 }
137
138 OpKind::LayerNormVjp { .. }
140 | OpKind::RmsNormVjp { .. }
141 | OpKind::SoftmaxVjp { .. }
142 | OpKind::SiluVjp
143 | OpKind::GeluVjp => {
144 let grad_output = inputs
145 .first()
146 .ok_or(ShapeError::Mismatch("missing grad_output (input 0)".into()))?;
147 let original_input = inputs.get(1).ok_or(ShapeError::Mismatch(
148 "missing original input (input 1)".into(),
149 ))?;
150 if grad_output.0 != original_input.0 {
151 return Err(ShapeError::Mismatch(
152 "VJP grad_output and input shapes must match".into(),
153 ));
154 }
155 Ok((*original_input).clone())
156 }
157
158 OpKind::ScaledMaskedSoftmax { .. } => {
160 let a = inputs
161 .first()
162 .ok_or(ShapeError::Mismatch("missing input".into()))?;
163 if a.ndim() != 2 {
164 return Err(ShapeError::Mismatch(
165 "ScaledMaskedSoftmax requires 2D input [Tq, Tk]".into(),
166 ));
167 }
168 Ok((*a).clone())
169 }
170
171 OpKind::Attention { .. } => {
173 let q = inputs
174 .first()
175 .ok_or(ShapeError::Mismatch("missing Q (input 0)".into()))?;
176 let k = inputs
177 .get(1)
178 .ok_or(ShapeError::Mismatch("missing K (input 1)".into()))?;
179 let v = inputs
180 .get(2)
181 .ok_or(ShapeError::Mismatch("missing V (input 2)".into()))?;
182 if q.ndim() != 2 || k.ndim() != 2 || v.ndim() != 2 {
183 return Err(ShapeError::Mismatch("Attention inputs must be 2D".into()));
184 }
185 let tq = q.0[0];
186 let dh = q.0[1];
187 let tk = k.0[0];
188 let dh_k = k.0[1];
189 let tk_v = v.0[0];
190 let dh_v = v.0[1];
191 if dh != dh_k {
192 return Err(ShapeError::Mismatch(format!(
193 "Q head_dim {} != K head_dim {}",
194 dh, dh_k
195 )));
196 }
197 if tk != tk_v {
198 return Err(ShapeError::Mismatch(format!(
199 "K seq_len {} != V seq_len {}",
200 tk, tk_v
201 )));
202 }
203 if dh != dh_v {
204 return Err(ShapeError::Mismatch(format!(
205 "Q head_dim {} != V head_dim {}",
206 dh, dh_v
207 )));
208 }
209 Ok(Shape::new(vec![tq, dh]))
210 }
211
212 OpKind::Embedding => {
214 let weight = inputs
215 .first()
216 .ok_or(ShapeError::Mismatch("missing weight (input 0)".into()))?;
217 let indices = inputs
218 .get(1)
219 .ok_or(ShapeError::Mismatch("missing indices (input 1)".into()))?;
220 if weight.ndim() != 2 {
221 return Err(ShapeError::Mismatch(
222 "Embedding weight must be 2D [vocab, dim]".into(),
223 ));
224 }
225 if indices.ndim() != 1 {
226 return Err(ShapeError::Mismatch(
227 "Embedding indices must be 1D [seq_len]".into(),
228 ));
229 }
230 let seq_len = indices.0[0];
231 let dim = weight.0[1];
232 Ok(Shape::new(vec![seq_len, dim]))
233 }
234
235 OpKind::Narrow {
237 axis,
238 start,
239 length,
240 } => {
241 let a = inputs
242 .first()
243 .ok_or(ShapeError::Mismatch("missing input".into()))?;
244 let resolved = resolve_axis(*axis, a.ndim())?;
245 let dim_size = a.0[resolved];
246 if *start < 0 || start + length > dim_size {
247 return Err(ShapeError::Mismatch(format!(
248 "Narrow: start {} + length {} exceeds dim size {}",
249 start, length, dim_size
250 )));
251 }
252 let mut dims = a.0.clone();
253 dims[resolved] = *length;
254 Ok(Shape::new(dims))
255 }
256
257 OpKind::Concatenate { axis } => {
259 let first = inputs
260 .first()
261 .ok_or(ShapeError::Mismatch("missing input".into()))?;
262 let resolved = resolve_axis(*axis, first.ndim())?;
263 let mut total_dim: i64 = 0;
264 for inp in inputs {
265 if inp.ndim() != first.ndim() {
266 return Err(ShapeError::Mismatch(
267 "Concatenate: all inputs must have same ndim".into(),
268 ));
269 }
270 for (d, (&a, &b)) in first.0.iter().zip(inp.0.iter()).enumerate() {
271 if d != resolved && a != b {
272 return Err(ShapeError::Mismatch(format!(
273 "Concatenate: mismatch at dim {d}: {a} vs {b}"
274 )));
275 }
276 }
277 total_dim += inp.0[resolved];
278 }
279 let mut dims = first.0.clone();
280 dims[resolved] = total_dim;
281 Ok(Shape::new(dims))
282 }
283
284 OpKind::Transpose { axes } => {
286 let a = inputs
287 .first()
288 .ok_or(ShapeError::Mismatch("missing input".into()))?;
289 let ndim = a.ndim();
290 let perm: Vec<usize> = match axes {
291 Some(ax) => {
292 if ax.len() != ndim {
293 return Err(ShapeError::Mismatch(format!(
294 "transpose axes length {} does not match ndim {}",
295 ax.len(),
296 ndim
297 )));
298 }
299 ax.clone()
300 }
301 None => (0..ndim).rev().collect(),
302 };
303
304 let mut seen = vec![false; ndim];
306 for &ax in &perm {
307 if ax >= ndim {
308 return Err(ShapeError::InvalidAxis {
309 axis: ax as i32,
310 ndim,
311 });
312 }
313 if seen[ax] {
314 return Err(ShapeError::Mismatch(format!(
315 "duplicate axis {} in transpose",
316 ax
317 )));
318 }
319 seen[ax] = true;
320 }
321
322 let new_dims: Vec<i64> = perm.iter().map(|&ax| a.0[ax]).collect();
323 Ok(Shape::new(new_dims))
324 }
325 }
326}
327
328fn validate_axis(axis: i32, ndim: usize) -> Result<usize, ShapeError> {
329 resolve_axis(axis, ndim)
330}
331
332fn resolve_axis(axis: i32, ndim: usize) -> Result<usize, ShapeError> {
333 let ndim_i = ndim as i32;
334 let resolved = if axis < 0 { ndim_i + axis } else { axis };
335 if resolved < 0 || resolved >= ndim_i {
336 return Err(ShapeError::InvalidAxis { axis, ndim });
337 }
338 Ok(resolved as usize)
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 fn s(dims: &[i64]) -> Shape {
346 Shape::new(dims.to_vec())
347 }
348
349 #[test]
350 fn test_binary_same_shape() {
351 let a = s(&[2, 3]);
352 let result = infer_shape(&OpKind::Add, &[&a, &a]).unwrap();
353 assert_eq!(result, s(&[2, 3]));
354 }
355
356 #[test]
357 fn test_binary_broadcast() {
358 let a = s(&[2, 1]);
359 let b = s(&[1, 3]);
360 let result = infer_shape(&OpKind::Mul, &[&a, &b]).unwrap();
361 assert_eq!(result, s(&[2, 3]));
362 }
363
364 #[test]
365 fn test_binary_incompatible() {
366 let a = s(&[2, 3]);
367 let b = s(&[2, 4]);
368 assert!(infer_shape(&OpKind::Add, &[&a, &b]).is_err());
369 }
370
371 #[test]
372 fn test_unary_preserves_shape() {
373 let a = s(&[3, 4]);
374 assert_eq!(infer_shape(&OpKind::Neg, &[&a]).unwrap(), s(&[3, 4]));
375 assert_eq!(infer_shape(&OpKind::Silu, &[&a]).unwrap(), s(&[3, 4]));
376 }
377
378 #[test]
379 fn test_sum_axis() {
380 let a = s(&[2, 3, 4]);
381 let result = infer_shape(&OpKind::Sum { axis: Some(1) }, &[&a]).unwrap();
382 assert_eq!(result, s(&[2, 4]));
383 }
384
385 #[test]
386 fn test_sum_all() {
387 let a = s(&[2, 3]);
388 let result = infer_shape(&OpKind::Sum { axis: None }, &[&a]).unwrap();
389 assert_eq!(result, Shape::scalar());
390 }
391
392 #[test]
393 fn test_sum_negative_axis() {
394 let a = s(&[2, 3, 4]);
395 let result = infer_shape(&OpKind::Sum { axis: Some(-1) }, &[&a]).unwrap();
396 assert_eq!(result, s(&[2, 3]));
397 }
398
399 #[test]
400 fn test_matmul() {
401 let a = s(&[2, 3]);
402 let b = s(&[3, 4]);
403 let result = infer_shape(&OpKind::MatMul, &[&a, &b]).unwrap();
404 assert_eq!(result, s(&[2, 4]));
405 }
406
407 #[test]
408 fn test_matmul_mismatch() {
409 let a = s(&[2, 3]);
410 let b = s(&[4, 5]);
411 assert!(infer_shape(&OpKind::MatMul, &[&a, &b]).is_err());
412 }
413
414 #[test]
415 fn test_transpose_default() {
416 let a = s(&[2, 3]);
417 let result = infer_shape(&OpKind::Transpose { axes: None }, &[&a]).unwrap();
418 assert_eq!(result, s(&[3, 2]));
419 }
420
421 #[test]
422 fn test_transpose_custom() {
423 let a = s(&[2, 3, 4]);
424 let result = infer_shape(
425 &OpKind::Transpose {
426 axes: Some(vec![2, 0, 1]),
427 },
428 &[&a],
429 )
430 .unwrap();
431 assert_eq!(result, s(&[4, 2, 3]));
432 }
433
434 #[test]
435 fn test_reshape() {
436 let a = s(&[2, 3]);
437 let result = infer_shape(
438 &OpKind::Reshape {
439 new_shape: s(&[3, 2]),
440 },
441 &[&a],
442 )
443 .unwrap();
444 assert_eq!(result, s(&[3, 2]));
445 }
446
447 #[test]
448 fn test_softmax_preserves_shape() {
449 let a = s(&[2, 3]);
450 let result = infer_shape(&OpKind::Softmax { axis: 1 }, &[&a]).unwrap();
451 assert_eq!(result, s(&[2, 3]));
452 }
453
454 #[test]
455 fn test_layer_norm_preserves_shape() {
456 let a = s(&[4, 8]);
457 let result = infer_shape(&OpKind::LayerNorm { eps: 1e-5 }, &[&a]).unwrap();
458 assert_eq!(result, s(&[4, 8]));
459 }
460
461 #[test]
462 fn test_transpose_validation() {
463 let a = s(&[2, 3]);
464
465 let res = infer_shape(
467 &OpKind::Transpose {
468 axes: Some(vec![0]),
469 },
470 &[&a],
471 );
472 assert!(res.is_err());
473
474 let res = infer_shape(
476 &OpKind::Transpose {
477 axes: Some(vec![0, 0]),
478 },
479 &[&a],
480 );
481 assert!(res.is_err());
482
483 let res = infer_shape(
485 &OpKind::Transpose {
486 axes: Some(vec![0, 5]),
487 },
488 &[&a],
489 );
490 assert!(res.is_err());
491 }
492
493 #[test]
494 fn test_reshape_validation() {
495 let a = s(&[2, 3]); let new_shape = s(&[2, 4]); let res = infer_shape(&OpKind::Reshape { new_shape }, &[&a]);
499 assert!(res.is_err());
500 }
501}