Skip to main content

lift_tensor/
shape.rs

1use lift_core::types::{Dimension, TensorTypeInfo};
2use crate::ops::TensorOp;
3
4#[derive(Debug)]
5pub struct ShapeInference;
6
7impl ShapeInference {
8    pub fn infer_output_shape(
9        op: &TensorOp,
10        inputs: &[&TensorTypeInfo],
11    ) -> Result<Vec<TensorTypeInfo>, String> {
12        match op {
13            // ── Binary element-wise (broadcast) ──
14            TensorOp::Add | TensorOp::Sub | TensorOp::Mul | TensorOp::Div => {
15                if inputs.len() != 2 {
16                    return Err(format!("{} requires 2 inputs", op.name()));
17                }
18                let result = broadcast_shapes(&inputs[0].shape, &inputs[1].shape)?;
19                Ok(vec![TensorTypeInfo {
20                    shape: result,
21                    dtype: inputs[0].dtype,
22                    layout: inputs[0].layout,
23                }])
24            }
25
26            // ── Unary shape-preserving ──
27            TensorOp::Neg | TensorOp::ReLU | TensorOp::GeLU | TensorOp::SiLU |
28            TensorOp::Sigmoid | TensorOp::Tanh |
29            TensorOp::LeakyReLU | TensorOp::ELU | TensorOp::Mish |
30            TensorOp::HardSwish | TensorOp::HardSigmoid |
31            TensorOp::Softmax | TensorOp::Cumsum |
32            TensorOp::Quantize | TensorOp::Dequantize |
33            TensorOp::QuantizeInt4 | TensorOp::DequantizeInt4 |
34            TensorOp::QuantizeFp8 | TensorOp::DequantizeFp8 |
35            TensorOp::Checkpoint | TensorOp::Offload |
36            TensorOp::GradReLU | TensorOp::GradGeLU | TensorOp::GradSoftmax => {
37                if inputs.is_empty() {
38                    return Err(format!("{} requires at least 1 input", op.name()));
39                }
40                Ok(vec![inputs[0].clone()])
41            }
42
43            // ── Normalisation (shape-preserving) ──
44            TensorOp::LayerNorm | TensorOp::RMSNorm | TensorOp::BatchNorm |
45            TensorOp::GroupNorm | TensorOp::InstanceNorm |
46            TensorOp::GradLayerNorm => {
47                if inputs.is_empty() {
48                    return Err(format!("{} requires at least 1 input", op.name()));
49                }
50                Ok(vec![inputs[0].clone()])
51            }
52
53            // ── MatMul ──
54            TensorOp::MatMul | TensorOp::SparseMatMul => {
55                if inputs.len() != 2 {
56                    return Err("matmul requires 2 inputs".into());
57                }
58                let a = &inputs[0].shape;
59                let b = &inputs[1].shape;
60                if a.len() < 2 || b.len() < 2 {
61                    return Err("matmul inputs must be at least 2D".into());
62                }
63                let m = a[a.len() - 2].clone();
64                let n = b[b.len() - 1].clone();
65
66                let k_a = &a[a.len() - 1];
67                let k_b = &b[b.len() - 2];
68                if let (Some(ka), Some(kb)) = (k_a.static_value(), k_b.static_value()) {
69                    if ka != kb {
70                        return Err(format!(
71                            "matmul inner dimension mismatch: {} vs {}", ka, kb
72                        ));
73                    }
74                }
75
76                let mut result_shape = Vec::new();
77                let batch_a = &a[..a.len() - 2];
78                let batch_b = &b[..b.len() - 2];
79                let batch = broadcast_shapes(batch_a, batch_b)?;
80                result_shape.extend(batch);
81                result_shape.push(m);
82                result_shape.push(n);
83
84                Ok(vec![TensorTypeInfo {
85                    shape: result_shape,
86                    dtype: inputs[0].dtype,
87                    layout: inputs[0].layout,
88                }])
89            }
90
91            // ── Linear ──
92            TensorOp::Linear => {
93                if inputs.len() < 2 {
94                    return Err("linear requires at least 2 inputs (x, W)".into());
95                }
96                let x = &inputs[0].shape;
97                let w = &inputs[1].shape;
98                if x.is_empty() || w.len() != 2 {
99                    return Err("linear: x must be at least 1D, W must be 2D".into());
100                }
101                let mut result_shape = x[..x.len() - 1].to_vec();
102                result_shape.push(w[1].clone());
103
104                Ok(vec![TensorTypeInfo {
105                    shape: result_shape,
106                    dtype: inputs[0].dtype,
107                    layout: inputs[0].layout,
108                }])
109            }
110
111            // ── Conv2D ──
112            TensorOp::Conv2D | TensorOp::DepthwiseConv2D | TensorOp::DilatedConv2D => {
113                if inputs.len() < 2 {
114                    return Err("conv2d requires at least 2 inputs (input, kernel)".into());
115                }
116                let input = &inputs[0].shape;
117                let kernel = &inputs[1].shape;
118                if input.len() != 4 || kernel.len() != 4 {
119                    return Err("conv2d: input and kernel must be 4D (NCHW)".into());
120                }
121
122                let n = input[0].clone();
123                let cout = kernel[0].clone();
124                let h_out = match (&input[2], &kernel[2]) {
125                    (Dimension::Constant(ih), Dimension::Constant(kh)) => {
126                        Dimension::Constant(ih - kh + 1)
127                    }
128                    _ => Dimension::Symbolic("H_out".into()),
129                };
130                let w_out = match (&input[3], &kernel[3]) {
131                    (Dimension::Constant(iw), Dimension::Constant(kw)) => {
132                        Dimension::Constant(iw - kw + 1)
133                    }
134                    _ => Dimension::Symbolic("W_out".into()),
135                };
136
137                Ok(vec![TensorTypeInfo {
138                    shape: vec![n, cout, h_out, w_out],
139                    dtype: inputs[0].dtype,
140                    layout: inputs[0].layout,
141                }])
142            }
143
144            // ── Conv1D ──
145            TensorOp::Conv1D => {
146                if inputs.len() < 2 {
147                    return Err("conv1d requires at least 2 inputs".into());
148                }
149                let input = &inputs[0].shape;
150                let kernel = &inputs[1].shape;
151                if input.len() != 3 || kernel.len() != 3 {
152                    return Err("conv1d: input [N,C,L] and kernel [Cout,Cin,K]".into());
153                }
154                let n = input[0].clone();
155                let cout = kernel[0].clone();
156                let l_out = match (&input[2], &kernel[2]) {
157                    (Dimension::Constant(il), Dimension::Constant(kl)) => {
158                        Dimension::Constant(il - kl + 1)
159                    }
160                    _ => Dimension::Symbolic("L_out".into()),
161                };
162                Ok(vec![TensorTypeInfo {
163                    shape: vec![n, cout, l_out],
164                    dtype: inputs[0].dtype,
165                    layout: inputs[0].layout,
166                }])
167            }
168
169            // ── Conv3D ──
170            TensorOp::Conv3D => {
171                if inputs.len() < 2 {
172                    return Err("conv3d requires at least 2 inputs".into());
173                }
174                let input = &inputs[0].shape;
175                let kernel = &inputs[1].shape;
176                if input.len() != 5 || kernel.len() != 5 {
177                    return Err("conv3d: input [N,C,D,H,W] and kernel [Cout,Cin,Kd,Kh,Kw]".into());
178                }
179                let n = input[0].clone();
180                let cout = kernel[0].clone();
181                let dims: Vec<Dimension> = (2..5).map(|i| {
182                    match (&input[i], &kernel[i]) {
183                        (Dimension::Constant(iv), Dimension::Constant(kv)) => {
184                            Dimension::Constant(iv - kv + 1)
185                        }
186                        _ => Dimension::Symbolic(format!("dim{}_out", i)),
187                    }
188                }).collect();
189                Ok(vec![TensorTypeInfo {
190                    shape: vec![n, cout, dims[0].clone(), dims[1].clone(), dims[2].clone()],
191                    dtype: inputs[0].dtype,
192                    layout: inputs[0].layout,
193                }])
194            }
195
196            // ── Pooling ──
197            TensorOp::MaxPool2D | TensorOp::AvgPool2D => {
198                if inputs.is_empty() {
199                    return Err(format!("{} requires at least 1 input", op.name()));
200                }
201                // Simplified: returns same shape (caller should use attrs for kernel/stride)
202                Ok(vec![inputs[0].clone()])
203            }
204
205            TensorOp::AdaptiveAvgPool2D => {
206                if inputs.is_empty() {
207                    return Err("adaptive_avgpool2d requires 1 input".into());
208                }
209                Ok(vec![inputs[0].clone()])
210            }
211
212            TensorOp::GlobalAvgPool => {
213                if inputs.is_empty() {
214                    return Err("global_avgpool requires 1 input".into());
215                }
216                let shape = &inputs[0].shape;
217                if shape.len() < 3 {
218                    return Err("global_avgpool: input must be at least 3D [N,C,...]".into());
219                }
220                // [N, C, ...] -> [N, C, 1, 1, ...]
221                let mut out = vec![shape[0].clone(), shape[1].clone()];
222                for _ in 2..shape.len() {
223                    out.push(Dimension::Constant(1));
224                }
225                Ok(vec![TensorTypeInfo {
226                    shape: out,
227                    dtype: inputs[0].dtype,
228                    layout: inputs[0].layout,
229                }])
230            }
231
232            // ── Attention variants ──
233            TensorOp::Attention | TensorOp::MultiHeadAttention |
234            TensorOp::MultiQueryAttention | TensorOp::GroupedQueryAttention |
235            TensorOp::FlashAttention | TensorOp::SlidingWindowAttention |
236            TensorOp::CrossAttention | TensorOp::PagedAttention |
237            TensorOp::GradAttention => {
238                if inputs.len() < 3 {
239                    return Err("attention requires at least 3 inputs (Q, K, V)".into());
240                }
241                Ok(vec![inputs[0].clone()])
242            }
243
244            // ── Recurrent ──
245            TensorOp::LSTMCell => {
246                if inputs.len() < 2 {
247                    return Err("lstm_cell requires input and hidden state".into());
248                }
249                // Returns (h_new, c_new) with same shape as hidden
250                Ok(vec![inputs[1].clone(), inputs[1].clone()])
251            }
252
253            TensorOp::GRUCell | TensorOp::RNNCell => {
254                if inputs.len() < 2 {
255                    return Err(format!("{} requires input and hidden state", op.name()));
256                }
257                Ok(vec![inputs[1].clone()])
258            }
259
260            // ── Shape / zero-flop ops ──
261            TensorOp::Reshape | TensorOp::Transpose | TensorOp::Squeeze |
262            TensorOp::Unsqueeze | TensorOp::Permute | TensorOp::Expand |
263            TensorOp::Slice | TensorOp::Pad | TensorOp::Tile => {
264                // These need target shape from attributes; passthrough for now
265                if inputs.is_empty() {
266                    return Err(format!("{} requires at least 1 input", op.name()));
267                }
268                Ok(vec![inputs[0].clone()])
269            }
270
271            // ── Concat ──
272            TensorOp::Concat => {
273                if inputs.is_empty() {
274                    return Err("concat requires at least 1 input".into());
275                }
276                Ok(vec![inputs[0].clone()])
277            }
278
279            // ── TopK / Sort ──
280            TensorOp::TopK | TensorOp::Sort => {
281                if inputs.is_empty() {
282                    return Err(format!("{} requires 1 input", op.name()));
283                }
284                Ok(vec![inputs[0].clone()])
285            }
286
287            // ── FFT / IFFT ──
288            TensorOp::FFT | TensorOp::IFFT => {
289                if inputs.is_empty() {
290                    return Err(format!("{} requires 1 input", op.name()));
291                }
292                Ok(vec![inputs[0].clone()])
293            }
294
295            // ── SVD: returns U, S, V ──
296            TensorOp::SVD => {
297                if inputs.is_empty() {
298                    return Err("svd requires 1 input".into());
299                }
300                Ok(vec![inputs[0].clone()])
301            }
302
303            // ── Where: condition, x, y -> x ──
304            TensorOp::Where | TensorOp::Clamp => {
305                if inputs.len() < 2 {
306                    return Err(format!("{} requires at least 2 inputs", op.name()));
307                }
308                Ok(vec![inputs[0].clone()])
309            }
310
311            _ => {
312                // For ops not yet handled, passthrough first input or empty
313                if !inputs.is_empty() {
314                    Ok(vec![inputs[0].clone()])
315                } else {
316                    Ok(Vec::new())
317                }
318            }
319        }
320    }
321
322    pub fn compute_flops(op: &TensorOp, inputs: &[&TensorTypeInfo]) -> Option<u64> {
323        match op {
324            TensorOp::MatMul | TensorOp::SparseMatMul => {
325                if inputs.len() != 2 { return None; }
326                let a = &inputs[0].shape;
327                let b = &inputs[1].shape;
328                let m = a.get(a.len().checked_sub(2)?)?.static_value()? as u64;
329                let k = a.last()?.static_value()? as u64;
330                let n = b.last()?.static_value()? as u64;
331                let batch: u64 = a[..a.len() - 2].iter()
332                    .filter_map(|d| d.static_value())
333                    .map(|v| v as u64)
334                    .product::<u64>()
335                    .max(1);
336                Some(2 * batch * m * n * k)
337            }
338
339            TensorOp::Add | TensorOp::Sub | TensorOp::Mul | TensorOp::Div => {
340                if inputs.is_empty() { return None; }
341                Some(element_count(&inputs[0].shape)? as u64)
342            }
343
344            TensorOp::ReLU | TensorOp::Sigmoid | TensorOp::Tanh |
345            TensorOp::LeakyReLU | TensorOp::ELU | TensorOp::HardSigmoid => {
346                if inputs.is_empty() { return None; }
347                Some(element_count(&inputs[0].shape)? as u64)
348            }
349
350            TensorOp::GeLU | TensorOp::SiLU | TensorOp::Mish | TensorOp::HardSwish => {
351                if inputs.is_empty() { return None; }
352                let n = element_count(&inputs[0].shape)? as u64;
353                Some(8 * n)
354            }
355
356            TensorOp::Softmax => {
357                if inputs.is_empty() { return None; }
358                let n = element_count(&inputs[0].shape)? as u64;
359                Some(5 * n)
360            }
361
362            TensorOp::LayerNorm | TensorOp::RMSNorm |
363            TensorOp::GroupNorm | TensorOp::InstanceNorm => {
364                if inputs.is_empty() { return None; }
365                let n = element_count(&inputs[0].shape)? as u64;
366                Some(7 * n)
367            }
368
369            TensorOp::BatchNorm => {
370                if inputs.is_empty() { return None; }
371                let n = element_count(&inputs[0].shape)? as u64;
372                Some(5 * n)
373            }
374
375            TensorOp::Linear => {
376                if inputs.len() < 2 { return None; }
377                let x = &inputs[0].shape;
378                let w = &inputs[1].shape;
379                let m: u64 = x[..x.len() - 1].iter()
380                    .filter_map(|d| d.static_value())
381                    .map(|v| v as u64)
382                    .product::<u64>()
383                    .max(1);
384                let k = x.last()?.static_value()? as u64;
385                let n = w.last()?.static_value()? as u64;
386                Some(2 * m * n * k + n)
387            }
388
389            TensorOp::Conv2D | TensorOp::DepthwiseConv2D | TensorOp::DilatedConv2D => {
390                if inputs.len() < 2 { return None; }
391                let kernel = &inputs[1].shape;
392                let cout = kernel[0].static_value()? as u64;
393                let cin = kernel[1].static_value()? as u64;
394                let kh = kernel[2].static_value()? as u64;
395                let kw = kernel[3].static_value()? as u64;
396                let input = &inputs[0].shape;
397                let n = input[0].static_value()? as u64;
398                let ih = input[2].static_value()? as u64;
399                let iw = input[3].static_value()? as u64;
400                let oh = ih.saturating_sub(kh) + 1;
401                let ow = iw.saturating_sub(kw) + 1;
402                Some(2 * n * cout * cin * kh * kw * oh * ow)
403            }
404
405            TensorOp::Conv1D => {
406                if inputs.len() < 2 { return None; }
407                let kernel = &inputs[1].shape;
408                let cout = kernel[0].static_value()? as u64;
409                let cin = kernel[1].static_value()? as u64;
410                let k = kernel[2].static_value()? as u64;
411                let input = &inputs[0].shape;
412                let n = input[0].static_value()? as u64;
413                let il = input[2].static_value()? as u64;
414                let ol = il.saturating_sub(k) + 1;
415                Some(2 * n * cout * cin * k * ol)
416            }
417
418            TensorOp::Conv3D => {
419                if inputs.len() < 2 { return None; }
420                let kernel = &inputs[1].shape;
421                let cout = kernel.get(0)?.static_value()? as u64;
422                let cin = kernel.get(1)?.static_value()? as u64;
423                let kd = kernel.get(2)?.static_value()? as u64;
424                let kh = kernel.get(3)?.static_value()? as u64;
425                let kw = kernel.get(4)?.static_value()? as u64;
426                let input = &inputs[0].shape;
427                let n = input.get(0)?.static_value()? as u64;
428                let id = input.get(2)?.static_value()? as u64;
429                let ih = input.get(3)?.static_value()? as u64;
430                let iw = input.get(4)?.static_value()? as u64;
431                let od = id.saturating_sub(kd) + 1;
432                let oh = ih.saturating_sub(kh) + 1;
433                let ow = iw.saturating_sub(kw) + 1;
434                Some(2 * n * cout * cin * kd * kh * kw * od * oh * ow)
435            }
436
437            // Attention variants: 2*B*H*(S^2*D + S*D^2)
438            TensorOp::Attention | TensorOp::MultiHeadAttention |
439            TensorOp::MultiQueryAttention | TensorOp::GroupedQueryAttention |
440            TensorOp::FlashAttention | TensorOp::SlidingWindowAttention |
441            TensorOp::CrossAttention => {
442                if inputs.is_empty() { return None; }
443                let shape = &inputs[0].shape;
444                if shape.len() < 3 { return None; }
445                let b = shape[0].static_value().unwrap_or(1) as u64;
446                let s = shape[shape.len() - 2].static_value()? as u64;
447                let d = shape.last()?.static_value()? as u64;
448                let h = if shape.len() >= 4 {
449                    shape[1].static_value().unwrap_or(1) as u64
450                } else { 1 };
451                Some(4 * b * h * s * s * d)
452            }
453
454            // Recurrent
455            TensorOp::LSTMCell => {
456                // 4 * (input_size + hidden_size) * hidden_size * 2
457                if inputs.len() < 2 { return None; }
458                let input_size = inputs[0].shape.last()?.static_value()? as u64;
459                let hidden_size = inputs[1].shape.last()?.static_value()? as u64;
460                Some(8 * (input_size + hidden_size) * hidden_size)
461            }
462
463            TensorOp::GRUCell => {
464                if inputs.len() < 2 { return None; }
465                let input_size = inputs[0].shape.last()?.static_value()? as u64;
466                let hidden_size = inputs[1].shape.last()?.static_value()? as u64;
467                Some(6 * (input_size + hidden_size) * hidden_size)
468            }
469
470            TensorOp::RNNCell => {
471                if inputs.len() < 2 { return None; }
472                let input_size = inputs[0].shape.last()?.static_value()? as u64;
473                let hidden_size = inputs[1].shape.last()?.static_value()? as u64;
474                Some(2 * (input_size + hidden_size) * hidden_size)
475            }
476
477            // FFT: 5*N*log2(N)
478            TensorOp::FFT | TensorOp::IFFT => {
479                if inputs.is_empty() { return None; }
480                let n = element_count(&inputs[0].shape)? as u64;
481                if n == 0 { return Some(0); }
482                let log2n = (n as f64).log2().ceil() as u64;
483                Some(5 * n * log2n)
484            }
485
486            // Pooling
487            TensorOp::MaxPool2D | TensorOp::AvgPool2D |
488            TensorOp::AdaptiveAvgPool2D | TensorOp::GlobalAvgPool => {
489                if inputs.is_empty() { return None; }
490                Some(element_count(&inputs[0].shape)? as u64)
491            }
492
493            // Zero-flop ops
494            _ if op.is_zero_flop() => Some(0),
495
496            _ => None,
497        }
498    }
499
500    pub fn compute_memory_bytes(op: &TensorOp, inputs: &[&TensorTypeInfo]) -> Option<u64> {
501        match op {
502            TensorOp::MatMul | TensorOp::SparseMatMul => {
503                if inputs.len() != 2 { return None; }
504                let a_bytes = tensor_bytes(inputs[0])? as u64;
505                let b_bytes = tensor_bytes(inputs[1])? as u64;
506                let out_shape = Self::infer_output_shape(op, inputs).ok()?;
507                let out_bytes = if let Some(out) = out_shape.first() {
508                    tensor_info_bytes(out)? as u64
509                } else { 0 };
510                Some(a_bytes + b_bytes + out_bytes)
511            }
512            _ => {
513                let total: u64 = inputs.iter()
514                    .filter_map(|i| tensor_bytes(i).map(|b| b as u64))
515                    .sum();
516                Some(total)
517            }
518        }
519    }
520}
521
522fn broadcast_shapes(a: &[Dimension], b: &[Dimension]) -> Result<Vec<Dimension>, String> {
523    let max_rank = a.len().max(b.len());
524    let mut result = Vec::with_capacity(max_rank);
525
526    for i in 0..max_rank {
527        let da = if i < a.len() { Some(&a[a.len() - 1 - i]) } else { None };
528        let db = if i < b.len() { Some(&b[b.len() - 1 - i]) } else { None };
529
530        let dim = match (da, db) {
531            (Some(a_dim), Some(b_dim)) => {
532                match (a_dim.static_value(), b_dim.static_value()) {
533                    (Some(a_val), Some(b_val)) => {
534                        if a_val == b_val { Dimension::Constant(a_val) }
535                        else if a_val == 1 { Dimension::Constant(b_val) }
536                        else if b_val == 1 { Dimension::Constant(a_val) }
537                        else { return Err(format!(
538                            "Shape broadcast error: {} vs {}", a_val, b_val
539                        )); }
540                    }
541                    _ => Dimension::Symbolic("broadcast".into()),
542                }
543            }
544            (Some(d), None) | (None, Some(d)) => d.clone(),
545            (None, None) => unreachable!(),
546        };
547        result.push(dim);
548    }
549
550    result.reverse();
551    Ok(result)
552}
553
554fn element_count(shape: &[Dimension]) -> Option<usize> {
555    let mut count = 1usize;
556    for dim in shape {
557        count = count.checked_mul(dim.static_value()?)?;
558    }
559    Some(count)
560}
561
562fn tensor_bytes(info: &TensorTypeInfo) -> Option<usize> {
563    Some(element_count(&info.shape)? * info.dtype.byte_size())
564}
565
566fn tensor_info_bytes(info: &TensorTypeInfo) -> Option<usize> {
567    tensor_bytes(info)
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use lift_core::types::{DataType, MemoryLayout};
574
575    fn make_tensor(shape: Vec<usize>, dtype: DataType) -> TensorTypeInfo {
576        TensorTypeInfo {
577            shape: shape.into_iter().map(Dimension::Constant).collect(),
578            dtype,
579            layout: MemoryLayout::Contiguous,
580        }
581    }
582
583    #[test]
584    fn test_matmul_shape() {
585        let a = make_tensor(vec![2, 3, 4], DataType::FP32);
586        let b = make_tensor(vec![2, 4, 5], DataType::FP32);
587        let result = ShapeInference::infer_output_shape(
588            &TensorOp::MatMul, &[&a, &b]
589        ).unwrap();
590        assert_eq!(result.len(), 1);
591        let shape = &result[0].shape;
592        assert_eq!(shape.len(), 3);
593        assert_eq!(shape[0].static_value(), Some(2));
594        assert_eq!(shape[1].static_value(), Some(3));
595        assert_eq!(shape[2].static_value(), Some(5));
596    }
597
598    #[test]
599    fn test_matmul_dimension_mismatch() {
600        let a = make_tensor(vec![3, 4], DataType::FP32);
601        let b = make_tensor(vec![5, 6], DataType::FP32);
602        let result = ShapeInference::infer_output_shape(
603            &TensorOp::MatMul, &[&a, &b]
604        );
605        assert!(result.is_err());
606    }
607
608    #[test]
609    fn test_matmul_flops() {
610        let a = make_tensor(vec![2, 3], DataType::FP32);
611        let b = make_tensor(vec![3, 4], DataType::FP32);
612        let flops = ShapeInference::compute_flops(&TensorOp::MatMul, &[&a, &b]);
613        assert_eq!(flops, Some(2 * 2 * 4 * 3)); // 2*M*N*K
614    }
615
616    #[test]
617    fn test_relu_shape() {
618        let a = make_tensor(vec![2, 3, 4], DataType::FP32);
619        let result = ShapeInference::infer_output_shape(
620            &TensorOp::ReLU, &[&a]
621        ).unwrap();
622        assert_eq!(result[0].shape, a.shape);
623    }
624
625    #[test]
626    fn test_linear_shape() {
627        let x = make_tensor(vec![1, 784], DataType::FP32);
628        let w = make_tensor(vec![784, 64], DataType::FP32);
629        let b = make_tensor(vec![64], DataType::FP32);
630        let result = ShapeInference::infer_output_shape(
631            &TensorOp::Linear, &[&x, &w, &b]
632        ).unwrap();
633        assert_eq!(result[0].shape[0].static_value(), Some(1));
634        assert_eq!(result[0].shape[1].static_value(), Some(64));
635    }
636
637    #[test]
638    fn test_conv2d_shape() {
639        let input = make_tensor(vec![1, 3, 28, 28], DataType::FP32);
640        let kernel = make_tensor(vec![16, 3, 5, 5], DataType::FP32);
641        let result = ShapeInference::infer_output_shape(
642            &TensorOp::Conv2D, &[&input, &kernel]
643        ).unwrap();
644        assert_eq!(result[0].shape[0].static_value(), Some(1));
645        assert_eq!(result[0].shape[1].static_value(), Some(16));
646        assert_eq!(result[0].shape[2].static_value(), Some(24)); // 28-5+1
647        assert_eq!(result[0].shape[3].static_value(), Some(24));
648    }
649}