use std::sync::Arc;
use morok_ir::{ContiguousHint, Op, UOp};
use crate::rangeify::kernel::LocalAddBufferContext;
use crate::rangeify::patterns::rangeify_codegen_patterns;
fn apply_patterns(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
let matcher = rangeify_codegen_patterns();
let mut ctx = LocalAddBufferContext::new();
let result = crate::rewrite::graph_rewrite(&matcher, uop.clone(), &mut ctx);
if Arc::ptr_eq(&result, uop) { None } else { Some(result) }
}
fn apply_patterns_with_ctx(uop: &Arc<UOp>) -> (Arc<UOp>, LocalAddBufferContext) {
let matcher = rangeify_codegen_patterns();
let mut ctx = LocalAddBufferContext::new();
let result = crate::rewrite::graph_rewrite(&matcher, uop.clone(), &mut ctx);
(result, ctx)
}
#[test]
fn test_remove_noop_void_returns_none() {
let noop = UOp::noop();
let result = apply_patterns(&noop);
assert!(result.is_none());
}
#[test]
fn test_remove_noop_non_void() {
let noop = UOp::noop();
assert!(matches!(noop.op(), Op::Noop));
let result = apply_patterns(&noop);
assert!(result.is_none()); }
#[test]
fn test_remove_noop_returns_none_for_non_noop() {
let const_op = UOp::native_const(1.0f32);
let result = apply_patterns(&const_op);
assert!(result.is_none());
}
#[test]
fn test_remove_noop_pattern_matching() {
let noop = UOp::noop();
let const_op = UOp::native_const(0.0f32);
assert!(matches!(noop.op(), Op::Noop));
let noop_result = apply_patterns(&noop);
assert!(noop_result.is_none());
assert!(!matches!(const_op.op(), Op::Noop));
let const_result = apply_patterns(&const_op);
assert!(const_result.is_none());
}
#[test]
fn test_get_contiguous_removes_marker() {
let tensor = UOp::native_const(1.0f32);
let contiguous = tensor.contiguous();
let result = apply_patterns(&contiguous);
assert!(result.is_some());
let unwrapped = result.unwrap();
assert!(Arc::ptr_eq(&unwrapped, &tensor));
}
#[test]
fn test_get_contiguous_returns_none_for_non_contiguous() {
let const_op = UOp::native_const(1.0f32);
let result = apply_patterns(&const_op);
assert!(result.is_none());
}
#[test]
fn test_codegen_patterns_creates_matcher() {
let _matcher = rangeify_codegen_patterns();
}
#[test]
fn test_contiguous_opts_empty() {
let tensor = UOp::native_const(1.0f32);
let contiguous = tensor.contiguous();
let (_result, ctx) = apply_patterns_with_ctx(&contiguous);
assert!(ctx.opts.is_empty(), "ctx.opts should be empty when CONTIGUOUS has no hints");
}
#[test]
fn test_contiguous_opts_single_hint() {
let tensor = UOp::native_const(1.0f32);
let opts = smallvec::smallvec![ContiguousHint { op: "UPCAST".to_string(), axis: Some(0), arg: Some(4) }];
let contiguous = tensor.contiguous_with_opts(opts);
let (_result, ctx) = apply_patterns_with_ctx(&contiguous);
assert_eq!(ctx.opts.len(), 1, "ctx.opts should have 1 hint");
assert_eq!(ctx.opts[0].op, "UPCAST");
assert_eq!(ctx.opts[0].axis, Some(0));
assert_eq!(ctx.opts[0].arg, Some(4));
}
#[test]
fn test_contiguous_opts_multiple_hints() {
let tensor = UOp::native_const(1.0f32);
let opts = smallvec::smallvec![
ContiguousHint { op: "UPCAST".to_string(), axis: Some(0), arg: Some(4) },
ContiguousHint { op: "UPCAST".to_string(), axis: Some(1), arg: Some(4) },
];
let contiguous = tensor.contiguous_with_opts(opts);
let (_result, ctx) = apply_patterns_with_ctx(&contiguous);
assert_eq!(ctx.opts.len(), 2, "ctx.opts should have 2 hints");
assert_eq!(ctx.opts[0].op, "UPCAST");
assert_eq!(ctx.opts[0].axis, Some(0));
assert_eq!(ctx.opts[1].op, "UPCAST");
assert_eq!(ctx.opts[1].axis, Some(1));
}
#[test]
fn test_contiguous_opts_mixed_hint_types() {
let tensor = UOp::native_const(1.0f32);
let opts = smallvec::smallvec![
ContiguousHint { op: "UPCAST".to_string(), axis: Some(0), arg: Some(4) },
ContiguousHint { op: "UNROLL".to_string(), axis: Some(1), arg: Some(4) },
];
let contiguous = tensor.contiguous_with_opts(opts);
let (_result, ctx) = apply_patterns_with_ctx(&contiguous);
assert_eq!(ctx.opts.len(), 2);
assert_eq!(ctx.opts[0].op, "UPCAST");
assert_eq!(ctx.opts[1].op, "UNROLL");
}
#[test]
fn test_contiguous_opts_four_hints() {
let tensor = UOp::native_const(1.0f32);
let opts = smallvec::smallvec![
ContiguousHint { op: "UPCAST".to_string(), axis: Some(0), arg: Some(4) },
ContiguousHint { op: "UPCAST".to_string(), axis: Some(1), arg: Some(4) },
ContiguousHint { op: "UNROLL".to_string(), axis: Some(0), arg: Some(4) },
ContiguousHint { op: "UNROLL".to_string(), axis: Some(1), arg: Some(4) },
];
let contiguous = tensor.contiguous_with_opts(opts);
let (_result, ctx) = apply_patterns_with_ctx(&contiguous);
assert_eq!(ctx.opts.len(), 4, "ctx.opts should have 4 hints");
assert_eq!(ctx.opts[0].op, "UPCAST");
assert_eq!(ctx.opts[0].axis, Some(0));
assert_eq!(ctx.opts[1].op, "UPCAST");
assert_eq!(ctx.opts[1].axis, Some(1));
assert_eq!(ctx.opts[2].op, "UNROLL");
assert_eq!(ctx.opts[2].axis, Some(0));
assert_eq!(ctx.opts[3].op, "UNROLL");
assert_eq!(ctx.opts[3].axis, Some(1));
}
#[test]
fn test_contiguous_opts_returns_source() {
let tensor = UOp::native_const(42.0f32);
let opts = smallvec::smallvec![ContiguousHint { op: "LOCAL".to_string(), axis: Some(2), arg: Some(8) }];
let contiguous = tensor.contiguous_with_opts(opts);
let (result, _ctx) = apply_patterns_with_ctx(&contiguous);
assert!(Arc::ptr_eq(&result, &tensor));
}
#[test]
fn test_contiguous_opts_hint_without_axis() {
let tensor = UOp::native_const(1.0f32);
let opts = smallvec::smallvec![ContiguousHint { op: "NOLOCALS".to_string(), axis: None, arg: None }];
let contiguous = tensor.contiguous_with_opts(opts);
let (_result, ctx) = apply_patterns_with_ctx(&contiguous);
assert_eq!(ctx.opts.len(), 1);
assert_eq!(ctx.opts[0].op, "NOLOCALS");
assert_eq!(ctx.opts[0].axis, None);
assert_eq!(ctx.opts[0].arg, None);
}