use std::time::Duration;
use bon::bon;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum OptStrategy {
None,
#[default]
Heuristic,
Beam {
width: usize,
},
}
impl OptStrategy {
pub fn from_env() -> Self {
if std::env::var("MOROK_NOOPT").is_ok() {
return Self::None;
}
if let Ok(beam_str) = std::env::var("MOROK_BEAM")
&& let Ok(width) = beam_str.parse::<usize>()
&& width > 0
{
return Self::Beam { width };
}
Self::Heuristic
}
pub fn is_none(&self) -> bool {
matches!(self, Self::None)
}
pub fn is_beam(&self) -> bool {
matches!(self, Self::Beam { .. })
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum TcUsage {
Disabled,
#[default]
Enabled,
ShapeOnly,
}
impl TcUsage {
pub fn as_usize(&self) -> usize {
match self {
Self::Disabled => 0,
Self::Enabled => 1,
Self::ShapeOnly => 2,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum TcOpt {
Strict,
Relaxed,
#[default]
Padded,
}
impl TcOpt {
pub fn as_usize(&self) -> usize {
match self {
Self::Strict => 0,
Self::Relaxed => 1,
Self::Padded => 2,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum TcSelect {
#[default]
Auto,
Index(usize),
}
impl TcSelect {
pub fn as_i32(&self) -> i32 {
match self {
Self::Auto => -1,
Self::Index(idx) => *idx as i32,
}
}
}
#[derive(Debug, Clone)]
pub struct BeamConfig {
pub beam_width: usize,
pub timeout: Duration,
pub max_upcast: usize,
pub max_local: usize,
pub max_uops: usize,
pub num_runs: usize,
pub disable_cache: bool,
}
impl Default for BeamConfig {
fn default() -> Self {
Self {
beam_width: 4,
timeout: Duration::from_secs(60),
max_upcast: 256,
max_local: 1024,
max_uops: 3000,
num_runs: 3,
disable_cache: false,
}
}
}
#[bon]
impl BeamConfig {
#[builder]
pub fn builder(
#[builder(default = 4)] beam_width: usize,
#[builder(default = 60)] timeout_secs: u64,
#[builder(default = 256)] max_upcast: usize,
#[builder(default = 1024)] max_local: usize,
#[builder(default = 3000)] max_uops: usize,
#[builder(default = 3)] num_runs: usize,
#[builder(default = false)] disable_cache: bool,
) -> Self {
Self {
beam_width,
timeout: Duration::from_secs(timeout_secs),
max_upcast,
max_local,
max_uops,
num_runs,
disable_cache,
}
}
pub fn from_env() -> Self {
let beam_width = std::env::var("MOROK_BEAM").ok().and_then(|s| s.parse().ok()).unwrap_or(4);
let timeout_secs = std::env::var("MOROK_BEAM_TIMEOUT").ok().and_then(|s| s.parse().ok()).unwrap_or(60);
let max_upcast = std::env::var("BEAM_UPCAST_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(256);
let max_local = std::env::var("BEAM_LOCAL_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(1024);
let max_uops = std::env::var("BEAM_UOPS_MAX").ok().and_then(|s| s.parse().ok()).unwrap_or(3000);
let num_runs = std::env::var("BEAM_RUNS").ok().and_then(|s| s.parse().ok()).unwrap_or(3);
let disable_cache = std::env::var("IGNORE_BEAM_CACHE").is_ok();
Self {
beam_width,
timeout: Duration::from_secs(timeout_secs),
max_upcast,
max_local,
max_uops,
num_runs,
disable_cache,
}
}
pub fn with_strategy_width(mut self, strategy: &OptStrategy) -> Self {
if let OptStrategy::Beam { width } = strategy {
self.beam_width = *width;
}
self
}
}
#[derive(Debug, Clone)]
pub struct HeuristicsConfig {
pub tc_enabled: TcUsage,
pub tc_opt: TcOpt,
pub tc_select: TcSelect,
pub matvec_enabled: bool,
pub matvec_blocksize: usize,
pub grouped_threshold: usize,
pub unroll_threshold: usize,
pub disable_locals: bool,
pub thread_count: usize,
pub k_vectorize: bool,
pub output_upcast: bool,
pub debug_level: u8,
}
fn default_thread_count() -> usize {
std::thread::available_parallelism().map(|p| p.get()).unwrap_or(8)
}
impl HeuristicsConfig {
pub fn from_env() -> Self {
let thread_count =
std::env::var("MOROK_THREADS").ok().and_then(|s| s.parse().ok()).unwrap_or_else(default_thread_count);
let k_vectorize = std::env::var("MOROK_K_VECTORIZE").is_ok();
let output_upcast = std::env::var("MOROK_NO_OUTPUT_UPCAST").is_err();
Self { thread_count, k_vectorize, output_upcast, ..Default::default() }
}
}
impl Default for HeuristicsConfig {
fn default() -> Self {
Self {
tc_enabled: TcUsage::Enabled,
tc_opt: TcOpt::Padded,
tc_select: TcSelect::Auto,
matvec_enabled: true,
matvec_blocksize: 4,
grouped_threshold: 256,
unroll_threshold: 32,
disable_locals: false,
thread_count: default_thread_count(),
k_vectorize: false,
output_upcast: true,
debug_level: 0,
}
}
}
#[bon]
impl HeuristicsConfig {
#[builder]
pub fn builder(
#[builder(default)] tc_enabled: TcUsage,
#[builder(default)] tc_opt: TcOpt,
#[builder(default)] tc_select: TcSelect,
#[builder(default = true)] matvec_enabled: bool,
#[builder(default = 4)] matvec_blocksize: usize,
#[builder(default = 256)] grouped_threshold: usize,
#[builder(default = 32)] unroll_threshold: usize,
#[builder(default = false)] disable_locals: bool,
#[builder(default = default_thread_count())] thread_count: usize,
#[builder(default = false)] k_vectorize: bool,
#[builder(default = true)] output_upcast: bool,
#[builder(default = 0)] debug_level: u8,
) -> Self {
Self {
tc_enabled,
tc_opt,
tc_select,
matvec_enabled,
matvec_blocksize,
grouped_threshold,
unroll_threshold,
disable_locals,
thread_count,
k_vectorize,
output_upcast,
debug_level,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct OptimizerConfig {
pub strategy: OptStrategy,
pub beam: BeamConfig,
pub heuristics: HeuristicsConfig,
}
#[bon]
impl OptimizerConfig {
#[builder]
pub fn builder(
#[builder(default)] strategy: OptStrategy,
#[builder(default)] beam: BeamConfig,
#[builder(default)] heuristics: HeuristicsConfig,
) -> Self {
Self { strategy, beam, heuristics }
}
pub fn from_env() -> Self {
let strategy = OptStrategy::from_env();
let beam = BeamConfig::from_env().with_strategy_width(&strategy);
let heuristics = HeuristicsConfig::from_env();
Self { strategy, beam, heuristics }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opt_strategy_default_is_heuristic() {
assert_eq!(OptStrategy::default(), OptStrategy::Heuristic);
}
#[test]
fn test_opt_strategy_is_none() {
assert!(OptStrategy::None.is_none());
assert!(!OptStrategy::Heuristic.is_none());
assert!(!OptStrategy::Beam { width: 4 }.is_none());
}
#[test]
fn test_opt_strategy_is_beam() {
assert!(!OptStrategy::None.is_beam());
assert!(!OptStrategy::Heuristic.is_beam());
assert!(OptStrategy::Beam { width: 4 }.is_beam());
}
#[test]
fn test_beam_config_default() {
let config = BeamConfig::default();
assert_eq!(config.beam_width, 4);
assert_eq!(config.timeout, Duration::from_secs(60));
assert_eq!(config.max_upcast, 256);
assert_eq!(config.max_local, 1024);
}
#[test]
fn test_beam_config_builder() {
let config = BeamConfig::builder().beam_width(8).timeout_secs(120).max_upcast(512).build();
assert_eq!(config.beam_width, 8);
assert_eq!(config.timeout, Duration::from_secs(120));
assert_eq!(config.max_upcast, 512);
assert_eq!(config.max_local, 1024); }
#[test]
fn test_heuristics_config_default() {
let config = HeuristicsConfig::default();
assert_eq!(config.tc_enabled, TcUsage::Enabled);
assert_eq!(config.tc_opt, TcOpt::Padded);
assert!(config.matvec_enabled);
assert_eq!(config.grouped_threshold, 256);
}
#[test]
fn test_heuristics_config_builder() {
let config = HeuristicsConfig::builder()
.tc_enabled(TcUsage::Disabled)
.matvec_enabled(false)
.grouped_threshold(128)
.build();
assert_eq!(config.tc_enabled, TcUsage::Disabled);
assert!(!config.matvec_enabled);
assert_eq!(config.grouped_threshold, 128);
}
#[test]
fn test_optimizer_config_default() {
let config = OptimizerConfig::default();
assert_eq!(config.strategy, OptStrategy::Heuristic);
assert_eq!(config.beam.beam_width, 4);
}
#[test]
fn test_optimizer_config_builder() {
let config = OptimizerConfig::builder()
.strategy(OptStrategy::Beam { width: 8 })
.beam(BeamConfig::builder().timeout_secs(120).build())
.build();
assert_eq!(config.strategy, OptStrategy::Beam { width: 8 });
assert_eq!(config.beam.timeout, Duration::from_secs(120));
}
#[test]
fn test_tc_usage_as_usize() {
assert_eq!(TcUsage::Disabled.as_usize(), 0);
assert_eq!(TcUsage::Enabled.as_usize(), 1);
assert_eq!(TcUsage::ShapeOnly.as_usize(), 2);
}
#[test]
fn test_tc_opt_as_usize() {
assert_eq!(TcOpt::Strict.as_usize(), 0);
assert_eq!(TcOpt::Relaxed.as_usize(), 1);
assert_eq!(TcOpt::Padded.as_usize(), 2);
}
#[test]
fn test_tc_select_as_i32() {
assert_eq!(TcSelect::Auto.as_i32(), -1);
assert_eq!(TcSelect::Index(5).as_i32(), 5);
}
}