morok-schedule 0.1.0-alpha.2

Optimization passes and pattern engine for the Morok ML compiler
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
//! Test helpers for devectorize.rs tests.
//!
//! Provides builders for creating test UOps and assertion helpers.
//! Mirrors Tinygrad's test patterns for memory access operations.

use std::sync::Arc;

use morok_dtype::{AddrSpace, DType, ScalarDType};
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use smallvec::SmallVec;

use crate::devectorize::{
    bool_storage_patterns, correct_load_store_patterns, devectorize, load_store_folding_patterns,
    load_store_indexing_patterns, no_vectorized_alu, pm_render,
};
use crate::rewrite::graph_rewrite;

// =============================================================================
// Phase Application Helpers
// =============================================================================

/// Apply full devectorize pass to a UOp.
///
/// Now uses single-pass rewriting (aligned with Tinygrad), followed by pm_render
/// to convert CAT to VECTORIZE for rendering.
pub fn apply_devectorize(uop: &Arc<UOp>) -> Arc<UOp> {
    let devectorized = devectorize(uop);
    // Also run pm_render to convert CAT to VECTORIZE (required for codegen)
    graph_rewrite(pm_render(), devectorized, &mut ())
}

/// Apply load_store_folding patterns only.
///
/// Includes: expand_index, GEP movement, PTRCAT distribution.
pub fn apply_load_store_folding(uop: &Arc<UOp>) -> Arc<UOp> {
    graph_rewrite(load_store_folding_patterns(), uop.clone(), &mut ())
}

/// Apply correct_load_store patterns only.
///
/// Includes: split_load, split_store (CAST(INDEX) patterns).
pub fn apply_correct_load_store(uop: &Arc<UOp>) -> Arc<UOp> {
    graph_rewrite(correct_load_store_patterns(), uop.clone(), &mut ())
}

/// Apply bool storage patterns only.
///
/// Converts bool LOAD/STORE to uint8.
pub fn apply_bool_storage(uop: &Arc<UOp>) -> Arc<UOp> {
    graph_rewrite(bool_storage_patterns(), uop.clone(), &mut ())
}

/// Apply pm_render patterns (post-devectorize rendering).
///
/// Includes: CAT→VECTORIZE, multi-index GEP→VECTORIZE, unwrap single-element.
pub fn apply_pm_render(uop: &Arc<UOp>) -> Arc<UOp> {
    graph_rewrite(pm_render(), uop.clone(), &mut ())
}

/// Apply ALU devectorization patterns.
pub fn apply_no_vectorized_alu(uop: &Arc<UOp>) -> Arc<UOp> {
    graph_rewrite(no_vectorized_alu(), uop.clone(), &mut ())
}

/// Apply pm_render patterns for VECTORIZE normalization.
///
/// (Legacy name for compatibility - now uses pm_render)
pub fn apply_vectorize_normalize(uop: &Arc<UOp>) -> Arc<UOp> {
    apply_pm_render(uop)
}

/// Apply load_store_indexing patterns (gate dropping).
pub fn apply_load_store_indexing(uop: &Arc<UOp>) -> Arc<UOp> {
    graph_rewrite(load_store_indexing_patterns(), uop.clone(), &mut ())
}

/// Apply cast_after pattern.
pub fn apply_cast_after(uop: &Arc<UOp>) -> Arc<UOp> {
    use crate::devectorize::devectorize_patterns;
    graph_rewrite(devectorize_patterns(), uop.clone(), &mut ())
}

// =============================================================================
// Buffer Builders
// =============================================================================

/// Create a global buffer with float32 element type.
///
/// Returns a BUFFER UOp with Ptr dtype pointing to float32 data.
pub fn create_buffer(size: usize) -> Arc<UOp> {
    create_buffer_typed(size, ScalarDType::Float32)
}

/// Create a global buffer with specified element type.
pub fn create_buffer_typed(size: usize, scalar: ScalarDType) -> Arc<UOp> {
    let dtype = DType::Scalar(scalar).ptr(Some(size), AddrSpace::Global);
    UOp::new_buffer(morok_dtype::DeviceSpec::Cpu, size, dtype)
}

