morok-schedule 0.1.0-alpha.2

Optimization passes and pattern engine for the Morok ML compiler
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
//! GEP/CAT/VECTORIZE pattern tests.
//!
//! Tests for the gep_ptrcat_patterns which handle:
//! - CAT -> VECTORIZE conversion (multi-element sources)
//! - GEP(VECTORIZE) -> element extraction
//! - GEP(CAT) -> reorder
//! - GEP(PTRCAT) -> reorder pointers
//! - Identity patterns (single-source unwrap)
//! - WHERE devectorization
//!
//! Based on Tinygrad's symbolic.py and devectorizer.py patterns.

use std::sync::Arc;

use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{Op, TernaryOp, UOp};

use super::helpers::*;

// =============================================================================
// CAT -> VECTORIZE Tests
// =============================================================================

/// Test: CAT with multi-element sources converts to VECTORIZE.
///
/// CAT([a<4>, b<4>]) -> VECTORIZE(a.gep(0), ..., a.gep(3), b.gep(0), ..., b.gep(3))
#[test]
fn test_cat_vec4_to_vectorize() {
    let a = create_vector_float_iota(4);
    let b = create_vector_float_values(vec![10.0, 11.0, 12.0, 13.0]);

    let cat = UOp::cat().sources(vec![a, b]).call();
    assert_vcount(&cat, 8);

    let result = apply_pm_render(&cat);

    // Should become VECTORIZE with 8 elements (extracted via GEP)
    match result.op() {
        Op::Vectorize { elements } => {
            assert_eq!(elements.len(), 8, "Should have 8 elements");
            // Each element should be scalar
            for elem in elements.iter() {
                assert_eq!(elem.dtype().vcount(), 1, "Each element should be scalar");
            }
        }
        // Could remain CAT if elements are already scalar
        Op::Cat { sources } => {
            assert_eq!(sources.len(), 8);
        }
        other => panic!("Expected VECTORIZE or CAT, got {:?}", other),
    }
}

/// Test: CAT with scalar sources remains unchanged.
///
/// CAT([a, b, c, d]) with all scalars -> unchanged (handled by GEP(CAT) reorder)
#[test]
fn test_cat_scalar_unchanged() {
    let a = create_float_const(1.0);
    let b = create_float_const(2.0);
    let c = create_float_const(3.0);
    let d = create_float_const(4.0);

    let cat = UOp::cat().sources(vec![a, b, c, d]).call();

    let result = apply_pm_render(&cat);

    // Scalar CAT should remain as CAT (pattern only fires for multi-element sources)
    match result.op() {
        Op::Cat { sources } => {
            assert_eq!(sources.len(), 4);
            for src in sources.iter() {
                assert_eq!(src.dtype().vcount(), 1);
            }
        }
        Op::Vectorize { elements } => {
            // Could be converted to VECTORIZE by other patterns
            assert_eq!(elements.len(), 4);
        }
        other => panic!("Expected CAT or VECTORIZE, got {:?}", other),
    }
}

/// Test: Single-source CAT unwraps to source.
///
/// CAT([a]) -> a
#[test]
fn test_cat_single_source_unwrap() {
    let a = create_vector_float_iota(4);
    let cat = UOp::cat().sources(vec![a.clone()]).call();

    let result = apply_pm_render(&cat);

    // Should unwrap to just 'a'
    assert!(Arc::ptr_eq(&result, &a), "Single-source CAT should unwrap");
}

// =============================================================================
// GEP(VECTORIZE) Tests
// =============================================================================

/// Test: GEP on VECTORIZE extracts single element.
///
/// GEP(VECTORIZE([e0, e1, e2]), [1]) -> e1
#[test]
fn test_gep_vectorize_single() {
    let e0 = create_float_const(0.0);
    let e1 = create_float_const(1.0);
    let e2 = create_float_const(2.0);

    let vec = UOp::vectorize([e0, e1.clone(), e2].into_iter().collect());
    let gep = vec.gep(vec![1]);

    let result = apply_pm_render(&gep);

    // Should extract e1 directly
    assert_eq!(result.dtype().vcount(), 1, "Should be scalar");
    // Check it's the right constant
    match result.op() {
        Op::Const(v) => {
            assert_eq!(v.0, ConstValue::Float(1.0), "Should extract value 1.0");
        }
        Op::Gep { indices, .. } => {
            // If not simplified, should at least have correct index
            assert_eq!(indices.len(), 1, "Should have single index");
            assert_eq!(indices[0], 1, "Index should be 1");
        }
        other => panic!("Expected Const or GEP, got {:?}", other),
    }
}

