use crate::graph::Node;
use crate::tensor::Tensor;
use std::collections::HashMap;
use crate::optimizer::shape_inference::get_input_shape;
pub(super) fn infer_gather_nd_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
) -> Option<Vec<Vec<usize>>> {
let data_shape = get_input_shape(node, 0, known)?;
let indices_shape = get_input_shape(node, 1, known)?;
let batch_dims = node.attrs.i("batch_dims", 0) as usize;
if indices_shape.is_empty() {
return None;
}
let last_idx_dim = *indices_shape.last()?;
let data_rank = data_shape.len();
if batch_dims + last_idx_dim > data_rank {
return None;
}
let mut out = Vec::new();
out.extend_from_slice(&indices_shape[..indices_shape.len() - 1]);
out.extend_from_slice(&data_shape[batch_dims + last_idx_dim..]);
Some(vec![out])
}
pub(super) fn infer_onehot_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let indices_shape = get_input_shape(node, 0, known)?;
let axis = node.attrs.i("axis", -1);
let depth_name = node.inputs.get(1)?;
if depth_name.is_empty() {
return None;
}
let depth = if let Some(depth_tensor) = weights.get(depth_name) {
if depth_tensor.data.is_empty() {
return None;
}
depth_tensor.data[0] as usize
} else {
return None;
};
let rank = indices_shape.len() as i64 + 1;
let norm_axis = if axis < 0 {
(axis + rank) as usize
} else {
axis as usize
};
let mut out = indices_shape;
if norm_axis > out.len() {
return None;
}
out.insert(norm_axis, depth);
Some(vec![out])
}
pub(super) fn infer_topk_shape(
node: &Node,
known: &HashMap<String, Vec<usize>>,
weights: &HashMap<String, Tensor>,
) -> Option<Vec<Vec<usize>>> {
let input_shape = get_input_shape(node, 0, known)?;
let rank = input_shape.len() as i64;
let axis_raw = node.attrs.i("axis", -1);
let axis = if axis_raw < 0 {
(axis_raw + rank) as usize
} else {
axis_raw as usize
};
if axis >= input_shape.len() {
return None;
}
let k_name = node.inputs.get(1)?;
if k_name.is_empty() {
return None;
}
let k = if let Some(k_tensor) = weights.get(k_name) {
if k_tensor.data.is_empty() {
return None;
}
k_tensor.data[0] as usize
} else {
return None;
};
let mut out = input_shape;
out[axis] = k;
Some(vec![out.clone(), out])
}