morok-schedule 0.1.0-alpha.2

Optimization passes and pattern engine for the Morok ML compiler
Documentation
//! Integration tests for codegen pattern pipeline.
//!
//! Tests verify that all patterns work together correctly in the complete
//! transformation pipeline from BUFFERIZE to KERNEL operations.
//!
//! Adapted from Tinygrad's test_rangeify.py and test_assign.py.

use std::sync::Arc;

use morok_device::DeviceSpec;
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, Op, UOp};

use crate::rangeify::kernel::{LocalAddBufferContext, split_store};
use crate::rangeify::patterns::rangeify_codegen_patterns;
use crate::rangeify::transforms::find_bufs;

/// Helper to call split_store with the new signature
fn call_split_store(x: &Arc<UOp>) -> Option<Arc<UOp>> {
    let mut uop_list = Vec::new();
    split_store(&mut uop_list, x)
}

/// Helper to apply rangeify_codegen patterns and return result
fn apply_codegen_patterns(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
    let matcher = rangeify_codegen_patterns();
    let mut ctx = LocalAddBufferContext::new();
    match matcher.rewrite(uop, &mut ctx) {
        morok_ir::pattern::RewriteResult::Rewritten(result) => Some(result),
        _ => None,
    }
}

/// Test that remove_noop integrates correctly in pipeline context.
///
/// Verifies that NOOP operations in computation graphs get replaced
/// with zero constants during pattern application.
#[test]
fn test_remove_noop_in_pipeline() {
    // Create a simple NOOP (Void dtype returns None)
    let noop = UOp::noop();

    // Pattern should return None for Void dtype
    let result = apply_codegen_patterns(&noop);
    assert!(result.is_none());

    // In actual pipeline, NOOPs with real dtypes would be replaced
    // This is a structural test confirming the pattern is correctly implemented
}

/// Test that get_contiguous removes CONTIGUOUS markers in pipeline.
///
/// Based on Tinygrad's test_schedule.py:659-676 contiguous tests.
#[test]
fn test_get_contiguous_in_pipeline() {
    // Create a computation with CONTIGUOUS marker
    let value = UOp::native_const(42.0f32);
    let contiguous = value.contiguous();

    // Pattern should remove the CONTIGUOUS marker
    let result = apply_codegen_patterns(&contiguous);
    assert!(result.is_some());

    let unwrapped = result.unwrap();
    assert!(Arc::ptr_eq(&unwrapped, &value));
}

// AFTER(EXPAND) pattern was removed — Tinygrad doesn't have it.

/// Test cycle detection with valid buffer accesses.
///
/// Based on Tinygrad's test_assign.py:155-160 (diamond pattern without cycle).
#[test]
fn test_no_cycle_valid_access_pattern() {
    // Create a valid pattern: LOAD from input buffer, STORE to output buffer
    let in_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let const_idx = UOp::index_const(0);

    // LOAD from input using INDEX node
    let load_idx = UOp::index().buffer(in_buf.clone()).indices(vec![const_idx.clone()]).call().unwrap();
    let loaded = UOp::load().buffer(in_buf.clone()).index(load_idx).call();

    // Compute something
    let const_val = UOp::native_const(2.0f32);
    let computed = loaded.try_mul(&const_val).unwrap();

    // STORE to output using INDEX node
    let store_idx = UOp::index().buffer(out_buf.clone()).indices(vec![const_idx]).call().unwrap();
    let store = store_idx.store(computed);

    // Should not panic - valid access pattern
    #[allow(clippy::mutable_key_type)]
    let buf_accesses = find_bufs(&store);

    // Verify we tracked both buffers correctly
    assert_eq!(buf_accesses.len(), 2);
}

/// Test split_store integration with simple STORE operation.
///
/// Verifies the complete pipeline: filtering, transformation, validation, KERNEL creation.
#[test]
fn test_split_store_simple_kernel() {
    // Create a simple STORE operation
    let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let const_idx = UOp::index_const(0);
    let value = UOp::native_const(1.0f32);
    let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
    let store = store_idx.store(value);

    // split_store may succeed if the STORE has no non-OUTER ranges in scope
    let result = call_split_store(&store);

    // Verify result is a KERNEL operation if successful
    if let Some(kernel) = result {
        assert!(matches!(kernel.op(), Op::Kernel { .. }));
    }
}