/// Test: GEP on VECTORIZE extracts multiple elements.
///
/// GEP(VECTORIZE([e0, e1, e2, e3]), [0, 2]) -> VECTORIZE([e0, e2])
#[test]
fn test_gep_vectorize_multi() {
    let elements: smallvec::SmallVec<[Arc<UOp>; 4]> = (0..4).map(|i| create_float_const(i as f64)).collect();

    let vec = UOp::vectorize(elements);
    let gep = vec.gep(vec![0, 2]);

    let result = apply_pm_render(&gep);

    // Should extract elements 0 and 2
    assert_vcount(&result, 2);
    match result.op() {
        Op::Vectorize { elements } => {
            assert_eq!(elements.len(), 2);
        }
        other => panic!("Expected VECTORIZE, got {:?}", other),
    }
}

/// Test: GEP on broadcast VECTORIZE extracts single element.
///
/// GEP(VECTORIZE([x, x, x, x]), [i]) -> x
#[test]
fn test_gep_broadcast_extraction() {
    let x = create_float_const(42.0);
    let vec = x.broadcast(4);
    let gep = vec.gep(vec![2]);

    let result = apply_pm_render(&gep);

    // Should extract to just x
    assert_eq!(result.dtype().vcount(), 1, "Should be scalar");
    match result.op() {
        Op::Const(v) => {
            assert_eq!(v.0, ConstValue::Float(42.0), "Should extract value 42.0");
        }
        Op::Gep { indices, .. } => {
            // If not simplified, should have correct index
            assert_eq!(indices.len(), 1, "Should have single index");
        }
        other => panic!("Expected Const or GEP, got {:?}", other),
    }
}

// =============================================================================
// GEP(CAT) Tests
// =============================================================================

/// Test: GEP on CAT reorders sources.
///
/// GEP(CAT([a, b, c]), [1, 2]) -> CAT([b, c])
#[test]
fn test_gep_cat_reorder() {
    let a = create_float_const(1.0);
    let b = create_float_const(2.0);
    let c = create_float_const(3.0);

    let cat = UOp::cat().sources(vec![a, b.clone(), c.clone()]).call();
    let gep = cat.gep(vec![1, 2]);

    let result = apply_pm_render(&gep);

    // Should produce CAT([b, c]) or VECTORIZE([b, c])
    assert_vcount(&result, 2);
    match result.op() {
        Op::Cat { sources } => {
            assert_eq!(sources.len(), 2);
        }
        Op::Vectorize { elements } => {
            assert_eq!(elements.len(), 2);
        }
        other => panic!("Expected CAT or VECTORIZE, got {:?}", other),
    }
}

/// Test: GEP on CAT extracts single element.
#[test]
fn test_gep_cat_single() {
    let a = create_float_const(1.0);
    let b = create_float_const(2.0);
    let c = create_float_const(3.0);

    let cat = UOp::cat().sources(vec![a, b.clone(), c]).call();
    let gep = cat.gep(vec![1]);

    let result = apply_pm_render(&gep);

    // Should extract b directly
    assert_eq!(result.dtype().vcount(), 1);
}

/// Test: Single-source PTRCAT unwraps.
///
/// PTRCAT([p]) -> p
#[test]
fn test_ptrcat_single_unwrap() {
    let buffer = create_buffer(64);
    let p = create_index(buffer.clone(), 0);

    let ptrcat = UOp::ptrcat().sources(vec![p.clone()]).call();

    let result = apply_pm_render(&ptrcat);

    // Should unwrap to just p
    assert_is_index(&result);
}

// =============================================================================
// Identity Reconstruction Test
// =============================================================================

/// Test: CAT(GEP(x,[0]), GEP(x,[1]), ..., GEP(x,[n-1])) -> x
#[test]
fn test_cat_gep_identity() {
    let x = create_vector_float_iota(4);

    // Create CAT(GEP(x,[0]), GEP(x,[1]), GEP(x,[2]), GEP(x,[3]))
    let geps: Vec<Arc<UOp>> = (0..4).map(|i| x.gep(vec![i])).collect();
    let cat = UOp::cat().sources(geps).call();

    let result = apply_pm_render(&cat);

    // Should simplify to just x
    // Note: This requires the identity reconstruction pattern to fire
    assert_vcount(&result, 4);
}

// =============================================================================
// WHERE Devectorization Tests
// =============================================================================

