use std::{
ffi::{CString, c_void},
ptr::{null, null_mut},
};
use mlir_sys::{
MlirOperation, MlirType, mlirArrayAttrGet, mlirAttributeEqual, mlirAttributeGetContext,
mlirFlatSymbolRefAttrGet, mlirIndexTypeGet, mlirIntegerAttrGet, mlirIntegerTypeGet,
mlirLocationUnknownGet, mlirStringRefCreateFromCString,
};
use rstest::{fixture, rstest};
use crate::{
llzkArrayAttrTypeParamsUnify, llzkAssertValidAttrForParamOfType, llzkForceIntAttrType,
llzkHasAffineMapAttr, llzkIsConcreteType, llzkIsMoreConcreteUnification,
llzkIsValidArrayElemType, llzkIsValidArrayType, llzkIsValidColumnType,
llzkIsValidConstReadType, llzkIsValidEmitEqType, llzkIsValidGlobalType, llzkIsValidType,
llzkTypeParamsUnify, llzkTypesUnify,
sanity_tests::{TestContext, context},
};
pub struct IndexType {
#[allow(dead_code)]
pub context: TestContext,
pub t: MlirType,
}
#[fixture]
pub fn index_type(context: TestContext) -> IndexType {
unsafe {
let ctx = context.ctx;
IndexType {
context,
t: mlirIndexTypeGet(ctx),
}
}
}
pub struct I16Type {
#[allow(dead_code)]
pub context: TestContext,
pub t: MlirType,
}
#[fixture]
pub fn i16_type(context: TestContext) -> I16Type {
unsafe {
let ctx = context.ctx;
I16Type {
context,
t: mlirIntegerTypeGet(ctx, 16),
}
}
}
#[rstest]
fn test_llzk_assert_valid_attr_for_param_of_type(index_type: IndexType) {
unsafe {
let int_attr = mlirIntegerAttrGet(index_type.t, 0);
llzkAssertValidAttrForParamOfType(int_attr);
}
}
#[rstest]
fn test_llzk_is_valid_type(index_type: IndexType) {
unsafe {
assert!(llzkIsValidType(index_type.t));
}
}
#[rstest]
fn test_llzk_is_valid_column_type(index_type: IndexType) {
unsafe {
let null_op = MlirOperation { ptr: null_mut() };
assert!(!llzkIsValidColumnType(index_type.t, null_op));
}
}
#[rstest]
fn test_llzk_is_valid_emit_eq_type(index_type: IndexType) {
unsafe {
assert!(llzkIsValidEmitEqType(index_type.t));
}
}
#[rstest]
fn test_llzk_is_valid_const_read_type(index_type: IndexType) {
unsafe {
assert!(llzkIsValidConstReadType(index_type.t));
}
}
#[rstest]
fn test_llzk_is_valid_array_elem_type(index_type: IndexType) {
unsafe {
assert!(llzkIsValidArrayElemType(index_type.t));
}
}
#[rstest]
fn test_llzk_is_valid_array_type(index_type: IndexType) {
unsafe {
assert!(!llzkIsValidArrayType(index_type.t));
}
}
#[rstest]
fn test_llzk_is_concrete_type(index_type: IndexType) {
unsafe {
assert!(llzkIsConcreteType(index_type.t, true));
}
}
#[rstest]
fn test_llzk_has_affine_map_attr(index_type: IndexType) {
unsafe {
assert!(!llzkHasAffineMapAttr(index_type.t));
}
}
#[rstest]
fn test_llzk_type_params_unify_empty() {
unsafe {
let lhs = [];
let rhs = [];
assert!(llzkTypeParamsUnify(
isize::try_from(lhs.len()).expect("lhs too large"),
lhs.as_ptr(),
isize::try_from(rhs.len()).expect("rhs too large"),
rhs.as_ptr()
));
}
}
#[rstest]
fn test_llzk_type_params_unify_pass(index_type: IndexType) {
unsafe {
let string = CString::new("N").unwrap();
let string_ref = mlirStringRefCreateFromCString(string.as_ptr());
let lhs = [mlirIntegerAttrGet(index_type.t, 0)];
let rhs = [mlirFlatSymbolRefAttrGet(
mlirAttributeGetContext(lhs[0]),
string_ref,
)];
assert!(llzkTypeParamsUnify(
isize::try_from(lhs.len()).expect("lhs too large"),
lhs.as_ptr(),
isize::try_from(rhs.len()).expect("rhs too large"),
rhs.as_ptr()
));
}
}
#[rstest]
fn test_llzk_type_params_unify_fail(index_type: IndexType) {
unsafe {
let lhs = [mlirIntegerAttrGet(index_type.t, 0)];
let rhs = [mlirIntegerAttrGet(index_type.t, 1)];
assert!(!llzkTypeParamsUnify(
isize::try_from(lhs.len()).expect("lhs too large"),
lhs.as_ptr(),
isize::try_from(rhs.len()).expect("rhs too large"),
rhs.as_ptr()
));
}
}
#[rstest]
fn test_llzk_array_attr_type_params_unify_empty(context: TestContext) {
unsafe {
let lhs = [];
let lhs = mlirArrayAttrGet(
context.ctx,
isize::try_from(lhs.len()).expect("lhs too large"),
lhs.as_ptr(),
);
let rhs = [];
let rhs = mlirArrayAttrGet(
context.ctx,
isize::try_from(rhs.len()).expect("rhs too large"),
rhs.as_ptr(),
);
assert!(llzkArrayAttrTypeParamsUnify(lhs, rhs));
}
}
#[rstest]
fn test_llzk_array_attr_type_params_unify_pass(index_type: IndexType) {
unsafe {
let string = CString::new("N").unwrap();
let string_ref = mlirStringRefCreateFromCString(string.as_ptr());
let lhs = [mlirIntegerAttrGet(index_type.t, 0)];
let lhs = mlirArrayAttrGet(
mlirAttributeGetContext(lhs[0]),
isize::try_from(lhs.len()).expect("lhs too large"),
lhs.as_ptr(),
);
let rhs = [mlirFlatSymbolRefAttrGet(
mlirAttributeGetContext(lhs),
string_ref,
)];
let rhs = mlirArrayAttrGet(
mlirAttributeGetContext(lhs),
isize::try_from(rhs.len()).expect("rhs too large"),
rhs.as_ptr(),
);
assert!(llzkArrayAttrTypeParamsUnify(lhs, rhs));
}
}
#[rstest]
fn test_llzk_array_attr_type_params_unify_fail(index_type: IndexType) {
unsafe {
let lhs = [mlirIntegerAttrGet(index_type.t, 0)];
let lhs = mlirArrayAttrGet(
mlirAttributeGetContext(lhs[0]),
isize::try_from(lhs.len()).expect("lhs too large"),
lhs.as_ptr(),
);
let rhs = [mlirIntegerAttrGet(index_type.t, 1)];
let rhs = mlirArrayAttrGet(
mlirAttributeGetContext(lhs),
isize::try_from(rhs.len()).expect("rhs too large"),
rhs.as_ptr(),
);
assert!(!llzkArrayAttrTypeParamsUnify(lhs, rhs));
}
}
#[rstest]
fn test_llzk_types_unify(index_type: IndexType) {
unsafe {
assert!(llzkTypesUnify(index_type.t, index_type.t, 0, null()));
}
}
#[rstest]
fn test_llzk_is_more_concrete_unification(index_type: IndexType) {
unsafe {
assert!(llzkIsMoreConcreteUnification(
index_type.t,
index_type.t,
Some(test_callback1),
null_mut()
));
}
}
#[rstest]
fn test_llzk_force_int_attr_type(i16_type: I16Type) {
unsafe {
let location = mlirLocationUnknownGet(i16_type.context.ctx);
let in_attr = mlirIntegerAttrGet(i16_type.t, 0);
let out_attr = llzkForceIntAttrType(in_attr, location);
assert!(!mlirAttributeEqual(in_attr, out_attr));
}
}
#[rstest]
fn test_llzk_is_valid_global_type(index_type: IndexType) {
unsafe {
assert!(llzkIsValidGlobalType(index_type.t));
}
}
unsafe extern "C" fn test_callback1(_: MlirType, _: MlirType, _: *mut c_void) -> bool {
true
}