use crate::{
llzkAffineMapOperandsBuilderCreate, llzkAffineMapOperandsBuilderDestroy,
llzkArray_ArrayTypeGetDimensionSizesAt, llzkArray_ArrayTypeGetDimensionSizesCount,
llzkArray_ArrayTypeGetElementType, llzkArray_ArrayTypeGetWithDims,
llzkArray_ArrayTypeGetWithShape, llzkArray_CreateArrayOpBuildWithMapOperands,
llzkArray_CreateArrayOpBuildWithValues, llzkTypeIsA_Array_ArrayType,
mlirGetDialectHandle__llzk__array__, mlirOpBuilderCreate, mlirOpBuilderDestroy,
sanity_tests::{
TestContext, context, load_llzk_dialects,
typing::{IndexType, index_type},
},
};
use mlir_sys::{
MlirContext, MlirOperation, MlirType, mlirAttributeEqual, mlirIdentifierGet, mlirIndexTypeGet,
mlirIntegerAttrGet, mlirLocationUnknownGet, mlirNamedAttributeGet, mlirOperationCreate,
mlirOperationDestroy, mlirOperationGetResult, mlirOperationStateAddAttributes,
mlirOperationStateEnableResultTypeInference, mlirOperationStateGet, mlirOperationVerify,
mlirStringRefCreateFromCString, mlirTypeEqual,
};
use rstest::rstest;
use std::{ffi::CString, ptr::null};
#[test]
fn test_mlir_get_dialect_handle_llzk_array() {
unsafe {
mlirGetDialectHandle__llzk__array__();
}
}
#[rstest]
fn test_llzk_array_type_get(index_type: IndexType) {
unsafe {
let size = mlirIntegerAttrGet(index_type.t, 1);
let dims = [size];
let arr_type = llzkArray_ArrayTypeGetWithDims(
index_type.t,
isize::try_from(dims.len()).expect("dims too large"),
dims.as_ptr(),
);
assert_ne!(arr_type.ptr, null());
}
}
#[rstest]
fn test_llzk_type_isa_array_type_pass(index_type: IndexType) {
unsafe {
let size = mlirIntegerAttrGet(index_type.t, 1);
let dims = [size];
let arr_type = llzkArray_ArrayTypeGetWithDims(
index_type.t,
isize::try_from(dims.len()).expect("dims too large"),
dims.as_ptr(),
);
assert_ne!(arr_type.ptr, null());
assert!(llzkTypeIsA_Array_ArrayType(arr_type));
}
}
#[rstest]
fn test_llzk_type_isa_array_type_fail(index_type: IndexType) {
unsafe {
assert!(!llzkTypeIsA_Array_ArrayType(index_type.t));
}
}
#[rstest]
fn test_llzk_array_type_get_with_numeric_dims(index_type: IndexType) {
unsafe {
let dims = [1, 2];
let arr_type = llzkArray_ArrayTypeGetWithShape(
index_type.t,
isize::try_from(dims.len()).expect("dims too large"),
dims.as_ptr(),
);
assert_ne!(arr_type.ptr, null());
}
}
#[rstest]
fn test_llzk_array_type_get_element_type(index_type: IndexType) {
unsafe {
let dims = [1, 2];
let arr_type = llzkArray_ArrayTypeGetWithShape(
index_type.t,
isize::try_from(dims.len()).expect("dims too large"),
dims.as_ptr(),
);
assert_ne!(arr_type.ptr, null());
let elt_type = llzkArray_ArrayTypeGetElementType(arr_type);
assert!(mlirTypeEqual(index_type.t, elt_type));
}
}
#[rstest]
fn test_llzk_array_type_get_num_dims(index_type: IndexType) {
unsafe {
let dims = [1, 2];
let arr_type = llzkArray_ArrayTypeGetWithShape(
index_type.t,
isize::try_from(dims.len()).expect("dims too large"),
dims.as_ptr(),
);
assert_ne!(arr_type.ptr, null());
let n_dims = llzkArray_ArrayTypeGetDimensionSizesCount(arr_type);
assert_eq!(n_dims, isize::try_from(dims.len()).expect("dims too large"));
}
}
#[rstest]
fn test_llzk_array_type_get_dim(index_type: IndexType) {
unsafe {
let dims = [1, 2];
let arr_type = llzkArray_ArrayTypeGetWithShape(
index_type.t,
isize::try_from(dims.len()).expect("dims too large"),
dims.as_ptr(),
);
assert_ne!(arr_type.ptr, null());
let out_dim = llzkArray_ArrayTypeGetDimensionSizesAt(arr_type, 0);
let dim_as_attr = mlirIntegerAttrGet(index_type.t, dims[0]);
assert!(mlirAttributeEqual(out_dim, dim_as_attr));
}
}
#[rstest]
fn test_llzk_create_array_op_build_with_values(context: TestContext, #[values(&[1])] dims: &[i64]) {
unsafe {
let elt_type = mlirIndexTypeGet(context.ctx);
let test_type = test_array(elt_type, dims);
let n_elements: i64 = dims.iter().product();
let ops = create_n_ops(context.ctx, n_elements, elt_type);
let values = ops
.iter()
.map(|op| mlirOperationGetResult(*op, 0))
.collect::<Vec<_>>();
let builder = mlirOpBuilderCreate(context.ctx);
let location = mlirLocationUnknownGet(context.ctx);
let create_array_op = llzkArray_CreateArrayOpBuildWithValues(
builder,
location,
test_type,
isize::try_from(values.len()).expect("values too large"),
values.as_ptr(),
);
for op in &ops {
assert!(mlirOperationVerify(*op));
}
assert!(mlirOperationVerify(create_array_op));
mlirOperationDestroy(create_array_op);
for op in ops {
mlirOperationDestroy(op);
}
mlirOpBuilderDestroy(builder);
}
}
#[rstest]
fn test_llzk_create_array_op_build_with_map_operands(
context: TestContext,
#[values(&[1])] dims: &[i64],
) {
load_llzk_dialects(&context);
unsafe {
let elt_type = mlirIndexTypeGet(context.ctx);
let test_type = test_array(elt_type, dims);
let builder = mlirOpBuilderCreate(context.ctx);
let location = mlirLocationUnknownGet(context.ctx);
let mut map_operands = llzkAffineMapOperandsBuilderCreate();
let op =
llzkArray_CreateArrayOpBuildWithMapOperands(builder, location, test_type, map_operands);
assert!(mlirOperationVerify(op));
mlirOperationDestroy(op);
llzkAffineMapOperandsBuilderDestroy(&mut map_operands);
mlirOpBuilderDestroy(builder);
}
}
fn test_array(elt: MlirType, dims: &[i64]) -> MlirType {
unsafe {
llzkArray_ArrayTypeGetWithShape(
elt,
isize::try_from(dims.len()).expect("dims too large"),
dims.as_ptr(),
)
}
}
fn create_n_ops(ctx: MlirContext, n_ops: i64, elt_type: MlirType) -> Vec<MlirOperation> {
unsafe {
let arith_constant_op_str = CString::new("arith.constant").unwrap();
let value_str = CString::new("value").unwrap();
let name = mlirStringRefCreateFromCString(arith_constant_op_str.as_ptr());
let attr_name = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString(value_str.as_ptr()));
let location = mlirLocationUnknownGet(ctx);
(0..n_ops)
.map(|n| {
let attr = mlirNamedAttributeGet(attr_name, mlirIntegerAttrGet(elt_type, n));
let mut op_state = mlirOperationStateGet(name, location);
mlirOperationStateAddAttributes(&mut op_state, 1, &attr);
mlirOperationStateEnableResultTypeInference(&mut op_state);
mlirOperationCreate(&mut op_state)
})
.collect()
}
}