morok-schedule 0.1.0-alpha.2

Optimization passes and pattern engine for the Morok ML compiler
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
//! Optimization operation implementations.
//!
//! Implements: UPCAST (SIMD), LOCAL (shared memory), GROUP (two-stage reduction),
//! UNROLL (loop unrolling), SWAP (axis reordering), NOLOCALS (disable local mem).

use std::collections::HashMap;
use std::sync::Arc;

use morok_ir::{AxisType, ConstValue, Op, UOp, UOpKey};
use smallvec::SmallVec;

use crate::optimizer::{Opt, OptOps, Scheduler, error::*, tc};

// ============================================================================
// DISPATCHER
// ============================================================================

/// Apply an optimization to the scheduler.
pub fn apply_opt(scheduler: &mut Scheduler, opt: &Opt, append_opt: bool) -> Result<(), OptError> {
    let real_axis = scheduler.real_axis(opt.op, opt.axis)?;
    let rng = if real_axis >= 0 { Some(scheduler.rngs()[real_axis as usize].clone()) } else { None };

    match opt.op {
        OptOps::TC => {
            let (tc_select, tc_opt, use_tensor_cores) = opt.arg.tc()?;
            // TODO: propagate TC axes for post-TC upcasts on non-AMX devices
            let _axes = tc::apply(scheduler, tc_select, tc_opt, use_tensor_cores)?;
        }
        OptOps::UPCAST => {
            apply_upcast(scheduler, rng.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.int()?)?;
        }
        OptOps::LOCAL => {
            apply_local(scheduler, rng.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.int()?)?;
        }
        OptOps::UNROLL => {
            apply_unroll(scheduler, opt.axis.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.int()?)?;
        }
        OptOps::NOLOCALS => {
            apply_nolocals(scheduler)?;
        }
        OptOps::SWAP => {
            apply_swap(scheduler, opt.axis.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.swap()?)?;
        }
        OptOps::GROUP => {
            apply_group(scheduler, rng.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.int()?, false)?;
        }
        OptOps::GROUPTOP => {
            apply_group(scheduler, rng.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.int()?, true)?;
        }
        OptOps::THREAD => {
            apply_thread(scheduler, rng.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.int()?)?;
        }
        OptOps::PADTO => {
            apply_padto(scheduler, rng.ok_or_else(|| MissingAxisParameterSnafu.build())?, opt.arg.int()?)?;
        }
    }

    if append_opt {
        scheduler.applied_opts.push(opt.clone());
    }
    Ok(())
}

// ============================================================================
// UPCAST - Vectorization (SIMD)
// ============================================================================

/// Split dimension into smaller range + UPCAST for vector operations.
///
/// UPCAST is for output dimension vectorization (OUTER/GLOBAL/LOCAL/LOOP).
/// For reduce axis unrolling, use UNROLL instead.
fn apply_upcast(scheduler: &mut Scheduler, rng: Arc<UOp>, amount: usize) -> Result<(), OptError> {
    let axis_type = match rng.op() {
        Op::Range { axis_type, .. } => *axis_type,
        _ => return ExpectedRangeOperationSnafu.fail(),
    };

    // UPCAST is for output dimension vectorization (parallel lanes compute different outputs)
    // Allowed: OUTER (reduce kernel outputs), GLOBAL/LOCAL/LOOP (elementwise outputs)
    // REDUCE/GROUP_REDUCE should use UNROLL instead (unrolled iterations, scalar accumulators)
    if !matches!(axis_type, AxisType::Outer | AxisType::Global | AxisType::Local | AxisType::Loop) {
        return ValidationFailedSnafu { op: "UPCAST", reason: "can only upcast Outer/Global/Local/Loop axes" }.fail();
    }

    if amount > scheduler.ren.upcast_max {
        return DeviceLimitExceededSnafu { limit_type: "upcast", value: amount, max: scheduler.ren.upcast_max }.fail();
    }

    scheduler.shift_to(rng, amount, AxisType::Upcast, false, None)?;
    Ok(())
}

// ============================================================================
// LOCAL - Shared memory (GPU workgroup)
// ============================================================================

/// Split dimension into smaller range + LOCAL for GPU workgroup threads.
fn apply_local(scheduler: &mut Scheduler, rng: Arc<UOp>, amount: usize) -> Result<(), OptError> {
    if !scheduler.ren.has_local {
        return UnsupportedFeatureSnafu { feature: "local memory" }.fail();
    }
    if scheduler.dont_use_locals {
        return ValidationFailedSnafu { op: "LOCAL", reason: "NOLOCALS was applied" }.fail();
    }

    let axis_type = match rng.op() {
        Op::Range { axis_type, .. } => *axis_type,
        _ => return ExpectedRangeOperationSnafu.fail(),
    };

    if !matches!(axis_type, AxisType::Global | AxisType::Loop) {
        return ValidationFailedSnafu { op: "LOCAL", reason: "can only localize Global/Loop axes" }.fail();
    }

    scheduler.shift_to(rng, amount, AxisType::Local, false, None)?;
    Ok(())
}

