use crate::graph::Node;
use crate::tensor::Tensor;
use std::collections::HashMap;
use super::helpers::get_input_shape;
pub(super) fn infer_reshape_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let total_elements: usize = input_shape.iter().product();
let shape_name = node.inputs.get(1)?;
if shape_name.is_empty() {
return None;
}
let shape_tensor = weights.get(shape_name)?;
let mut out_shape: Vec<usize> = Vec::with_capacity(shape_tensor.data.len());
let mut neg_one_idx: Option<usize> = None;
for (i, &val) in shape_tensor.data.iter().enumerate() {
let dim = val as i64;
if dim == -1 {
if neg_one_idx.is_some() {
return None; }
neg_one_idx = Some(i);
out_shape.push(0); } else if dim == 0 {
if i < input_shape.len() {
out_shape.push(input_shape[i]);
} else {
return None;
}
} else if dim > 0 {
out_shape.push(dim as usize);
} else {
return None; }
}
if let Some(idx) = neg_one_idx {
let known_product: usize = out_shape
.iter()
.enumerate()
.filter(|(j, _)| *j != idx)
.map(|(_, &v)| v)
.product();
if known_product == 0 {
return None;
}
out_shape[idx] = total_elements / known_product;
}
Some(vec![out_shape])
}
pub(super) fn infer_transpose_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let rank = input_shape.len();
let perm: Vec<usize> = if let Some(p) = node.attrs.int_lists.get("perm") {
if p.is_empty() {
(0..rank).rev().collect()
} else {
p.iter().map(|&v| v as usize).collect()
}
} else {
(0..rank).rev().collect()
};
if perm.len() != rank {
return None;
}
if perm.iter().any(|&p| p >= rank) {
return None;
}
let out: Vec<usize> = perm.iter().map(|&p| input_shape[p]).collect();
Some(vec![out])
}
pub(super) fn infer_squeeze_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let rank = input_shape.len() as i64;
let axes: Vec<i64> = node.attrs.ints("axes").to_vec();
if axes.is_empty() {
let out: Vec<usize> = input_shape.iter().copied().filter(|&d| d != 1).collect();
return Some(vec![out]);
}
let normalized: Vec<usize> = axes
.iter()
.map(|&a| {
if a < 0 {
(a + rank) as usize
} else {
a as usize
}
})
.collect();
let out: Vec<usize> = input_shape
.iter()
.enumerate()
.filter(|(i, _)| !normalized.contains(i))
.map(|(_, &d)| d)
.collect();
Some(vec![out])
}
pub(super) fn infer_unsqueeze_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let axes: Vec<i64> = node.attrs.ints("axes").to_vec();
if axes.is_empty() {
return Some(vec![input_shape]);
}
let out_rank = input_shape.len() + axes.len();
let mut normalized: Vec<usize> = axes
.iter()
.map(|&a| {
if a < 0 {
(a + out_rank as i64) as usize
} else {
a as usize
}
})
.collect();
normalized.sort();
let mut out = Vec::with_capacity(out_rank);
let mut src_idx = 0;
for i in 0..out_rank {
if normalized.contains(&i) {
out.push(1);
} else if src_idx < input_shape.len() {
out.push(input_shape[src_idx]);
src_idx += 1;
} else {
return None;
}
}
Some(vec![out])
}
pub(super) fn infer_flatten_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let rank = input_shape.len() as i64;
let axis_raw = node.attrs.i("axis", 1);
let axis = if axis_raw < 0 {
(axis_raw + rank) as usize
} else {
axis_raw as usize
};
if axis > input_shape.len() {
return None;
}
let d0: usize = input_shape[..axis].iter().product();
let d1: usize = input_shape[axis..].iter().product();
Some(vec![vec![d0, d1]])
}