/// Create a local (shared) memory buffer.
pub fn create_buffer_local(size: usize, scalar: ScalarDType) -> Arc<UOp> {
    let dtype = DType::Scalar(scalar).ptr(Some(size), AddrSpace::Local);
    UOp::new_buffer(morok_dtype::DeviceSpec::Cpu, size, dtype)
}

/// Create a bool buffer.
pub fn create_bool_buffer(size: usize) -> Arc<UOp> {
    create_buffer_typed(size, ScalarDType::Bool)
}

// =============================================================================
// Index Builders
// =============================================================================

/// Create a scalar INDEX operation.
///
/// INDEX(buffer, [idx]) with scalar index.
pub fn create_index(buffer: Arc<UOp>, idx: i64) -> Arc<UOp> {
    let idx_uop = UOp::const_(DType::Index, ConstValue::Int(idx));
    UOp::index().buffer(buffer).indices(vec![idx_uop]).call().unwrap()
}

/// Create a vector INDEX with iota pattern: [0, 1, 2, ..., count-1].
///
/// Creates INDEX(VECTORIZE([def, def, ...]), VECTORIZE([0, 1, ..., count-1]))
/// which matches Tinygrad's expand_index pattern (devectorizer.py:115).
pub fn create_vector_index_iota(buffer: Arc<UOp>, count: usize) -> Arc<UOp> {
    let indices: SmallVec<[Arc<UOp>; 4]> =
        (0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(i as i64))).collect();
    let vec_idx = UOp::vectorize(indices);
    let idx_dtype = buffer.dtype().base();

    // Wrap buffer in VECTORIZE to match Tinygrad's expand_index pattern:
    // INDEX(VECTORIZE(Defines.or_after()), vec_idx)
    let define = buffer_to_define(&buffer);
    let buf_vec = define.broadcast(count);

    UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}

/// Convert a BUFFER to codegen PARAM for testing.
///
/// In real code, this conversion happens during kernel splitting.
/// For tests, we create codegen PARAM (device: None) directly.
fn buffer_to_define(buffer: &Arc<UOp>) -> Arc<UOp> {
    static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
    let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
    let size = match buffer.dtype() {
        DType::Ptr { size: Some(s), .. } => s,
        _ => 1024,
    };
    UOp::param(id, size, buffer.dtype(), None)
}

/// Create a vector INDEX with offset: [offset, offset+1, offset+2, ..., offset+count-1].
pub fn create_vector_index_offset(buffer: Arc<UOp>, count: usize, offset: i64) -> Arc<UOp> {
    let indices: SmallVec<[Arc<UOp>; 4]> =
        (0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(offset + i as i64))).collect();
    let vec_idx = UOp::vectorize(indices);
    let idx_dtype = buffer.dtype().base();

    let define = buffer_to_define(&buffer);
    let buf_vec = define.broadcast(count);

    UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}

/// Create a vector INDEX with scaled pattern: [0*scale, 1*scale, 2*scale, ..., (count-1)*scale].
///
/// This creates strided access patterns.
pub fn create_vector_index_scaled(buffer: Arc<UOp>, count: usize, scale: i64) -> Arc<UOp> {
    let indices: SmallVec<[Arc<UOp>; 4]> =
        (0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(i as i64 * scale))).collect();
    let vec_idx = UOp::vectorize(indices);
    let idx_dtype = buffer.dtype().base();

    let define = buffer_to_define(&buffer);
    let buf_vec = define.broadcast(count);

    UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}

/// Create a vector INDEX with explicit values.
pub fn create_vector_index_values(buffer: Arc<UOp>, values: Vec<i64>) -> Arc<UOp> {
    let indices: SmallVec<[Arc<UOp>; 4]> =
        values.iter().map(|&v| UOp::const_(DType::Index, ConstValue::Int(v))).collect();
    let vec_idx = UOp::vectorize(indices);
    let idx_dtype = buffer.dtype().base();
    let count = values.len();

    let define = buffer_to_define(&buffer);
    let buf_vec = define.broadcast(count);

    UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Scalar(idx_dtype))
}

