use std::sync::Arc;
use proptest::prelude::*;
use morok_dtype::DType;
use morok_ir::UOp;
use crate::rewrite::graph_rewrite;
use crate::symbolic::symbolic_simple;
use morok_ir::test::property::generators::*;
use morok_ir::test::property::shrinking::{uop_depth, uop_op_count};
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn symbolic_idempotent(graph in arb_arithmetic_tree_up_to(DType::Int32, 4)) {
let matcher = symbolic_simple();
let once = graph_rewrite(&matcher, graph.clone(), &mut ());
let twice = graph_rewrite(&matcher, once.clone(), &mut ());
prop_assert!(Arc::ptr_eq(&once, &twice),
"Optimizing twice should give same result as optimizing once");
}
#[test]
fn symbolic_idempotent_known_props(kpg in arb_known_property_graph()) {
let graph = kpg.build();
let matcher = symbolic_simple();
let once = graph_rewrite(&matcher, graph, &mut ());
let twice = graph_rewrite(&matcher, once.clone(), &mut ());
prop_assert!(Arc::ptr_eq(&once, &twice),
"Known property graphs should be idempotent");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn cost_monotonic_op_count(graph in arb_arithmetic_tree_up_to(DType::Int32, 4)) {
let original_count = uop_op_count(&graph);
let matcher = symbolic_simple();
let optimized = graph_rewrite(&matcher, graph, &mut ());
let optimized_count = uop_op_count(&optimized);
prop_assert!(optimized_count <= original_count + 2,
"Optimization should not significantly increase op count: {} -> {}",
original_count, optimized_count);
}
#[test]
fn cost_depth_bounded(graph in arb_arithmetic_tree_up_to(DType::Int32, 4)) {
let original_depth = uop_depth(&graph);
let matcher = symbolic_simple();
let optimized = graph_rewrite(&matcher, graph, &mut ());
let optimized_depth = uop_depth(&optimized);
prop_assert!(optimized_depth <= original_depth + 1,
"Optimization should not significantly increase depth: {} -> {}",
original_depth, optimized_depth);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn preserves_dtype(graph in arb_arithmetic_tree_up_to(DType::Int32, 4)) {
let original_dtype = graph.dtype().clone();
let matcher = symbolic_simple();
let optimized = graph_rewrite(&matcher, graph, &mut ());
let optimized_dtype = optimized.dtype().clone();
prop_assert_eq!(original_dtype, optimized_dtype,
"Optimization must preserve dtype");
}
#[test]
fn constants_properly_typed(graph in arb_arithmetic_tree_up_to(DType::Int32, 4)) {
let matcher = symbolic_simple();
let optimized = graph_rewrite(&matcher, graph, &mut ());
verify_constant_dtypes(&optimized)?;
}
#[test]
fn no_cycles_created(graph in arb_arithmetic_tree_up_to(DType::Int32, 4)) {
let matcher = symbolic_simple();
let optimized = graph_rewrite(&matcher, graph, &mut ());
let _topo = optimized.toposort();
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(300))]
#[test]
#[ignore = "Distribution patterns conflict with compositional optimization"]
fn compositional_subexpr_optimization(
a in arb_arithmetic_tree_up_to(DType::Int32, 2),
b in arb_arithmetic_tree_up_to(DType::Int32, 2),
op in arb_arithmetic_binary_op(),
) {
let matcher = symbolic_simple();
let opt_a = graph_rewrite(&matcher, a.clone(), &mut ());
let opt_b = graph_rewrite(&matcher, b.clone(), &mut ());
let expr_opt_subs = UOp::new(
morok_ir::Op::Binary(op, opt_a, opt_b),
DType::Int32,
);
let final_opt = graph_rewrite(&matcher, expr_opt_subs, &mut ());
let final_count = uop_op_count(&final_opt);
let expr_unopt = UOp::new(
morok_ir::Op::Binary(op, a, b),
DType::Int32,
);
let direct_opt = graph_rewrite(&matcher, expr_unopt, &mut ());
let direct_count = uop_op_count(&direct_opt);
prop_assert!(final_count <= direct_count + 1,
"Compositional optimization should be nearly as good as direct: {} vs {}",
final_count, direct_count);
}
}
fn verify_constant_dtypes(uop: &Arc<UOp>) -> Result<(), TestCaseError> {
use morok_ir::Op;
match uop.op() {
Op::Const(cv) => {
let const_dtype = cv.0.dtype();
let uop_dtype = uop.dtype();
if let Some(scalar_dt) = uop_dtype.scalar() {
let expected_dtype = DType::Scalar(scalar_dt);
if const_dtype != expected_dtype {
let const_is_int = matches!(const_dtype.scalar(), Some(dt) if dt.is_int());
let uop_is_int = matches!(uop_dtype.scalar(), Some(dt) if dt.is_int());
prop_assert!(
const_is_int == uop_is_int,
"Constant dtype family mismatch: {:?} vs {:?}",
const_dtype,
expected_dtype
);
}
}
Ok(())
}
Op::Unary(_, src) => verify_constant_dtypes(src),
Op::Binary(_, lhs, rhs) => {
verify_constant_dtypes(lhs)?;
verify_constant_dtypes(rhs)
}
Op::Ternary(_, a, b, c) => {
verify_constant_dtypes(a)?;
verify_constant_dtypes(b)?;
verify_constant_dtypes(c)
}
_ => Ok(()),
}
}