morok-schedule 0.1.0-alpha.2

Optimization passes and pattern engine for the Morok ML compiler
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
//! Beam search auto-tuning for kernel optimization.
//!
//! Implements a beam search algorithm that explores the optimization space
//! to find high-performance kernel configurations. This is slower than
//! heuristic-based optimization but can achieve ML-quality performance.
//!
//! # Algorithm
//!
//! 1. Start with base scheduler
//! 2. Generate all valid actions (OptOps applications)
//! 3. Compile and time each candidate
//! 4. Keep top K (beam width) by timing
//! 5. Repeat until no improvement or timeout
//!
//! # Caching
//!
//! Results are cached to disk using sled. The cache key is a hash of
//! (ast_hash, beam_width, device_name). Caching can be disabled via
//! the IGNORE_BEAM_CACHE environment variable.

use std::time::{Duration, Instant};

use once_cell::sync::Lazy;

use morok_ir::{AxisType, ConstValue, Op};

use super::Scheduler;
use super::config::BeamConfig;
use super::error::*;
use super::opts::apply_opt;
use super::types::Opt;

// ============================================================================
// ACTION SPACE
// ============================================================================

/// Generate thread counts that are likely to divide common tensor sizes.
///
/// Instead of fixed power-of-2, includes all values up to max_threads that
/// divide common sizes (64, 128, 256, 512, 1024). This ensures beam search
/// can find optimal thread counts for various tensor dimensions.
fn thread_action_amounts(max_threads: usize) -> Vec<usize> {
    const COMMON_SIZES: [usize; 5] = [64, 128, 256, 512, 1024];

    let mut amounts: Vec<usize> = (2..=max_threads).filter(|&t| COMMON_SIZES.iter().any(|&sz| sz % t == 0)).collect();
    amounts.sort_unstable();
    amounts.dedup();
    amounts
}

/// Pre-computed action space for beam search (~500 actions).
///
/// Based on tinygrad's beam search action generation.
pub static BEAM_ACTIONS: Lazy<Vec<Opt>> = Lazy::new(|| {
    let mut actions = Vec::with_capacity(600);

    // UPCAST: axes 0-7, amounts [0, 2, 3, 4, 5, 7]
    // amount=0 means "full size" - handled specially in apply
    for axis in 0..8 {
        for &amt in &[0, 2, 3, 4, 5, 7] {
            actions.push(Opt::upcast(axis, amt));
        }
    }

    // UNROLL: axes 0-4, amounts [0, 4, 7]
    for axis in 0..5 {
        for &amt in &[0, 4, 7] {
            actions.push(Opt::unroll(axis, amt));
        }
    }

    // LOCAL: axes 0-5, amounts [2, 3, 4, 8, 13, 16, 29]
    for axis in 0..6 {
        for &amt in &[2, 3, 4, 8, 13, 16, 29] {
            actions.push(Opt::local(axis, amt));
        }
    }

    // GROUPTOP: axes 0-2, amounts [13, 16, 28, 29, 32, 49, 64, 256]
    for axis in 0..3 {
        for &amt in &[13, 16, 28, 29, 32, 49, 64, 256] {
            actions.push(Opt::grouptop(axis, amt));
        }
    }

    // GROUP: axes 0-2, amounts [0, 4, 8, 16]
    for axis in 0..3 {
        for &amt in &[0, 4, 8, 16] {
            actions.push(Opt::group(axis, amt));
        }
    }

    // TC: tensor cores
    // Auto-select
    actions.push(Opt::tc(None, -1, 0, 1));
    // Specific TCs
    for axis in 0..9 {
        actions.push(Opt::tc(None, axis, 2, 1));
    }

    // SWAP: axis pairs
    for a0 in 0..5 {
        for a1 in (a0 + 1)..5 {
            actions.push(Opt::swap(a0, a1));
        }
    }

    // THREAD: CPU parallelization with smart divisor selection
    // Include thread counts that divide common tensor sizes (64, 128, 256, 512, 1024)
    let max_threads = std::thread::available_parallelism().map(|p| p.get()).unwrap_or(8);
    let thread_amounts = thread_action_amounts(max_threads);
    for axis in 0..3 {
        for &amt in &thread_amounts {
            actions.push(Opt::thread(axis, amt));
        }
    }

    // NOLOCALS
    actions.push(Opt::nolocals());

    actions
});

