llzk-sys 30.1.0

Rust bindings to the LLZK C API.
use std::ptr::null_mut;

use mlir_sys::{
    MlirAttribute, MlirContext, MlirNamedAttribute, MlirOperation, MlirType,
    mlirAttributeGetContext, mlirFlatSymbolRefAttrGet, mlirIdentifierGet, mlirIndexTypeGet,
    mlirIntegerAttrGet, mlirLocationUnknownGet, mlirNamedAttributeGet, mlirOperationCreate,
    mlirOperationDestroy, mlirOperationStateAddAttributes, mlirOperationStateGet, mlirTypeAttrGet,
    mlirUnitAttrGet,
};
use rstest::rstest;

use crate::{
    llzkGlobal_GlobalDefOpIsConstant, llzkOperationIsA_Global_GlobalDefOp,
    mlirGetDialectHandle__llzk__global__,
    sanity_tests::{TestContext, context, str_ref},
};

#[test]
fn test_mlir_get_dialect_handle_llzk_global() {
    unsafe {
        mlirGetDialectHandle__llzk__global__();
    }
}

fn named_attr(s: &'static str, attr: MlirAttribute) -> MlirNamedAttribute {
    unsafe {
        mlirNamedAttributeGet(
            mlirIdentifierGet(mlirAttributeGetContext(attr), str_ref(s)),
            attr,
        )
    }
}

fn create_global_def_op(
    ctx: MlirContext,
    sym_name: &'static str,
    constant: bool,
    r#type: MlirType,
    initial_value: Option<MlirAttribute>,
) -> MlirOperation {
    unsafe {
        let sym_name = mlirFlatSymbolRefAttrGet(ctx, str_ref(sym_name));
        let mut attrs = vec![
            named_attr("sym_name", sym_name),
            named_attr("type", mlirTypeAttrGet(r#type)),
        ];
        if constant {
            attrs.push(named_attr("constant", mlirUnitAttrGet(ctx)));
        }
        if let Some(value) = initial_value {
            attrs.push(named_attr("initial_value", value));
        }
        let name = str_ref("global.def");
        let mut state = mlirOperationStateGet(name, mlirLocationUnknownGet(ctx));
        mlirOperationStateAddAttributes(
            &mut state,
            isize::try_from(attrs.len()).expect("attrs too large"),
            attrs.as_ptr(),
        );

        mlirOperationCreate(&mut state)
    }
}

#[rstest]
fn test_llzk_operation_is_a_global_def_op(context: TestContext) {
    unsafe {
        let op = create_global_def_op(context.ctx, "G", false, mlirIndexTypeGet(context.ctx), None);
        assert_ne!(op.ptr, null_mut());
        assert!(llzkOperationIsA_Global_GlobalDefOp(op));
        mlirOperationDestroy(op);
    }
}

#[rstest]
fn test_llzk_global_def_op_get_is_constant_1(context: TestContext) {
    unsafe {
        let op = create_global_def_op(context.ctx, "G", false, mlirIndexTypeGet(context.ctx), None);
        assert_ne!(op.ptr, null_mut());
        assert!(!llzkGlobal_GlobalDefOpIsConstant(op));
        mlirOperationDestroy(op);
    }
}

#[rstest]
fn test_llzk_global_def_op_get_is_constant_2(context: TestContext) {
    unsafe {
        let op = create_global_def_op(
            context.ctx,
            "G",
            true,
            mlirIndexTypeGet(context.ctx),
            Some(mlirIntegerAttrGet(mlirIndexTypeGet(context.ctx), 1)),
        );
        assert_ne!(op.ptr, null_mut());
        assert!(llzkGlobal_GlobalDefOpIsConstant(op));
        mlirOperationDestroy(op);
    }
}