use crate::dtype::DType;
use crate::error::Result;
#[cfg(not(feature = "std"))]
use alloc::{string::String, vec::Vec};
#[cfg(feature = "std")]
use std::{string::String, vec::Vec};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MlirDialect {
Tensor,
Linalg,
Affine,
Scf,
Arith,
MemRef,
Gpu,
Llvm,
Builtin,
}
impl MlirDialect {
pub fn name(&self) -> &'static str {
match self {
MlirDialect::Tensor => "tensor",
MlirDialect::Linalg => "linalg",
MlirDialect::Affine => "affine",
MlirDialect::Scf => "scf",
MlirDialect::Arith => "arith",
MlirDialect::MemRef => "memref",
MlirDialect::Gpu => "gpu",
MlirDialect::Llvm => "llvm",
MlirDialect::Builtin => "builtin",
}
}
pub fn is_high_level(&self) -> bool {
matches!(self, MlirDialect::Tensor | MlirDialect::Linalg)
}
pub fn is_low_level(&self) -> bool {
matches!(self, MlirDialect::Llvm | MlirDialect::MemRef)
}
pub fn lowering_target(&self) -> Option<MlirDialect> {
match self {
MlirDialect::Tensor => Some(MlirDialect::Linalg),
MlirDialect::Linalg => Some(MlirDialect::Affine),
MlirDialect::Affine => Some(MlirDialect::Scf),
MlirDialect::Scf => Some(MlirDialect::Llvm),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MlirOpcode {
TensorEmpty,
TensorExtract,
TensorInsert,
TensorFromElements,
LinalgMatmul,
LinalgDot,
LinalgConv,
LinalgPooling,
LinalgBroadcast,
LinalgTranspose,
LinalgReduce,
ArithAddf,
ArithSubf,
ArithMulf,
ArithDivf,
ArithAddi,
ArithSubi,
ArithMuli,
ArithDivi,
MemRefAlloc,
MemRefDealloc,
MemRefLoad,
MemRefStore,
ScfFor,
ScfIf,
ScfWhile,
ScfParallel,
FuncFunc,
FuncReturn,
FuncCall,
ModuleOp,
UnrealizedConversionCast,
}
impl MlirOpcode {
pub fn name(&self) -> &'static str {
match self {
MlirOpcode::TensorEmpty => "tensor.empty",
MlirOpcode::TensorExtract => "tensor.extract",
MlirOpcode::TensorInsert => "tensor.insert",
MlirOpcode::TensorFromElements => "tensor.from_elements",
MlirOpcode::LinalgMatmul => "linalg.matmul",
MlirOpcode::LinalgDot => "linalg.dot",
MlirOpcode::LinalgConv => "linalg.conv_2d",
MlirOpcode::LinalgPooling => "linalg.pooling",
MlirOpcode::LinalgBroadcast => "linalg.broadcast",
MlirOpcode::LinalgTranspose => "linalg.transpose",
MlirOpcode::LinalgReduce => "linalg.reduce",
MlirOpcode::ArithAddf => "arith.addf",
MlirOpcode::ArithSubf => "arith.subf",
MlirOpcode::ArithMulf => "arith.mulf",
MlirOpcode::ArithDivf => "arith.divf",
MlirOpcode::ArithAddi => "arith.addi",
MlirOpcode::ArithSubi => "arith.subi",
MlirOpcode::ArithMuli => "arith.muli",
MlirOpcode::ArithDivi => "arith.divi",
MlirOpcode::MemRefAlloc => "memref.alloc",
MlirOpcode::MemRefDealloc => "memref.dealloc",
MlirOpcode::MemRefLoad => "memref.load",
MlirOpcode::MemRefStore => "memref.store",
MlirOpcode::ScfFor => "scf.for",
MlirOpcode::ScfIf => "scf.if",
MlirOpcode::ScfWhile => "scf.while",
MlirOpcode::ScfParallel => "scf.parallel",
MlirOpcode::FuncFunc => "func.func",
MlirOpcode::FuncReturn => "func.return",
MlirOpcode::FuncCall => "func.call",
MlirOpcode::ModuleOp => "builtin.module",
MlirOpcode::UnrealizedConversionCast => "builtin.unrealized_conversion_cast",
}
}
pub fn dialect(&self) -> MlirDialect {
match self {
MlirOpcode::TensorEmpty
| MlirOpcode::TensorExtract
| MlirOpcode::TensorInsert
| MlirOpcode::TensorFromElements => MlirDialect::Tensor,
MlirOpcode::LinalgMatmul
| MlirOpcode::LinalgDot
| MlirOpcode::LinalgConv
| MlirOpcode::LinalgPooling
| MlirOpcode::LinalgBroadcast
| MlirOpcode::LinalgTranspose
| MlirOpcode::LinalgReduce => MlirDialect::Linalg,
MlirOpcode::ArithAddf
| MlirOpcode::ArithSubf
| MlirOpcode::ArithMulf
| MlirOpcode::ArithDivf
| MlirOpcode::ArithAddi
| MlirOpcode::ArithSubi
| MlirOpcode::ArithMuli
| MlirOpcode::ArithDivi => MlirDialect::Arith,
MlirOpcode::MemRefAlloc
| MlirOpcode::MemRefDealloc
| MlirOpcode::MemRefLoad
| MlirOpcode::MemRefStore => MlirDialect::MemRef,
MlirOpcode::ScfFor
| MlirOpcode::ScfIf
| MlirOpcode::ScfWhile
| MlirOpcode::ScfParallel => MlirDialect::Scf,
MlirOpcode::FuncFunc | MlirOpcode::FuncReturn | MlirOpcode::FuncCall => {
MlirDialect::Builtin
}
MlirOpcode::ModuleOp | MlirOpcode::UnrealizedConversionCast => MlirDialect::Builtin,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MlirValue(pub usize);
#[derive(Debug, Clone)]
pub struct MlirOp {
pub opcode: MlirOpcode,
pub operands: Vec<MlirValue>,
pub results: Vec<MlirValue>,
pub attributes: MlirAttributes,
pub result_types: Vec<MlirType>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MlirType {
Tensor {
shape: Vec<i64>,
element_type: DType,
},
MemRef {
shape: Vec<i64>,
element_type: DType,
memory_space: Option<usize>,
},
Scalar(DType),
Integer { width: u32, signed: bool },
Float { width: u32 },
Index,
Function {
inputs: Vec<Box<MlirType>>,
outputs: Vec<Box<MlirType>>,
},
None,
}
impl MlirType {
pub fn tensor(shape: &[usize], dtype: DType) -> Self {
MlirType::Tensor {
shape: shape.iter().map(|&s| s as i64).collect(),
element_type: dtype,
}
}
pub fn memref(shape: &[usize], dtype: DType) -> Self {
MlirType::MemRef {
shape: shape.iter().map(|&s| s as i64).collect(),
element_type: dtype,
memory_space: None,
}
}
pub fn scalar(dtype: DType) -> Self {
MlirType::Scalar(dtype)
}
pub fn function(inputs: Vec<MlirType>, outputs: Vec<MlirType>) -> Self {
MlirType::Function {
inputs: inputs.into_iter().map(Box::new).collect(),
outputs: outputs.into_iter().map(Box::new).collect(),
}
}
pub fn is_tensor(&self) -> bool {
matches!(self, MlirType::Tensor { .. })
}
pub fn is_memref(&self) -> bool {
matches!(self, MlirType::MemRef { .. })
}
pub fn to_mlir_string(&self) -> String {
match self {
MlirType::Tensor {
shape,
element_type,
} => {
let shape_str: Vec<String> = shape.iter().map(|s| s.to_string()).collect();
format!("tensor<{}x{}>", shape_str.join("x"), element_type.name())
}
MlirType::MemRef {
shape,
element_type,
memory_space,
} => {
let shape_str: Vec<String> = shape.iter().map(|s| s.to_string()).collect();
let base = format!("memref<{}x{}>", shape_str.join("x"), element_type.name());
if let Some(space) = memory_space {
format!("{}, {}>", &base[..base.len() - 1], space)
} else {
base
}
}
MlirType::Scalar(dtype) => dtype.name().to_string(),
MlirType::Integer { width, signed } => {
if *signed {
format!("i{}", width)
} else {
format!("ui{}", width)
}
}
MlirType::Float { width } => format!("f{}", width),
MlirType::Index => "index".to_string(),
MlirType::Function { inputs, outputs } => {
let inputs_str: Vec<String> = inputs.iter().map(|t| t.to_mlir_string()).collect();
let outputs_str: Vec<String> = outputs.iter().map(|t| t.to_mlir_string()).collect();
format!(
"({}) -> ({})",
inputs_str.join(", "),
outputs_str.join(", ")
)
}
MlirType::None => "none".to_string(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct MlirAttributes {
pub strings: Vec<(String, String)>,
pub integers: Vec<(String, i64)>,
pub bools: Vec<(String, bool)>,
pub types: Vec<(String, MlirType)>,
}
impl MlirAttributes {
pub fn new() -> Self {
Self::default()
}
pub fn add_string(&mut self, name: impl Into<String>, value: impl Into<String>) {
self.strings.push((name.into(), value.into()));
}
pub fn add_integer(&mut self, name: impl Into<String>, value: i64) {
self.integers.push((name.into(), value));
}
pub fn add_bool(&mut self, name: impl Into<String>, value: bool) {
self.bools.push((name.into(), value));
}
pub fn add_type(&mut self, name: impl Into<String>, ty: MlirType) {
self.types.push((name.into(), ty));
}
pub fn is_empty(&self) -> bool {
self.strings.is_empty()
&& self.integers.is_empty()
&& self.bools.is_empty()
&& self.types.is_empty()
}
}
pub struct MlirBuilder {
operations: Vec<MlirOp>,
next_value: usize,
name: String,
value_types: Vec<MlirType>,
}
impl MlirBuilder {
pub fn new() -> Self {
Self {
operations: Vec::new(),
next_value: 0,
name: "main".to_string(),
value_types: Vec::new(),
}
}
pub fn with_name(name: impl Into<String>) -> Self {
Self {
operations: Vec::new(),
next_value: 0,
name: name.into(),
value_types: Vec::new(),
}
}
fn allocate_value(&mut self, ty: MlirType) -> MlirValue {
let id = MlirValue(self.next_value);
self.value_types.push(ty);
self.next_value += 1;
id
}
fn get_value_type(&self, value: MlirValue) -> Option<&MlirType> {
self.value_types.get(value.0)
}
pub fn add_tensor_input(&mut self, shape: &[usize], dtype: DType) -> Result<MlirValue> {
let ty = MlirType::tensor(shape, dtype);
let result = self.allocate_value(ty.clone());
self.operations.push(MlirOp {
opcode: MlirOpcode::TensorEmpty,
operands: Vec::new(),
results: vec![result],
attributes: MlirAttributes::new(),
result_types: vec![ty],
});
Ok(result)
}
pub fn add_tensor_constant(&mut self, shape: &[usize], dtype: DType) -> Result<MlirValue> {
let ty = MlirType::tensor(shape, dtype);
let result = self.allocate_value(ty.clone());
let mut attrs = MlirAttributes::new();
attrs.add_string("value", "constant");
self.operations.push(MlirOp {
opcode: MlirOpcode::TensorFromElements,
operands: Vec::new(),
results: vec![result],
attributes: attrs,
result_types: vec![ty],
});
Ok(result)
}
pub fn add_matmul(&mut self, lhs: MlirValue, rhs: MlirValue) -> Result<MlirValue> {
let result_type = match (self.get_value_type(lhs), self.get_value_type(rhs)) {
(
Some(MlirType::Tensor {
shape: lhs_shape,
element_type: lhs_dtype,
}),
Some(MlirType::Tensor {
shape: rhs_shape,
element_type: _,
}),
) => {
if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
return Err(crate::error::TorshError::dimension_error(
"Matrix multiplication requires 2D tensors",
"matmul",
));
}
if lhs_shape[1] != rhs_shape[0] {
return Err(crate::error::TorshError::dimension_error(
&format!(
"Incompatible dimensions for matmul: {}x{} and {}x{}",
lhs_shape[0], lhs_shape[1], rhs_shape[0], rhs_shape[1]
),
"matmul",
));
}
MlirType::Tensor {
shape: vec![lhs_shape[0], rhs_shape[1]],
element_type: *lhs_dtype,
}
}
_ => {
return Err(crate::error::TorshError::dimension_error(
"Matmul operands must be tensors",
"matmul",
));
}
};
let result = self.allocate_value(result_type.clone());
self.operations.push(MlirOp {
opcode: MlirOpcode::LinalgMatmul,
operands: vec![lhs, rhs],
results: vec![result],
attributes: MlirAttributes::new(),
result_types: vec![result_type],
});
Ok(result)
}
pub fn add_add(&mut self, lhs: MlirValue, rhs: MlirValue, dtype: DType) -> Result<MlirValue> {
let result_type = MlirType::scalar(dtype);
let result = self.allocate_value(result_type.clone());
let opcode = match dtype {
DType::F16 | DType::F32 | DType::F64 | DType::BF16 => MlirOpcode::ArithAddf,
_ => MlirOpcode::ArithAddi,
};
self.operations.push(MlirOp {
opcode,
operands: vec![lhs, rhs],
results: vec![result],
attributes: MlirAttributes::new(),
result_types: vec![result_type],
});
Ok(result)
}
pub fn add_transpose(
&mut self,
operand: MlirValue,
permutation: &[usize],
) -> Result<MlirValue> {
let result_type = match self.get_value_type(operand) {
Some(MlirType::Tensor {
shape,
element_type,
}) => {
if permutation.len() != shape.len() {
return Err(crate::error::TorshError::dimension_error(
"Transpose permutation must match tensor rank",
"transpose",
));
}
let new_shape: Vec<i64> = permutation.iter().map(|&i| shape[i]).collect();
MlirType::Tensor {
shape: new_shape,
element_type: *element_type,
}
}
_ => {
return Err(crate::error::TorshError::dimension_error(
"Transpose operand must be a tensor",
"transpose",
));
}
};
let result = self.allocate_value(result_type.clone());
let mut attrs = MlirAttributes::new();
let perm_str: Vec<String> = permutation.iter().map(|p| p.to_string()).collect();
attrs.add_string("permutation", perm_str.join(","));
self.operations.push(MlirOp {
opcode: MlirOpcode::LinalgTranspose,
operands: vec![operand],
results: vec![result],
attributes: attrs,
result_types: vec![result_type],
});
Ok(result)
}
pub fn build(self) -> Result<MlirModule> {
Ok(MlirModule {
name: self.name,
operations: self.operations,
})
}
pub fn num_operations(&self) -> usize {
self.operations.len()
}
}
impl Default for MlirBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MlirModule {
pub name: String,
pub operations: Vec<MlirOp>,
}
impl MlirModule {
pub fn name(&self) -> &str {
&self.name
}
pub fn operations(&self) -> &[MlirOp] {
&self.operations
}
pub fn operations_by_dialect(&self) -> Vec<(MlirDialect, usize)> {
let mut counts: Vec<(MlirDialect, usize)> = Vec::new();
let mut dialects = Vec::new();
for op in &self.operations {
let dialect = op.opcode.dialect();
if let Some(pos) = dialects.iter().position(|&d| d == dialect) {
counts[pos].1 += 1;
} else {
dialects.push(dialect);
counts.push((dialect, 1));
}
}
counts
}
pub fn to_mlir_text(&self) -> String {
let mut text = format!("module @{} {{\n", self.name);
for op in self.operations.iter() {
let result_refs: Vec<String> = op.results.iter().map(|v| format!("%{}", v.0)).collect();
let operand_refs: Vec<String> =
op.operands.iter().map(|v| format!("%{}", v.0)).collect();
let line = if !result_refs.is_empty() {
format!(
" {} = {}({})\n",
result_refs.join(", "),
op.opcode.name(),
operand_refs.join(", ")
)
} else {
format!(" {}({})\n", op.opcode.name(), operand_refs.join(", "))
};
text.push_str(&line);
}
text.push_str("}\n");
text
}
pub fn optimize(&self) -> Result<MlirModule> {
Ok(self.clone())
}
pub fn lower_to(&self, _target: MlirDialect) -> Result<MlirModule> {
Ok(self.clone())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MlirPass {
Canonicalize,
CSE,
DCE,
LoopFusion,
LoopTiling,
BufferAllocation,
LowerToLLVM,
}
impl MlirPass {
pub fn name(&self) -> &'static str {
match self {
MlirPass::Canonicalize => "canonicalize",
MlirPass::CSE => "cse",
MlirPass::DCE => "dce",
MlirPass::LoopFusion => "loop-fusion",
MlirPass::LoopTiling => "loop-tiling",
MlirPass::BufferAllocation => "buffer-allocation",
MlirPass::LowerToLLVM => "lower-to-llvm",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mlir_dialect_name() {
assert_eq!(MlirDialect::Tensor.name(), "tensor");
assert_eq!(MlirDialect::Linalg.name(), "linalg");
}
#[test]
fn test_mlir_dialect_levels() {
assert!(MlirDialect::Tensor.is_high_level());
assert!(!MlirDialect::Llvm.is_high_level());
assert!(MlirDialect::Llvm.is_low_level());
assert!(!MlirDialect::Tensor.is_low_level());
}
#[test]
fn test_mlir_dialect_lowering() {
assert_eq!(
MlirDialect::Tensor.lowering_target(),
Some(MlirDialect::Linalg)
);
assert_eq!(
MlirDialect::Linalg.lowering_target(),
Some(MlirDialect::Affine)
);
}
#[test]
fn test_mlir_opcode_name() {
assert_eq!(MlirOpcode::LinalgMatmul.name(), "linalg.matmul");
assert_eq!(MlirOpcode::ArithAddf.name(), "arith.addf");
}
#[test]
fn test_mlir_opcode_dialect() {
assert_eq!(MlirOpcode::LinalgMatmul.dialect(), MlirDialect::Linalg);
assert_eq!(MlirOpcode::ArithAddf.dialect(), MlirDialect::Arith);
}
#[test]
fn test_mlir_type_tensor() {
let ty = MlirType::tensor(&[128, 256], DType::F32);
assert!(ty.is_tensor());
assert!(!ty.is_memref());
}
#[test]
fn test_mlir_type_to_string() {
let ty = MlirType::tensor(&[128, 256], DType::F32);
let str_repr = ty.to_mlir_string();
assert!(str_repr.contains("tensor"));
assert!(str_repr.contains("128"));
assert!(str_repr.contains("256"));
}
#[test]
fn test_mlir_builder_input() {
let mut builder = MlirBuilder::new();
let input = builder
.add_tensor_input(&[10, 20], DType::F32)
.expect("add_tensor_input should succeed");
assert_eq!(input, MlirValue(0));
assert_eq!(builder.num_operations(), 1);
}
#[test]
fn test_mlir_builder_matmul() {
let mut builder = MlirBuilder::new();
let lhs = builder
.add_tensor_input(&[128, 256], DType::F32)
.expect("add_tensor_input should succeed");
let rhs = builder
.add_tensor_input(&[256, 512], DType::F32)
.expect("add_tensor_input should succeed");
let result = builder
.add_matmul(lhs, rhs)
.expect("add_matmul should succeed");
assert_eq!(result, MlirValue(2));
assert_eq!(builder.num_operations(), 3);
}
#[test]
fn test_mlir_builder_add() {
let mut builder = MlirBuilder::new();
let lhs = builder
.add_tensor_input(&[10], DType::F32)
.expect("add_tensor_input should succeed");
let rhs = builder
.add_tensor_input(&[10], DType::F32)
.expect("add_tensor_input should succeed");
let result = builder
.add_add(lhs, rhs, DType::F32)
.expect("add_add should succeed");
assert_eq!(result, MlirValue(2));
}
#[test]
fn test_mlir_module_build() {
let mut builder = MlirBuilder::new();
let lhs = builder
.add_tensor_input(&[128, 256], DType::F32)
.expect("add_tensor_input should succeed");
let rhs = builder
.add_tensor_input(&[256, 512], DType::F32)
.expect("add_tensor_input should succeed");
let _result = builder
.add_matmul(lhs, rhs)
.expect("add_matmul should succeed");
let module = builder.build().expect("build should succeed");
assert_eq!(module.name(), "main");
assert_eq!(module.operations().len(), 3);
}
#[test]
fn test_mlir_module_operations_by_dialect() {
let mut builder = MlirBuilder::new();
let lhs = builder
.add_tensor_input(&[128, 256], DType::F32)
.expect("add_tensor_input should succeed");
let rhs = builder
.add_tensor_input(&[256, 512], DType::F32)
.expect("add_tensor_input should succeed");
let _result = builder
.add_matmul(lhs, rhs)
.expect("add_matmul should succeed");
let module = builder.build().expect("build should succeed");
let counts = module.operations_by_dialect();
assert_eq!(counts.len(), 2); }
#[test]
fn test_mlir_module_to_text() {
let mut builder = MlirBuilder::with_name("test");
let lhs = builder
.add_tensor_input(&[10, 20], DType::F32)
.expect("add_tensor_input should succeed");
let rhs = builder
.add_tensor_input(&[10, 20], DType::F32)
.expect("add_tensor_input should succeed");
let _result = builder
.add_add(lhs, rhs, DType::F32)
.expect("add_add should succeed");
let module = builder.build().expect("build should succeed");
let text = module.to_mlir_text();
assert!(text.contains("module @test"));
assert!(text.contains("tensor.empty"));
assert!(text.contains("arith.addf"));
}
#[test]
fn test_mlir_attributes() {
let mut attrs = MlirAttributes::new();
assert!(attrs.is_empty());
attrs.add_string("name", "test");
attrs.add_integer("value", 42);
attrs.add_bool("flag", true);
assert!(!attrs.is_empty());
assert_eq!(attrs.strings.len(), 1);
assert_eq!(attrs.integers.len(), 1);
assert_eq!(attrs.bools.len(), 1);
}
#[test]
fn test_mlir_pass_name() {
assert_eq!(MlirPass::Canonicalize.name(), "canonicalize");
assert_eq!(MlirPass::CSE.name(), "cse");
}
#[test]
fn test_mlir_transpose() {
let mut builder = MlirBuilder::new();
let input = builder
.add_tensor_input(&[128, 256], DType::F32)
.expect("add_tensor_input should succeed");
let transposed = builder
.add_transpose(input, &[1, 0])
.expect("add_transpose should succeed");
assert_eq!(transposed, MlirValue(1));
assert_eq!(builder.num_operations(), 2);
}
#[test]
fn test_mlir_type_function() {
let input_types = vec![
MlirType::tensor(&[128, 256], DType::F32),
MlirType::tensor(&[256, 512], DType::F32),
];
let output_types = vec![MlirType::tensor(&[128, 512], DType::F32)];
let func_type = MlirType::function(input_types, output_types);
let str_repr = func_type.to_mlir_string();
assert!(str_repr.contains("->"));
}
#[test]
fn test_complex_module() {
let mut builder = MlirBuilder::with_name("complex");
let a = builder
.add_tensor_input(&[128, 256], DType::F32)
.expect("add_tensor_input should succeed");
let b = builder
.add_tensor_input(&[256, 512], DType::F32)
.expect("add_tensor_input should succeed");
let d = builder
.add_tensor_input(&[128, 512], DType::F32)
.expect("add_tensor_input should succeed");
let matmul = builder.add_matmul(a, b).expect("add_matmul should succeed");
let _result = builder
.add_add(matmul, d, DType::F32)
.expect("add_add should succeed");
let module = builder.build().expect("build should succeed");
assert_eq!(module.operations().len(), 5);
let dialect_counts = module.operations_by_dialect();
assert!(dialect_counts.len() >= 2); }
#[test]
fn test_mlir_type_validation_invalid_matmul_dims() {
let mut builder = MlirBuilder::new();
let a = builder
.add_tensor_input(&[128, 256], DType::F32)
.expect("add_tensor_input should succeed");
let b = builder
.add_tensor_input(&[128, 512], DType::F32)
.expect("add_tensor_input should succeed");
let result = builder.add_matmul(a, b);
assert!(result.is_err());
}
#[test]
fn test_mlir_type_validation_invalid_matmul_rank() {
let mut builder = MlirBuilder::new();
let a = builder
.add_tensor_input(&[128], DType::F32)
.expect("add_tensor_input should succeed"); let b = builder
.add_tensor_input(&[128, 512], DType::F32)
.expect("add_tensor_input should succeed");
let result = builder.add_matmul(a, b);
assert!(result.is_err());
}
#[test]
fn test_mlir_type_validation_invalid_transpose() {
let mut builder = MlirBuilder::new();
let input = builder
.add_tensor_input(&[10, 20, 30], DType::F32)
.expect("add_tensor_input should succeed");
let result = builder.add_transpose(input, &[1, 0]);
assert!(result.is_err());
}
#[test]
fn test_mlir_type_inference_matmul() {
let mut builder = MlirBuilder::new();
let a = builder
.add_tensor_input(&[128, 256], DType::F32)
.expect("add_tensor_input should succeed");
let b = builder
.add_tensor_input(&[256, 512], DType::F64)
.expect("add_tensor_input should succeed");
let result = builder.add_matmul(a, b).expect("add_matmul should succeed");
let result_type = builder
.get_value_type(result)
.expect("get_value_type should succeed");
match result_type {
MlirType::Tensor {
shape,
element_type,
} => {
assert_eq!(shape, &[128, 512]);
assert_eq!(*element_type, DType::F32); }
_ => panic!("Expected tensor type"),
}
}
#[test]
fn test_mlir_type_inference_transpose() {
let mut builder = MlirBuilder::new();
let input = builder
.add_tensor_input(&[10, 20, 30], DType::F32)
.expect("add_tensor_input should succeed");
let result = builder
.add_transpose(input, &[2, 0, 1])
.expect("add_transpose should succeed");
let result_type = builder
.get_value_type(result)
.expect("get_value_type should succeed");
match result_type {
MlirType::Tensor {
shape,
element_type,
} => {
assert_eq!(shape, &[30, 10, 20]);
assert_eq!(*element_type, DType::F32);
}
_ => panic!("Expected tensor type"),
}
}
}