// ============================================================================
// ACTION GENERATION & FILTERING
// ============================================================================

/// Generate all valid next-states from the current scheduler.
///
/// Applies each action from `BEAM_ACTIONS` and filters to those that:
/// 1. Apply successfully (divisibility, bounds, etc.)
/// 2. Pass limit checks (upcast size, local size, UOp count)
fn generate_actions(scheduler: &Scheduler, config: &BeamConfig) -> Vec<Scheduler> {
    BEAM_ACTIONS
        .iter()
        .filter_map(|action| {
            // Clone scheduler and try to apply action
            let mut candidate = scheduler.clone();
            match apply_opt(&mut candidate, action, true) {
                Ok(()) if validate_limits(&candidate, config) => Some(candidate),
                _ => None,
            }
        })
        .collect()
}

/// Validate that a scheduler state is within configured limits.
fn validate_limits(scheduler: &Scheduler, config: &BeamConfig) -> bool {
    // Calculate upcast size (product of UPCAST/UNROLL dimensions)
    let upcast_sz = product_of_axes(scheduler, &[AxisType::Upcast, AxisType::Unroll]);

    // Calculate local size (product of LOCAL/WARP/GROUP_REDUCE dimensions)
    let local_sz = product_of_axes(scheduler, &[AxisType::Local, AxisType::Warp, AxisType::GroupReduce]);

    // Check UOp count
    let uop_count = scheduler.ast().toposort().len();

    upcast_sz <= config.max_upcast && local_sz <= config.max_local && uop_count <= config.max_uops
}

/// Calculate product of dimension sizes for given axis types.
fn product_of_axes(scheduler: &Scheduler, types: &[AxisType]) -> usize {
    scheduler
        .rngs()
        .iter()
        .filter_map(|rng| {
            if let Op::Range { axis_type, end, .. } = rng.op()
                && types.contains(axis_type)
                && let Op::Const(cv) = end.op()
                && let ConstValue::Int(sz) = cv.0
            {
                Some(sz as usize)
            } else {
                None
            }
        })
        .product::<usize>()
        .max(1)
}

// ============================================================================
// BEAM SEARCH ALGORITHM
// ============================================================================

/// Beam search result containing optimized scheduler and timing.
pub struct BeamResult {
    /// Optimized scheduler state.
    pub scheduler: Scheduler,
    /// Best timing achieved.
    pub timing: Duration,
    /// Number of iterations performed.
    pub iterations: usize,
    /// Total candidates evaluated.
    pub candidates_evaluated: usize,
}