/// Create a gated vector INDEX.
pub fn create_vector_index_gated(buffer: Arc<UOp>, count: usize, gate: Arc<UOp>) -> Arc<UOp> {
    let indices: SmallVec<[Arc<UOp>; 4]> =
        (0..count).map(|i| UOp::const_(DType::Index, ConstValue::Int(i as i64))).collect();
    let vec_idx = UOp::vectorize(indices);
    let idx_dtype = buffer.dtype().base();

    let define = buffer_to_define(&buffer);
    let buf_vec = define.broadcast(count);

    UOp::new(
        Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: Some(gate) },
        DType::Scalar(idx_dtype),
    )
}

/// Create an INDEX with symbolic root + offset pattern.
///
/// INDEX(buffer, [range_var * scale + offset])
/// Used for testing root extraction and grouping.
pub fn create_index_with_range(buffer: Arc<UOp>, axis_id: usize, bound: i64, scale: i64, offset: i64) -> Arc<UOp> {
    use morok_ir::{AxisId, AxisType, BinaryOp};

    let range = UOp::new(
        Op::Range {
            end: UOp::const_(DType::Index, ConstValue::Int(bound)),
            axis_id: AxisId::Renumbered(axis_id),
            axis_type: AxisType::Loop,
            deps: smallvec::SmallVec::new(),
        },
        DType::Index,
    );

    // range * scale + offset
    let scaled = if scale == 1 {
        range
    } else {
        UOp::new(Op::Binary(BinaryOp::Mul, range, UOp::const_(DType::Index, ConstValue::Int(scale))), DType::Index)
    };

    let idx = if offset == 0 {
        scaled
    } else {
        UOp::new(Op::Binary(BinaryOp::Add, scaled, UOp::const_(DType::Index, ConstValue::Int(offset))), DType::Index)
    };

    UOp::index().buffer(buffer).indices(vec![idx]).call().unwrap()
}

// =============================================================================
// Load/Store Builders
// =============================================================================

/// Create a LOAD operation.
pub fn create_load(buffer: Arc<UOp>, index: Arc<UOp>) -> Arc<UOp> {
    UOp::load().buffer(buffer).index(index).call()
}

/// Create a STORE operation.
///
/// Note: `index` must be an INDEX operation that references the buffer.
pub fn create_store(index: Arc<UOp>, value: Arc<UOp>) -> Arc<UOp> {
    index.store(value)
}

/// Create a vector LOAD with iota index.
pub fn create_vector_load_iota(buffer: Arc<UOp>, count: usize) -> Arc<UOp> {
    let index = create_vector_index_iota(buffer.clone(), count);
    UOp::load().buffer(buffer).index(index).call()
}

/// Create a vector STORE with iota index.
pub fn create_vector_store_iota(buffer: Arc<UOp>, count: usize, value: Arc<UOp>) -> Arc<UOp> {
    let index = create_vector_index_iota(buffer, count);
    index.store(value)
}

// =============================================================================
// Value Builders
// =============================================================================

/// Create a scalar float constant.
pub fn create_float_const(value: f64) -> Arc<UOp> {
    UOp::const_(DType::Float32, ConstValue::Float(value))
}

/// Create a scalar int constant.
pub fn create_int_const(value: i64) -> Arc<UOp> {
    UOp::const_(DType::Int64, ConstValue::Int(value))
}

/// Create a scalar bool constant.
pub fn create_bool_const(value: bool) -> Arc<UOp> {
    UOp::const_(DType::Bool, ConstValue::Bool(value))
}

/// Create a vector float constant with iota pattern.
pub fn create_vector_float_iota(count: usize) -> Arc<UOp> {
    let elements: SmallVec<[Arc<UOp>; 4]> =
        (0..count).map(|i| UOp::const_(DType::Float32, ConstValue::Float(i as f64))).collect();
    UOp::vectorize(elements)
}

/// Create a vector int constant with iota pattern.
pub fn create_vector_int_iota(count: usize) -> Arc<UOp> {
    let elements: SmallVec<[Arc<UOp>; 4]> =
        (0..count).map(|i| UOp::const_(DType::Int64, ConstValue::Int(i as i64))).collect();
    UOp::vectorize(elements)
}

