use std::collections::HashMap;
use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DimSize {
Static(usize),
Dynamic,
Symbolic(String),
}
impl DimSize {
pub fn is_static(&self) -> bool {
matches!(self, DimSize::Static(_))
}
pub fn as_static(&self) -> Option<usize> {
match self {
DimSize::Static(size) => Some(*size),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TensorShape {
pub dims: Vec<DimSize>,
}
impl TensorShape {
pub fn new(dims: Vec<DimSize>) -> Self {
TensorShape { dims }
}
pub fn static_shape(sizes: Vec<usize>) -> Self {
TensorShape {
dims: sizes.into_iter().map(DimSize::Static).collect(),
}
}
pub fn dynamic(rank: usize) -> Self {
TensorShape {
dims: vec![DimSize::Dynamic; rank],
}
}
pub fn rank(&self) -> usize {
self.dims.len()
}
pub fn is_static(&self) -> bool {
self.dims.iter().all(|d| d.is_static())
}
pub fn as_static(&self) -> Option<Vec<usize>> {
self.dims.iter().map(|d| d.as_static()).collect()
}
pub fn compatible_with(&self, other: &TensorShape) -> bool {
if self.rank() != other.rank() {
return false;
}
for (a, b) in self.dims.iter().zip(other.dims.iter()) {
match (a, b) {
(DimSize::Static(size_a), DimSize::Static(size_b))
if size_a != size_b && *size_a != 1 && *size_b != 1 =>
{
return false;
}
_ => {
}
}
}
true
}
}
pub struct ShapeInferenceContext {
tensor_shapes: HashMap<usize, TensorShape>,
}
impl ShapeInferenceContext {
pub fn new() -> Self {
ShapeInferenceContext {
tensor_shapes: HashMap::new(),
}
}
pub fn set_tensor_shape(&mut self, tensor_idx: usize, shape: TensorShape) {
self.tensor_shapes.insert(tensor_idx, shape);
}
pub fn get_tensor_shape(&self, tensor_idx: usize) -> Option<&TensorShape> {
self.tensor_shapes.get(&tensor_idx)
}
pub fn infer_graph_shapes(
&mut self,
graph: &EinsumGraph,
input_shapes: &HashMap<usize, TensorShape>,
) -> Result<(), String> {
for (idx, shape) in input_shapes {
self.tensor_shapes.insert(*idx, shape.clone());
}
for (node_idx, node) in graph.nodes.iter().enumerate() {
let output_idx = node_idx + graph.tensors.len(); let output_shape = self.infer_node_shape(node)?;
self.tensor_shapes.insert(output_idx, output_shape);
}
Ok(())
}
fn infer_node_shape(&self, node: &EinsumNode) -> Result<TensorShape, String> {
match &node.op {
OpType::Einsum { spec } => {
self.infer_einsum_shape(spec, &node.inputs)
}
OpType::ElemUnary { op: _ } => {
if let Some(input_shape) = self.get_tensor_shape(node.inputs[0]) {
Ok(input_shape.clone())
} else {
Err("Input shape not available for unary op".to_string())
}
}
OpType::ElemBinary { op: _ } => {
if node.inputs.len() < 2 {
return Err("Binary op requires 2 inputs".to_string());
}
let shape_a = self
.get_tensor_shape(node.inputs[0])
.ok_or("Input 0 shape not available")?;
let shape_b = self
.get_tensor_shape(node.inputs[1])
.ok_or("Input 1 shape not available")?;
if !shape_a.compatible_with(shape_b) {
return Err(format!(
"Incompatible shapes for binary op: {:?} vs {:?}",
shape_a, shape_b
));
}
Ok(shape_a.clone())
}
OpType::Reduce { op: _, axes } => {
if let Some(input_shape) = self.get_tensor_shape(node.inputs[0]) {
let mut output_dims = input_shape.dims.clone();
for &axis in axes.iter().rev() {
if axis < output_dims.len() {
output_dims.remove(axis);
}
}
Ok(TensorShape::new(output_dims))
} else {
Err("Input shape not available for reduce op".to_string())
}
}
}
}
fn infer_einsum_shape(&self, spec: &str, inputs: &[usize]) -> Result<TensorShape, String> {
let (input_specs, output_spec) = if let Some(arrow_pos) = spec.find("->") {
let input_part = &spec[..arrow_pos];
let output_part = &spec[arrow_pos + 2..];
(input_part, Some(output_part))
} else {
(spec, None)
};
let input_specs: Vec<&str> = input_specs.split(',').map(|s| s.trim()).collect();
if input_specs.len() != inputs.len() {
return Err(format!(
"Einsum spec has {} inputs but {} tensors provided",
input_specs.len(),
inputs.len()
));
}
let mut dim_sizes: std::collections::HashMap<char, DimSize> =
std::collections::HashMap::new();
for (spec_idx, &input_idx) in inputs.iter().enumerate() {
let input_shape = self
.get_tensor_shape(input_idx)
.ok_or_else(|| format!("Input {} shape not available", input_idx))?;
let axes = input_specs[spec_idx].chars().collect::<Vec<_>>();
if axes.len() != input_shape.rank() {
return Err(format!(
"Input {} spec '{}' has {} axes but tensor has rank {}",
spec_idx,
input_specs[spec_idx],
axes.len(),
input_shape.rank()
));
}
for (axis_idx, axis_char) in axes.iter().enumerate() {
let dim_size = input_shape.dims[axis_idx].clone();
if let Some(existing) = dim_sizes.get(axis_char) {
if let (DimSize::Static(size1), DimSize::Static(size2)) = (existing, &dim_size)
{
if size1 != size2 {
return Err(format!(
"Dimension '{}' has inconsistent sizes: {} vs {}",
axis_char, size1, size2
));
}
}
} else {
dim_sizes.insert(*axis_char, dim_size);
}
}
}
let output_dims = if let Some(output_axes) = output_spec {
output_axes
.chars()
.map(|c| dim_sizes.get(&c).cloned().unwrap_or(DimSize::Dynamic))
.collect()
} else {
let mut all_axes: Vec<char> = dim_sizes.keys().copied().collect();
all_axes.sort();
all_axes
.into_iter()
.map(|c| dim_sizes[&c].clone())
.collect()
};
Ok(TensorShape::new(output_dims))
}
}
impl Default for ShapeInferenceContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_shape_static() {
let shape = TensorShape::static_shape(vec![3, 4, 5]);
assert_eq!(shape.rank(), 3);
assert!(shape.is_static());
assert_eq!(shape.as_static(), Some(vec![3, 4, 5]));
}
#[test]
fn test_tensor_shape_dynamic() {
let shape = TensorShape::dynamic(3);
assert_eq!(shape.rank(), 3);
assert!(!shape.is_static());
assert_eq!(shape.as_static(), None);
}
#[test]
fn test_shape_compatibility() {
let shape1 = TensorShape::static_shape(vec![3, 4]);
let shape2 = TensorShape::static_shape(vec![3, 4]);
assert!(shape1.compatible_with(&shape2));
let shape3 = TensorShape::static_shape(vec![3, 1]);
assert!(shape1.compatible_with(&shape3));
let shape4 = TensorShape::static_shape(vec![3, 5]);
assert!(!shape1.compatible_with(&shape4));
}
#[test]
fn test_shape_inference_context() {
let mut ctx = ShapeInferenceContext::new();
let shape = TensorShape::static_shape(vec![2, 3]);
ctx.set_tensor_shape(0, shape.clone());
assert_eq!(ctx.get_tensor_shape(0), Some(&shape));
assert_eq!(ctx.get_tensor_shape(1), None);
}
#[test]
fn test_einsum_shape_inference() {
let mut ctx = ShapeInferenceContext::new();
ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 4]));
ctx.set_tensor_shape(1, TensorShape::static_shape(vec![4, 5]));
let shape = ctx
.infer_einsum_shape("ab,bc->ac", &[0, 1])
.expect("unwrap");
assert_eq!(shape.rank(), 2);
assert_eq!(shape.as_static(), Some(vec![3, 5]));
}
#[test]
fn test_einsum_shape_inference_explicit() {
let mut ctx = ShapeInferenceContext::new();
ctx.set_tensor_shape(0, TensorShape::static_shape(vec![2, 3, 4]));
let shape = ctx.infer_einsum_shape("abc->ab", &[0]).expect("unwrap");
assert_eq!(shape.rank(), 2);
assert_eq!(shape.as_static(), Some(vec![2, 3]));
}
#[test]
fn test_einsum_shape_inference_diagonal() {
let mut ctx = ShapeInferenceContext::new();
ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 3]));
let shape = ctx.infer_einsum_shape("aa->a", &[0]).expect("unwrap");
assert_eq!(shape.rank(), 1);
assert_eq!(shape.as_static(), Some(vec![3]));
}
#[test]
fn test_einsum_shape_inference_batch_matmul() {
let mut ctx = ShapeInferenceContext::new();
ctx.set_tensor_shape(0, TensorShape::static_shape(vec![10, 3, 4]));
ctx.set_tensor_shape(1, TensorShape::static_shape(vec![10, 4, 5]));
let shape = ctx
.infer_einsum_shape("bik,bkj->bij", &[0, 1])
.expect("unwrap");
assert_eq!(shape.rank(), 3);
assert_eq!(shape.as_static(), Some(vec![10, 3, 5]));
}
#[test]
fn test_einsum_shape_inference_inconsistent_dims() {
let mut ctx = ShapeInferenceContext::new();
ctx.set_tensor_shape(0, TensorShape::static_shape(vec![3, 4]));
ctx.set_tensor_shape(1, TensorShape::static_shape(vec![5, 6]));
let result = ctx.infer_einsum_shape("ab,bc->ac", &[0, 1]);
assert!(result.is_err());
assert!(result.expect_err("unwrap_err").contains("inconsistent"));
}
}