/// Run beam search optimization.
///
/// # Arguments
///
/// * `scheduler` - Initial scheduler state
/// * `config` - Beam search configuration
/// * `compile_and_time` - Function to compile and time a scheduler state
///
/// # Returns
///
/// `BeamResult` containing the best scheduler found and performance metrics.
///
/// # Example
///
/// ```ignore
/// let config = BeamConfig::default();
/// let compile_and_time = |s: &Scheduler| {
///     let ast = s.get_optimized_ast(None);
///     let kernel = compile_kernel(&ast)?;
///     let timing = benchmark_kernel(&kernel)?;
///     Some(timing)
/// };
///
/// let result = beam_search(scheduler, &config, compile_and_time)?;
/// println!("Best time: {:?}", result.timing);
/// ```
pub fn beam_search<F>(scheduler: Scheduler, config: &BeamConfig, compile_and_time: F) -> Result<BeamResult, OptError>
where
    F: Fn(&Scheduler) -> Option<Duration> + Sync,
{
    let start = Instant::now();
    let mut iterations = 0;
    let mut candidates_evaluated = 0;

    // Initialize beam with starting state
    let initial_timing = compile_and_time(&scheduler).unwrap_or(Duration::MAX);
    let mut beam: Vec<(Scheduler, Duration)> = vec![(scheduler.clone(), initial_timing)];

    while start.elapsed() < config.timeout {
        iterations += 1;

        // 1. EXPAND: Generate all valid next states from current beam (sequential)
        // Note: Scheduler is not Sync due to OnceCell caches, so expansion is sequential
        let candidates: Vec<Scheduler> = beam.iter().flat_map(|(s, _)| generate_actions(s, config)).collect();

        if candidates.is_empty() {
            break;
        }

        // 2. COMPILE & TIME: Evaluate performance
        // The compile_and_time function should handle parallelism internally if needed
        let timed: Vec<(Scheduler, Duration)> = candidates
            .into_iter()
            .filter_map(|s| {
                let timing = compile_and_time(&s)?;
                Some((s, timing))
            })
            .collect();

        candidates_evaluated += timed.len();

        if timed.is_empty() {
            break;
        }

        // 3. SORT: Sort by timing (best first)
        let mut sorted = timed;
        sorted.sort_by_key(|(_, t)| *t);

        // 4. CHECK TERMINATION: No improvement
        let best_new = sorted[0].1;
        let best_old = beam.first().map(|(_, t)| *t).unwrap_or(Duration::MAX);

        if best_new >= best_old {
            // No improvement - stop
            break;
        }

        // 5. PRUNE: Keep top K by timing
        beam = sorted.into_iter().take(config.beam_width).collect();

        // Memory management: With weak references in the UOp cache (Tinygrad-aligned),
        // discarded candidates are automatically cleaned up when their Arcs are dropped.
        // No manual GC calls needed.
    }

    // Return best result
    let (best_scheduler, best_timing) = beam.into_iter().next().unwrap_or((scheduler, Duration::MAX));

    Ok(BeamResult { scheduler: best_scheduler, timing: best_timing, iterations, candidates_evaluated })
}

/// Run beam search with timeout check per iteration.
///
/// Similar to `beam_search` but includes additional timeout checks
/// to avoid long-running searches and early cutoff for slow candidates.
pub fn beam_search_with_timeout<F>(
    scheduler: Scheduler,
    config: &BeamConfig,
    compile_and_time: F,
) -> Result<BeamResult, OptError>
where
    F: Fn(&Scheduler) -> Option<Duration> + Sync,
{
    let start = Instant::now();
    let mut iterations = 0;
    let mut candidates_evaluated = 0;

    let initial_timing = compile_and_time(&scheduler).unwrap_or(Duration::MAX);
    let mut beam: Vec<(Scheduler, Duration)> = vec![(scheduler.clone(), initial_timing)];

    // Early termination threshold (3x the best time so far)
    let mut cutoff = initial_timing.saturating_mul(3);

    while start.elapsed() < config.timeout {
        iterations += 1;

        // Check timeout before expansion
        if start.elapsed() > config.timeout {
            break;
        }

        let candidates: Vec<Scheduler> = beam.iter().flat_map(|(s, _)| generate_actions(s, config)).collect();

        if candidates.is_empty() {
            break;
        }

        // Time with cutoff for early termination
        let timed: Vec<(Scheduler, Duration)> = candidates
            .into_iter()
            .filter_map(|s| {
                let timing = compile_and_time(&s)?;
                // Skip if clearly worse than cutoff
                if timing > cutoff {
                    return None;
                }
                Some((s, timing))
            })
            .collect();

        candidates_evaluated += timed.len();

        if timed.is_empty() {
            break;
        }

        let mut sorted = timed;
        sorted.sort_by_key(|(_, t)| *t);

        let best_new = sorted[0].1;
        let best_old = beam.first().map(|(_, t)| *t).unwrap_or(Duration::MAX);

        if best_new >= best_old {
            break;
        }

        // Update cutoff based on new best
        cutoff = best_new.saturating_mul(3);

        beam = sorted.into_iter().take(config.beam_width).collect();

        // Memory management: With weak refs, discarded candidates are auto-cleaned.
    }

    let (best_scheduler, best_timing) = beam.into_iter().next().unwrap_or((scheduler, Duration::MAX));

    Ok(BeamResult { scheduler: best_scheduler, timing: best_timing, iterations, candidates_evaluated })
}