/// Create a vector constant from explicit float values.
pub fn create_vector_float_values(values: Vec<f64>) -> Arc<UOp> {
    let elements: SmallVec<[Arc<UOp>; 4]> =
        values.into_iter().map(|v| UOp::const_(DType::Float32, ConstValue::Float(v))).collect();
    UOp::vectorize(elements)
}

/// Create a vector constant from explicit int values.
pub fn create_vector_int_values(values: Vec<i64>) -> Arc<UOp> {
    let elements: SmallVec<[Arc<UOp>; 4]> =
        values.into_iter().map(|v| UOp::const_(DType::Int64, ConstValue::Int(v))).collect();
    UOp::vectorize(elements)
}

/// Create a vector bool constant.
pub fn create_vector_bool(values: Vec<bool>) -> Arc<UOp> {
    let elements: SmallVec<[Arc<UOp>; 4]> =
        values.into_iter().map(|v| UOp::const_(DType::Bool, ConstValue::Bool(v))).collect();
    UOp::vectorize(elements)
}

// =============================================================================
// Assertion Helpers
// =============================================================================

/// Assert that a UOp is a PTRCAT with expected source count.
pub fn assert_is_ptrcat(uop: &Arc<UOp>, expected_count: usize) {
    match uop.op() {
        Op::PtrCat { sources } => {
            assert_eq!(
                sources.len(),
                expected_count,
                "PTRCAT source count mismatch: expected {}, got {}",
                expected_count,
                sources.len()
            );
        }
        other => panic!("Expected PTRCAT, got {:?}", other),
    }
}

/// Assert that a UOp is a CAT with expected source count.
pub fn assert_is_cat(uop: &Arc<UOp>, expected_count: usize) {
    match uop.op() {
        Op::Cat { sources } => {
            assert_eq!(
                sources.len(),
                expected_count,
                "CAT source count mismatch: expected {}, got {}",
                expected_count,
                sources.len()
            );
        }
        other => panic!("Expected CAT, got {:?}", other),
    }
}

/// Assert that a UOp is a VECTORIZE with expected element count.
pub fn assert_is_vectorize(uop: &Arc<UOp>, expected_count: usize) {
    match uop.op() {
        Op::Vectorize { elements } => {
            assert_eq!(
                elements.len(),
                expected_count,
                "VECTORIZE element count mismatch: expected {}, got {}",
                expected_count,
                elements.len()
            );
        }
        other => panic!("Expected VECTORIZE, got {:?}", other),
    }
}

/// Assert that a UOp has expected vcount (vector width).
pub fn assert_vcount(uop: &Arc<UOp>, expected: usize) {
    assert_eq!(uop.dtype().vcount(), expected, "vcount mismatch: expected {}, got {}", expected, uop.dtype().vcount());
}

/// Assert dtype matches expected.
pub fn assert_dtype(uop: &Arc<UOp>, expected: DType) {
    assert_eq!(uop.dtype(), expected, "dtype mismatch");
}

/// Assert base scalar dtype matches expected.
pub fn assert_base_dtype(uop: &Arc<UOp>, expected: ScalarDType) {
    assert_eq!(uop.dtype().base(), expected, "base dtype mismatch");
}

/// Assert that a UOp is a LOAD.
pub fn assert_is_load(uop: &Arc<UOp>) {
    assert!(matches!(uop.op(), Op::Load { .. }), "Expected LOAD, got {:?}", uop.op());
}

/// Assert that a UOp is a STORE.
pub fn assert_is_store(uop: &Arc<UOp>) {
    assert!(matches!(uop.op(), Op::Store { .. }), "Expected STORE, got {:?}", uop.op());
}

/// Assert that a UOp is a GEP with expected indices.
pub fn assert_is_gep(uop: &Arc<UOp>, expected_indices: &[usize]) {
    match uop.op() {
        Op::Gep { indices, .. } => {
            assert_eq!(
                indices, expected_indices,
                "GEP indices mismatch: expected {:?}, got {:?}",
                expected_indices, indices
            );
        }
        other => panic!("Expected GEP, got {:?}", other),
    }
}

