Skip to main content

morok_schedule/optimizer/
config.rs

1//! Optimizer configuration types.
2//!
3//! Provides typed configuration for kernel optimization with bon builders.
4//! Supports both explicit configuration and environment variable fallbacks.
5
6use std::time::Duration;
7
8use bon::bon;
9
10// ============================================================================
11// OPTIMIZATION STRATEGY
12// ============================================================================
13
14/// Optimization strategy for kernel tuning.
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum OptStrategy {
17    /// No optimization (for debugging/regression testing).
18    None,
19
20    /// Hand-coded heuristics (default).
21    #[default]
22    Heuristic,
23
24    /// Beam search optimization.
25    Beam {
26        /// Beam width - number of candidates to keep at each step.
27        width: usize,
28    },
29}
30
31impl OptStrategy {
32    /// Get optimization strategy from environment variables.
33    ///
34    /// # Environment Variables
35    ///
36    /// * `MOROK_NOOPT=1` - Disable all optimizations
37    /// * `MOROK_BEAM=N` - Use beam search with width N
38    pub fn from_env() -> Self {
39        if std::env::var("MOROK_NOOPT").is_ok() {
40            return Self::None;
41        }
42
43        if let Ok(beam_str) = std::env::var("MOROK_BEAM")
44            && let Ok(width) = beam_str.parse::<usize>()
45            && width > 0
46        {
47            return Self::Beam { width };
48        }
49
50        Self::Heuristic
51    }
52
53    /// Check if this strategy disables optimization.
54    pub fn is_none(&self) -> bool {
55        matches!(self, Self::None)
56    }
57
58    /// Check if this strategy uses beam search.
59    pub fn is_beam(&self) -> bool {
60        matches!(self, Self::Beam { .. })
61    }
62}
63
64// ============================================================================
65// TENSOR CORE SETTINGS
66// ============================================================================
67
68/// Tensor core usage level.
69#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
70pub enum TcUsage {
71    /// Disabled (USE_TC=0).
72    Disabled,
73
74    /// Enabled (USE_TC=1, default).
75    #[default]
76    Enabled,
77
78    /// Shape-only mode (USE_TC=2).
79    ShapeOnly,
80}
81
82impl TcUsage {
83    /// Convert to integer value for internal APIs.
84    pub fn as_usize(&self) -> usize {
85        match self {
86            Self::Disabled => 0,
87            Self::Enabled => 1,
88            Self::ShapeOnly => 2,
89        }
90    }
91}
92
93/// Tensor core optimization level.
94#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
95pub enum TcOpt {
96    /// Strict matching (TC_OPT=0).
97    Strict,
98
99    /// Relaxed matching (TC_OPT=1).
100    Relaxed,
101
102    /// Padded matching (TC_OPT=2, default).
103    #[default]
104    Padded,
105}
106
107impl TcOpt {
108    /// Convert to integer value for internal APIs.
109    pub fn as_usize(&self) -> usize {
110        match self {
111            Self::Strict => 0,
112            Self::Relaxed => 1,
113            Self::Padded => 2,
114        }
115    }
116}
117
118/// Tensor core selection mode.
119#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
120pub enum TcSelect {
121    /// Auto-select best tensor core (TC_SELECT=-1, default).
122    #[default]
123    Auto,
124
125    /// Use specific tensor core index.
126    Index(usize),
127}
128
129impl TcSelect {
130    /// Convert to integer value for internal APIs.
131    pub fn as_i32(&self) -> i32 {
132        match self {
133            Self::Auto => -1,
134            Self::Index(idx) => *idx as i32,
135        }
136    }
137}
138
139// ============================================================================
140// BEAM SEARCH CONFIGURATION
141// ============================================================================
142
143/// Configuration for beam search auto-tuning.
144#[derive(Debug, Clone)]
145pub struct BeamConfig {
146    /// Beam width - number of candidates to keep at each step.
147    pub beam_width: usize,
148    /// Maximum search time.
149    pub timeout: Duration,
150    /// Maximum upcast size (product of UPCAST/UNROLL dimensions).
151    pub max_upcast: usize,
152    /// Maximum local size (product of LOCAL/WARP/GROUP_REDUCE dimensions).
153    pub max_local: usize,
154    /// Maximum UOps in kernel before rejecting.
155    pub max_uops: usize,
156    /// Number of benchmark runs per kernel.
157    pub num_runs: usize,
158    /// Disable disk cache.
159    pub disable_cache: bool,
160}
161
162impl Default for BeamConfig {
163    fn default() -> Self {
164        Self {
165            beam_width: 4,
166            timeout: Duration::from_secs(60),
167            max_upcast: 256,
168            max_local: 1024,
169            max_uops: 3000,
170            num_runs: 3,
171            disable_cache: false,
172        }
173    }
174}
175
176#[bon]
177impl BeamConfig {
178    /// Create a beam configuration with builder pattern.
179    #[builder]
180    pub fn builder(
181        #[builder(default = 4)] beam_width: usize,
182        #[builder(default = 60)] timeout_secs: u64,
183        #[builder(default = 256)] max_upcast: usize,
184        #[builder(default = 1024)] max_local: usize,
185        #[builder(default = 3000)] max_uops: usize,
186        #[builder(default = 3)] num_runs: usize,
187        #[builder(default = false)] disable_cache: bool,
188    ) -> Self {
189        Self {
190            beam_width,
191            timeout: Duration::from_secs(timeout_secs),
192            max_upcast,
193            max_local,
194            max_uops,
195            num_runs,
196            disable_cache,
197        }
198    }
199
200    /// Create configuration from environment variables.
201    ///
202    /// # Environment Variables
203    ///
204    /// * `MOROK_BEAM` - Beam width (default: 4)
205    /// * `MOROK_BEAM_TIMEOUT` - Max search time in seconds (default: 60)
206    /// * `BEAM_UPCAST_MAX` - Max upcast size (default: 256)
207    /// * `BEAM_LOCAL_MAX` - Max local memory elements (default: 1024)
208    /// * `BEAM_UOPS_MAX` - Max UOps before rejecting (default: 3000)
209    /// * `BEAM_RUNS` - Benchmark runs per kernel (default: 3)
210    /// * `IGNORE_BEAM_CACHE` - Bypass disk cache if set
211    pub fn from_env() -> Self {
212        let beam_width = std::env::var("MOROK_BEAM").ok().and_then(|s| s.parse().ok()).unwrap_or(4);
213        let timeout_secs = std::env::var("MOROK_BEAM_TIMEOUT").ok().and_then(|s| s.parse().ok()).unwrap_or(60);
214        let max_upcast = std::env::var("BEAM_UPCAST_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(256);
215        let max_local = std::env::var("BEAM_LOCAL_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(1024);
216        let max_uops = std::env::var("BEAM_UOPS_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(3000);
217        let num_runs = std::env::var("BEAM_RUNS").ok().and_then(|s| s.parse().ok()).unwrap_or(3);
218        let disable_cache = std::env::var("IGNORE_BEAM_CACHE").is_ok();
219
220        Self {
221            beam_width,
222            timeout: Duration::from_secs(timeout_secs),
223            max_upcast,
224            max_local,
225            max_uops,
226            num_runs,
227            disable_cache,
228        }
229    }
230
231    /// Get beam width from strategy if applicable.
232    pub fn with_strategy_width(mut self, strategy: &OptStrategy) -> Self {
233        if let OptStrategy::Beam { width } = strategy {
234            self.beam_width = *width;
235        }
236        self
237    }
238}
239
240// ============================================================================
241// HEURISTICS CONFIGURATION
242// ============================================================================
243
244/// Configuration for heuristic-based optimization.
245#[derive(Debug, Clone)]
246pub struct HeuristicsConfig {
247    // Tensor cores
248    /// Tensor core usage level.
249    pub tc_enabled: TcUsage,
250    /// Tensor core optimization level.
251    pub tc_opt: TcOpt,
252    /// Tensor core selection mode.
253    pub tc_select: TcSelect,
254
255    // Matrix-vector optimization
256    /// Enable matrix-vector optimization.
257    pub matvec_enabled: bool,
258    /// Matrix-vector block size (rows per workgroup).
259    pub matvec_blocksize: usize,
260
261    // Reduction thresholds
262    /// Threshold for applying grouped reduction.
263    pub grouped_threshold: usize,
264    /// Threshold for applying unroll.
265    pub unroll_threshold: usize,
266
267    // Local memory
268    /// Disable local memory globally.
269    pub disable_locals: bool,
270
271    // Threading
272    /// Maximum thread count for CPU parallelization.
273    /// Default: std::thread::available_parallelism().
274    /// Set to 1 to disable threading.
275    pub thread_count: usize,
276
277    // Vectorization
278    /// Enable K-axis vectorization for matmul.
279    /// When enabled, UPCAST is applied to the reduce (K) axis creating vector accumulators.
280    /// Disabled by default: K-vectorization complicates output tiling and horizontal reduce.
281    /// Tinygrad doesn't use K-vectorization - they rely on output tiling (register blocking).
282    /// Default: false.
283    pub k_vectorize: bool,
284
285    /// Enable output dimension upcasting for matmul (register blocking).
286    /// When enabled, UPCAST is applied to M/N axes creating register tiles.
287    /// Each thread computes an MxN tile instead of a single element.
288    /// Default: false (blocked by vector width mismatch issue in expand.rs).
289    pub output_upcast: bool,
290
291    // Debug
292    /// Debug verbosity level.
293    pub debug_level: u8,
294}
295
296/// Get default thread count from system (used by Default and builder).
297fn default_thread_count() -> usize {
298    std::thread::available_parallelism().map(|p| p.get()).unwrap_or(8)
299}
300
301impl HeuristicsConfig {
302    /// Create configuration from environment variables.
303    ///
304    /// # Environment Variables
305    ///
306    /// * `MOROK_THREADS` - Maximum thread count (default: available_parallelism)
307    /// * `MOROK_K_VECTORIZE` - Enable K-axis vectorization (default: disabled)
308    /// * `MOROK_NO_OUTPUT_UPCAST` - Disable output dimension upcasting (default: enabled)
309    pub fn from_env() -> Self {
310        let thread_count =
311            std::env::var("MOROK_THREADS").ok().and_then(|s| s.parse().ok()).unwrap_or_else(default_thread_count);
312        let k_vectorize = std::env::var("MOROK_K_VECTORIZE").is_ok();
313        // Default enabled, use MOROK_NO_OUTPUT_UPCAST to disable
314        let output_upcast = std::env::var("MOROK_NO_OUTPUT_UPCAST").is_err();
315
316        Self { thread_count, k_vectorize, output_upcast, ..Default::default() }
317    }
318}
319
320impl Default for HeuristicsConfig {
321    fn default() -> Self {
322        Self {
323            tc_enabled: TcUsage::Enabled,
324            tc_opt: TcOpt::Padded,
325            tc_select: TcSelect::Auto,
326            matvec_enabled: true,
327            matvec_blocksize: 4,
328            grouped_threshold: 256,
329            unroll_threshold: 32,
330            disable_locals: false,
331            thread_count: default_thread_count(),
332            k_vectorize: false,
333            output_upcast: true,
334            debug_level: 0,
335        }
336    }
337}
338
339#[bon]
340impl HeuristicsConfig {
341    /// Create a heuristics configuration with builder pattern.
342    #[builder]
343    pub fn builder(
344        #[builder(default)] tc_enabled: TcUsage,
345        #[builder(default)] tc_opt: TcOpt,
346        #[builder(default)] tc_select: TcSelect,
347        #[builder(default = true)] matvec_enabled: bool,
348        #[builder(default = 4)] matvec_blocksize: usize,
349        #[builder(default = 256)] grouped_threshold: usize,
350        #[builder(default = 32)] unroll_threshold: usize,
351        #[builder(default = false)] disable_locals: bool,
352        #[builder(default = default_thread_count())] thread_count: usize,
353        #[builder(default = false)] k_vectorize: bool,
354        #[builder(default = true)] output_upcast: bool,
355        #[builder(default = 0)] debug_level: u8,
356    ) -> Self {
357        Self {
358            tc_enabled,
359            tc_opt,
360            tc_select,
361            matvec_enabled,
362            matvec_blocksize,
363            grouped_threshold,
364            unroll_threshold,
365            disable_locals,
366            thread_count,
367            k_vectorize,
368            output_upcast,
369            debug_level,
370        }
371    }
372}
373
374// ============================================================================
375// TOP-LEVEL OPTIMIZER CONFIGURATION
376// ============================================================================
377
378/// Top-level optimizer configuration.
379///
380/// Combines strategy selection, beam search settings, and heuristic parameters.
381#[derive(Debug, Clone, Default)]
382pub struct OptimizerConfig {
383    /// Optimization strategy (None, Heuristic, or Beam).
384    pub strategy: OptStrategy,
385    /// Beam search configuration (used when strategy is Beam).
386    pub beam: BeamConfig,
387    /// Heuristics configuration (used when strategy is Heuristic).
388    pub heuristics: HeuristicsConfig,
389}
390
391#[bon]
392impl OptimizerConfig {
393    /// Create an optimizer configuration with builder pattern.
394    #[builder]
395    pub fn builder(
396        #[builder(default)] strategy: OptStrategy,
397        #[builder(default)] beam: BeamConfig,
398        #[builder(default)] heuristics: HeuristicsConfig,
399    ) -> Self {
400        Self { strategy, beam, heuristics }
401    }
402
403    /// Create configuration from environment variables.
404    ///
405    /// Reads strategy from env, then populates beam and heuristics config accordingly.
406    ///
407    /// # Environment Variables
408    ///
409    /// * `MOROK_NOOPT=1` - Disable all optimizations
410    /// * `MOROK_BEAM=N` - Use beam search with width N
411    pub fn from_env() -> Self {
412        let strategy = OptStrategy::from_env();
413        let beam = BeamConfig::from_env().with_strategy_width(&strategy);
414        let heuristics = HeuristicsConfig::from_env();
415
416        Self { strategy, beam, heuristics }
417    }
418}
419
420// ============================================================================
421// TESTS
422// ============================================================================
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_opt_strategy_default_is_heuristic() {
430        assert_eq!(OptStrategy::default(), OptStrategy::Heuristic);
431    }
432
433    #[test]
434    fn test_opt_strategy_is_none() {
435        assert!(OptStrategy::None.is_none());
436        assert!(!OptStrategy::Heuristic.is_none());
437        assert!(!OptStrategy::Beam { width: 4 }.is_none());
438    }
439
440    #[test]
441    fn test_opt_strategy_is_beam() {
442        assert!(!OptStrategy::None.is_beam());
443        assert!(!OptStrategy::Heuristic.is_beam());
444        assert!(OptStrategy::Beam { width: 4 }.is_beam());
445    }
446
447    #[test]
448    fn test_beam_config_default() {
449        let config = BeamConfig::default();
450        assert_eq!(config.beam_width, 4);
451        assert_eq!(config.timeout, Duration::from_secs(60));
452        assert_eq!(config.max_upcast, 256);
453        assert_eq!(config.max_local, 1024);
454    }
455
456    #[test]
457    fn test_beam_config_builder() {
458        let config = BeamConfig::builder().beam_width(8).timeout_secs(120).max_upcast(512).build();
459
460        assert_eq!(config.beam_width, 8);
461        assert_eq!(config.timeout, Duration::from_secs(120));
462        assert_eq!(config.max_upcast, 512);
463        assert_eq!(config.max_local, 1024); // default
464    }
465
466    #[test]
467    fn test_heuristics_config_default() {
468        let config = HeuristicsConfig::default();
469        assert_eq!(config.tc_enabled, TcUsage::Enabled);
470        assert_eq!(config.tc_opt, TcOpt::Padded);
471        assert!(config.matvec_enabled);
472        assert_eq!(config.grouped_threshold, 256);
473    }
474
475    #[test]
476    fn test_heuristics_config_builder() {
477        let config = HeuristicsConfig::builder()
478            .tc_enabled(TcUsage::Disabled)
479            .matvec_enabled(false)
480            .grouped_threshold(128)
481            .build();
482
483        assert_eq!(config.tc_enabled, TcUsage::Disabled);
484        assert!(!config.matvec_enabled);
485        assert_eq!(config.grouped_threshold, 128);
486    }
487
488    #[test]
489    fn test_optimizer_config_default() {
490        let config = OptimizerConfig::default();
491        assert_eq!(config.strategy, OptStrategy::Heuristic);
492        assert_eq!(config.beam.beam_width, 4);
493    }
494
495    #[test]
496    fn test_optimizer_config_builder() {
497        let config = OptimizerConfig::builder()
498            .strategy(OptStrategy::Beam { width: 8 })
499            .beam(BeamConfig::builder().timeout_secs(120).build())
500            .build();
501
502        assert_eq!(config.strategy, OptStrategy::Beam { width: 8 });
503        assert_eq!(config.beam.timeout, Duration::from_secs(120));
504    }
505
506    #[test]
507    fn test_tc_usage_as_usize() {
508        assert_eq!(TcUsage::Disabled.as_usize(), 0);
509        assert_eq!(TcUsage::Enabled.as_usize(), 1);
510        assert_eq!(TcUsage::ShapeOnly.as_usize(), 2);
511    }
512
513    #[test]
514    fn test_tc_opt_as_usize() {
515        assert_eq!(TcOpt::Strict.as_usize(), 0);
516        assert_eq!(TcOpt::Relaxed.as_usize(), 1);
517        assert_eq!(TcOpt::Padded.as_usize(), 2);
518    }
519
520    #[test]
521    fn test_tc_select_as_i32() {
522        assert_eq!(TcSelect::Auto.as_i32(), -1);
523        assert_eq!(TcSelect::Index(5).as_i32(), 5);
524    }
525}