// ============================================================================
// GROUP/GROUPTOP - Two-stage reduction
// ============================================================================

/// Split reduction into smaller range + GROUP_REDUCE using shared memory.
fn apply_group(scheduler: &mut Scheduler, rng: Arc<UOp>, amount: usize, top: bool) -> Result<(), OptError> {
    if scheduler.applied_opts.iter().any(|opt| opt.op == OptOps::TC) {
        return ValidationFailedSnafu { op: "GROUP", reason: "no grouping with tensor cores" }.fail();
    }
    if !scheduler.ren.has_local {
        return UnsupportedFeatureSnafu { feature: "local memory" }.fail();
    }
    if !scheduler.ren.has_shared {
        return UnsupportedFeatureSnafu { feature: "shared memory" }.fail();
    }

    let axis_type = match rng.op() {
        Op::Range { axis_type, .. } => *axis_type,
        _ => return ExpectedRangeOperationSnafu.fail(),
    };

    if axis_type != AxisType::Reduce {
        return ValidationFailedSnafu { op: "GROUP", reason: "can only group REDUCE axes" }.fail();
    }

    // Calculate shared memory usage
    let upcast_local_sz: usize = scheduler
        .rngs()
        .iter()
        .filter_map(|r| {
            if let Op::Range { axis_type, end, .. } = r.op()
                && matches!(axis_type, AxisType::Upcast | AxisType::Warp | AxisType::Local | AxisType::GroupReduce)
                && let Op::Const(cv) = end.op()
                && let ConstValue::Int(sz) = cv.0
            {
                return Some(sz as usize);
            }
            None
        })
        .product();

    let reduce_uop = find_reduce_using_range(scheduler, &rng)?;
    let smem_sz = amount * upcast_local_sz * reduce_uop.dtype().bytes();

    if smem_sz > scheduler.ren.shared_max {
        return DeviceLimitExceededSnafu { limit_type: "shared memory", value: smem_sz, max: scheduler.ren.shared_max }
            .fail();
    }

    // Check not inside nested reduction
    let reduce_ptr = Arc::as_ptr(&reduce_uop);
    for node in reduce_uop.backward_slice() {
        if let Op::Reduce { .. } = node.op()
            && Arc::as_ptr(&node) != reduce_ptr
        {
            return ValidationFailedSnafu { op: "GROUP", reason: "cannot apply GROUP inside another reduction" }.fail();
        }
    }

    scheduler.shift_to(rng, amount, AxisType::GroupReduce, top, None)?;
    Ok(())
}

fn find_reduce_using_range(scheduler: &Scheduler, rng: &Arc<UOp>) -> Result<Arc<UOp>, OptError> {
    for reduce in scheduler.reduceops() {
        if let Op::Reduce { ranges, .. } = reduce.op()
            && ranges.iter().any(|r| Arc::ptr_eq(r, rng))
        {
            return Ok(reduce.clone());
        }
    }
    ValidationFailedSnafu { op: "GROUP", reason: "could not find REDUCE using this range" }.fail()
}

// ============================================================================
// UNROLL - Loop unrolling
// ============================================================================

/// Split reduction into smaller range + UNROLL for compile-time expansion.
/// When `amount == 0`, the entire axis is unrolled (full unroll), matching Tinygrad's convention.
fn apply_unroll(scheduler: &mut Scheduler, axis: usize, amount: usize) -> Result<(), OptError> {
    let unrollable = scheduler.unrollable_dims();
    let real_axis =
        *unrollable.get(axis).ok_or_else(|| AxisOutOfBoundsSnafu { axis, max: unrollable.len() }.build())?;
    let rng = scheduler.rngs()[real_axis].clone();

    // Resolve amount=0 to full axis size (full unroll, matching Tinygrad's convention)
    let amount = if amount == 0 {
        if let Op::Range { end, .. } = rng.op()
            && let Op::Const(cv) = end.op()
            && let morok_ir::ConstValue::Int(sz) = cv.0
        {
            sz as usize
        } else {
            return ValidationFailedSnafu { op: "UNROLL", reason: "full unroll requires constant axis size" }.fail();
        }
    } else {
        amount
    };

    const MAX_UNROLL: usize = 32;
    if amount > MAX_UNROLL {
        return DeviceLimitExceededSnafu { limit_type: "unroll", value: amount, max: MAX_UNROLL }.fail();
    }

    scheduler.shift_to(rng, amount, AxisType::Unroll, false, None)?;
    Ok(())
}