/// Assert that a UOp is a CAST.
pub fn assert_is_cast(uop: &Arc<UOp>) {
    assert!(matches!(uop.op(), Op::Cast { .. }), "Expected CAST, got {:?}", uop.op());
}

/// Assert that a UOp is a GROUP with expected source count.
pub fn assert_is_group(uop: &Arc<UOp>, expected_count: usize) {
    match uop.op() {
        Op::Group { sources } => {
            assert_eq!(
                sources.len(),
                expected_count,
                "GROUP source count mismatch: expected {}, got {}",
                expected_count,
                sources.len()
            );
        }
        other => panic!("Expected GROUP, got {:?}", other),
    }
}

/// Assert that a UOp is an INDEX.
pub fn assert_is_index(uop: &Arc<UOp>) {
    assert!(matches!(uop.op(), Op::Index { .. }), "Expected INDEX, got {:?}", uop.op());
}

// =============================================================================
// Op Counting Helpers
// =============================================================================

/// Count operations matching a predicate in the UOp tree.
pub fn count_ops<F>(uop: &Arc<UOp>, predicate: F) -> usize
where
    F: Fn(&Arc<UOp>) -> bool,
{
    let mut count = 0;
    count_ops_recursive(uop, &predicate, &mut count);
    count
}

fn count_ops_recursive<F>(uop: &Arc<UOp>, predicate: &F, count: &mut usize)
where
    F: Fn(&Arc<UOp>) -> bool,
{
    if predicate(uop) {
        *count += 1;
    }
    for child in uop.op().children() {
        count_ops_recursive(child, predicate, count);
    }
}

/// Count LOAD operations in the tree.
pub fn count_loads(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Load { .. }))
}

/// Count STORE operations in the tree.
pub fn count_stores(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Store { .. }))
}

/// Count INDEX operations in the tree.
pub fn count_indices(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Index { .. }))
}

/// Count PTRCAT operations in the tree.
pub fn count_ptrcats(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::PtrCat { .. }))
}

/// Count CAT operations in the tree.
pub fn count_cats(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Cat { .. }))
}

/// Count VECTORIZE operations in the tree.
pub fn count_vectorizes(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Vectorize { .. }))
}

/// Count GEP operations in the tree.
pub fn count_geps(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Gep { .. }))
}

/// Count CAST operations in the tree.
pub fn count_casts(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Cast { .. }))
}

// =============================================================================
// Unwrap Helpers
// =============================================================================

/// Unwrap PTRCAT and return sources.
pub fn unwrap_ptrcat(uop: &Arc<UOp>) -> SmallVec<[Arc<UOp>; 4]> {
    match uop.op() {
        Op::PtrCat { sources } => sources.clone(),
        other => panic!("Expected PTRCAT, got {:?}", other),
    }
}

/// Unwrap CAT and return sources.
pub fn unwrap_cat(uop: &Arc<UOp>) -> SmallVec<[Arc<UOp>; 4]> {
    match uop.op() {
        Op::Cat { sources } => sources.clone(),
        other => panic!("Expected CAT, got {:?}", other),
    }
}

/// Unwrap VECTORIZE and return elements.
pub fn unwrap_vectorize(uop: &Arc<UOp>) -> SmallVec<[Arc<UOp>; 4]> {
    match uop.op() {
        Op::Vectorize { elements } => elements.clone(),
        other => panic!("Expected VECTORIZE, got {:?}", other),
    }
}

/// Unwrap LOAD and return (buffer, index).
pub fn unwrap_load(uop: &Arc<UOp>) -> (Arc<UOp>, Arc<UOp>) {
    match uop.op() {
        Op::Load { buffer, index, .. } => (buffer.clone(), index.clone()),
        other => panic!("Expected LOAD, got {:?}", other),
    }
}

/// Unwrap STORE and return (index, value).
///
/// The buffer can be accessed via `index.op()` (which should be an INDEX op)
/// or use the `store_buffer()` helper on the store UOp.
pub fn unwrap_store(uop: &Arc<UOp>) -> (Arc<UOp>, Arc<UOp>) {
    match uop.op() {
        Op::Store { index, value, .. } => (index.clone(), value.clone()),
        other => panic!("Expected STORE, got {:?}", other),
    }
}

