1use std::time::Duration;
7
8use bon::bon;
9
10#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum OptStrategy {
17 None,
19
20 #[default]
22 Heuristic,
23
24 Beam {
26 width: usize,
28 },
29}
30
31impl OptStrategy {
32 pub fn from_env() -> Self {
39 if std::env::var("MOROK_NOOPT").is_ok() {
40 return Self::None;
41 }
42
43 if let Ok(beam_str) = std::env::var("MOROK_BEAM")
44 && let Ok(width) = beam_str.parse::<usize>()
45 && width > 0
46 {
47 return Self::Beam { width };
48 }
49
50 Self::Heuristic
51 }
52
53 pub fn is_none(&self) -> bool {
55 matches!(self, Self::None)
56 }
57
58 pub fn is_beam(&self) -> bool {
60 matches!(self, Self::Beam { .. })
61 }
62}
63
64#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
70pub enum TcUsage {
71 Disabled,
73
74 #[default]
76 Enabled,
77
78 ShapeOnly,
80}
81
82impl TcUsage {
83 pub fn as_usize(&self) -> usize {
85 match self {
86 Self::Disabled => 0,
87 Self::Enabled => 1,
88 Self::ShapeOnly => 2,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
95pub enum TcOpt {
96 Strict,
98
99 Relaxed,
101
102 #[default]
104 Padded,
105}
106
107impl TcOpt {
108 pub fn as_usize(&self) -> usize {
110 match self {
111 Self::Strict => 0,
112 Self::Relaxed => 1,
113 Self::Padded => 2,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
120pub enum TcSelect {
121 #[default]
123 Auto,
124
125 Index(usize),
127}
128
129impl TcSelect {
130 pub fn as_i32(&self) -> i32 {
132 match self {
133 Self::Auto => -1,
134 Self::Index(idx) => *idx as i32,
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
145pub struct BeamConfig {
146 pub beam_width: usize,
148 pub timeout: Duration,
150 pub max_upcast: usize,
152 pub max_local: usize,
154 pub max_uops: usize,
156 pub num_runs: usize,
158 pub disable_cache: bool,
160}
161
162impl Default for BeamConfig {
163 fn default() -> Self {
164 Self {
165 beam_width: 4,
166 timeout: Duration::from_secs(60),
167 max_upcast: 256,
168 max_local: 1024,
169 max_uops: 3000,
170 num_runs: 3,
171 disable_cache: false,
172 }
173 }
174}
175
176#[bon]
177impl BeamConfig {
178 #[builder]
180 pub fn builder(
181 #[builder(default = 4)] beam_width: usize,
182 #[builder(default = 60)] timeout_secs: u64,
183 #[builder(default = 256)] max_upcast: usize,
184 #[builder(default = 1024)] max_local: usize,
185 #[builder(default = 3000)] max_uops: usize,
186 #[builder(default = 3)] num_runs: usize,
187 #[builder(default = false)] disable_cache: bool,
188 ) -> Self {
189 Self {
190 beam_width,
191 timeout: Duration::from_secs(timeout_secs),
192 max_upcast,
193 max_local,
194 max_uops,
195 num_runs,
196 disable_cache,
197 }
198 }
199
200 pub fn from_env() -> Self {
212 let beam_width = std::env::var("MOROK_BEAM").ok().and_then(|s| s.parse().ok()).unwrap_or(4);
213 let timeout_secs = std::env::var("MOROK_BEAM_TIMEOUT").ok().and_then(|s| s.parse().ok()).unwrap_or(60);
214 let max_upcast = std::env::var("BEAM_UPCAST_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(256);
215 let max_local = std::env::var("BEAM_LOCAL_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(1024);
216 let max_uops = std::env::var("BEAM_UOPS_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(3000);
217 let num_runs = std::env::var("BEAM_RUNS").ok().and_then(|s| s.parse().ok()).unwrap_or(3);
218 let disable_cache = std::env::var("IGNORE_BEAM_CACHE").is_ok();
219
220 Self {
221 beam_width,
222 timeout: Duration::from_secs(timeout_secs),
223 max_upcast,
224 max_local,
225 max_uops,
226 num_runs,
227 disable_cache,
228 }
229 }
230
231 pub fn with_strategy_width(mut self, strategy: &OptStrategy) -> Self {
233 if let OptStrategy::Beam { width } = strategy {
234 self.beam_width = *width;
235 }
236 self
237 }
238}
239
240#[derive(Debug, Clone)]
246pub struct HeuristicsConfig {
247 pub tc_enabled: TcUsage,
250 pub tc_opt: TcOpt,
252 pub tc_select: TcSelect,
254
255 pub matvec_enabled: bool,
258 pub matvec_blocksize: usize,
260
261 pub grouped_threshold: usize,
264 pub unroll_threshold: usize,
266
267 pub disable_locals: bool,
270
271 pub thread_count: usize,
276
277 pub k_vectorize: bool,
284
285 pub output_upcast: bool,
290
291 pub debug_level: u8,
294}
295
296fn default_thread_count() -> usize {
298 std::thread::available_parallelism().map(|p| p.get()).unwrap_or(8)
299}
300
301impl HeuristicsConfig {
302 pub fn from_env() -> Self {
310 let thread_count =
311 std::env::var("MOROK_THREADS").ok().and_then(|s| s.parse().ok()).unwrap_or_else(default_thread_count);
312 let k_vectorize = std::env::var("MOROK_K_VECTORIZE").is_ok();
313 let output_upcast = std::env::var("MOROK_NO_OUTPUT_UPCAST").is_err();
315
316 Self { thread_count, k_vectorize, output_upcast, ..Default::default() }
317 }
318}
319
320impl Default for HeuristicsConfig {
321 fn default() -> Self {
322 Self {
323 tc_enabled: TcUsage::Enabled,
324 tc_opt: TcOpt::Padded,
325 tc_select: TcSelect::Auto,
326 matvec_enabled: true,
327 matvec_blocksize: 4,
328 grouped_threshold: 256,
329 unroll_threshold: 32,
330 disable_locals: false,
331 thread_count: default_thread_count(),
332 k_vectorize: false,
333 output_upcast: true,
334 debug_level: 0,
335 }
336 }
337}
338
339#[bon]
340impl HeuristicsConfig {
341 #[builder]
343 pub fn builder(
344 #[builder(default)] tc_enabled: TcUsage,
345 #[builder(default)] tc_opt: TcOpt,
346 #[builder(default)] tc_select: TcSelect,
347 #[builder(default = true)] matvec_enabled: bool,
348 #[builder(default = 4)] matvec_blocksize: usize,
349 #[builder(default = 256)] grouped_threshold: usize,
350 #[builder(default = 32)] unroll_threshold: usize,
351 #[builder(default = false)] disable_locals: bool,
352 #[builder(default = default_thread_count())] thread_count: usize,
353 #[builder(default = false)] k_vectorize: bool,
354 #[builder(default = true)] output_upcast: bool,
355 #[builder(default = 0)] debug_level: u8,
356 ) -> Self {
357 Self {
358 tc_enabled,
359 tc_opt,
360 tc_select,
361 matvec_enabled,
362 matvec_blocksize,
363 grouped_threshold,
364 unroll_threshold,
365 disable_locals,
366 thread_count,
367 k_vectorize,
368 output_upcast,
369 debug_level,
370 }
371 }
372}
373
374#[derive(Debug, Clone, Default)]
382pub struct OptimizerConfig {
383 pub strategy: OptStrategy,
385 pub beam: BeamConfig,
387 pub heuristics: HeuristicsConfig,
389}
390
391#[bon]
392impl OptimizerConfig {
393 #[builder]
395 pub fn builder(
396 #[builder(default)] strategy: OptStrategy,
397 #[builder(default)] beam: BeamConfig,
398 #[builder(default)] heuristics: HeuristicsConfig,
399 ) -> Self {
400 Self { strategy, beam, heuristics }
401 }
402
403 pub fn from_env() -> Self {
412 let strategy = OptStrategy::from_env();
413 let beam = BeamConfig::from_env().with_strategy_width(&strategy);
414 let heuristics = HeuristicsConfig::from_env();
415
416 Self { strategy, beam, heuristics }
417 }
418}
419
420#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_opt_strategy_default_is_heuristic() {
430 assert_eq!(OptStrategy::default(), OptStrategy::Heuristic);
431 }
432
433 #[test]
434 fn test_opt_strategy_is_none() {
435 assert!(OptStrategy::None.is_none());
436 assert!(!OptStrategy::Heuristic.is_none());
437 assert!(!OptStrategy::Beam { width: 4 }.is_none());
438 }
439
440 #[test]
441 fn test_opt_strategy_is_beam() {
442 assert!(!OptStrategy::None.is_beam());
443 assert!(!OptStrategy::Heuristic.is_beam());
444 assert!(OptStrategy::Beam { width: 4 }.is_beam());
445 }
446
447 #[test]
448 fn test_beam_config_default() {
449 let config = BeamConfig::default();
450 assert_eq!(config.beam_width, 4);
451 assert_eq!(config.timeout, Duration::from_secs(60));
452 assert_eq!(config.max_upcast, 256);
453 assert_eq!(config.max_local, 1024);
454 }
455
456 #[test]
457 fn test_beam_config_builder() {
458 let config = BeamConfig::builder().beam_width(8).timeout_secs(120).max_upcast(512).build();
459
460 assert_eq!(config.beam_width, 8);
461 assert_eq!(config.timeout, Duration::from_secs(120));
462 assert_eq!(config.max_upcast, 512);
463 assert_eq!(config.max_local, 1024); }
465
466 #[test]
467 fn test_heuristics_config_default() {
468 let config = HeuristicsConfig::default();
469 assert_eq!(config.tc_enabled, TcUsage::Enabled);
470 assert_eq!(config.tc_opt, TcOpt::Padded);
471 assert!(config.matvec_enabled);
472 assert_eq!(config.grouped_threshold, 256);
473 }
474
475 #[test]
476 fn test_heuristics_config_builder() {
477 let config = HeuristicsConfig::builder()
478 .tc_enabled(TcUsage::Disabled)
479 .matvec_enabled(false)
480 .grouped_threshold(128)
481 .build();
482
483 assert_eq!(config.tc_enabled, TcUsage::Disabled);
484 assert!(!config.matvec_enabled);
485 assert_eq!(config.grouped_threshold, 128);
486 }
487
488 #[test]
489 fn test_optimizer_config_default() {
490 let config = OptimizerConfig::default();
491 assert_eq!(config.strategy, OptStrategy::Heuristic);
492 assert_eq!(config.beam.beam_width, 4);
493 }
494
495 #[test]
496 fn test_optimizer_config_builder() {
497 let config = OptimizerConfig::builder()
498 .strategy(OptStrategy::Beam { width: 8 })
499 .beam(BeamConfig::builder().timeout_secs(120).build())
500 .build();
501
502 assert_eq!(config.strategy, OptStrategy::Beam { width: 8 });
503 assert_eq!(config.beam.timeout, Duration::from_secs(120));
504 }
505
506 #[test]
507 fn test_tc_usage_as_usize() {
508 assert_eq!(TcUsage::Disabled.as_usize(), 0);
509 assert_eq!(TcUsage::Enabled.as_usize(), 1);
510 assert_eq!(TcUsage::ShapeOnly.as_usize(), 2);
511 }
512
513 #[test]
514 fn test_tc_opt_as_usize() {
515 assert_eq!(TcOpt::Strict.as_usize(), 0);
516 assert_eq!(TcOpt::Relaxed.as_usize(), 1);
517 assert_eq!(TcOpt::Padded.as_usize(), 2);
518 }
519
520 #[test]
521 fn test_tc_select_as_i32() {
522 assert_eq!(TcSelect::Auto.as_i32(), -1);
523 assert_eq!(TcSelect::Index(5).as_i32(), 5);
524 }
525}