// ============================================================================
// REPLAY
// ============================================================================

/// Replay a sequence of optimizations on a scheduler.
///
/// Used to restore cached beam search results.
pub fn replay_opts(mut scheduler: Scheduler, opts: &[Opt]) -> Result<Scheduler, OptError> {
    for opt in opts {
        apply_opt(&mut scheduler, opt, true)?;
    }
    Ok(scheduler)
}

/// Get the applied optimizations from a scheduler.
pub fn get_applied_opts(scheduler: &Scheduler) -> &[Opt] {
    &scheduler.applied_opts
}

// ============================================================================
// CACHING
// ============================================================================

/// Global sled database for beam search cache.
///
/// Lazy-initialized on first access. Returns None if cache directory
/// cannot be created or database cannot be opened.
static CACHE_DB: Lazy<Option<sled::Db>> = Lazy::new(|| {
    let cache_dir = dirs::cache_dir()?.join("morok");
    std::fs::create_dir_all(&cache_dir).ok()?;
    sled::open(cache_dir.join("beam_cache")).ok()
});

/// Cache key for beam search results.
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct CacheKey {
    /// Hash of the AST structure.
    ast_hash: u64,
    /// Beam width used for search.
    beam_width: usize,
    /// Device name (e.g., "cpu", "cuda").
    device: String,
}

impl CacheKey {
    /// Create a cache key from a scheduler and config.
    fn from_scheduler(scheduler: &Scheduler, config: &BeamConfig) -> Self {
        // Use structural hash for cross-run stability. The recursive Hash for UOp
        // traverses (dtype, op) of the entire DAG — same AST structure produces
        // the same hash regardless of process-local ids.
        use std::hash::{Hash, Hasher};
        let mut hasher = std::collections::hash_map::DefaultHasher::new();
        scheduler.ast().hash(&mut hasher);
        let ast_hash = hasher.finish();

        Self { ast_hash, beam_width: config.beam_width, device: scheduler.ren.device.clone() }
    }

    /// Convert to bytes for database key.
    fn to_bytes(&self) -> Vec<u8> {
        let mut bytes = Vec::with_capacity(24);
        bytes.extend_from_slice(&self.ast_hash.to_le_bytes());
        bytes.extend_from_slice(&self.beam_width.to_le_bytes());
        bytes.extend_from_slice(self.device.as_bytes());
        bytes
    }
}

/// Serialize applied opts to bytes for caching using bincode.
fn serialize_opts(opts: &[Opt]) -> Vec<u8> {
    bincode::serialize(opts).expect("Opt serialization should not fail")
}

/// Deserialize opts from cached bytes using bincode.
fn deserialize_opts(bytes: &[u8]) -> Option<Vec<Opt>> {
    bincode::deserialize(bytes).ok()
}

/// Get cached beam search result.
fn cache_get(key: &CacheKey) -> Option<Vec<Opt>> {
    let db = CACHE_DB.as_ref()?;
    let bytes = db.get(key.to_bytes()).ok()??;
    deserialize_opts(&bytes)
}

/// Store beam search result in cache.
fn cache_put(key: &CacheKey, opts: &[Opt]) {
    if let Some(db) = CACHE_DB.as_ref()
        && db.insert(key.to_bytes(), serialize_opts(opts)).is_ok()
    {
        // Flush to disk to ensure persistence across runs
        let _ = db.flush();
    }
}

/// Remove a stale cache entry.
fn cache_invalidate(key: &CacheKey) {
    if let Some(db) = CACHE_DB.as_ref() {
        let _ = db.remove(key.to_bytes());
        let _ = db.flush();
    }
}