// ============================================================================
// SWAP - Axis reordering
// ============================================================================

/// Swap axis_id values of two GLOBAL ranges for memory access optimization.
fn apply_swap(scheduler: &mut Scheduler, axis: usize, other_axis: usize) -> Result<(), OptError> {
    if axis == other_axis {
        return ValidationFailedSnafu { op: "SWAP", reason: "cannot swap axis with itself" }.fail();
    }

    let rngs = scheduler.rngs();
    if axis >= rngs.len() {
        return AxisOutOfBoundsSnafu { axis, max: rngs.len() }.fail();
    }
    if other_axis >= rngs.len() {
        return AxisOutOfBoundsSnafu { axis: other_axis, max: rngs.len() }.fail();
    }

    let (rng1, rng2) = (&rngs[axis], &rngs[other_axis]);

    let (end1, axis_id1, axis_type1) = match rng1.op() {
        Op::Range { end, axis_id, axis_type, .. } => (end.clone(), *axis_id, *axis_type),
        _ => return ExpectedRangeOperationSnafu.fail(),
    };

    let (end2, axis_id2, axis_type2) = match rng2.op() {
        Op::Range { end, axis_id, axis_type, .. } => (end.clone(), *axis_id, *axis_type),
        _ => return ExpectedRangeOperationSnafu.fail(),
    };

    if axis_type1 != AxisType::Global || axis_type2 != AxisType::Global {
        return ValidationFailedSnafu { op: "SWAP", reason: "both axes must be GLOBAL" }.fail();
    }

    let new_rng1 = UOp::range_axis(end1, axis_id2, axis_type1);
    let new_rng2 = UOp::range_axis(end2, axis_id1, axis_type2);

    #[allow(clippy::mutable_key_type)]
    let mut subst_map = HashMap::new();
    subst_map.insert(UOpKey(rng1.clone()), new_rng1);
    subst_map.insert(UOpKey(rng2.clone()), new_rng2);
    let new_ast = scheduler.ast().substitute(&subst_map);
    scheduler.set_ast(new_ast);

    Ok(())
}

// ============================================================================
// NOLOCALS - Disable local memory
// ============================================================================

/// Set flag to prevent future LOCAL/WARP/GROUP_REDUCE optimizations.
fn apply_nolocals(scheduler: &mut Scheduler) -> Result<(), OptError> {
    for rng in scheduler.rngs() {
        if let Op::Range { axis_type, .. } = rng.op()
            && matches!(axis_type, AxisType::Local | AxisType::Warp | AxisType::GroupReduce)
        {
            return ValidationFailedSnafu {
                op: "NOLOCALS",
                reason: "cannot apply after LOCAL/WARP/GROUP_REDUCE exist",
            }
            .fail();
        }
    }
    scheduler.dont_use_locals = true;
    Ok(())
}

// ============================================================================
// THREAD - CPU parallel dispatch
// ============================================================================

// ============================================================================
// PADTO - Tensor core alignment padding
// ============================================================================

