use crate::graph::Node;
use crate::tensor::Tensor;
use std::collections::HashMap;
use crate::optimizer::shape_inference::get_input_shape;
pub(super) fn infer_constant_of_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let shape_name = node.inputs.first()?;
if shape_name.is_empty() {
return None;
}
if let Some(shape_tensor) = weights.get(shape_name) {
let out: Vec<usize> = shape_tensor.data.iter().map(|&v| v as usize).collect();
return Some(vec![out]);
}
let _shape = known.get(shape_name)?;
None
}
pub(super) fn infer_global_pool_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
if input_shape.len() < 2 {
return None;
}
let mut out = vec![input_shape[0], input_shape[1]];
out.extend(std::iter::repeat(1).take(input_shape.len() - 2));
Some(vec![out])
}
pub(super) fn infer_pool_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
if input_shape.len() < 3 {
return None;
}
let n = input_shape[0];
let c = input_shape[1];
let spatial_dims = input_shape.len() - 2;
let kernel_shape_attr: Vec<i64> = node.attrs.ints("kernel_shape").to_vec();
if kernel_shape_attr.is_empty() {
return None;
}
let kernel_shape: Vec<usize> = kernel_shape_attr.iter().map(|&k| k as usize).collect();
if kernel_shape.len() != spatial_dims {
return None;
}
let strides_attr: Vec<i64> = node.attrs.ints("strides").to_vec();
let strides: Vec<usize> = if strides_attr.is_empty() {
vec![1; spatial_dims]
} else {
strides_attr.iter().map(|&s| s as usize).collect()
};
let dilations_attr: Vec<i64> = node.attrs.ints("dilations").to_vec();
let dilations: Vec<usize> = if dilations_attr.is_empty() {
vec![1; spatial_dims]
} else {
dilations_attr.iter().map(|&d| d as usize).collect()
};
let pads_attr: Vec<i64> = node.attrs.ints("pads").to_vec();
let pads: Vec<usize> = if pads_attr.is_empty() {
vec![0; spatial_dims * 2]
} else {
pads_attr.iter().map(|&p| p as usize).collect()
};
if pads.len() != spatial_dims * 2 {
return None;
}
let ceil_mode = node.attrs.i("ceil_mode", 0) != 0;
let mut out_shape = vec![n, c];
for d in 0..spatial_dims {
let input_dim = input_shape[d + 2];
let effective_kernel = (kernel_shape[d] - 1) * dilations[d] + 1;
let padded = input_dim + pads[d] + pads[d + spatial_dims];
if padded < effective_kernel {
return None;
}
let out_dim = if ceil_mode {
(padded - effective_kernel).div_ceil(strides[d]) + 1
} else {
(padded - effective_kernel) / strides[d] + 1
};
out_shape.push(out_dim);
}
Some(vec![out_shape])
}
pub(super) fn infer_expand_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let shape_name = node.inputs.get(1)?;
if shape_name.is_empty() {
return None;
}
if let Some(shape_tensor) = weights.get(shape_name) {
let target: Vec<usize> = shape_tensor.data.iter().map(|&v| v as usize).collect();
let out = Tensor::broadcast_shape(&input_shape, &target).ok()?;
return Some(vec![out]);
}
None
}
pub(super) fn infer_tile_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let repeats_name = node.inputs.get(1)?;
if repeats_name.is_empty() {
return None;
}
if let Some(repeats_tensor) = weights.get(repeats_name) {
let repeats: Vec<usize> = repeats_tensor.data.iter().map(|&v| v as usize).collect();
if repeats.len() != input_shape.len() {
return None;
}
let out: Vec<usize> = input_shape
.iter()
.zip(repeats.iter())
.map(|(&d, &r)| d * r)
.collect();
return Some(vec![out]);
}
None
}
pub(super) fn infer_pad_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let pads_name = node.inputs.get(1)?;
if pads_name.is_empty() {
return None;
}
if let Some(pads_tensor) = weights.get(pads_name) {
let pads: Vec<i64> = pads_tensor.data.iter().map(|&v| v as i64).collect();
let rank = input_shape.len();
if pads.len() != rank * 2 {
return None;
}
let mut out = Vec::with_capacity(rank);
for i in 0..rank {
let padded = input_shape[i] as i64 + pads[i] + pads[i + rank];
if padded < 0 {
return None;
}
out.push(padded as usize);
}
return Some(vec![out]);
}
None
}
pub(super) fn infer_resize_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
if let Some(sizes_name) = node.inputs.get(3) {
if !sizes_name.is_empty() {
if let Some(sizes_tensor) = weights.get(sizes_name) {
let out: Vec<usize> = sizes_tensor.data.iter().map(|&v| v as usize).collect();
if out.len() == input_shape.len() {
return Some(vec![out]);
}
}
}
}
if let Some(scales_name) = node.inputs.get(2) {
if !scales_name.is_empty() {
if let Some(scales_tensor) = weights.get(scales_name) {
if scales_tensor.data.len() == input_shape.len() {
let out: Vec<usize> = input_shape
.iter()
.zip(scales_tensor.data.iter())
.map(|(&d, &s)| (d as f32 * s) as usize)
.collect();
return Some(vec![out]);
}
}
}
}
None
}