/// Unwrap GEP and return (vector, indices).
pub fn unwrap_gep(uop: &Arc<UOp>) -> (Arc<UOp>, Vec<usize>) {
    match uop.op() {
        Op::Gep { vector, indices } => (vector.clone(), indices.clone()),
        other => panic!("Expected GEP, got {:?}", other),
    }
}

/// Unwrap CAST and return (src, dtype).
pub fn unwrap_cast(uop: &Arc<UOp>) -> (Arc<UOp>, DType) {
    match uop.op() {
        Op::Cast { src, dtype } => (src.clone(), dtype.clone()),
        other => panic!("Expected CAST, got {:?}", other),
    }
}

/// Unwrap INDEX and return (buffer, indices, gate).
#[allow(clippy::type_complexity)]
pub fn unwrap_index(uop: &Arc<UOp>) -> (Arc<UOp>, SmallVec<[Arc<UOp>; 4]>, Option<Arc<UOp>>) {
    match uop.op() {
        Op::Index { buffer, indices, gate } => (buffer.clone(), indices.clone(), gate.clone()),
        other => panic!("Expected INDEX, got {:?}", other),
    }
}

/// Unwrap GROUP and return sources.
pub fn unwrap_group(uop: &Arc<UOp>) -> Vec<Arc<UOp>> {
    match uop.op() {
        Op::Group { sources } => sources.to_vec(),
        other => panic!("Expected GROUP, got {:?}", other),
    }
}

// =============================================================================
// REDUCE/GEP Test Helpers
// =============================================================================

use morok_ir::{AxisId, AxisType, ReduceOp};

/// Apply pm_reduce patterns to a UOp.
///
/// This runs the REDUCE → accumulator transformation (reduce_to_acc).
pub fn apply_pm_reduce(uop: &Arc<UOp>) -> Arc<UOp> {
    use crate::devectorize::{ReduceContext, pm_reduce};
    let mut ctx = ReduceContext::default();
    graph_rewrite(&pm_reduce(), uop.clone(), &mut ctx)
}

/// Apply GEP movement and related load/store folding patterns.
///
/// This uses load_store_folding_patterns which includes:
/// - expand_index patterns
/// - gep_movement patterns (move_gep_after_load, move_gep_on_store)
/// - ptrcat_distribution patterns
///
/// For isolated GEP movement testing, the patterns still apply correctly
/// because the other patterns won't fire on inputs that don't match.
pub fn apply_gep_movement(uop: &Arc<UOp>) -> Arc<UOp> {
    apply_load_store_folding(uop)
}

// =============================================================================
// REDUCE Builders
// =============================================================================

/// Create a REDUCE operation with specified ranges and operation.
pub fn create_reduce(src: Arc<UOp>, ranges: Vec<Arc<UOp>>, reduce_op: ReduceOp) -> Arc<UOp> {
    src.reduce(ranges.into_iter().collect(), reduce_op)
}

/// Create a Range with Loop axis type.
pub fn create_range_loop(end: i64, axis_id: u32) -> Arc<UOp> {
    let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
    UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Loop)
}

/// Create a Range with Reduce axis type.
pub fn create_range_reduce(end: i64, axis_id: u32) -> Arc<UOp> {
    let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
    UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Reduce)
}

/// Create a Range with Thread axis type (parallel).
pub fn create_range_thread(end: i64, axis_id: u32) -> Arc<UOp> {
    let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
    UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Thread)
}

/// Create a Range with Global axis type (parallel).
pub fn create_range_global(end: i64, axis_id: u32) -> Arc<UOp> {
    let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
    UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Global)
}

/// Create a Range with Local axis type (parallel).
pub fn create_range_local(end: i64, axis_id: u32) -> Arc<UOp> {
    let end_uop = UOp::const_(DType::Index, ConstValue::Int(end));
    UOp::range_axis(end_uop, AxisId::Renumbered(axis_id as usize), AxisType::Local)
}