/// Run beam search with disk caching.
///
/// Checks the cache before running beam search. If a cached result exists,
/// replays the optimizations instead of searching. Results are cached after
/// successful search.
///
/// # Arguments
///
/// * `scheduler` - Initial scheduler state
/// * `config` - Beam search configuration (includes disable_cache flag)
/// * `compile_and_time` - Function to compile and time a scheduler state
///
/// # Returns
///
/// `BeamResult` containing the best scheduler found.
pub fn beam_search_cached<F>(
    scheduler: Scheduler,
    config: &BeamConfig,
    compile_and_time: F,
) -> Result<BeamResult, OptError>
where
    F: Fn(&Scheduler) -> Option<Duration> + Sync,
{
    let key = CacheKey::from_scheduler(&scheduler, config);

    // Check cache (unless disabled)
    if !config.disable_cache
        && let Some(cached_opts) = cache_get(&key)
    {
        // Replay cached optimizations. If replay fails (stale entry from code changes),
        // invalidate and fall through to fresh search.
        tracing::info!(opts_count = cached_opts.len(), "Beam cache HIT - replaying opts");
        match replay_opts(scheduler.clone(), &cached_opts) {
            Ok(replayed) => {
                let timing = compile_and_time(&replayed).unwrap_or(Duration::MAX);
                return Ok(BeamResult { scheduler: replayed, timing, iterations: 0, candidates_evaluated: 0 });
            }
            Err(e) => {
                tracing::warn!(?e, "Beam cache replay failed (stale entry?) - invalidating");
                cache_invalidate(&key);
            }
        }
    }

    tracing::info!("Beam cache MISS - running search");
    // Run beam search
    let result = beam_search(scheduler, config, compile_and_time)?;

    // Cache result (unless disabled)
    if !config.disable_cache {
        cache_put(&key, &result.scheduler.applied_opts);
    }

    Ok(result)
}

/// Clear the beam search cache.
///
/// Useful for testing or when invalidating cached results.
pub fn clear_cache() {
    if let Some(db) = CACHE_DB.as_ref() {
        let _ = db.clear();
    }
}

#[cfg(test)]
mod tests {
    use super::super::types::OptOps;
    use super::*;

    #[test]
    fn test_beam_config_default() {
        let config = BeamConfig::default();
        assert_eq!(config.beam_width, 4);
        assert_eq!(config.timeout, Duration::from_secs(60));
        assert_eq!(config.max_upcast, 256);
        assert_eq!(config.max_local, 1024);
    }

    #[test]
    fn test_beam_actions_not_empty() {
        assert!(!BEAM_ACTIONS.is_empty());
        // Should have a reasonable number of actions
        // UPCAST: 8 axes * 6 amounts = 48
        // UNROLL: 5 axes * 3 amounts = 15
        // LOCAL: 6 axes * 7 amounts = 42
        // GROUPTOP: 3 axes * 8 amounts = 24
        // GROUP: 3 axes * 4 amounts = 12
        // TC: 1 + 9 = 10
        // SWAP: 10 pairs
        // NOLOCALS: 1
        // Total: ~162 actions
        assert!(BEAM_ACTIONS.len() > 100, "Expected >100 actions, got {}", BEAM_ACTIONS.len());
        assert!(BEAM_ACTIONS.len() < 500, "Expected <500 actions, got {}", BEAM_ACTIONS.len());
    }

    #[test]
    fn test_beam_actions_contains_expected_types() {
        let has_upcast = BEAM_ACTIONS.iter().any(|a| a.op == OptOps::UPCAST);
        let has_local = BEAM_ACTIONS.iter().any(|a| a.op == OptOps::LOCAL);
        let has_unroll = BEAM_ACTIONS.iter().any(|a| a.op == OptOps::UNROLL);
        let has_tc = BEAM_ACTIONS.iter().any(|a| a.op == OptOps::TC);
        let has_swap = BEAM_ACTIONS.iter().any(|a| a.op == OptOps::SWAP);
        let has_nolocals = BEAM_ACTIONS.iter().any(|a| a.op == OptOps::NOLOCALS);

        assert!(has_upcast);
        assert!(has_local);
        assert!(has_unroll);
        assert!(has_tc);
        assert!(has_swap);
        assert!(has_nolocals);
    }