/// Test split_store filtering with non-OUTER ranges.
///
/// Verifies that kernels are only created at OUTER range boundaries.
///
/// Note: This test may pass if has_non_outer_ranges() is not implemented yet
/// or if the END operation doesn't carry range context in the expected way.
#[test]
fn test_split_store_with_loop_ranges() {
    // Create a STORE with LOOP ranges
    let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let const_idx = UOp::index_const(0);
    let value = UOp::native_const(1.0f32);
    let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
    let store = store_idx.store(value);

    // Wrap in END with LOOP range
    let range_end = UOp::index_const(10);
    let loop_range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
    let end = store.end(vec![loop_range].into());

    // Try to split - behavior depends on has_non_outer_ranges() implementation
    let result = call_split_store(&end);

    // If successful, verify it's a KERNEL
    if let Some(kernel) = result {
        assert!(matches!(kernel.op(), Op::Kernel { .. }));
    }
}

/// Test that patterns are applied in correct order during transformation.
///
/// This integration test verifies the pipeline applies patterns sequentially:
/// 1. to_param patterns
/// 2. rangeify_codegen patterns (remove_noop, get_contiguous, fix_after_broadcast)
/// 3. Cycle detection
/// 4. SINK/KERNEL creation
#[test]
fn test_pattern_application_order() {
    // Create a computation with patterns to apply
    let value = UOp::native_const(1.0f32);

    // Wrap in CONTIGUOUS (should be removed by get_contiguous)
    let contiguous = value.contiguous();

    // In real pipeline, this would go through split_store
    // For now, verify pattern works via matcher
    let result = apply_codegen_patterns(&contiguous);
    assert!(result.is_some());
}

/// Test integration with multiple buffer accesses.
///
/// Based on Tinygrad's test_schedule.py buffer fusion tests.
#[test]
fn test_multiple_buffer_integration() {
    // Create computation with multiple input buffers
    let buf1 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let buf2 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let const_idx = UOp::index_const(0);

    // LOAD from both inputs using INDEX nodes
    let load1_idx = UOp::index().buffer(buf1.clone()).indices(vec![const_idx.clone()]).call().unwrap();
    let load1 = UOp::load().buffer(buf1.clone()).index(load1_idx).call();
    let load2_idx = UOp::index().buffer(buf2.clone()).indices(vec![const_idx.clone()]).call().unwrap();
    let load2 = UOp::load().buffer(buf2.clone()).index(load2_idx).call();

    // Compute sum
    let sum = load1.try_add(&load2).unwrap();

    // STORE to output using INDEX node
    let store_idx = UOp::index().buffer(out_buf.clone()).indices(vec![const_idx]).call().unwrap();
    let store = store_idx.store(sum);

    // Verify cycle detection works
    #[allow(clippy::mutable_key_type)]
    let buf_accesses = find_bufs(&store);
    assert_eq!(buf_accesses.len(), 3);
}

/// Test that END preserves STORE structure through pipeline.
///
/// Verifies split_store correctly handles END(STORE) wrapper.
#[test]
fn test_end_store_structure() {
    // Create STORE
    let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
    let const_idx = UOp::index_const(0);
    let value = UOp::native_const(1.0f32);
    let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
    let store = store_idx.store(value);

    // Wrap in END (normal pipeline output)
    let range_end = UOp::index_const(10);
    let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
    let end = store.end(vec![range].into());

    // Verify END wraps STORE correctly before transformation
    if let Op::End { computation, .. } = end.op() {
        assert!(matches!(computation.op(), Op::Store { .. }));
    } else {
        panic!("Expected END operation");
    }

    // split_store should handle END(STORE) structure
    let result = call_split_store(&end);

    // If successful, verify it's a KERNEL
    if let Some(kernel) = result {
        assert!(matches!(kernel.op(), Op::Kernel { .. }));
    }
}