// =============================================================================
// REDUCE Assertion Helpers
// =============================================================================

/// Assert that a UOp is a DEFINE_REG.
pub fn assert_is_define_reg(uop: &Arc<UOp>) {
    assert!(matches!(uop.op(), Op::DefineReg { .. }), "Expected DEFINE_REG, got {:?}", uop.op());
}

/// Assert that a UOp has the specified number of AFTER dependencies.
pub fn assert_has_after_deps(uop: &Arc<UOp>, count: usize) {
    match uop.op() {
        Op::After { deps, .. } => {
            assert_eq!(deps.len(), count, "Expected {} AFTER deps, got {}", count, deps.len());
        }
        other => panic!("Expected AFTER, got {:?}", other),
    }
}

/// Assert that a UOp is an END.
pub fn assert_is_end(uop: &Arc<UOp>) {
    assert!(matches!(uop.op(), Op::End { .. }), "Expected END, got {:?}", uop.op());
}

/// Assert that a UOp is a REDUCE.
pub fn assert_is_reduce(uop: &Arc<UOp>) {
    assert!(matches!(uop.op(), Op::Reduce { .. }), "Expected REDUCE, got {:?}", uop.op());
}

/// Unwrap REDUCE and return (src, ranges, reduce_op).
pub fn unwrap_reduce(uop: &Arc<UOp>) -> (Arc<UOp>, SmallVec<[Arc<UOp>; 4]>, ReduceOp) {
    match uop.op() {
        Op::Reduce { src, ranges, reduce_op } => (src.clone(), ranges.clone(), *reduce_op),
        other => panic!("Expected REDUCE, got {:?}", other),
    }
}

// =============================================================================
// GEP Builders
// =============================================================================

/// Create a GEP operation with explicit indices.
pub fn create_gep(vector: Arc<UOp>, indices: Vec<usize>) -> Arc<UOp> {
    vector.gep(indices)
}

/// Create a LOAD with GEP on the index.
///
/// LOAD(buffer, GEP(index, indices))
pub fn create_load_with_gep_index(buffer: Arc<UOp>, index: Arc<UOp>, gep_indices: Vec<usize>) -> Arc<UOp> {
    let gep_index = index.gep(gep_indices);
    UOp::load().buffer(buffer).index(gep_index).call()
}

/// Create a STORE with GEP on the index.
///
/// STORE(GEP(index, indices), value, ranges)
pub fn create_store_with_gep_index(
    index: Arc<UOp>,
    gep_indices: Vec<usize>,
    value: Arc<UOp>,
    ranges: SmallVec<[Arc<UOp>; 4]>,
) -> Arc<UOp> {
    let gep_index = index.gep(gep_indices);
    gep_index.store_with_ranges(value, ranges)
}

/// Compute the inverse permutation for GEP indices.
///
/// Given indices [2,0,1], returns [1,2,0] such that applying the inverse
/// permutation to a vector reordered by the original undoes the reorder.
pub fn compute_inverse_permutation(indices: &[usize]) -> Vec<usize> {
    let mut inverse_map: Vec<(usize, usize)> = indices.iter().enumerate().map(|(i, &x)| (x, i)).collect();
    inverse_map.sort_by_key(|&(x, _)| x);
    inverse_map.iter().map(|&(_, i)| i).collect()
}

// =============================================================================
// Op Counting Helpers (REDUCE/GEP specific)
// =============================================================================

/// Count REDUCE operations in the tree.
pub fn count_reduces(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Reduce { .. }))
}

/// Count DEFINE_REG operations in the tree.
pub fn count_define_regs(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::DefineReg { .. }))
}

/// Count END operations in the tree.
pub fn count_ends(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::End { .. }))
}

/// Count Range operations in the tree.
pub fn count_ranges(uop: &Arc<UOp>) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Range { .. }))
}

/// Count Range operations with specific axis type.
pub fn count_ranges_by_type(uop: &Arc<UOp>, target_type: AxisType) -> usize {
    count_ops(uop, |u| matches!(u.op(), Op::Range { axis_type, .. } if *axis_type == target_type))
}