    #[test]
    fn test_beam_search_with_mock_scoring() {
        use super::super::renderer::Renderer;
        use morok_ir::UOp;

        // Create a simple scheduler
        let val = UOp::native_const(1.0f32);
        let sink = UOp::sink(vec![val]);
        let renderer = Renderer::cpu();
        let scheduler = Scheduler::new(sink, renderer);

        let config = BeamConfig { beam_width: 2, timeout: Duration::from_millis(100), ..Default::default() };

        // Mock scoring: just return a constant time
        let mock_score = |_s: &Scheduler| Some(Duration::from_micros(100));

        let result = beam_search(scheduler, &config, mock_score);
        assert!(result.is_ok());

        let result = result.unwrap();
        assert!(result.iterations > 0 || result.candidates_evaluated == 0);
    }

    #[test]
    fn test_validate_limits() {
        use super::super::renderer::Renderer;
        use morok_ir::UOp;

        let val = UOp::native_const(1.0f32);
        let sink = UOp::sink(vec![val]);
        let renderer = Renderer::cpu();
        let scheduler = Scheduler::new(sink, renderer);

        let config = BeamConfig::default();

        // Simple scheduler should pass limits
        assert!(validate_limits(&scheduler, &config));

        // With very restrictive limits
        let strict_config = BeamConfig { max_upcast: 1, max_local: 1, max_uops: 1, ..Default::default() };

        // May or may not pass depending on UOp count
        let _result = validate_limits(&scheduler, &strict_config);
    }

    #[test]
    fn test_replay_opts_empty() {
        use super::super::renderer::Renderer;
        use morok_ir::UOp;

        let val = UOp::native_const(1.0f32);
        let sink = UOp::sink(vec![val]);
        let renderer = Renderer::cpu();
        let scheduler = Scheduler::new(sink, renderer);

        // Empty replay should succeed
        let result = replay_opts(scheduler, &[]);
        assert!(result.is_ok());
    }

    #[test]
    fn test_serialize_deserialize_opts_empty() {
        let opts: Vec<Opt> = vec![];
        let serialized = serialize_opts(&opts);
        let deserialized = deserialize_opts(&serialized);

        assert!(deserialized.is_some());
        assert!(deserialized.unwrap().is_empty());
    }

    #[test]
    fn test_serialize_deserialize_opts_upcast() {
        let opts = vec![Opt::upcast(0, 4), Opt::upcast(1, 8)];
        let serialized = serialize_opts(&opts);
        let deserialized = deserialize_opts(&serialized);

        assert!(deserialized.is_some());
        let result = deserialized.unwrap();
        assert_eq!(result.len(), 2);
        assert_eq!(result[0].op, OptOps::UPCAST);
        assert_eq!(result[0].axis, Some(0));
        assert_eq!(result[1].op, OptOps::UPCAST);
        assert_eq!(result[1].axis, Some(1));
    }