/// Pad dimension to alignment for tensor core compatibility.
///
/// PADTO rounds up a loop dimension to enable tensor core alignment.
/// Based on Tinygrad's PADTO (kernel.py).
///
/// # Constraints
///
/// - Only pad constant-sized axes
/// - Cannot pad UPCAST/UNROLL/THREAD axes (already vectorized/expanded)
/// - For REDUCE axes: only with ADD reduction and no unsafe ops before reduce
/// - Don't add more than 4x work (padding 1→5 rejected)
///
/// # Algorithm
///
/// 1. Round up range size to alignment
/// 2. Create validity condition: idx < old_size
/// 3. Add validity gate to all INDEX ops using this range
fn apply_padto(scheduler: &mut Scheduler, rng: Arc<UOp>, alignment: usize) -> Result<(), OptError> {
    use morok_ir::ReduceOp;

    let (end, axis_id, axis_type) = match rng.op() {
        Op::Range { end, axis_id, axis_type, .. } => (end.clone(), *axis_id, *axis_type),
        _ => return ExpectedRangeOperationSnafu.fail(),
    };

    // Constraint 1: only pad constant-sized axes
    let old_sz = match end.op() {
        Op::Const(cv) => match cv.0 {
            ConstValue::Int(v) if v > 0 => v as usize,
            _ => return ValidationFailedSnafu { op: "PADTO", reason: "range end must be positive integer" }.fail(),
        },
        _ => return ValidationFailedSnafu { op: "PADTO", reason: "can only pad constant-sized axes" }.fail(),
    };

    // Constraint 2: cannot pad UPCAST/UNROLL/THREAD axes
    if matches!(axis_type, AxisType::Upcast | AxisType::Unroll | AxisType::Thread) {
        return ValidationFailedSnafu { op: "PADTO", reason: "cannot pad vectorized/unrolled/thread axes" }.fail();
    }

    // Calculate new padded size
    let new_sz = old_sz.div_ceil(alignment) * alignment;

    // No-op if already aligned
    if new_sz == old_sz {
        return Ok(());
    }

    // Constraint 4: don't add more than 4x work
    // Tinygrad: check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
    // Strict inequality: exactly 4x is also rejected.
    if old_sz * 4 <= new_sz {
        return ValidationFailedSnafu { op: "PADTO", reason: "padding would add more than 4x work" }.fail();
    }

    // Constraint 3: for REDUCE axes, only with ADD and no unsafe ops
    if matches!(axis_type, AxisType::Reduce | AxisType::GroupReduce)
        && let Some(reduce_op) = scheduler.reduceop()
    {
        // Check reduce operation is ADD
        if let Op::Reduce { reduce_op: op, .. } = reduce_op.op()
            && *op != ReduceOp::Add
        {
            return ValidationFailedSnafu { op: "PADTO", reason: "can only pad ADD reductions (not MAX/MUL)" }.fail();
        }
        // Check for unsafe operations before reduce
        if has_unsafe_ops_before_reduce(&reduce_op) {
            return ValidationFailedSnafu {
                op: "PADTO",
                reason: "cannot pad with unsafe ops (EXP, LOG, DIV, comparisons) before reduce",
            }
            .fail();
        }
    }

    // Create new padded range
    let new_end = UOp::index_const(new_sz as i64);
    let new_rng = UOp::range_axis(new_end, axis_id, axis_type);

    // Create validity condition: new_rng < old_size
    let old_sz_const = UOp::index_const(old_sz as i64);
    let valid = new_rng
        .try_cmplt(&old_sz_const)
        .map_err(|_| ValidationFailedSnafu { op: "PADTO", reason: "failed to create validity condition" }.build())?;

    // Build substitution map
    #[allow(clippy::mutable_key_type)]
    let mut subst_map = HashMap::new();
    subst_map.insert(UOpKey(rng.clone()), new_rng.clone());

    // Update INDEX operations that use this range - add validity gate.
    // The replacement INDEX must use the new padded range in its indices
    // (not the original range), since substitute replaces the INDEX node
    // directly without recursing into its children.
    #[allow(clippy::mutable_key_type)]
    let range_subst: HashMap<UOpKey, Arc<UOp>> = [(UOpKey(rng.clone()), new_rng.clone())].into_iter().collect();

    for buf_op in scheduler.bufs() {
        if buf_uses_range(buf_op, &rng)
            && let Op::Index { buffer, indices, gate } = buf_op.op()
        {
            // Substitute old range → new range in index expressions
            let new_indices: SmallVec<[Arc<UOp>; 4]> = indices.iter().map(|idx| idx.substitute(&range_subst)).collect();

            // Encode validity gate using WHERE(cond, idx, Invalid) in the index source
            // instead of the INDEX gate field. This prevents the expander from vectorizing
            // the gate independently (Tinygrad's approach: symbolic.py invalid_gate encoding).
            let new_index = if let Some(first_idx) = new_indices.first() {
                // Extract any existing WHERE-encoded validity from the index
                let existing_valid = first_idx.get_valid();
                let real_idx = first_idx.get_idx();

                // Combine PADTO validity with existing validity and gate field
                let mut combined = valid.clone();
                if let Some(existing_gate) = gate {
                    combined = combined.try_and_op(existing_gate).map_err(|_| {
                        ValidationFailedSnafu { op: "PADTO", reason: "failed to combine gates" }.build()
                    })?;
                }
                if !matches!(existing_valid.op(), Op::Const(cv) if cv.0 == ConstValue::Bool(true)) {
                    combined = combined.try_and_op(&existing_valid).map_err(|_| {
                        ValidationFailedSnafu { op: "PADTO", reason: "failed to combine validity" }.build()
                    })?;
                }

                // Encode combined validity in the index source
                let encoded_idx = real_idx.valid(combined);
                let mut remaining_indices = new_indices.clone();
                remaining_indices[0] = encoded_idx;

                UOp::index().buffer(buffer.clone()).indices(remaining_indices).call().map_err(|_| {
                    ValidationFailedSnafu { op: "PADTO", reason: "failed to create gated INDEX" }.build()
                })?
            } else {
                // No indices (bare buffer reference) — use gate field as fallback
                UOp::index().buffer(buffer.clone()).indices(new_indices).gate(valid.clone()).call().map_err(|_| {
                    ValidationFailedSnafu { op: "PADTO", reason: "failed to create gated INDEX" }.build()
                })?
            };
            subst_map.insert(UOpKey(buf_op.clone()), new_index);
        }
    }

    // Apply substitutions
    let new_ast = scheduler.ast().substitute(&subst_map);
    scheduler.set_ast(new_ast);

    Ok(())
}

