use llzk::{
builder::OpBuilder,
dialect::poly::{
TemplateExprOpLike, TemplateOpLike, TemplateParamOpLike, applymap, expr, is_applymap_op,
is_expr_op, is_param_op, is_template_op, is_unifiable_cast_op, is_yield_op, param,
template, unifiable_cast, r#yield,
},
prelude::*,
};
use melior::{dialect::arith, ir::Location};
use rstest::rstest;
mod common;
#[test]
fn get_type() {
common::setup();
let context = LlzkContext::new();
let t = TVarType::new(&context, StringRef::new("A"));
let ir = format!("{t}");
let expected = "!poly.tvar<@A>";
assert_eq!(ir, expected);
}
#[test]
fn get_type_name_ref() {
common::setup();
let context = LlzkContext::new();
let t = TVarType::new(&context, StringRef::new("A"));
let ir = format!("{:?}", t.name().as_str().unwrap());
let expected = "\"A\"";
assert_eq!(ir, expected);
}
#[test]
fn create_read_const() {
common::setup();
let context = LlzkContext::new();
let loc = Location::unknown(&context);
let op = dialect::poly::read_const(loc, "A", FeltType::new(&context).into());
let ir = format!("{op}");
let expected = "%0 = poly.read_const @A : !felt.type\n";
assert_eq!(ir, expected);
assert!(op.verify());
}
#[test]
fn is_read_const() {
common::setup();
let context = LlzkContext::new();
let loc = Location::unknown(&context);
let op = dialect::poly::read_const(loc, "C", IntegerType::new(&context, 64).into());
let op_ref = unsafe { OperationRef::from_raw(op.to_raw()) };
assert!(dialect::poly::is_read_const_op(&op_ref));
}
#[test]
fn create_param() {
common::setup();
let context = LlzkContext::new();
let loc = Location::unknown(&context);
let module = llzk_module(loc);
let op = param(
loc,
"T",
Some(TVarType::new(&context, StringRef::new("T")).into()),
)
.unwrap();
let ir = format!("{op}");
assert!(ir.contains("\"poly.param\""));
assert!(ir.contains("sym_name = \"T\""));
assert!(ir.contains("type_opt = !poly.tvar<@T>"));
assert!(op.type_opt().is_some());
assert!(is_param_op(&op));
let tmpl = template(loc, "tmpl", [Ok(op.into())]).unwrap();
let tmpl = module.body().append_operation(tmpl.into());
assert!(tmpl.verify());
}
#[test]
fn create_template_with_param_and_expr() {
common::setup();
let context = LlzkContext::new();
let module = llzk_module(Location::unknown(&context));
let loc = Location::unknown(&context);
let c1 = arith::constant(
&context,
IntegerAttribute::new(Type::index(&context), 1).into(),
loc,
);
let c1_res = c1.result(0).unwrap();
let tmpl = template(
loc,
"tmpl",
[
param(loc, "T", None).map(Into::into),
expr(
loc,
"N",
[Ok(c1), r#yield(loc, c1_res.into()).map(Into::into)],
)
.map(Into::into),
],
)
.unwrap();
assert!(tmpl.has_const_param_ops());
assert!(tmpl.has_const_expr_ops());
assert!(tmpl.has_const_param_named("T"));
assert!(tmpl.has_const_expr_named("N"));
assert_eq!(tmpl.const_param_names().len(), 1);
assert_eq!(tmpl.const_expr_names().len(), 1);
assert!(is_template_op(&tmpl));
let tmpl = module.body().append_operation(tmpl.into());
let ir = format!("{}", module.as_operation());
let expected = r#"module attributes {llzk.lang} {
poly.template @tmpl {
poly.param @T
poly.expr @N {
%c1 = arith.constant 1 : index
poly.yield %c1 : index
}
}
}
"#;
assert_eq!(ir, expected);
assert!(tmpl.verify());
}
#[test]
fn empty_struct_with_one_param() {
common::setup();
let context = LlzkContext::new();
let module = llzk_module(Location::unknown(&context));
let loc = Location::unknown(&context);
let typ = StructType::new(
SymbolRefAttribute::new(&context, "tmpl", &["empty"]),
&[FlatSymbolRefAttribute::new(&context, "T").into()],
);
let s = dialect::r#struct::def(
loc,
"empty",
[
dialect::r#struct::helpers::compute_fn(loc, typ, &[], None).map(Into::into),
dialect::r#struct::helpers::constrain_fn(loc, typ, &[], None).map(Into::into),
],
)
.unwrap();
let tmpl = template(
loc,
"tmpl",
[param(loc, "T", None).map(Into::into), Ok(s.into())],
)
.unwrap();
let tmpl = module.body().append_operation(tmpl.into());
assert_test!(tmpl, module, @file "expected/empty_struct_with_one_param.mlir");
}
#[test]
fn create_expr_and_get_type() {
common::setup();
let context = LlzkContext::new();
let module = llzk_module(Location::unknown(&context));
let loc = Location::unknown(&context);
let c2 = arith::constant(
&context,
IntegerAttribute::new(Type::index(&context), 2).into(),
loc,
);
let c2_res = c2.result(0).unwrap();
let op = expr(
loc,
"Two",
[Ok(c2), r#yield(loc, c2_res.into()).map(Into::into)],
)
.unwrap();
assert!(is_expr_op(&op));
assert_eq!(format!("{}", op.expr_type()), "index");
assert_eq!(
op.initializer_region()
.first_block()
.unwrap()
.argument_count(),
0
);
let tmpl = template(loc, "tmpl", [Ok(op.into())]).unwrap();
let tmpl = module.body().append_operation(tmpl.into());
assert!(tmpl.verify());
}
#[test]
fn create_yield() {
common::setup();
let context = LlzkContext::new();
let loc = Location::unknown(&context);
let module = Module::new(loc);
let block = module.body();
let c3 = arith::constant(
&context,
IntegerAttribute::new(Type::index(&context), 3).into(),
loc,
);
let c3 = block.append_operation(c3);
let y = r#yield(loc, c3.result(0).unwrap().into()).unwrap();
let y = block.append_operation(y.into());
assert!(is_yield_op(&y));
let ir = format!("{block}");
assert!(ir.contains("\"poly.yield\"(%0)"));
assert!(ir.contains("value = 3 : index"));
}
fn create_index_constant<'c>(
ctx: &'c Context,
block: &Block<'c>,
location: Location<'c>,
i: i64,
) -> Value<'c, 'c> {
let int_attr = IntegerAttribute::new(Type::index(ctx), i);
let op = arith::constant(ctx, int_attr.into(), location);
let op_ref = block.append_operation(op);
assert_eq!(1, op_ref.result_count());
op_ref.result(0).unwrap().into()
}
#[rstest]
#[case("affine_map<()[] -> (2)>", &[],
r"^bb0:
%0 = poly.applymap () affine_map<() -> (2)>
")]
#[case("affine_map<(i)[] -> (i)>", &[1],
r"^bb0:
%c1 = arith.constant 1 : index
%0 = poly.applymap (%c1) affine_map<(d0) -> (d0)>
")]
#[case("affine_map<()[s0, s1] -> (s0 + s1)>", &[7, 9],
r"^bb0:
%c7 = arith.constant 7 : index
%c9 = arith.constant 9 : index
%0 = poly.applymap ()[%c7, %c9] affine_map<()[s0, s1] -> (s0 + s1)>
")]
#[case("affine_map<(i, j) -> (i + j)>", &[2, 4],
r"^bb0:
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%0 = poly.applymap (%c2, %c4) affine_map<(d0, d1) -> (d0 + d1)>
")]
fn create_applymap(#[case] affine_map: &str, #[case] ops: &[i64], #[case] expected: &str) {
common::setup();
let context = LlzkContext::new();
let location = Location::unknown(&context);
let affine_map =
Attribute::parse(&context, affine_map).expect("could not parse affine_map attribute");
let module = Module::new(location);
let block = module.body();
let operands = ops
.iter()
.map(|i| create_index_constant(&context, &block, location, *i))
.collect::<Vec<_>>();
let applymap_op = applymap(location, affine_map, &operands);
assert!(applymap_op.verify(), "op {applymap_op} failed to verify");
assert!(is_applymap_op(&applymap_op));
block.append_operation(applymap_op);
let ir = format!("{block}");
assert_eq!(ir, expected);
}
#[test]
fn create_unifiable_cast() {
common::setup();
let context = LlzkContext::new();
let location = Location::unknown(&context);
let module = Module::new(location);
let block = module.body();
let affine_map_str = "affine_map<()[s0, s1] -> (s0 + s1)>";
let affine_map =
Attribute::parse(&context, affine_map_str).expect("could not parse affine_map attribute");
let array_ty = ArrayType::new(
FeltType::new(&context).into(),
&[FlatSymbolRefAttribute::new(&context, "N").into()],
);
let array_op = dialect::array::new(
&OpBuilder::new(&context),
location,
array_ty,
llzk::dialect::array::ArrayCtor::Values(&[]),
);
let array_op = block.append_operation(array_op);
let new_array_ty = ArrayType::new(FeltType::new(&context).into(), &[affine_map]);
let cast = unifiable_cast(
location,
array_op.result(0).unwrap().into(),
new_array_ty.into(),
);
let cast = block.append_operation(cast);
assert!(cast.verify(), "op {cast} failed to verify");
assert!(is_unifiable_cast_op(&cast));
let expected = r#"^bb0:
%0 = "array.new"() <{mapOpGroupSizes = array<i32>, numDimsPerMap = array<i32>, operandSegmentSizes = array<i32: 0, 0>}> : () -> !array.type<@N x !felt.type>
%1 = "poly.unifiable_cast"(%0) : (!array.type<@N x !felt.type>) -> !array.type<affine_map<()[s0, s1] -> (s0 + s1)> x !felt.type>
"#;
let ir = format!("{block}");
assert_eq!(ir, expected);
}