    #[test]
    fn test_serialize_deserialize_opts_tc() {
        use super::super::types::OptArg;

        let opts = vec![Opt::tc(None, -1, 2, 1)];
        let serialized = serialize_opts(&opts);
        let deserialized = deserialize_opts(&serialized);

        assert!(deserialized.is_some());
        let result = deserialized.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].op, OptOps::TC);
        assert_eq!(result[0].axis, None);
        if let OptArg::TensorCore { tc_select, opt_level, use_tc } = &result[0].arg {
            assert_eq!(*tc_select, -1);
            assert_eq!(*opt_level, 2);
            assert_eq!(*use_tc, 1);
        } else {
            panic!("Expected TensorCore arg");
        }
    }

    #[test]
    fn test_serialize_deserialize_opts_swap() {
        use super::super::types::OptArg;

        let opts = vec![Opt::swap(0, 2)];
        let serialized = serialize_opts(&opts);
        let deserialized = deserialize_opts(&serialized);

        assert!(deserialized.is_some());
        let result = deserialized.unwrap();
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].op, OptOps::SWAP);
        assert_eq!(result[0].axis, Some(0));
        if let OptArg::Swap { other_axis } = &result[0].arg {
            assert_eq!(*other_axis, 2);
        } else {
            panic!("Expected Swap arg");
        }
    }

    #[test]
    fn test_serialize_deserialize_opts_mixed() {
        let opts = vec![Opt::upcast(0, 4), Opt::local(1, 16), Opt::unroll(0, 8), Opt::nolocals()];
        let serialized = serialize_opts(&opts);
        let deserialized = deserialize_opts(&serialized);

        assert!(deserialized.is_some());
        let result = deserialized.unwrap();
        assert_eq!(result.len(), 4);
        assert_eq!(result[0].op, OptOps::UPCAST);
        assert_eq!(result[1].op, OptOps::LOCAL);
        assert_eq!(result[2].op, OptOps::UNROLL);
        assert_eq!(result[3].op, OptOps::NOLOCALS);
    }

    #[test]
    fn test_beam_actions_contains_thread() {
        let has_thread = BEAM_ACTIONS.iter().any(|a| a.op == OptOps::THREAD);
        assert!(has_thread, "BEAM_ACTIONS should contain THREAD actions");

        // Count thread actions
        let thread_count = BEAM_ACTIONS.iter().filter(|a| a.op == OptOps::THREAD).count();
        assert!(thread_count >= 6, "Expected at least 6 THREAD actions (3 axes × 2+ amounts), got {}", thread_count);
    }

    #[test]
    fn test_thread_action_applied_to_outer_axis() {
        use super::super::renderer::Renderer;
        use morok_ir::{AxisId, AxisType, UOp};

        // Create a kernel with Outer axis (like matmul reduce kernels)
        let end_512 = UOp::index_const(512);
        let r_outer = UOp::range_axis(end_512, AxisId::Renumbered(0), AxisType::Outer);
        let compute = UOp::native_const(1.0f32);
        let sink = UOp::sink(vec![compute, r_outer]);

        let renderer = Renderer::cpu();
        let scheduler = Scheduler::new(sink, renderer);

        // Verify renderer supports threading
        assert!(scheduler.renderer().has_threads, "CPU renderer should have has_threads=true");

        // Try to apply THREAD opt - use available parallelism to work on machines with few cores
        let thread_count = std::thread::available_parallelism().map(|p| p.get()).unwrap_or(4);
        let mut test_scheduler = scheduler.clone();
        let result = apply_opt(&mut test_scheduler, &Opt::thread(0, thread_count), true);
        assert!(result.is_ok(), "THREAD(0, {}) should succeed on Outer axis: {:?}", thread_count, result);

        // Verify Thread axis was created
        let thread_axes = test_scheduler.axes_of(&[AxisType::Thread]);
        assert!(!thread_axes.is_empty(), "Should have Thread axis after THREAD opt");
    }

    #[test]
    fn test_generate_actions_includes_thread_for_cpu() {
        use super::super::renderer::Renderer;
        use morok_ir::{AxisId, AxisType, UOp};

        // Create a kernel with Outer axis
        let end_512 = UOp::index_const(512);
        let r_outer = UOp::range_axis(end_512, AxisId::Renumbered(0), AxisType::Outer);
        let compute = UOp::native_const(1.0f32);
        let sink = UOp::sink(vec![compute, r_outer]);

        let renderer = Renderer::cpu();
        let scheduler = Scheduler::new(sink, renderer);

        let config = BeamConfig::default();
        let candidates = generate_actions(&scheduler, &config);

        // Check if any candidate has a Thread axis
        let has_threaded = candidates.iter().any(|s| !s.axes_of(&[AxisType::Thread]).is_empty());
        assert!(has_threaded, "generate_actions should produce candidates with Thread axes for CPU");
    }
}