1use lift_core::types::{Dimension, TensorTypeInfo};
2use crate::ops::TensorOp;
3
4#[derive(Debug)]
5pub struct ShapeInference;
6
7impl ShapeInference {
8 pub fn infer_output_shape(
9 op: &TensorOp,
10 inputs: &[&TensorTypeInfo],
11 ) -> Result<Vec<TensorTypeInfo>, String> {
12 match op {
13 TensorOp::Add | TensorOp::Sub | TensorOp::Mul | TensorOp::Div => {
15 if inputs.len() != 2 {
16 return Err(format!("{} requires 2 inputs", op.name()));
17 }
18 let result = broadcast_shapes(&inputs[0].shape, &inputs[1].shape)?;
19 Ok(vec![TensorTypeInfo {
20 shape: result,
21 dtype: inputs[0].dtype,
22 layout: inputs[0].layout,
23 }])
24 }
25
26 TensorOp::Neg | TensorOp::ReLU | TensorOp::GeLU | TensorOp::SiLU |
28 TensorOp::Sigmoid | TensorOp::Tanh |
29 TensorOp::LeakyReLU | TensorOp::ELU | TensorOp::Mish |
30 TensorOp::HardSwish | TensorOp::HardSigmoid |
31 TensorOp::Softmax | TensorOp::Cumsum |
32 TensorOp::Quantize | TensorOp::Dequantize |
33 TensorOp::QuantizeInt4 | TensorOp::DequantizeInt4 |
34 TensorOp::QuantizeFp8 | TensorOp::DequantizeFp8 |
35 TensorOp::Checkpoint | TensorOp::Offload |
36 TensorOp::GradReLU | TensorOp::GradGeLU | TensorOp::GradSoftmax => {
37 if inputs.is_empty() {
38 return Err(format!("{} requires at least 1 input", op.name()));
39 }
40 Ok(vec![inputs[0].clone()])
41 }
42
43 TensorOp::LayerNorm | TensorOp::RMSNorm | TensorOp::BatchNorm |
45 TensorOp::GroupNorm | TensorOp::InstanceNorm |
46 TensorOp::GradLayerNorm => {
47 if inputs.is_empty() {
48 return Err(format!("{} requires at least 1 input", op.name()));
49 }
50 Ok(vec![inputs[0].clone()])
51 }
52
53 TensorOp::MatMul | TensorOp::SparseMatMul => {
55 if inputs.len() != 2 {
56 return Err("matmul requires 2 inputs".into());
57 }
58 let a = &inputs[0].shape;
59 let b = &inputs[1].shape;
60 if a.len() < 2 || b.len() < 2 {
61 return Err("matmul inputs must be at least 2D".into());
62 }
63 let m = a[a.len() - 2].clone();
64 let n = b[b.len() - 1].clone();
65
66 let k_a = &a[a.len() - 1];
67 let k_b = &b[b.len() - 2];
68 if let (Some(ka), Some(kb)) = (k_a.static_value(), k_b.static_value()) {
69 if ka != kb {
70 return Err(format!(
71 "matmul inner dimension mismatch: {} vs {}", ka, kb
72 ));
73 }
74 }
75
76 let mut result_shape = Vec::new();
77 let batch_a = &a[..a.len() - 2];
78 let batch_b = &b[..b.len() - 2];
79 let batch = broadcast_shapes(batch_a, batch_b)?;
80 result_shape.extend(batch);
81 result_shape.push(m);
82 result_shape.push(n);
83
84 Ok(vec![TensorTypeInfo {
85 shape: result_shape,
86 dtype: inputs[0].dtype,
87 layout: inputs[0].layout,
88 }])
89 }
90
91 TensorOp::Linear => {
93 if inputs.len() < 2 {
94 return Err("linear requires at least 2 inputs (x, W)".into());
95 }
96 let x = &inputs[0].shape;
97 let w = &inputs[1].shape;
98 if x.is_empty() || w.len() != 2 {
99 return Err("linear: x must be at least 1D, W must be 2D".into());
100 }
101 let mut result_shape = x[..x.len() - 1].to_vec();
102 result_shape.push(w[1].clone());
103
104 Ok(vec![TensorTypeInfo {
105 shape: result_shape,
106 dtype: inputs[0].dtype,
107 layout: inputs[0].layout,
108 }])
109 }
110
111 TensorOp::Conv2D | TensorOp::DepthwiseConv2D | TensorOp::DilatedConv2D => {
113 if inputs.len() < 2 {
114 return Err("conv2d requires at least 2 inputs (input, kernel)".into());
115 }
116 let input = &inputs[0].shape;
117 let kernel = &inputs[1].shape;
118 if input.len() != 4 || kernel.len() != 4 {
119 return Err("conv2d: input and kernel must be 4D (NCHW)".into());
120 }
121
122 let n = input[0].clone();
123 let cout = kernel[0].clone();
124 let h_out = match (&input[2], &kernel[2]) {
125 (Dimension::Constant(ih), Dimension::Constant(kh)) => {
126 Dimension::Constant(ih - kh + 1)
127 }
128 _ => Dimension::Symbolic("H_out".into()),
129 };
130 let w_out = match (&input[3], &kernel[3]) {
131 (Dimension::Constant(iw), Dimension::Constant(kw)) => {
132 Dimension::Constant(iw - kw + 1)
133 }
134 _ => Dimension::Symbolic("W_out".into()),
135 };
136
137 Ok(vec![TensorTypeInfo {
138 shape: vec![n, cout, h_out, w_out],
139 dtype: inputs[0].dtype,
140 layout: inputs[0].layout,
141 }])
142 }
143
144 TensorOp::Conv1D => {
146 if inputs.len() < 2 {
147 return Err("conv1d requires at least 2 inputs".into());
148 }
149 let input = &inputs[0].shape;
150 let kernel = &inputs[1].shape;
151 if input.len() != 3 || kernel.len() != 3 {
152 return Err("conv1d: input [N,C,L] and kernel [Cout,Cin,K]".into());
153 }
154 let n = input[0].clone();
155 let cout = kernel[0].clone();
156 let l_out = match (&input[2], &kernel[2]) {
157 (Dimension::Constant(il), Dimension::Constant(kl)) => {
158 Dimension::Constant(il - kl + 1)
159 }
160 _ => Dimension::Symbolic("L_out".into()),
161 };
162 Ok(vec![TensorTypeInfo {
163 shape: vec![n, cout, l_out],
164 dtype: inputs[0].dtype,
165 layout: inputs[0].layout,
166 }])
167 }
168
169 TensorOp::Conv3D => {
171 if inputs.len() < 2 {
172 return Err("conv3d requires at least 2 inputs".into());
173 }
174 let input = &inputs[0].shape;
175 let kernel = &inputs[1].shape;
176 if input.len() != 5 || kernel.len() != 5 {
177 return Err("conv3d: input [N,C,D,H,W] and kernel [Cout,Cin,Kd,Kh,Kw]".into());
178 }
179 let n = input[0].clone();
180 let cout = kernel[0].clone();
181 let dims: Vec<Dimension> = (2..5).map(|i| {
182 match (&input[i], &kernel[i]) {
183 (Dimension::Constant(iv), Dimension::Constant(kv)) => {
184 Dimension::Constant(iv - kv + 1)
185 }
186 _ => Dimension::Symbolic(format!("dim{}_out", i)),
187 }
188 }).collect();
189 Ok(vec![TensorTypeInfo {
190 shape: vec![n, cout, dims[0].clone(), dims[1].clone(), dims[2].clone()],
191 dtype: inputs[0].dtype,
192 layout: inputs[0].layout,
193 }])
194 }
195
196 TensorOp::MaxPool2D | TensorOp::AvgPool2D => {
198 if inputs.is_empty() {
199 return Err(format!("{} requires at least 1 input", op.name()));
200 }
201 Ok(vec![inputs[0].clone()])
203 }
204
205 TensorOp::AdaptiveAvgPool2D => {
206 if inputs.is_empty() {
207 return Err("adaptive_avgpool2d requires 1 input".into());
208 }
209 Ok(vec![inputs[0].clone()])
210 }
211
212 TensorOp::GlobalAvgPool => {
213 if inputs.is_empty() {
214 return Err("global_avgpool requires 1 input".into());
215 }
216 let shape = &inputs[0].shape;
217 if shape.len() < 3 {
218 return Err("global_avgpool: input must be at least 3D [N,C,...]".into());
219 }
220 let mut out = vec![shape[0].clone(), shape[1].clone()];
222 for _ in 2..shape.len() {
223 out.push(Dimension::Constant(1));
224 }
225 Ok(vec![TensorTypeInfo {
226 shape: out,
227 dtype: inputs[0].dtype,
228 layout: inputs[0].layout,
229 }])
230 }
231
232 TensorOp::Attention | TensorOp::MultiHeadAttention |
234 TensorOp::MultiQueryAttention | TensorOp::GroupedQueryAttention |
235 TensorOp::FlashAttention | TensorOp::SlidingWindowAttention |
236 TensorOp::CrossAttention | TensorOp::PagedAttention |
237 TensorOp::GradAttention => {
238 if inputs.len() < 3 {
239 return Err("attention requires at least 3 inputs (Q, K, V)".into());
240 }
241 Ok(vec![inputs[0].clone()])
242 }
243
244 TensorOp::LSTMCell => {
246 if inputs.len() < 2 {
247 return Err("lstm_cell requires input and hidden state".into());
248 }
249 Ok(vec![inputs[1].clone(), inputs[1].clone()])
251 }
252
253 TensorOp::GRUCell | TensorOp::RNNCell => {
254 if inputs.len() < 2 {
255 return Err(format!("{} requires input and hidden state", op.name()));
256 }
257 Ok(vec![inputs[1].clone()])
258 }
259
260 TensorOp::Reshape | TensorOp::Transpose | TensorOp::Squeeze |
262 TensorOp::Unsqueeze | TensorOp::Permute | TensorOp::Expand |
263 TensorOp::Slice | TensorOp::Pad | TensorOp::Tile => {
264 if inputs.is_empty() {
266 return Err(format!("{} requires at least 1 input", op.name()));
267 }
268 Ok(vec![inputs[0].clone()])
269 }
270
271 TensorOp::Concat => {
273 if inputs.is_empty() {
274 return Err("concat requires at least 1 input".into());
275 }
276 Ok(vec![inputs[0].clone()])
277 }
278
279 TensorOp::TopK | TensorOp::Sort => {
281 if inputs.is_empty() {
282 return Err(format!("{} requires 1 input", op.name()));
283 }
284 Ok(vec![inputs[0].clone()])
285 }
286
287 TensorOp::FFT | TensorOp::IFFT => {
289 if inputs.is_empty() {
290 return Err(format!("{} requires 1 input", op.name()));
291 }
292 Ok(vec![inputs[0].clone()])
293 }
294
295 TensorOp::SVD => {
297 if inputs.is_empty() {
298 return Err("svd requires 1 input".into());
299 }
300 Ok(vec![inputs[0].clone()])
301 }
302
303 TensorOp::Where | TensorOp::Clamp => {
305 if inputs.len() < 2 {
306 return Err(format!("{} requires at least 2 inputs", op.name()));
307 }
308 Ok(vec![inputs[0].clone()])
309 }
310
311 _ => {
312 if !inputs.is_empty() {
314 Ok(vec![inputs[0].clone()])
315 } else {
316 Ok(Vec::new())
317 }
318 }
319 }
320 }
321
322 pub fn compute_flops(op: &TensorOp, inputs: &[&TensorTypeInfo]) -> Option<u64> {
323 match op {
324 TensorOp::MatMul | TensorOp::SparseMatMul => {
325 if inputs.len() != 2 { return None; }
326 let a = &inputs[0].shape;
327 let b = &inputs[1].shape;
328 let m = a.get(a.len().checked_sub(2)?)?.static_value()? as u64;
329 let k = a.last()?.static_value()? as u64;
330 let n = b.last()?.static_value()? as u64;
331 let batch: u64 = a[..a.len() - 2].iter()
332 .filter_map(|d| d.static_value())
333 .map(|v| v as u64)
334 .product::<u64>()
335 .max(1);
336 Some(2 * batch * m * n * k)
337 }
338
339 TensorOp::Add | TensorOp::Sub | TensorOp::Mul | TensorOp::Div => {
340 if inputs.is_empty() { return None; }
341 Some(element_count(&inputs[0].shape)? as u64)
342 }
343
344 TensorOp::ReLU | TensorOp::Sigmoid | TensorOp::Tanh |
345 TensorOp::LeakyReLU | TensorOp::ELU | TensorOp::HardSigmoid => {
346 if inputs.is_empty() { return None; }
347 Some(element_count(&inputs[0].shape)? as u64)
348 }
349
350 TensorOp::GeLU | TensorOp::SiLU | TensorOp::Mish | TensorOp::HardSwish => {
351 if inputs.is_empty() { return None; }
352 let n = element_count(&inputs[0].shape)? as u64;
353 Some(8 * n)
354 }
355
356 TensorOp::Softmax => {
357 if inputs.is_empty() { return None; }
358 let n = element_count(&inputs[0].shape)? as u64;
359 Some(5 * n)
360 }
361
362 TensorOp::LayerNorm | TensorOp::RMSNorm |
363 TensorOp::GroupNorm | TensorOp::InstanceNorm => {
364 if inputs.is_empty() { return None; }
365 let n = element_count(&inputs[0].shape)? as u64;
366 Some(7 * n)
367 }
368
369 TensorOp::BatchNorm => {
370 if inputs.is_empty() { return None; }
371 let n = element_count(&inputs[0].shape)? as u64;
372 Some(5 * n)
373 }
374
375 TensorOp::Linear => {
376 if inputs.len() < 2 { return None; }
377 let x = &inputs[0].shape;
378 let w = &inputs[1].shape;
379 let m: u64 = x[..x.len() - 1].iter()
380 .filter_map(|d| d.static_value())
381 .map(|v| v as u64)
382 .product::<u64>()
383 .max(1);
384 let k = x.last()?.static_value()? as u64;
385 let n = w.last()?.static_value()? as u64;
386 Some(2 * m * n * k + n)
387 }
388
389 TensorOp::Conv2D | TensorOp::DepthwiseConv2D | TensorOp::DilatedConv2D => {
390 if inputs.len() < 2 { return None; }
391 let kernel = &inputs[1].shape;
392 let cout = kernel[0].static_value()? as u64;
393 let cin = kernel[1].static_value()? as u64;
394 let kh = kernel[2].static_value()? as u64;
395 let kw = kernel[3].static_value()? as u64;
396 let input = &inputs[0].shape;
397 let n = input[0].static_value()? as u64;
398 let ih = input[2].static_value()? as u64;
399 let iw = input[3].static_value()? as u64;
400 let oh = ih.saturating_sub(kh) + 1;
401 let ow = iw.saturating_sub(kw) + 1;
402 Some(2 * n * cout * cin * kh * kw * oh * ow)
403 }
404
405 TensorOp::Conv1D => {
406 if inputs.len() < 2 { return None; }
407 let kernel = &inputs[1].shape;
408 let cout = kernel[0].static_value()? as u64;
409 let cin = kernel[1].static_value()? as u64;
410 let k = kernel[2].static_value()? as u64;
411 let input = &inputs[0].shape;
412 let n = input[0].static_value()? as u64;
413 let il = input[2].static_value()? as u64;
414 let ol = il.saturating_sub(k) + 1;
415 Some(2 * n * cout * cin * k * ol)
416 }
417
418 TensorOp::Conv3D => {
419 if inputs.len() < 2 { return None; }
420 let kernel = &inputs[1].shape;
421 let cout = kernel.get(0)?.static_value()? as u64;
422 let cin = kernel.get(1)?.static_value()? as u64;
423 let kd = kernel.get(2)?.static_value()? as u64;
424 let kh = kernel.get(3)?.static_value()? as u64;
425 let kw = kernel.get(4)?.static_value()? as u64;
426 let input = &inputs[0].shape;
427 let n = input.get(0)?.static_value()? as u64;
428 let id = input.get(2)?.static_value()? as u64;
429 let ih = input.get(3)?.static_value()? as u64;
430 let iw = input.get(4)?.static_value()? as u64;
431 let od = id.saturating_sub(kd) + 1;
432 let oh = ih.saturating_sub(kh) + 1;
433 let ow = iw.saturating_sub(kw) + 1;
434 Some(2 * n * cout * cin * kd * kh * kw * od * oh * ow)
435 }
436
437 TensorOp::Attention | TensorOp::MultiHeadAttention |
439 TensorOp::MultiQueryAttention | TensorOp::GroupedQueryAttention |
440 TensorOp::FlashAttention | TensorOp::SlidingWindowAttention |
441 TensorOp::CrossAttention => {
442 if inputs.is_empty() { return None; }
443 let shape = &inputs[0].shape;
444 if shape.len() < 3 { return None; }
445 let b = shape[0].static_value().unwrap_or(1) as u64;
446 let s = shape[shape.len() - 2].static_value()? as u64;
447 let d = shape.last()?.static_value()? as u64;
448 let h = if shape.len() >= 4 {
449 shape[1].static_value().unwrap_or(1) as u64
450 } else { 1 };
451 Some(4 * b * h * s * s * d)
452 }
453
454 TensorOp::LSTMCell => {
456 if inputs.len() < 2 { return None; }
458 let input_size = inputs[0].shape.last()?.static_value()? as u64;
459 let hidden_size = inputs[1].shape.last()?.static_value()? as u64;
460 Some(8 * (input_size + hidden_size) * hidden_size)
461 }
462
463 TensorOp::GRUCell => {
464 if inputs.len() < 2 { return None; }
465 let input_size = inputs[0].shape.last()?.static_value()? as u64;
466 let hidden_size = inputs[1].shape.last()?.static_value()? as u64;
467 Some(6 * (input_size + hidden_size) * hidden_size)
468 }
469
470 TensorOp::RNNCell => {
471 if inputs.len() < 2 { return None; }
472 let input_size = inputs[0].shape.last()?.static_value()? as u64;
473 let hidden_size = inputs[1].shape.last()?.static_value()? as u64;
474 Some(2 * (input_size + hidden_size) * hidden_size)
475 }
476
477 TensorOp::FFT | TensorOp::IFFT => {
479 if inputs.is_empty() { return None; }
480 let n = element_count(&inputs[0].shape)? as u64;
481 if n == 0 { return Some(0); }
482 let log2n = (n as f64).log2().ceil() as u64;
483 Some(5 * n * log2n)
484 }
485
486 TensorOp::MaxPool2D | TensorOp::AvgPool2D |
488 TensorOp::AdaptiveAvgPool2D | TensorOp::GlobalAvgPool => {
489 if inputs.is_empty() { return None; }
490 Some(element_count(&inputs[0].shape)? as u64)
491 }
492
493 _ if op.is_zero_flop() => Some(0),
495
496 _ => None,
497 }
498 }
499
500 pub fn compute_memory_bytes(op: &TensorOp, inputs: &[&TensorTypeInfo]) -> Option<u64> {
501 match op {
502 TensorOp::MatMul | TensorOp::SparseMatMul => {
503 if inputs.len() != 2 { return None; }
504 let a_bytes = tensor_bytes(inputs[0])? as u64;
505 let b_bytes = tensor_bytes(inputs[1])? as u64;
506 let out_shape = Self::infer_output_shape(op, inputs).ok()?;
507 let out_bytes = if let Some(out) = out_shape.first() {
508 tensor_info_bytes(out)? as u64
509 } else { 0 };
510 Some(a_bytes + b_bytes + out_bytes)
511 }
512 _ => {
513 let total: u64 = inputs.iter()
514 .filter_map(|i| tensor_bytes(i).map(|b| b as u64))
515 .sum();
516 Some(total)
517 }
518 }
519 }
520}
521
522fn broadcast_shapes(a: &[Dimension], b: &[Dimension]) -> Result<Vec<Dimension>, String> {
523 let max_rank = a.len().max(b.len());
524 let mut result = Vec::with_capacity(max_rank);
525
526 for i in 0..max_rank {
527 let da = if i < a.len() { Some(&a[a.len() - 1 - i]) } else { None };
528 let db = if i < b.len() { Some(&b[b.len() - 1 - i]) } else { None };
529
530 let dim = match (da, db) {
531 (Some(a_dim), Some(b_dim)) => {
532 match (a_dim.static_value(), b_dim.static_value()) {
533 (Some(a_val), Some(b_val)) => {
534 if a_val == b_val { Dimension::Constant(a_val) }
535 else if a_val == 1 { Dimension::Constant(b_val) }
536 else if b_val == 1 { Dimension::Constant(a_val) }
537 else { return Err(format!(
538 "Shape broadcast error: {} vs {}", a_val, b_val
539 )); }
540 }
541 _ => Dimension::Symbolic("broadcast".into()),
542 }
543 }
544 (Some(d), None) | (None, Some(d)) => d.clone(),
545 (None, None) => unreachable!(),
546 };
547 result.push(dim);
548 }
549
550 result.reverse();
551 Ok(result)
552}
553
554fn element_count(shape: &[Dimension]) -> Option<usize> {
555 let mut count = 1usize;
556 for dim in shape {
557 count = count.checked_mul(dim.static_value()?)?;
558 }
559 Some(count)
560}
561
562fn tensor_bytes(info: &TensorTypeInfo) -> Option<usize> {
563 Some(element_count(&info.shape)? * info.dtype.byte_size())
564}
565
566fn tensor_info_bytes(info: &TensorTypeInfo) -> Option<usize> {
567 tensor_bytes(info)
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use lift_core::types::{DataType, MemoryLayout};
574
575 fn make_tensor(shape: Vec<usize>, dtype: DataType) -> TensorTypeInfo {
576 TensorTypeInfo {
577 shape: shape.into_iter().map(Dimension::Constant).collect(),
578 dtype,
579 layout: MemoryLayout::Contiguous,
580 }
581 }
582
583 #[test]
584 fn test_matmul_shape() {
585 let a = make_tensor(vec![2, 3, 4], DataType::FP32);
586 let b = make_tensor(vec![2, 4, 5], DataType::FP32);
587 let result = ShapeInference::infer_output_shape(
588 &TensorOp::MatMul, &[&a, &b]
589 ).unwrap();
590 assert_eq!(result.len(), 1);
591 let shape = &result[0].shape;
592 assert_eq!(shape.len(), 3);
593 assert_eq!(shape[0].static_value(), Some(2));
594 assert_eq!(shape[1].static_value(), Some(3));
595 assert_eq!(shape[2].static_value(), Some(5));
596 }
597
598 #[test]
599 fn test_matmul_dimension_mismatch() {
600 let a = make_tensor(vec![3, 4], DataType::FP32);
601 let b = make_tensor(vec![5, 6], DataType::FP32);
602 let result = ShapeInference::infer_output_shape(
603 &TensorOp::MatMul, &[&a, &b]
604 );
605 assert!(result.is_err());
606 }
607
608 #[test]
609 fn test_matmul_flops() {
610 let a = make_tensor(vec![2, 3], DataType::FP32);
611 let b = make_tensor(vec![3, 4], DataType::FP32);
612 let flops = ShapeInference::compute_flops(&TensorOp::MatMul, &[&a, &b]);
613 assert_eq!(flops, Some(2 * 2 * 4 * 3)); }
615
616 #[test]
617 fn test_relu_shape() {
618 let a = make_tensor(vec![2, 3, 4], DataType::FP32);
619 let result = ShapeInference::infer_output_shape(
620 &TensorOp::ReLU, &[&a]
621 ).unwrap();
622 assert_eq!(result[0].shape, a.shape);
623 }
624
625 #[test]
626 fn test_linear_shape() {
627 let x = make_tensor(vec![1, 784], DataType::FP32);
628 let w = make_tensor(vec![784, 64], DataType::FP32);
629 let b = make_tensor(vec![64], DataType::FP32);
630 let result = ShapeInference::infer_output_shape(
631 &TensorOp::Linear, &[&x, &w, &b]
632 ).unwrap();
633 assert_eq!(result[0].shape[0].static_value(), Some(1));
634 assert_eq!(result[0].shape[1].static_value(), Some(64));
635 }
636
637 #[test]
638 fn test_conv2d_shape() {
639 let input = make_tensor(vec![1, 3, 28, 28], DataType::FP32);
640 let kernel = make_tensor(vec![16, 3, 5, 5], DataType::FP32);
641 let result = ShapeInference::infer_output_shape(
642 &TensorOp::Conv2D, &[&input, &kernel]
643 ).unwrap();
644 assert_eq!(result[0].shape[0].static_value(), Some(1));
645 assert_eq!(result[0].shape[1].static_value(), Some(16));
646 assert_eq!(result[0].shape[2].static_value(), Some(24)); assert_eq!(result[0].shape[3].static_value(), Some(24));
648 }
649}