/// Check if a buffer INDEX operation uses a specific range.
fn buf_uses_range(buf_op: &Arc<UOp>, rng: &Arc<UOp>) -> bool {
    if let Op::Index { indices, .. } = buf_op.op() {
        // Check if the range appears in the indices dependency graph
        for idx in indices {
            for node in idx.toposort() {
                if Arc::ptr_eq(&node, rng) {
                    return true;
                }
            }
        }
    }
    false
}

/// Check for unsafe operations before reduce that prevent PADTO.
///
/// Tinygrad's UnsafePad group - cannot pad reduce axes if these appear before reduction:
/// - RECIPROCAL, LOG2, EXP2, IDIV, POW (non-linear ops where padding zeros changes result)
/// - Comparisons (LT, etc.) that could mask valid data
fn has_unsafe_ops_before_reduce(reduce_op: &Arc<UOp>) -> bool {
    use morok_ir::types::{BinaryOp, UnaryOp};

    for node in reduce_op.toposort() {
        match node.op() {
            // Unsafe unary ops
            Op::Unary(UnaryOp::Reciprocal | UnaryOp::Log2 | UnaryOp::Exp2, _) => return true,
            // Unsafe binary ops
            Op::Binary(BinaryOp::Idiv | BinaryOp::Pow, _, _) => return true,
            // Comparisons before sum are unsafe (padding zeros would add false comparisons)
            Op::Binary(BinaryOp::Lt, _, _) => return true,
            _ => {}
        }
    }
    false
}

// ============================================================================
// THREAD - CPU parallel dispatch
// ============================================================================

/// Split dimension into smaller range + THREAD for CPU parallel dispatch.
///
/// THREAD works like GPU's GLOBAL but for CPU: instead of GPU thread blocks,
/// we use OS threads (via rayon). The work partition is baked into index
/// expressions at optimization time - runtime just provides thread_id.
///
/// # Safety
///
/// Buffer safety is guaranteed by shift_to() transformation:
/// - Each thread_id maps to disjoint output indices
/// - Index formula: `output[thread_id * chunk_size + local_idx]`
/// - Same buffer pointers can be safely passed to all threads
fn apply_thread(scheduler: &mut Scheduler, rng: Arc<UOp>, amount: usize) -> Result<(), OptError> {
    // Validate renderer supports threads
    if !scheduler.ren.has_threads {
        return UnsupportedFeatureSnafu { feature: "CPU threads" }.fail();
    }

    // Check if already threaded - make THREAD opt idempotent
    // This allows replaying cached opts even when prepare_scheduler pre-applies threading
    let thread_axes = scheduler.axes_of(&[AxisType::Thread]);
    if !thread_axes.is_empty() {
        tracing::debug!("THREAD opt skipped: scheduler already has Thread axis");
        return Ok(());
    }

    // Validate thread count within limits
    if let Some(global_max) = &scheduler.ren.global_max
        && let Some(&max_threads) = global_max.first()
        && amount > max_threads
    {
        return DeviceLimitExceededSnafu { limit_type: "thread count", value: amount, max: max_threads }.fail();
    }

    // Validate axis type (must be parallelizable)
    let axis_type = match rng.op() {
        Op::Range { axis_type, .. } => *axis_type,
        _ => return ExpectedRangeOperationSnafu.fail(),
    };

    // Outer, Global, Loop can be threaded
    // Note: Reduce kernels keep Outer axes (convert_outer_to_loop skips them)
    if !matches!(axis_type, AxisType::Outer | AxisType::Global | AxisType::Loop) {
        return ValidationFailedSnafu { op: "THREAD", reason: "can only thread Outer/Global/Loop axes" }.fail();
    }

    // Apply shift_to with top=true (outer-most position, like Tinygrad's core_id)
    let _ = scheduler.shift_to(rng, amount, AxisType::Thread, true, None)?;
    Ok(())
}