use std::{
ffi::CString,
ptr::{null, null_mut},
};
use crate::sanity_tests::dialect::{TestOp, first_op, parse_module, test_op};
use mlir_sys::{
MlirOperation, MlirStringRef, mlirAffineConstantExprGet, mlirAffineMapGet, mlirArrayAttrGet,
mlirAttributeEqual, mlirFlatSymbolRefAttrGet, mlirIdentifierStr, mlirIndexTypeGet,
mlirLocationUnknownGet, mlirModuleGetBody, mlirOperationCreate, mlirOperationDestroy,
mlirOperationGetContext, mlirOperationGetName, mlirOperationGetResult,
mlirOperationStateAddResults, mlirOperationStateGet, mlirStringRefCreateFromCString,
};
use rstest::rstest;
use std::alloc::{Layout, alloc, dealloc};
use crate::{
MlirValueRange, llzkFunction_FuncDefOpNameIsProduct, llzkOperationIsA_Struct_MemberDefOp,
llzkOperationIsA_Struct_StructDefOp, llzkStruct_MemberDefOpHasPublicAttr,
llzkStruct_MemberDefOpSetPublicAttr, llzkStruct_MemberReadOpBuild,
llzkStruct_MemberReadOpBuildWithAffineMapDistance,
llzkStruct_MemberReadOpBuildWithLiteralDistance,
llzkStruct_MemberReadOpBuildWithTemplateSymbolDistance, llzkStruct_StructDefOpGetBody,
llzkStruct_StructDefOpGetBodyRegion,
llzkStruct_StructDefOpGetComputeFuncOp, llzkStruct_StructDefOpGetConstrainFuncOp,
llzkStruct_StructDefOpGetFullyQualifiedName, llzkStruct_StructDefOpGetHeaderString,
llzkStruct_StructDefOpGetMemberDef, llzkStruct_StructDefOpGetMemberDefs,
llzkStruct_StructDefOpGetNumMemberDefs, llzkStruct_StructDefOpGetProductFuncOp,
llzkStruct_StructDefOpGetType, llzkStruct_StructDefOpGetTypeWithParams,
llzkStruct_StructDefOpHasColumns, llzkStruct_StructDefOpIsMainComponent,
llzkStruct_StructTypeGet, llzkStruct_StructTypeGetNameRef, llzkStruct_StructTypeGetParams,
llzkStruct_StructTypeGetWithArrayAttr, llzkStruct_StructTypeGetWithAttrs,
llzkTypeIsA_Struct_StructType, mlirGetDialectHandle__llzk__component__, mlirOpBuilderCreate,
mlirOpBuilderDestroy, sanity_tests::{TestContext, context, identifier, str_ref},
};
fn string_ref_eq(value: MlirStringRef, expected: &str) -> bool {
value.length == expected.len()
&& !value.data.is_null()
&& unsafe { std::slice::from_raw_parts(value.data.cast::<u8>(), value.length) }
== expected.as_bytes()
}
fn op_name_eq(op: MlirOperation, expected: &str) -> bool {
unsafe { string_ref_eq(mlirIdentifierStr(mlirOperationGetName(op)), expected) }
}
#[test]
fn test_mlir_get_dialect_handle_llzk_component() {
unsafe {
mlirGetDialectHandle__llzk__component__();
}
}
#[rstest]
fn test_llzk_struct_type_get(context: TestContext) {
unsafe {
let s = str_ref("T");
let s = mlirFlatSymbolRefAttrGet(context.ctx, s);
let t = llzkStruct_StructTypeGet(s);
assert_ne!(t.ptr, null());
}
}
#[rstest]
fn test_llzk_struct_type_get_with_array_attr(context: TestContext) {
unsafe {
let s = str_ref("T");
let s = mlirFlatSymbolRefAttrGet(context.ctx, s);
let attrs = [mlirFlatSymbolRefAttrGet(context.ctx, str_ref("A"))];
let a = mlirArrayAttrGet(
context.ctx,
isize::try_from(attrs.len()).expect("attrs too large"),
attrs.as_ptr(),
);
let t = llzkStruct_StructTypeGetWithArrayAttr(s, a);
assert_ne!(t.ptr, null());
}
}
#[rstest]
fn test_llzk_struct_type_get_with_attrs(context: TestContext) {
unsafe {
let s = str_ref("T");
let s = mlirFlatSymbolRefAttrGet(context.ctx, s);
let attrs = [mlirFlatSymbolRefAttrGet(context.ctx, str_ref("A"))];
let t = llzkStruct_StructTypeGetWithAttrs(
s,
isize::try_from(attrs.len()).expect("attrs too large"),
attrs.as_ptr(),
);
assert_ne!(t.ptr, null());
}
}
#[rstest]
fn test_llzk_type_is_a_struct_type(context: TestContext) {
unsafe {
let s = str_ref("T");
let s = mlirFlatSymbolRefAttrGet(context.ctx, s);
let t = llzkStruct_StructTypeGet(s);
assert_ne!(t.ptr, null());
assert!(llzkTypeIsA_Struct_StructType(t));
}
}
#[rstest]
fn test_llzk_struct_type_get_name(context: TestContext) {
unsafe {
let s = str_ref("T");
let s = mlirFlatSymbolRefAttrGet(context.ctx, s);
let t = llzkStruct_StructTypeGet(s);
assert_ne!(t.ptr, null());
assert!(mlirAttributeEqual(s, llzkStruct_StructTypeGetNameRef(t)));
}
}
#[rstest]
fn test_llzk_struct_type_get_params(context: TestContext) {
unsafe {
let s = str_ref("T");
let s = mlirFlatSymbolRefAttrGet(context.ctx, s);
let attrs = [mlirFlatSymbolRefAttrGet(context.ctx, str_ref("A"))];
let a = mlirArrayAttrGet(
context.ctx,
isize::try_from(attrs.len()).expect("attrs too large"),
attrs.as_ptr(),
);
let t = llzkStruct_StructTypeGetWithArrayAttr(s, a);
assert_ne!(t.ptr, null());
assert!(mlirAttributeEqual(a, llzkStruct_StructTypeGetParams(t)));
}
}
#[rstest]
fn test_llzk_operation_is_a_struct_def_op(test_op: TestOp) {
unsafe {
assert!(!llzkOperationIsA_Struct_StructDefOp(test_op.op));
}
}
#[rstest]
fn test_llzk_struct_def_op_get_body_region(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetBodyRegion(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_body(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetBody(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_type(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetType(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_type_with_params(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
let attrs = mlirArrayAttrGet(mlirOperationGetContext(test_op.op), 0, null());
llzkStruct_StructDefOpGetTypeWithParams(test_op.op, attrs);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_field_def(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
let name = identifier(test_op.context.as_ref(), "p");
llzkStruct_StructDefOpGetMemberDef(test_op.op, name);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_field_defs(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetMemberDefs(test_op.op, null_mut());
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_num_field_defs(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetNumMemberDefs(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_has_columns(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpHasColumns(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_compute_func_op(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetComputeFuncOp(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_constrain_func_op(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetConstrainFuncOp(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_product_func_op(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetProductFuncOp(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_product_func_op_positive(context: TestContext) {
let module = parse_module(
context.ctx,
r#"
module attributes {llzk.lang} {
struct.def @StructProd {
function.def @product() -> !struct.type<@StructProd<[]>> attributes {function.allow_constraint, function.allow_non_native_field_ops, function.allow_witness} {
%self = struct.new : <@StructProd<[]>>
function.return %self : !struct.type<@StructProd<[]>>
}
}
}
"#,
);
unsafe {
let struct_def = first_op(mlirModuleGetBody(module.module));
assert!(op_name_eq(struct_def, "struct.def"));
let got = llzkStruct_StructDefOpGetProductFuncOp(struct_def);
assert!(!got.ptr.is_null());
assert!(op_name_eq(got, "function.def"));
assert!(llzkFunction_FuncDefOpNameIsProduct(got));
}
}
#[rstest]
fn test_llzk_struct_def_op_get_header_string(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
use core::ffi::c_char;
extern "C" fn allocator(size: usize) -> *mut c_char {
let layout = Layout::array::<c_char>(size).expect("failed to define string layout");
unsafe { alloc(layout) as *mut c_char }
}
let mut size = 0;
let str = llzkStruct_StructDefOpGetHeaderString(test_op.op, &mut size, Some(allocator));
let size = usize::try_from(size).expect("string size is negative or too large");
let layout = Layout::array::<c_char>(size).expect("failed to define string layout");
dealloc(str as *mut u8, layout);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_fully_qualified_name(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpGetFullyQualifiedName(test_op.op);
}
}
}
#[rstest]
fn test_llzk_struct_def_op_get_is_main_component(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_StructDefOp(test_op.op) {
llzkStruct_StructDefOpIsMainComponent(test_op.op);
}
}
}
#[rstest]
fn test_llzk_operation_is_a_field_def_op(test_op: TestOp) {
unsafe {
assert!(!llzkOperationIsA_Struct_MemberDefOp(test_op.op));
}
}
#[rstest]
fn test_llzk_field_def_op_get_has_public_attr(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_MemberDefOp(test_op.op) {
llzkStruct_MemberDefOpHasPublicAttr(test_op.op);
}
}
}
#[rstest]
fn test_llzk_field_def_op_set_public_attr(test_op: TestOp) {
unsafe {
if llzkOperationIsA_Struct_MemberDefOp(test_op.op) {
llzkStruct_MemberDefOpSetPublicAttr(test_op.op, true);
}
}
}
fn new_struct(context: &TestContext) -> MlirOperation {
unsafe {
let ctx = context.ctx;
let struct_name = mlirFlatSymbolRefAttrGet(context.ctx, str_ref("S"));
let arith_constant_op_str = CString::new("struct.new").unwrap();
let name = mlirStringRefCreateFromCString(arith_constant_op_str.as_ptr());
let location = mlirLocationUnknownGet(ctx);
let result = llzkStruct_StructTypeGet(struct_name);
let mut op_state = mlirOperationStateGet(name, location);
mlirOperationStateAddResults(&mut op_state, 1, &result);
mlirOperationCreate(&mut op_state)
}
}
#[rstest]
fn test_llzk_field_read_op_build(context: TestContext) {
unsafe {
let builder = mlirOpBuilderCreate(context.ctx);
let location = mlirLocationUnknownGet(context.ctx);
let index_type = mlirIndexTypeGet(context.ctx);
let r#struct = new_struct(&context);
let struct_value = mlirOperationGetResult(r#struct, 0);
let name = identifier(context.as_ref(), "f");
let op = llzkStruct_MemberReadOpBuild(builder, location, index_type, struct_value, name);
mlirOperationDestroy(op);
mlirOperationDestroy(r#struct);
mlirOpBuilderDestroy(builder);
}
}
#[rstest]
fn test_llzk_field_read_op_build_with_affine_map_distance(context: TestContext) {
unsafe {
let builder = mlirOpBuilderCreate(context.ctx);
let location = mlirLocationUnknownGet(context.ctx);
let index_type = mlirIndexTypeGet(context.ctx);
let r#struct = new_struct(&context);
let struct_value = mlirOperationGetResult(r#struct, 0);
let mut exprs = [mlirAffineConstantExprGet(context.ctx, 1)];
let affine_map = mlirAffineMapGet(
context.ctx,
0,
0,
isize::try_from(exprs.len()).expect("exprs too large"),
exprs.as_mut_ptr(),
);
let values = &[];
let op = llzkStruct_MemberReadOpBuildWithAffineMapDistance(
builder,
location,
index_type,
struct_value,
identifier(context.as_ref(), "f"),
affine_map,
MlirValueRange {
values: values.as_ptr(),
size: isize::try_from(values.len()).expect("values too large"),
},
);
mlirOperationDestroy(op);
mlirOperationDestroy(r#struct);
mlirOpBuilderDestroy(builder);
}
}
#[rstest]
fn test_llzk_field_read_op_builder_with_template_symbol_distance(context: TestContext) {
unsafe {
let builder = mlirOpBuilderCreate(context.ctx);
let location = mlirLocationUnknownGet(context.ctx);
let index_type = mlirIndexTypeGet(context.ctx);
let r#struct = new_struct(&context);
let struct_value = mlirOperationGetResult(r#struct, 0);
let op = llzkStruct_MemberReadOpBuildWithTemplateSymbolDistance(
builder,
location,
index_type,
struct_value,
identifier(context.as_ref(), "f"),
str_ref("N"),
);
mlirOperationDestroy(op);
mlirOperationDestroy(r#struct);
mlirOpBuilderDestroy(builder);
}
}
#[rstest]
fn test_llzk_field_read_op_build_with_literal_distance(context: TestContext) {
unsafe {
let builder = mlirOpBuilderCreate(context.ctx);
let location = mlirLocationUnknownGet(context.ctx);
let index_type = mlirIndexTypeGet(context.ctx);
let r#struct = new_struct(&context);
let struct_value = mlirOperationGetResult(r#struct, 0);
let op = llzkStruct_MemberReadOpBuildWithLiteralDistance(
builder,
location,
index_type,
struct_value,
identifier(context.as_ref(), "f"),
1,
);
mlirOperationDestroy(op);
mlirOperationDestroy(r#struct);
mlirOpBuilderDestroy(builder);
}
}