/// Test: WHERE with vector condition devectorizes.
///
/// WHERE(<4 x i1>, <4 x T>, <4 x T>) -> VECTORIZE(WHERE(i1, T, T), ...)
#[test]
fn test_where_devectorize() {
    let cond = create_vector_bool(vec![true, false, true, false]);
    let t_val = create_vector_float_iota(4);
    let f_val = create_vector_float_values(vec![10.0, 11.0, 12.0, 13.0]);

    let where_op = UOp::new(Op::Ternary(TernaryOp::Where, cond, t_val, f_val), DType::Float32.vec(4));

    let result = apply_pm_render(&where_op);

    // Should become VECTORIZE of 4 scalar WHEREs or remain as WHERE
    // Either way, total vcount should be 4
    assert_eq!(result.dtype().vcount(), 4, "Result vcount should be 4");
    match result.op() {
        Op::Vectorize { elements } => {
            assert_eq!(elements.len(), 4, "Should have 4 scalar WHEREs");
            for elem in elements.iter() {
                assert!(matches!(elem.op(), Op::Ternary(TernaryOp::Where, _, _, _)), "Each element should be WHERE");
                assert_eq!(elem.dtype().vcount(), 1, "Each WHERE should be scalar");
            }
        }
        Op::Ternary(TernaryOp::Where, c, t, f) => {
            // If not devectorized, inputs should still be vec4
            assert_eq!(c.dtype().vcount(), 4, "Condition should be vec4");
            assert_eq!(t.dtype().vcount(), 4, "True value should be vec4");
            assert_eq!(f.dtype().vcount(), 4, "False value should be vec4");
        }
        other => panic!("Expected VECTORIZE or WHERE, got {:?}", other),
    }
}

/// Test: Scalar WHERE remains unchanged.
#[test]
fn test_where_scalar_unchanged() {
    let cond = create_bool_const(true);
    let t_val = create_float_const(1.0);
    let f_val = create_float_const(0.0);

    let where_op = UOp::new(Op::Ternary(TernaryOp::Where, cond, t_val, f_val), DType::Float32);

    let result = apply_pm_render(&where_op);

    // Scalar WHERE should remain unchanged
    assert!(matches!(result.op(), Op::Ternary(TernaryOp::Where, _, _, _)), "Scalar WHERE should remain unchanged");
    assert_eq!(result.dtype().vcount(), 1);
}

// =============================================================================
// GEP Through Cast Tests
// =============================================================================

/// Test: GEP through CAST is handled correctly.
#[test]
fn test_gep_through_cast() {
    let vec = create_vector_float_iota(4);
    let cast = vec.cast(DType::Int64.vec(4));
    let gep = cast.gep(vec![1]);

    let result = apply_pm_render(&gep);

    // GEP should work through CAST
    assert_eq!(result.dtype().vcount(), 1);
}

// =============================================================================
// VECTORIZE Normalization Tests
// =============================================================================

/// Test: Multi-index GEP normalizes to VECTORIZE.
///
/// GEP(x, [0, 1, 2, 3]) -> VECTORIZE(GEP(x, [0]), GEP(x, [1]), GEP(x, [2]), GEP(x, [3]))
#[test]
fn test_multi_index_gep_normalizes() {
    let x = create_vector_float_iota(8);
    let gep = x.gep(vec![0, 1, 2, 3]);

    let result = apply_vectorize_normalize(&gep);

    // Multi-index GEP should become VECTORIZE of single-index GEPs
    match result.op() {
        Op::Vectorize { elements } => {
            assert_eq!(elements.len(), 4);
            for elem in elements.iter() {
                if let Op::Gep { indices, .. } = elem.op() {
                    assert_eq!(indices.len(), 1, "Each GEP should be single-index");
                }
            }
        }
        other => panic!("Expected VECTORIZE, got {:?}", other),
    }
}

/// Test: GEP on scalar with index 0 is identity.
///
/// GEP(scalar, [0]) -> scalar
#[test]
fn test_gep_scalar_identity() {
    let scalar = create_float_const(42.0);
    let gep = scalar.gep(vec![0]);

    let result = apply_vectorize_normalize(&gep);

    // Should simplify to scalar
    assert!(Arc::ptr_eq(&result, &scalar) || result.dtype().vcount() == 1);
}

/// Test: Single-element VECTORIZE unwraps.
///
/// VECTORIZE([x]) -> x
#[test]
fn test_single_element_vectorize_unwrap() {
    let x = create_float_const(42.0);
    let vec = UOp::vectorize([x.clone()].into_iter().collect());

    let result = apply_vectorize_normalize(&vec);

    // Should unwrap to x
    assert!(Arc::ptr_eq(&result, &x), "Single-element VECTORIZE should unwrap");
}

// =============================================================================
// Edge Cases
// =============================================================================

/// Test: Empty PTRCAT sources (should not happen but handle gracefully).
#[test]
#[should_panic]
fn test_empty_ptrcat_panics() {
    // PTRCAT requires at least one source
    let _ptrcat = UOp::ptrcat().sources(vec![]).call();
}

/// Test: Empty CAT sources (should not happen but handle gracefully).
#[test]
#[should_panic]
fn test_empty_cat_panics() {
    // CAT requires at least one source
    let _cat = UOp::cat().sources(vec![]).call();
}

/// Test: GEP with out-of-bounds index.
#[test]
fn test_gep_out_of_bounds() {
    let vec = create_vector_float_iota(4);
    let gep = vec.gep(vec![10]); // Index 10 is out of bounds

    // Should not panic, but may produce invalid result
    let _result = apply_pm_render(&gep);
}